improve training on signals, add save session button to store all progress
This commit is contained in:
@ -570,6 +570,21 @@ class CleanTradingDashboard:
|
||||
if n_clicks:
|
||||
self._clear_session()
|
||||
return [html.I(className="fas fa-trash me-1"), "Clear Session"]
|
||||
|
||||
@self.app.callback(
|
||||
Output('store-models-btn', 'children'),
|
||||
[Input('store-models-btn', 'n_clicks')],
|
||||
prevent_initial_call=True
|
||||
)
|
||||
def handle_store_models(n_clicks):
|
||||
"""Handle store all models button click"""
|
||||
if n_clicks:
|
||||
success = self._store_all_models()
|
||||
if success:
|
||||
return [html.I(className="fas fa-save me-1"), "Models Stored"]
|
||||
else:
|
||||
return [html.I(className="fas fa-exclamation-triangle me-1"), "Store Failed"]
|
||||
return [html.I(className="fas fa-save me-1"), "Store All Models"]
|
||||
|
||||
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price for symbol"""
|
||||
@ -2927,6 +2942,10 @@ class CleanTradingDashboard:
|
||||
if len(self.recent_decisions) > 200:
|
||||
self.recent_decisions = self.recent_decisions[-200:]
|
||||
|
||||
# Train ALL models on the signal (if executed)
|
||||
if signal['executed']:
|
||||
self._train_all_models_on_signal(signal)
|
||||
|
||||
# Log signal processing
|
||||
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
||||
logger.info(f"[{status}] {signal['action']} signal for {signal['symbol']} "
|
||||
@ -2935,10 +2954,307 @@ class CleanTradingDashboard:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing dashboard signal: {e}")
|
||||
|
||||
def _train_dqn_on_signal(self, signal: Dict):
|
||||
"""Train DQN agent on generated signal - NOT AVAILABLE IN BASIC ORCHESTRATOR"""
|
||||
# Basic orchestrator doesn't have DQN features
|
||||
return
|
||||
def _train_all_models_on_signal(self, signal: Dict):
|
||||
"""Train ALL models on executed trade signal - Comprehensive training system"""
|
||||
try:
|
||||
# Get trade outcome for training
|
||||
trade_outcome = self._get_trade_outcome_for_training(signal)
|
||||
if not trade_outcome:
|
||||
return
|
||||
|
||||
# 1. Train DQN model
|
||||
self._train_dqn_on_signal(signal, trade_outcome)
|
||||
|
||||
# 2. Train CNN model
|
||||
self._train_cnn_on_signal(signal, trade_outcome)
|
||||
|
||||
# 3. Train Transformer model
|
||||
self._train_transformer_on_signal(signal, trade_outcome)
|
||||
|
||||
# 4. Train COB RL model
|
||||
self._train_cob_rl_on_signal(signal, trade_outcome)
|
||||
|
||||
# 5. Train Decision Fusion model
|
||||
self._train_decision_fusion_on_signal(signal, trade_outcome)
|
||||
|
||||
logger.debug(f"Trained all models on {signal['action']} signal with outcome: {trade_outcome['pnl']:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training models on signal: {e}")
|
||||
|
||||
def _get_trade_outcome_for_training(self, signal: Dict) -> Optional[Dict]:
|
||||
"""Get trade outcome for training - either from completed trade or position change"""
|
||||
try:
|
||||
# Check if we have a completed trade
|
||||
if self.closed_trades:
|
||||
latest_trade = self.closed_trades[-1]
|
||||
# Verify this trade corresponds to the signal
|
||||
if (latest_trade.get('symbol') == signal.get('symbol') and
|
||||
abs(latest_trade.get('entry_time', 0) - signal.get('timestamp', 0)) < 60): # Within 1 minute
|
||||
return {
|
||||
'pnl': latest_trade.get('pnl', 0),
|
||||
'entry_price': latest_trade.get('entry_price', 0),
|
||||
'exit_price': latest_trade.get('exit_price', 0),
|
||||
'side': latest_trade.get('side', 'UNKNOWN'),
|
||||
'quantity': latest_trade.get('quantity', 0),
|
||||
'duration': latest_trade.get('exit_time', 0) - latest_trade.get('entry_time', 0),
|
||||
'trade_type': 'completed'
|
||||
}
|
||||
|
||||
# If no completed trade, use position change for training
|
||||
if self.current_position:
|
||||
current_price = self._get_current_price(signal.get('symbol', 'ETH/USDT'))
|
||||
if current_price:
|
||||
entry_price = self.current_position.get('price', 0)
|
||||
side = self.current_position.get('side', 'UNKNOWN')
|
||||
size = self.current_position.get('size', 0)
|
||||
|
||||
if entry_price > 0 and size > 0:
|
||||
# Calculate unrealized P&L
|
||||
if side.upper() == 'LONG':
|
||||
pnl = (current_price - entry_price) * size * self.current_leverage
|
||||
else: # SHORT
|
||||
pnl = (entry_price - current_price) * size * self.current_leverage
|
||||
|
||||
return {
|
||||
'pnl': pnl,
|
||||
'entry_price': entry_price,
|
||||
'current_price': current_price,
|
||||
'side': side,
|
||||
'quantity': size,
|
||||
'duration': 0, # Position still open
|
||||
'trade_type': 'position_change'
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting trade outcome: {e}")
|
||||
return None
|
||||
|
||||
def _train_dqn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train DQN agent on executed signal with trade outcome"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
return
|
||||
|
||||
# Create training data for DQN
|
||||
state_features = self._get_dqn_state_features(signal.get('symbol', 'ETH/USDT'), signal.get('price', 0))
|
||||
action = 0 if signal['action'] == 'BUY' else 1 # 0=BUY, 1=SELL
|
||||
|
||||
# Calculate reward based on trade outcome
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
reward = pnl * 100 # Scale reward for better learning
|
||||
|
||||
# Create next state (simplified)
|
||||
next_state_features = state_features.copy() # In real implementation, this would be the next market state
|
||||
|
||||
# Store experience in DQN memory
|
||||
if hasattr(self.orchestrator.rl_agent, 'remember'):
|
||||
self.orchestrator.rl_agent.remember(
|
||||
state_features, action, reward, next_state_features, done=True
|
||||
)
|
||||
|
||||
# Trigger training if enough samples
|
||||
if hasattr(self.orchestrator.rl_agent, 'memory') and len(self.orchestrator.rl_agent.memory) > 32:
|
||||
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
||||
loss = self.orchestrator.rl_agent.replay(batch_size=32)
|
||||
if loss is not None:
|
||||
logger.debug(f"DQN trained on signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training DQN on signal: {e}")
|
||||
|
||||
def _train_cnn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train CNN model on executed signal with trade outcome"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||
return
|
||||
|
||||
# Create training data for CNN
|
||||
symbol = signal.get('symbol', 'ETH/USDT')
|
||||
current_price = signal.get('price', 0)
|
||||
|
||||
# Get market features
|
||||
market_features = self._get_cnn_features_and_predictions(symbol)
|
||||
if not market_features:
|
||||
return
|
||||
|
||||
# Create target based on trade outcome
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
target = 1.0 if pnl > 0 else 0.0 # Binary classification: profitable vs not
|
||||
|
||||
# Prepare training data
|
||||
features = market_features.get('features', [])
|
||||
if features:
|
||||
# Convert to tensor format (simplified)
|
||||
import numpy as np
|
||||
feature_tensor = np.array(features, dtype=np.float32)
|
||||
target_tensor = np.array([target], dtype=np.float32)
|
||||
|
||||
# Train CNN model (if it has training method)
|
||||
if hasattr(self.orchestrator.cnn_model, 'train_on_batch'):
|
||||
loss = self.orchestrator.cnn_model.train_on_batch(feature_tensor, target_tensor)
|
||||
logger.debug(f"CNN trained on signal - loss: {loss:.4f}, target: {target}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training CNN on signal: {e}")
|
||||
|
||||
def _train_transformer_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train Transformer model on executed signal with trade outcome"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
|
||||
return
|
||||
|
||||
# Create training data for Transformer
|
||||
symbol = signal.get('symbol', 'ETH/USDT')
|
||||
current_price = signal.get('price', 0)
|
||||
|
||||
# Get comprehensive market state
|
||||
market_state = self._get_comprehensive_market_state(symbol, current_price)
|
||||
|
||||
# Create target based on trade outcome
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
target_action = 0 if signal['action'] == 'BUY' else 1 # 0=BUY, 1=SELL
|
||||
target_confidence = signal.get('confidence', 0.5)
|
||||
|
||||
# Prepare training data
|
||||
features = list(market_state.values())
|
||||
if features:
|
||||
import numpy as np
|
||||
feature_tensor = np.array(features, dtype=np.float32)
|
||||
target_tensor = np.array([target_action, target_confidence], dtype=np.float32)
|
||||
|
||||
# Train Transformer model (if it has training method)
|
||||
if hasattr(self.orchestrator.primary_transformer, 'train_on_batch'):
|
||||
loss = self.orchestrator.primary_transformer.train_on_batch(feature_tensor, target_tensor)
|
||||
logger.debug(f"Transformer trained on signal - loss: {loss:.4f}, action: {target_action}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training Transformer on signal: {e}")
|
||||
|
||||
def _train_cob_rl_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train COB RL model on executed signal with trade outcome"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'cob_rl_agent') or not self.orchestrator.cob_rl_agent:
|
||||
return
|
||||
|
||||
# Create training data for COB RL
|
||||
symbol = signal.get('symbol', 'ETH/USDT')
|
||||
|
||||
# Get COB features
|
||||
cob_features = self._get_cob_features_for_training(symbol, signal.get('price', 0))
|
||||
if not cob_features:
|
||||
return
|
||||
|
||||
# Create target based on trade outcome
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
action = 0 if signal['action'] == 'BUY' else 1
|
||||
reward = pnl * 100 # Scale reward
|
||||
|
||||
# Store experience in COB RL memory
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'remember'):
|
||||
self.orchestrator.cob_rl_agent.remember(
|
||||
cob_features, action, reward, cob_features, done=True # Simplified next state
|
||||
)
|
||||
|
||||
# Trigger training if enough samples
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'memory') and len(self.orchestrator.cob_rl_agent.memory) > 32:
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'replay'):
|
||||
loss = self.orchestrator.cob_rl_agent.replay(batch_size=32)
|
||||
if loss is not None:
|
||||
logger.debug(f"COB RL trained on signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training COB RL on signal: {e}")
|
||||
|
||||
def _train_decision_fusion_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train Decision Fusion model on executed signal with trade outcome"""
|
||||
try:
|
||||
# Decision fusion model combines predictions from all models
|
||||
# This would be implemented if there's a decision fusion model available
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'decision_model'):
|
||||
return
|
||||
|
||||
# Create training data for decision fusion
|
||||
symbol = signal.get('symbol', 'ETH/USDT')
|
||||
current_price = signal.get('price', 0)
|
||||
|
||||
# Get predictions from all models
|
||||
model_predictions = {
|
||||
'dqn': self._get_dqn_prediction(symbol, current_price),
|
||||
'cnn': self._get_cnn_prediction(symbol, current_price),
|
||||
'transformer': self._get_transformer_prediction(symbol, current_price),
|
||||
'cob_rl': self._get_cob_rl_prediction(symbol, current_price)
|
||||
}
|
||||
|
||||
# Create target based on trade outcome
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
target = 1.0 if pnl > 0 else 0.0
|
||||
|
||||
# Train decision fusion model (if available)
|
||||
if hasattr(self.orchestrator.decision_model, 'train_on_batch'):
|
||||
# Prepare training data
|
||||
import numpy as np
|
||||
prediction_tensor = np.array(list(model_predictions.values()), dtype=np.float32)
|
||||
target_tensor = np.array([target], dtype=np.float32)
|
||||
|
||||
loss = self.orchestrator.decision_model.train_on_batch(prediction_tensor, target_tensor)
|
||||
logger.debug(f"Decision Fusion trained on signal - loss: {loss:.4f}, target: {target}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training Decision Fusion on signal: {e}")
|
||||
|
||||
def _get_dqn_prediction(self, symbol: str, current_price: float) -> float:
|
||||
"""Get DQN prediction for decision fusion"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
state_features = self._get_dqn_state_features(symbol, current_price)
|
||||
if hasattr(self.orchestrator.rl_agent, 'predict'):
|
||||
return self.orchestrator.rl_agent.predict(state_features)
|
||||
return 0.5 # Default neutral prediction
|
||||
except:
|
||||
return 0.5
|
||||
|
||||
def _get_cnn_prediction(self, symbol: str, current_price: float) -> float:
|
||||
"""Get CNN prediction for decision fusion"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
market_features = self._get_cnn_features_and_predictions(symbol)
|
||||
if market_features and hasattr(self.orchestrator.cnn_model, 'predict'):
|
||||
features = market_features.get('features', [])
|
||||
if features:
|
||||
import numpy as np
|
||||
return self.orchestrator.cnn_model.predict(np.array([features]))
|
||||
return 0.5 # Default neutral prediction
|
||||
except:
|
||||
return 0.5
|
||||
|
||||
def _get_transformer_prediction(self, symbol: str, current_price: float) -> float:
|
||||
"""Get Transformer prediction for decision fusion"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
|
||||
market_state = self._get_comprehensive_market_state(symbol, current_price)
|
||||
if hasattr(self.orchestrator.primary_transformer, 'predict'):
|
||||
features = list(market_state.values())
|
||||
if features:
|
||||
import numpy as np
|
||||
return self.orchestrator.primary_transformer.predict(np.array([features]))
|
||||
return 0.5 # Default neutral prediction
|
||||
except:
|
||||
return 0.5
|
||||
|
||||
def _get_cob_rl_prediction(self, symbol: str, current_price: float) -> float:
|
||||
"""Get COB RL prediction for decision fusion"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||
cob_features = self._get_cob_features_for_training(symbol, current_price)
|
||||
if cob_features and hasattr(self.orchestrator.cob_rl_agent, 'predict'):
|
||||
import numpy as np
|
||||
return self.orchestrator.cob_rl_agent.predict(np.array([cob_features]))
|
||||
return 0.5 # Default neutral prediction
|
||||
except:
|
||||
return 0.5
|
||||
|
||||
def _execute_manual_trade(self, action: str):
|
||||
"""Execute manual trading action - ENHANCED with PERSISTENT SIGNAL STORAGE"""
|
||||
@ -3097,6 +3413,18 @@ class CleanTradingDashboard:
|
||||
self.closed_trades.append(trade_record)
|
||||
logger.info(f"Added completed trade to closed_trades: {action} P&L ${leveraged_pnl:.2f} (raw: ${raw_pnl:.2f}, leverage: x{self.current_leverage})")
|
||||
|
||||
# TRAIN ALL MODELS ON MANUAL TRADE OUTCOME
|
||||
manual_signal = {
|
||||
'action': action,
|
||||
'price': current_price,
|
||||
'symbol': symbol,
|
||||
'confidence': 1.0,
|
||||
'executed': True,
|
||||
'manual': True,
|
||||
'timestamp': datetime.now().timestamp()
|
||||
}
|
||||
self._train_all_models_on_signal(manual_signal)
|
||||
|
||||
# MOVE BASE CASE TO POSITIVE/NEGATIVE based on leveraged outcome
|
||||
if hasattr(self, 'pending_trade_case_id') and self.pending_trade_case_id:
|
||||
try:
|
||||
@ -3512,6 +3840,100 @@ class CleanTradingDashboard:
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing session: {e}")
|
||||
|
||||
def _store_all_models(self) -> bool:
|
||||
"""Store all current models to persistent storage"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
logger.warning("No orchestrator available for model storage")
|
||||
return False
|
||||
|
||||
stored_models = []
|
||||
|
||||
# 1. Store DQN model
|
||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
try:
|
||||
if hasattr(self.orchestrator.rl_agent, 'save'):
|
||||
save_path = self.orchestrator.rl_agent.save('models/saved/dqn_agent_session')
|
||||
stored_models.append(('DQN', save_path))
|
||||
logger.info(f"Stored DQN model: {save_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store DQN model: {e}")
|
||||
|
||||
# 2. Store CNN model
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
try:
|
||||
if hasattr(self.orchestrator.cnn_model, 'save'):
|
||||
save_path = self.orchestrator.cnn_model.save('models/saved/cnn_model_session')
|
||||
stored_models.append(('CNN', save_path))
|
||||
logger.info(f"Stored CNN model: {save_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store CNN model: {e}")
|
||||
|
||||
# 3. Store Transformer model
|
||||
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
|
||||
try:
|
||||
if hasattr(self.orchestrator.primary_transformer, 'save'):
|
||||
save_path = self.orchestrator.primary_transformer.save('models/saved/transformer_model_session')
|
||||
stored_models.append(('Transformer', save_path))
|
||||
logger.info(f"Stored Transformer model: {save_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store Transformer model: {e}")
|
||||
|
||||
# 4. Store COB RL model
|
||||
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||
try:
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'save'):
|
||||
save_path = self.orchestrator.cob_rl_agent.save('models/saved/cob_rl_agent_session')
|
||||
stored_models.append(('COB RL', save_path))
|
||||
logger.info(f"Stored COB RL model: {save_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store COB RL model: {e}")
|
||||
|
||||
# 5. Store Decision Fusion model
|
||||
if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
|
||||
try:
|
||||
if hasattr(self.orchestrator.decision_model, 'save'):
|
||||
save_path = self.orchestrator.decision_model.save('models/saved/decision_fusion_session')
|
||||
stored_models.append(('Decision Fusion', save_path))
|
||||
logger.info(f"Stored Decision Fusion model: {save_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store Decision Fusion model: {e}")
|
||||
|
||||
# 6. Store model metadata and training state
|
||||
try:
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
metadata = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'session_pnl': self.session_pnl,
|
||||
'trade_count': len(self.closed_trades),
|
||||
'stored_models': stored_models,
|
||||
'training_iterations': getattr(self, 'training_iteration', 0),
|
||||
'model_performance': self.get_model_performance_metrics()
|
||||
}
|
||||
|
||||
metadata_path = 'models/saved/session_metadata.json'
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
logger.info(f"Stored session metadata: {metadata_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store metadata: {e}")
|
||||
|
||||
# Log summary
|
||||
if stored_models:
|
||||
logger.info(f"Successfully stored {len(stored_models)} models: {[name for name, _ in stored_models]}")
|
||||
return True
|
||||
else:
|
||||
logger.warning("No models were stored - no models available or save methods not found")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing models: {e}")
|
||||
return False
|
||||
|
||||
def _get_signal_attribute(self, signal, attr_name, default=None):
|
||||
"""Safely get attribute from signal (handles both dict and dataclass objects)"""
|
||||
try:
|
||||
|
Reference in New Issue
Block a user