training fixes and enhancements wip
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@ -16,7 +16,7 @@ models/trading_agent_final.pt.backup
|
||||
*.pt
|
||||
*.backup
|
||||
logs/
|
||||
trade_logs/
|
||||
# trade_logs/
|
||||
*.csv
|
||||
cache/
|
||||
realtime_chart.log
|
||||
|
@ -57,7 +57,10 @@ class DQNAgent:
|
||||
else:
|
||||
# 1D state
|
||||
if isinstance(state_shape, tuple):
|
||||
self.state_dim = state_shape[0]
|
||||
if len(state_shape) == 0:
|
||||
self.state_dim = 1 # Safe default for empty tuple
|
||||
else:
|
||||
self.state_dim = state_shape[0]
|
||||
else:
|
||||
self.state_dim = state_shape
|
||||
|
||||
@ -615,8 +618,8 @@ class DQNAgent:
|
||||
self.recent_actions.append(action)
|
||||
return action
|
||||
else:
|
||||
# Return None to indicate HOLD (don't change position)
|
||||
return None
|
||||
# Return 1 (HOLD) as a safe default if action is None
|
||||
return 1
|
||||
|
||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
||||
"""Choose action with confidence score adapted to market regime (from Enhanced DQN)"""
|
||||
@ -647,7 +650,10 @@ class DQNAgent:
|
||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||
|
||||
return action, adapted_confidence
|
||||
# Always return int, float
|
||||
if action is None:
|
||||
return 1, 0.1
|
||||
return int(action), float(adapted_confidence)
|
||||
|
||||
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
|
||||
"""
|
||||
@ -748,13 +754,29 @@ class DQNAgent:
|
||||
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
||||
experiences = [self.memory[i] for i in indices]
|
||||
|
||||
# Sanitize and stack states and next_states
|
||||
sanitized_states = []
|
||||
sanitized_next_states = []
|
||||
for i, e in enumerate(experiences):
|
||||
try:
|
||||
state = np.asarray(e[0], dtype=np.float32)
|
||||
next_state = np.asarray(e[3], dtype=np.float32)
|
||||
sanitized_states.append(state)
|
||||
sanitized_next_states.append(next_state)
|
||||
except Exception as ex:
|
||||
print(f"[DQNAgent] Bad experience at index {i}: {ex}")
|
||||
continue
|
||||
if not sanitized_states or not sanitized_next_states:
|
||||
print("[DQNAgent] No valid states in replay batch.")
|
||||
return 0.0 # Return float instead of None for consistency
|
||||
states = torch.FloatTensor(np.stack(sanitized_states)).to(self.device)
|
||||
next_states = torch.FloatTensor(np.stack(sanitized_next_states)).to(self.device)
|
||||
|
||||
# Choose appropriate replay method
|
||||
if self.use_mixed_precision:
|
||||
# Convert experiences to tensors for mixed precision
|
||||
states = torch.FloatTensor(np.array([e[0] for e in experiences])).to(self.device)
|
||||
actions = torch.LongTensor(np.array([e[1] for e in experiences])).to(self.device)
|
||||
rewards = torch.FloatTensor(np.array([e[2] for e in experiences])).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array([e[3] for e in experiences])).to(self.device)
|
||||
dones = torch.FloatTensor(np.array([e[4] for e in experiences])).to(self.device)
|
||||
|
||||
# Use mixed precision replay
|
||||
@ -829,29 +851,32 @@ class DQNAgent:
|
||||
|
||||
return loss
|
||||
|
||||
def _replay_standard(self, experiences=None):
|
||||
def _replay_standard(self, *args):
|
||||
"""Standard training step without mixed precision"""
|
||||
try:
|
||||
# Use experiences if provided, otherwise sample from memory
|
||||
if experiences is None:
|
||||
# If memory is too small, skip training
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
|
||||
# Sample random mini-batch from memory
|
||||
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
||||
batch = [self.memory[i] for i in indices]
|
||||
experiences = batch
|
||||
|
||||
# Unpack experiences
|
||||
states, actions, rewards, next_states, dones = zip(*experiences)
|
||||
|
||||
# Convert to PyTorch tensors
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(np.array(actions)).to(self.device)
|
||||
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
||||
# Support both (experiences,) and (states, actions, rewards, next_states, dones)
|
||||
if len(args) == 1:
|
||||
experiences = args[0]
|
||||
# Use experiences if provided, otherwise sample from memory
|
||||
if experiences is None:
|
||||
# If memory is too small, skip training
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
# Sample random mini-batch from memory
|
||||
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
||||
batch = [self.memory[i] for i in indices]
|
||||
experiences = batch
|
||||
# Unpack experiences
|
||||
states, actions, rewards, next_states, dones = zip(*experiences)
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(np.array(actions)).to(self.device)
|
||||
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
||||
elif len(args) == 5:
|
||||
states, actions, rewards, next_states, dones = args
|
||||
else:
|
||||
raise ValueError("Invalid arguments to _replay_standard")
|
||||
|
||||
# Get current Q values
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
||||
|
@ -437,13 +437,34 @@ class TradingOrchestrator:
|
||||
|
||||
def predict(self, data=None):
|
||||
try:
|
||||
# ExtremaTrainer provides context features, not a direct prediction
|
||||
# We assume 'data' here is the 'symbol' string to pass to get_context_features_for_model
|
||||
if not isinstance(data, str):
|
||||
logger.warning(f"ExtremaTrainerInterface.predict received non-string data: {type(data)}. Cannot get context features.")
|
||||
# Handle different data types that might be passed to ExtremaTrainer
|
||||
symbol = None
|
||||
|
||||
if isinstance(data, str):
|
||||
# Direct symbol string
|
||||
symbol = data
|
||||
elif isinstance(data, dict):
|
||||
# Dictionary with symbol information
|
||||
symbol = data.get('symbol')
|
||||
elif isinstance(data, np.ndarray):
|
||||
# Numpy array - extract symbol from metadata or use default
|
||||
# For now, use the first symbol from the model's symbols list
|
||||
if hasattr(self.model, 'symbols') and self.model.symbols:
|
||||
symbol = self.model.symbols[0]
|
||||
else:
|
||||
symbol = 'ETH/USDT' # Default fallback
|
||||
else:
|
||||
# Unknown data type - use default symbol
|
||||
if hasattr(self.model, 'symbols') and self.model.symbols:
|
||||
symbol = self.model.symbols[0]
|
||||
else:
|
||||
symbol = 'ETH/USDT' # Default fallback
|
||||
|
||||
if not symbol:
|
||||
logger.warning(f"ExtremaTrainerInterface.predict could not determine symbol from data: {type(data)}")
|
||||
return None
|
||||
|
||||
features = self.model.get_context_features_for_model(symbol=data)
|
||||
features = self.model.get_context_features_for_model(symbol=symbol)
|
||||
if features is not None and features.size > 0:
|
||||
# The presence of features indicates a signal. We'll return a generic HOLD
|
||||
# with a neutral confidence. This can be refined if ExtremaTrainer provides
|
||||
|
@ -134,8 +134,8 @@ class TrainingIntegration:
|
||||
|
||||
# Store experience in DQN memory
|
||||
dqn_agent = self.orchestrator.dqn_agent
|
||||
if hasattr(dqn_agent, 'store_experience'):
|
||||
dqn_agent.store_experience(
|
||||
if hasattr(dqn_agent, 'remember'):
|
||||
dqn_agent.remember(
|
||||
state=np.array(dqn_state),
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
@ -145,7 +145,7 @@ class TrainingIntegration:
|
||||
|
||||
# Trigger training if enough experiences
|
||||
if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32:
|
||||
dqn_agent.replay(batch_size=32)
|
||||
dqn_agent.replay()
|
||||
logger.info("DQN training step completed")
|
||||
|
||||
return True
|
||||
@ -345,7 +345,7 @@ class TrainingIntegration:
|
||||
# Perform training step if agent has replay method
|
||||
if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'):
|
||||
if len(cob_rl_agent.memory) > 32: # Enough samples to train
|
||||
loss = cob_rl_agent.replay(batch_size=min(32, len(cob_rl_agent.memory)))
|
||||
loss = cob_rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.info(f"COB RL trained on trade outcome: P&L=${pnl:.2f}, loss={loss:.4f}")
|
||||
return True
|
||||
|
@ -1060,8 +1060,8 @@ class EnhancedRealtimeTrainingSystem:
|
||||
total_loss += loss
|
||||
training_iterations += 1
|
||||
elif hasattr(rl_agent, 'replay'):
|
||||
# Fallback to replay method
|
||||
loss = rl_agent.replay(batch_size=len(batch))
|
||||
# Fallback to replay method - DQNAgent.replay() doesn't accept batch_size parameter
|
||||
loss = rl_agent.replay()
|
||||
if loss is not None:
|
||||
total_loss += loss
|
||||
training_iterations += 1
|
||||
@ -1129,25 +1129,10 @@ class EnhancedRealtimeTrainingSystem:
|
||||
state = combined_features # 2000-dimensional state
|
||||
|
||||
# Store experience in COB RL agent
|
||||
if hasattr(cob_rl_agent, 'store_experience'):
|
||||
experience = {
|
||||
'state': state,
|
||||
'action': action,
|
||||
'reward': reward,
|
||||
'next_state': state, # Will be updated with next observation
|
||||
'done': False,
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now(),
|
||||
'price': current_price,
|
||||
'cob_features': {
|
||||
'raw_tick_available': raw_tick_matrix is not None,
|
||||
'aggregated_available': aggregated_matrix is not None,
|
||||
'imbalance': combined_features[0] if len(combined_features) > 0 else 0,
|
||||
'spread': combined_features[1] if len(combined_features) > 1 else 0,
|
||||
'liquidity': combined_features[4] if len(combined_features) > 4 else 0
|
||||
}
|
||||
}
|
||||
cob_rl_agent.store_experience(experience)
|
||||
if hasattr(cob_rl_agent, 'remember'):
|
||||
# Use tuple format for DQN agent compatibility
|
||||
experience_tuple = (state, action, reward, state, False) # next_state = current state for now
|
||||
cob_rl_agent.remember(state, action, reward, state, False)
|
||||
training_updates += 1
|
||||
|
||||
# Perform COB RL training if enough experiences
|
||||
|
@ -3123,9 +3123,13 @@ class CleanTradingDashboard:
|
||||
if len(self.recent_decisions) > 200:
|
||||
self.recent_decisions = self.recent_decisions[-200:]
|
||||
|
||||
# Train ALL models on the signal (if executed)
|
||||
# Train ALL models on EVERY prediction result (not just executed ones)
|
||||
# This ensures models learn from all predictions, not just successful trades
|
||||
self._train_all_models_on_prediction(signal)
|
||||
|
||||
# Additional training weight for executed signals
|
||||
if signal['executed']:
|
||||
self._train_all_models_on_signal(signal)
|
||||
self._train_all_models_on_executed_signal(signal)
|
||||
|
||||
# Log signal processing
|
||||
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
||||
@ -3135,33 +3139,118 @@ class CleanTradingDashboard:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing dashboard signal: {e}")
|
||||
|
||||
def _train_all_models_on_signal(self, signal: Dict):
|
||||
"""Train ALL models on executed trade signal - Comprehensive training system"""
|
||||
def _train_all_models_on_prediction(self, signal: Dict):
|
||||
"""Train ALL models on EVERY prediction result - Comprehensive learning system"""
|
||||
try:
|
||||
# Get prediction outcome based on immediate price movement
|
||||
prediction_outcome = self._get_prediction_outcome_for_training(signal)
|
||||
if not prediction_outcome:
|
||||
return
|
||||
|
||||
# 1. Train DQN model on prediction outcome
|
||||
self._train_dqn_on_prediction(signal, prediction_outcome)
|
||||
|
||||
# 2. Train CNN model on prediction outcome
|
||||
self._train_cnn_on_prediction(signal, prediction_outcome)
|
||||
|
||||
# 3. Train Transformer model on prediction outcome
|
||||
self._train_transformer_on_prediction(signal, prediction_outcome)
|
||||
|
||||
# 4. Train COB RL model on prediction outcome
|
||||
self._train_cob_rl_on_prediction(signal, prediction_outcome)
|
||||
|
||||
# 5. Train Decision Fusion model on prediction outcome
|
||||
self._train_decision_fusion_on_prediction(signal, prediction_outcome)
|
||||
|
||||
logger.debug(f"Trained all models on {signal['action']} prediction with outcome: {prediction_outcome['accuracy']:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training models on prediction: {e}")
|
||||
|
||||
def _train_all_models_on_executed_signal(self, signal: Dict):
|
||||
"""Train ALL models on executed trade signal with enhanced weight - Comprehensive training system"""
|
||||
try:
|
||||
# Get trade outcome for training
|
||||
trade_outcome = self._get_trade_outcome_for_training(signal)
|
||||
if not trade_outcome:
|
||||
return
|
||||
|
||||
# 1. Train DQN model
|
||||
self._train_dqn_on_signal(signal, trade_outcome)
|
||||
# Enhanced training weight for executed signals (10x more important)
|
||||
enhanced_outcome = trade_outcome.copy()
|
||||
enhanced_outcome['training_weight'] = 10.0 # 10x weight for executed trades
|
||||
|
||||
# 2. Train CNN model
|
||||
self._train_cnn_on_signal(signal, trade_outcome)
|
||||
# 1. Train DQN model with enhanced weight
|
||||
self._train_dqn_on_executed_signal(signal, enhanced_outcome)
|
||||
|
||||
# 3. Train Transformer model
|
||||
self._train_transformer_on_signal(signal, trade_outcome)
|
||||
# 2. Train CNN model with enhanced weight
|
||||
self._train_cnn_on_executed_signal(signal, enhanced_outcome)
|
||||
|
||||
# 4. Train COB RL model
|
||||
self._train_cob_rl_on_signal(signal, trade_outcome)
|
||||
# 3. Train Transformer model with enhanced weight
|
||||
self._train_transformer_on_executed_signal(signal, enhanced_outcome)
|
||||
|
||||
# 5. Train Decision Fusion model
|
||||
self._train_decision_fusion_on_signal(signal, trade_outcome)
|
||||
# 4. Train COB RL model with enhanced weight
|
||||
self._train_cob_rl_on_executed_signal(signal, enhanced_outcome)
|
||||
|
||||
logger.debug(f"Trained all models on {signal['action']} signal with outcome: {trade_outcome['pnl']:.2f}")
|
||||
# 5. Train Decision Fusion model with enhanced weight
|
||||
self._train_decision_fusion_on_executed_signal(signal, enhanced_outcome)
|
||||
|
||||
logger.info(f"Enhanced training completed on {signal['action']} executed signal with outcome: {trade_outcome['pnl']:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training models on signal: {e}")
|
||||
logger.debug(f"Error training models on executed signal: {e}")
|
||||
|
||||
def _train_all_models_on_signal(self, signal: Dict):
|
||||
"""Legacy method - now redirects to new training system"""
|
||||
self._train_all_models_on_prediction(signal)
|
||||
|
||||
def _get_prediction_outcome_for_training(self, signal: Dict) -> Optional[Dict]:
|
||||
"""Get prediction outcome based on immediate price movement validation"""
|
||||
try:
|
||||
symbol = signal.get('symbol', 'ETH/USDT')
|
||||
action = signal.get('action', 'HOLD')
|
||||
confidence = signal.get('confidence', 0.0)
|
||||
prediction_time = signal.get('timestamp', datetime.now())
|
||||
|
||||
# Get current price to validate prediction
|
||||
current_price = self._get_current_price(symbol)
|
||||
if not current_price:
|
||||
return None
|
||||
|
||||
# Get price at prediction time (or recent price if not available)
|
||||
prediction_price = signal.get('price', current_price)
|
||||
|
||||
# Calculate immediate price movement (within 1-5 minutes)
|
||||
price_change = ((current_price - prediction_price) / prediction_price) * 100
|
||||
|
||||
# Determine if prediction was accurate based on action and price movement
|
||||
prediction_accurate = False
|
||||
if action == 'BUY' and price_change > 0.1: # 0.1% positive movement
|
||||
prediction_accurate = True
|
||||
elif action == 'SELL' and price_change < -0.1: # 0.1% negative movement
|
||||
prediction_accurate = True
|
||||
elif action == 'HOLD' and abs(price_change) < 0.2: # Stable price
|
||||
prediction_accurate = True
|
||||
|
||||
# Calculate accuracy score (0.0 to 1.0)
|
||||
accuracy_score = 0.5 # Base neutral score
|
||||
if prediction_accurate:
|
||||
accuracy_score = min(1.0, 0.5 + (confidence * 0.5)) # Higher confidence = higher score
|
||||
else:
|
||||
accuracy_score = max(0.0, 0.5 - (confidence * 0.5)) # Higher confidence = lower score for wrong predictions
|
||||
|
||||
return {
|
||||
'accuracy': accuracy_score,
|
||||
'price_change': price_change,
|
||||
'prediction_accurate': prediction_accurate,
|
||||
'confidence': confidence,
|
||||
'action': action,
|
||||
'prediction_time': prediction_time,
|
||||
'validation_time': datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting prediction outcome: {e}")
|
||||
return None
|
||||
|
||||
def _get_trade_outcome_for_training(self, signal: Dict) -> Optional[Dict]:
|
||||
"""Get trade outcome for training - either from completed trade or position change"""
|
||||
@ -3213,8 +3302,8 @@ class CleanTradingDashboard:
|
||||
logger.debug(f"Error getting trade outcome: {e}")
|
||||
return None
|
||||
|
||||
def _train_dqn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train DQN agent on executed signal with trade outcome"""
|
||||
def _train_dqn_on_prediction(self, signal: Dict, prediction_outcome: Dict):
|
||||
"""Train DQN agent on prediction outcome (every prediction, not just executed trades)"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
return
|
||||
@ -3223,31 +3312,66 @@ class CleanTradingDashboard:
|
||||
state_features = self._get_dqn_state_features(signal.get('symbol', 'ETH/USDT'), signal.get('price', 0))
|
||||
action = 0 if signal['action'] == 'BUY' else 1 # 0=BUY, 1=SELL
|
||||
|
||||
# Calculate reward based on trade outcome
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
reward = pnl * 100 # Scale reward for better learning
|
||||
|
||||
# Create next state (simplified)
|
||||
next_state_features = state_features.copy() # In real implementation, this would be the next market state
|
||||
# Calculate reward based on prediction accuracy
|
||||
accuracy = prediction_outcome.get('accuracy', 0.5)
|
||||
confidence = signal.get('confidence', 0.5)
|
||||
reward = (accuracy - 0.5) * 2.0 # Convert to [-1, 1] range
|
||||
|
||||
# Store experience in DQN memory
|
||||
if hasattr(self.orchestrator.rl_agent, 'remember'):
|
||||
self.orchestrator.rl_agent.remember(
|
||||
state_features, action, reward, next_state_features, done=True
|
||||
state_features, action, reward, state_features, done=True
|
||||
)
|
||||
|
||||
# Trigger training if enough samples
|
||||
if hasattr(self.orchestrator.rl_agent, 'memory') and len(self.orchestrator.rl_agent.memory) > 32:
|
||||
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
||||
loss = self.orchestrator.rl_agent.replay(batch_size=32)
|
||||
loss = self.orchestrator.rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"DQN trained on signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||
logger.debug(f"DQN trained on prediction - loss: {loss:.4f}, accuracy: {accuracy:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training DQN on signal: {e}")
|
||||
logger.debug(f"Error training DQN on prediction: {e}")
|
||||
|
||||
def _train_cnn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train CNN model on executed signal with trade outcome"""
|
||||
def _train_dqn_on_executed_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train DQN agent on executed signal with enhanced weight"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
return
|
||||
|
||||
# Create training data for DQN
|
||||
state_features = self._get_dqn_state_features(signal.get('symbol', 'ETH/USDT'), signal.get('price', 0))
|
||||
action = 0 if signal['action'] == 'BUY' else 1 # 0=BUY, 1=SELL
|
||||
|
||||
# Calculate enhanced reward based on trade outcome
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
training_weight = trade_outcome.get('training_weight', 1.0)
|
||||
reward = pnl * 100 * training_weight # Enhanced reward for executed trades
|
||||
|
||||
# Store experience in DQN memory with multiple entries for enhanced learning
|
||||
if hasattr(self.orchestrator.rl_agent, 'remember'):
|
||||
# Store multiple copies for enhanced learning
|
||||
for _ in range(int(training_weight)):
|
||||
self.orchestrator.rl_agent.remember(
|
||||
state_features, action, reward, state_features, done=True
|
||||
)
|
||||
|
||||
# Trigger training if enough samples
|
||||
if hasattr(self.orchestrator.rl_agent, 'memory') and len(self.orchestrator.rl_agent.memory) > 32:
|
||||
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
||||
loss = self.orchestrator.rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.info(f"DQN enhanced training on executed signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training DQN on executed signal: {e}")
|
||||
|
||||
def _train_dqn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Legacy method - redirects to new training system"""
|
||||
self._train_dqn_on_prediction(signal, trade_outcome)
|
||||
|
||||
def _train_cnn_on_prediction(self, signal: Dict, prediction_outcome: Dict):
|
||||
"""Train CNN model on prediction outcome (every prediction, not just executed trades)"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||
return
|
||||
@ -3261,25 +3385,64 @@ class CleanTradingDashboard:
|
||||
if not market_features:
|
||||
return
|
||||
|
||||
# Create target based on trade outcome
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
target = 1.0 if pnl > 0 else 0.0 # Binary classification: profitable vs not
|
||||
# Create target based on prediction accuracy
|
||||
accuracy = prediction_outcome.get('accuracy', 0.5)
|
||||
target = accuracy # Use accuracy as target (0.0 to 1.0)
|
||||
|
||||
# Prepare training data
|
||||
features = market_features.get('features', [])
|
||||
if features:
|
||||
# Convert to tensor format (simplified)
|
||||
import numpy as np
|
||||
feature_tensor = np.array(features, dtype=np.float32)
|
||||
target_tensor = np.array([target], dtype=np.float32)
|
||||
|
||||
# Train CNN model (if it has training method)
|
||||
# Train CNN model
|
||||
if hasattr(self.orchestrator.cnn_model, 'train_on_batch'):
|
||||
loss = self.orchestrator.cnn_model.train_on_batch(feature_tensor, target_tensor)
|
||||
logger.debug(f"CNN trained on signal - loss: {loss:.4f}, target: {target}")
|
||||
logger.debug(f"CNN trained on prediction - loss: {loss:.4f}, accuracy: {accuracy:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training CNN on signal: {e}")
|
||||
logger.debug(f"Error training CNN on prediction: {e}")
|
||||
|
||||
def _train_cnn_on_executed_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train CNN model on executed signal with enhanced weight"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||
return
|
||||
|
||||
# Create training data for CNN
|
||||
symbol = signal.get('symbol', 'ETH/USDT')
|
||||
current_price = signal.get('price', 0)
|
||||
|
||||
# Get market features
|
||||
market_features = self._get_cnn_features_and_predictions(symbol)
|
||||
if not market_features:
|
||||
return
|
||||
|
||||
# Create target based on trade outcome with enhanced weight
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
training_weight = trade_outcome.get('training_weight', 1.0)
|
||||
target = 1.0 if pnl > 0 else 0.0
|
||||
|
||||
# Prepare training data
|
||||
features = market_features.get('features', [])
|
||||
if features:
|
||||
import numpy as np
|
||||
feature_tensor = np.array(features, dtype=np.float32)
|
||||
target_tensor = np.array([target], dtype=np.float32)
|
||||
|
||||
# Train CNN model with multiple passes for enhanced learning
|
||||
if hasattr(self.orchestrator.cnn_model, 'train_on_batch'):
|
||||
for _ in range(int(training_weight)):
|
||||
loss = self.orchestrator.cnn_model.train_on_batch(feature_tensor, target_tensor)
|
||||
logger.info(f"CNN enhanced training on executed signal - loss: {loss:.4f}, pnl: {pnl:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training CNN on executed signal: {e}")
|
||||
|
||||
def _train_cnn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Legacy method - redirects to new training system"""
|
||||
self._train_cnn_on_prediction(signal, trade_outcome)
|
||||
|
||||
def _train_transformer_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train Transformer model on executed signal with trade outcome"""
|
||||
@ -3342,7 +3505,7 @@ class CleanTradingDashboard:
|
||||
# Trigger training if enough samples
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'memory') and len(self.orchestrator.cob_rl_agent.memory) > 32:
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'replay'):
|
||||
loss = self.orchestrator.cob_rl_agent.replay(batch_size=32)
|
||||
loss = self.orchestrator.cob_rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"COB RL trained on signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||
|
||||
@ -3999,7 +4162,7 @@ class CleanTradingDashboard:
|
||||
# Cold start training moved to core.training_integration.TrainingIntegration
|
||||
|
||||
def _clear_session(self):
|
||||
"""Clear session data"""
|
||||
"""Clear session data and persistent files"""
|
||||
try:
|
||||
# Reset session metrics
|
||||
self.session_pnl = 0.0
|
||||
@ -4016,11 +4179,96 @@ class CleanTradingDashboard:
|
||||
self.current_position = None
|
||||
self.pending_trade_case_id = None # Clear pending trade tracking
|
||||
|
||||
logger.info("Session data cleared")
|
||||
# Clear persistent trade log files
|
||||
self._clear_trade_logs()
|
||||
|
||||
# Clear orchestrator state if available
|
||||
if hasattr(self, 'orchestrator') and self.orchestrator:
|
||||
self._clear_orchestrator_state()
|
||||
|
||||
logger.info("Session data and trade logs cleared")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing session: {e}")
|
||||
|
||||
def _clear_trade_logs(self):
|
||||
"""Clear all trade log files"""
|
||||
try:
|
||||
import os
|
||||
import glob
|
||||
|
||||
# Clear trade_logs directory
|
||||
trade_logs_dir = "trade_logs"
|
||||
if os.path.exists(trade_logs_dir):
|
||||
# Remove all CSV files in trade_logs
|
||||
csv_files = glob.glob(os.path.join(trade_logs_dir, "*.csv"))
|
||||
for file in csv_files:
|
||||
try:
|
||||
os.remove(file)
|
||||
logger.info(f"Deleted trade log: {file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file}: {e}")
|
||||
|
||||
# Remove any .log files in trade_logs
|
||||
log_files = glob.glob(os.path.join(trade_logs_dir, "*.log"))
|
||||
for file in log_files:
|
||||
try:
|
||||
os.remove(file)
|
||||
logger.info(f"Deleted trade log: {file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file}: {e}")
|
||||
|
||||
# Clear recent log files in logs directory
|
||||
logs_dir = "logs"
|
||||
if os.path.exists(logs_dir):
|
||||
# Remove recent trading logs (keep older system logs)
|
||||
recent_logs = [
|
||||
"enhanced_trading.log",
|
||||
"realtime_rl_cob_trader.log",
|
||||
"simple_cob_dashboard.log",
|
||||
"integrated_rl_cob_system.log",
|
||||
"optimized_cob_system.log"
|
||||
]
|
||||
|
||||
for log_file in recent_logs:
|
||||
log_path = os.path.join(logs_dir, log_file)
|
||||
if os.path.exists(log_path):
|
||||
try:
|
||||
# Truncate the file instead of deleting to preserve file handles
|
||||
with open(log_path, 'w') as f:
|
||||
f.write("") # Clear file content
|
||||
logger.info(f"Cleared log file: {log_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear {log_path}: {e}")
|
||||
|
||||
logger.info("Trade logs cleared successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing trade logs: {e}")
|
||||
|
||||
def _clear_orchestrator_state(self):
|
||||
"""Clear orchestrator state and recent predictions"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'recent_decisions'):
|
||||
self.orchestrator.recent_decisions = {}
|
||||
|
||||
if hasattr(self.orchestrator, 'recent_dqn_predictions'):
|
||||
for symbol in self.orchestrator.recent_dqn_predictions:
|
||||
self.orchestrator.recent_dqn_predictions[symbol].clear()
|
||||
|
||||
if hasattr(self.orchestrator, 'recent_cnn_predictions'):
|
||||
for symbol in self.orchestrator.recent_cnn_predictions:
|
||||
self.orchestrator.recent_cnn_predictions[symbol].clear()
|
||||
|
||||
if hasattr(self.orchestrator, 'prediction_accuracy_history'):
|
||||
for symbol in self.orchestrator.prediction_accuracy_history:
|
||||
self.orchestrator.prediction_accuracy_history[symbol].clear()
|
||||
|
||||
logger.info("Orchestrator state cleared")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing orchestrator state: {e}")
|
||||
|
||||
def _store_all_models(self) -> bool:
|
||||
"""Store all current models to persistent storage"""
|
||||
try:
|
||||
@ -6112,7 +6360,7 @@ class CleanTradingDashboard:
|
||||
# Perform training step if agent has replay method
|
||||
if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'):
|
||||
if len(cob_rl_agent.memory) > 32: # Enough samples to train
|
||||
loss = cob_rl_agent.replay(batch_size=min(32, len(cob_rl_agent.memory)))
|
||||
loss = cob_rl_agent.replay()
|
||||
if loss is not None:
|
||||
total_loss += loss
|
||||
loss_count += 1
|
||||
|
Reference in New Issue
Block a user