Merge branch 'cleanup' of https://git.d-popov.com/popov/gogo2 into cleanup

This commit is contained in:
Dobromir Popov
2025-11-22 20:45:37 +02:00
25 changed files with 2825 additions and 258 deletions

View File

@@ -16,7 +16,7 @@ sys.path.insert(0, str(parent_dir))
from flask import Flask, render_template, request, jsonify, send_file
from dash import Dash, html
import logging
from datetime import datetime
from datetime import datetime, timezone
from typing import Optional, Dict, List, Any
import json
import pandas as pd
@@ -538,6 +538,9 @@ class AnnotationDashboard:
engineio_logger=False
)
self.has_socketio = True
# Pass socketio to training adapter for live trade updates
if self.training_adapter:
self.training_adapter.socketio = self.socketio
logger.info("SocketIO initialized for real-time updates")
except ImportError:
self.socketio = None
@@ -586,6 +589,8 @@ class AnnotationDashboard:
self.annotation_manager = AnnotationManager()
# Use REAL training adapter - NO SIMULATION!
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
# Pass socketio to training adapter for live trade updates
self.training_adapter.socketio = None # Will be set after socketio initialization
# Backtest runner for replaying visible chart with predictions
self.backtest_runner = BacktestRunner()
@@ -626,63 +631,38 @@ class AnnotationDashboard:
if not self.orchestrator:
logger.info("Initializing TradingOrchestrator...")
self.orchestrator = TradingOrchestrator(
data_provider=self.data_provider,
config=self.config
data_provider=self.data_provider
)
self.training_adapter.orchestrator = self.orchestrator
logger.info("TradingOrchestrator initialized")
# Get checkpoint info before loading
checkpoint_info = self._get_best_checkpoint_info(model_name)
# Load the specific model
# Check if the specific model is already initialized
if model_name == 'Transformer':
logger.info("Loading Transformer model...")
self.orchestrator.load_transformer_model()
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer_trainer
# Store checkpoint info in orchestrator for UI access
if checkpoint_info:
self.orchestrator.transformer_checkpoint_info = {
'status': 'loaded',
'filename': checkpoint_info.get('filename', 'unknown'),
'epoch': checkpoint_info.get('epoch', 0),
'loss': checkpoint_info.get('loss', 0.0),
'accuracy': checkpoint_info.get('accuracy', 0.0),
'loaded_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
logger.info("Transformer model loaded successfully")
logger.info("Checking Transformer model...")
if self.orchestrator.primary_transformer:
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer
logger.info("Transformer model loaded successfully")
else:
logger.warning("Transformer model not initialized in orchestrator")
return
elif model_name == 'CNN':
logger.info("Loading CNN model...")
self.orchestrator.load_cnn_model()
self.loaded_models['CNN'] = self.orchestrator.cnn_model
# Store checkpoint info
if checkpoint_info:
self.orchestrator.cnn_checkpoint_info = {
'status': 'loaded',
'filename': checkpoint_info.get('filename', 'unknown'),
'loaded_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
logger.info("CNN model loaded successfully")
logger.info("Checking CNN model...")
if self.orchestrator.cnn_model:
self.loaded_models['CNN'] = self.orchestrator.cnn_model
logger.info("CNN model loaded successfully")
else:
logger.warning("CNN model not initialized in orchestrator")
return
elif model_name == 'DQN':
logger.info("Loading DQN model...")
self.orchestrator.load_dqn_model()
self.loaded_models['DQN'] = self.orchestrator.dqn_agent
# Store checkpoint info
if checkpoint_info:
self.orchestrator.dqn_checkpoint_info = {
'status': 'loaded',
'filename': checkpoint_info.get('filename', 'unknown'),
'loaded_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
logger.info("DQN model loaded successfully")
logger.info("Checking DQN model...")
if self.orchestrator.rl_agent:
self.loaded_models['DQN'] = self.orchestrator.rl_agent
logger.info("DQN model loaded successfully")
else:
logger.warning("DQN model not initialized in orchestrator")
return
else:
logger.warning(f"Unknown model name: {model_name}")
@@ -1741,6 +1721,9 @@ class AnnotationDashboard:
# CRITICAL: Get current symbol to filter annotations
current_symbol = data.get('symbol', 'ETH/USDT')
# Get primary timeframe for display (optional)
timeframe = data.get('timeframe', '1m')
# If no specific annotations provided, use all for current symbol
if not annotation_ids:
annotations = self.annotation_manager.get_annotations(symbol=current_symbol)
@@ -1769,12 +1752,14 @@ class AnnotationDashboard:
}
})
logger.info(f"Starting REAL training with {len(test_cases)} test cases for model {model_name}")
logger.info(f"Starting REAL training with {len(test_cases)} test cases ({len(annotation_ids)} annotations) for model {model_name} on {timeframe}")
# Start REAL training (NO SIMULATION!)
training_id = self.training_adapter.start_training(
model_name=model_name,
test_cases=test_cases
test_cases=test_cases,
annotation_count=len(annotation_ids),
timeframe=timeframe
)
return jsonify({
@@ -2392,6 +2377,55 @@ class AnnotationDashboard:
except Exception as e:
logger.error(f"Error handling prediction request: {e}")
emit('prediction_error', {'error': str(e)})
@self.socketio.on('prediction_accuracy')
def handle_prediction_accuracy(data):
"""
Handle validated prediction accuracy - trigger incremental training
This is called when frontend validates a prediction against actual candle.
We use this data to incrementally train the model for continuous improvement.
"""
from flask_socketio import emit
try:
timeframe = data.get('timeframe')
timestamp = data.get('timestamp')
predicted = data.get('predicted') # [O, H, L, C, V]
actual = data.get('actual') # [O, H, L, C]
errors = data.get('errors') # {open, high, low, close}
pct_errors = data.get('pctErrors')
direction_correct = data.get('directionCorrect')
accuracy = data.get('accuracy')
if not all([timeframe, timestamp, predicted, actual]):
logger.warning("Incomplete prediction accuracy data received")
return
logger.info(f"[{timeframe}] Prediction validated: {accuracy:.1f}% accuracy, direction: {direction_correct}")
logger.debug(f" Errors: O={pct_errors['open']:.2f}% H={pct_errors['high']:.2f}% L={pct_errors['low']:.2f}% C={pct_errors['close']:.2f}%")
# Trigger incremental training on this validated prediction
self._train_on_validated_prediction(
timeframe=timeframe,
timestamp=timestamp,
predicted=predicted,
actual=actual,
errors=errors,
direction_correct=direction_correct,
accuracy=accuracy
)
# Send confirmation back to frontend
emit('training_update', {
'status': 'training_triggered',
'timestamp': timestamp,
'accuracy': accuracy,
'message': f'Incremental training triggered on validated prediction'
})
except Exception as e:
logger.error(f"Error handling prediction accuracy: {e}", exc_info=True)
emit('training_error', {'error': str(e)})
def _start_live_update_thread(self):
"""Start background thread for live updates"""
@@ -2415,24 +2449,44 @@ class AnnotationDashboard:
for timeframe in ['1s', '1m']:
room = f"{symbol}_{timeframe}"
# Get latest candle
# Get latest candles (need last 2 to determine confirmation status)
try:
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=1)
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=2)
if candles and len(candles) > 0:
latest_candle = candles[-1]
# Emit chart update
# Determine if candle is confirmed (closed)
# For 1s: candle is confirmed when next candle starts (2s delay)
# For others: candle is confirmed when next candle starts
is_confirmed = len(candles) >= 2 # If we have 2 candles, the first is confirmed
# Format timestamp consistently
timestamp = latest_candle.get('timestamp')
if isinstance(timestamp, str):
# Already formatted
formatted_timestamp = timestamp
else:
# Convert to ISO string then format
from datetime import datetime
if isinstance(timestamp, datetime):
formatted_timestamp = timestamp.strftime('%Y-%m-%d %H:%M:%S')
else:
formatted_timestamp = str(timestamp)
# Emit chart update with full candle data
self.socketio.emit('chart_update', {
'symbol': symbol,
'timeframe': timeframe,
'candle': {
'timestamp': latest_candle.get('timestamp'),
'open': latest_candle.get('open'),
'high': latest_candle.get('high'),
'low': latest_candle.get('low'),
'close': latest_candle.get('close'),
'volume': latest_candle.get('volume')
}
'timestamp': formatted_timestamp,
'open': float(latest_candle.get('open', 0)),
'high': float(latest_candle.get('high', 0)),
'low': float(latest_candle.get('low', 0)),
'close': float(latest_candle.get('close', 0)),
'volume': float(latest_candle.get('volume', 0))
},
'is_confirmed': is_confirmed, # True if this candle is closed/confirmed
'has_previous': len(candles) >= 2 # True if we have previous candle for validation
}, room=room)
# Get prediction if model is loaded
@@ -2453,6 +2507,144 @@ class AnnotationDashboard:
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
self._live_update_thread.start()
def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list,
actual: list, errors: dict, direction_correct: bool, accuracy: float):
"""
Incrementally train model on validated prediction
This implements online learning where each validated prediction becomes
a training sample, with loss weighting based on prediction accuracy.
"""
try:
if not self.training_adapter:
logger.warning("Training adapter not available for incremental training")
return
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
logger.warning("Transformer model not available for incremental training")
return
# Get the transformer trainer
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer:
logger.warning("Transformer trainer not available")
return
# Calculate sample weight based on accuracy
# Low accuracy predictions get higher weight (we need to learn from mistakes)
# High accuracy predictions get lower weight (model already knows this)
if accuracy < 50:
sample_weight = 3.0 # Learn hard from bad predictions
elif accuracy < 70:
sample_weight = 2.0 # Moderate learning
elif accuracy < 85:
sample_weight = 1.0 # Normal learning
else:
sample_weight = 0.5 # Light touch-up for good predictions
# Also weight by direction correctness
if not direction_correct:
sample_weight *= 1.5 # Wrong direction is critical - learn more
logger.info(f"[{timeframe}] Incremental training: accuracy={accuracy:.1f}%, weight={sample_weight:.1f}x")
# Create training sample from validated prediction
# We need to fetch the market state at that timestamp
symbol = 'ETH/USDT' # TODO: Get from active trading pair
training_sample = {
'symbol': symbol,
'timestamp': timestamp,
'predicted_candle': predicted, # [O, H, L, C, V]
'actual_candle': actual, # [O, H, L, C]
'errors': errors,
'accuracy': accuracy,
'direction_correct': direction_correct,
'sample_weight': sample_weight
}
# Get market state at that timestamp
try:
market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe)
training_sample['market_state'] = market_state
except Exception as e:
logger.warning(f"Could not fetch market state: {e}")
return
# Convert to transformer batch format
batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe)
if not batch:
logger.warning("Could not convert validated prediction to training batch")
return
# Train on this batch with sample weighting
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
if result:
loss = result.get('total_loss', 0)
candle_accuracy = result.get('candle_accuracy', 0)
logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
# Save checkpoint periodically (every 10 incremental steps)
if not hasattr(self, '_incremental_training_steps'):
self._incremental_training_steps = 0
self._incremental_training_steps += 1
if self._incremental_training_steps % 10 == 0:
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
trainer.save_checkpoint(
filepath=None, # Auto-generate path
metadata={
'training_type': 'incremental_online',
'steps': self._incremental_training_steps,
'last_accuracy': accuracy
}
)
except Exception as e:
logger.error(f"Error in incremental training: {e}", exc_info=True)
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
"""Fetch market state at a specific timestamp for training"""
try:
from datetime import datetime
import pandas as pd
# Parse timestamp
ts = pd.Timestamp(timestamp)
# Get historical data for multiple timeframes
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
for tf in ['1s', '1m', '1h']:
try:
df = self.data_provider.get_historical_data(symbol, tf, limit=200)
if df is not None and not df.empty:
# Find data up to (but not including) the target timestamp
df_before = df[df.index < ts]
if not df_before.empty:
recent = df_before.tail(200)
market_state['timeframes'][tf] = {
'timestamps': recent.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': recent['open'].tolist(),
'high': recent['high'].tolist(),
'low': recent['low'].tolist(),
'close': recent['close'].tolist(),
'volume': recent['volume'].tolist()
}
except Exception as e:
logger.warning(f"Could not fetch {tf} data: {e}")
return market_state
except Exception as e:
logger.error(f"Error fetching market state: {e}")
return {}
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
"""Get live prediction from model"""
try:
@@ -2471,7 +2663,7 @@ class AnnotationDashboard:
return {
'symbol': symbol,
'timeframe': timeframe,
'timestamp': datetime.now().isoformat(),
'timestamp': datetime.now(timezone.utc).isoformat(),
'action': random.choice(['BUY', 'SELL', 'HOLD']),
'confidence': random.uniform(0.6, 0.95),
'predicted_price': candles[-1].get('close', 0) * (1 + random.uniform(-0.01, 0.01)),