training fixes and enhancements wip
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@ -16,7 +16,7 @@ models/trading_agent_final.pt.backup
|
|||||||
*.pt
|
*.pt
|
||||||
*.backup
|
*.backup
|
||||||
logs/
|
logs/
|
||||||
trade_logs/
|
# trade_logs/
|
||||||
*.csv
|
*.csv
|
||||||
cache/
|
cache/
|
||||||
realtime_chart.log
|
realtime_chart.log
|
||||||
|
@ -57,7 +57,10 @@ class DQNAgent:
|
|||||||
else:
|
else:
|
||||||
# 1D state
|
# 1D state
|
||||||
if isinstance(state_shape, tuple):
|
if isinstance(state_shape, tuple):
|
||||||
self.state_dim = state_shape[0]
|
if len(state_shape) == 0:
|
||||||
|
self.state_dim = 1 # Safe default for empty tuple
|
||||||
|
else:
|
||||||
|
self.state_dim = state_shape[0]
|
||||||
else:
|
else:
|
||||||
self.state_dim = state_shape
|
self.state_dim = state_shape
|
||||||
|
|
||||||
@ -615,8 +618,8 @@ class DQNAgent:
|
|||||||
self.recent_actions.append(action)
|
self.recent_actions.append(action)
|
||||||
return action
|
return action
|
||||||
else:
|
else:
|
||||||
# Return None to indicate HOLD (don't change position)
|
# Return 1 (HOLD) as a safe default if action is None
|
||||||
return None
|
return 1
|
||||||
|
|
||||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
||||||
"""Choose action with confidence score adapted to market regime (from Enhanced DQN)"""
|
"""Choose action with confidence score adapted to market regime (from Enhanced DQN)"""
|
||||||
@ -647,7 +650,10 @@ class DQNAgent:
|
|||||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||||
|
|
||||||
return action, adapted_confidence
|
# Always return int, float
|
||||||
|
if action is None:
|
||||||
|
return 1, 0.1
|
||||||
|
return int(action), float(adapted_confidence)
|
||||||
|
|
||||||
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
|
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
|
||||||
"""
|
"""
|
||||||
@ -748,13 +754,29 @@ class DQNAgent:
|
|||||||
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
||||||
experiences = [self.memory[i] for i in indices]
|
experiences = [self.memory[i] for i in indices]
|
||||||
|
|
||||||
|
# Sanitize and stack states and next_states
|
||||||
|
sanitized_states = []
|
||||||
|
sanitized_next_states = []
|
||||||
|
for i, e in enumerate(experiences):
|
||||||
|
try:
|
||||||
|
state = np.asarray(e[0], dtype=np.float32)
|
||||||
|
next_state = np.asarray(e[3], dtype=np.float32)
|
||||||
|
sanitized_states.append(state)
|
||||||
|
sanitized_next_states.append(next_state)
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"[DQNAgent] Bad experience at index {i}: {ex}")
|
||||||
|
continue
|
||||||
|
if not sanitized_states or not sanitized_next_states:
|
||||||
|
print("[DQNAgent] No valid states in replay batch.")
|
||||||
|
return 0.0 # Return float instead of None for consistency
|
||||||
|
states = torch.FloatTensor(np.stack(sanitized_states)).to(self.device)
|
||||||
|
next_states = torch.FloatTensor(np.stack(sanitized_next_states)).to(self.device)
|
||||||
|
|
||||||
# Choose appropriate replay method
|
# Choose appropriate replay method
|
||||||
if self.use_mixed_precision:
|
if self.use_mixed_precision:
|
||||||
# Convert experiences to tensors for mixed precision
|
# Convert experiences to tensors for mixed precision
|
||||||
states = torch.FloatTensor(np.array([e[0] for e in experiences])).to(self.device)
|
|
||||||
actions = torch.LongTensor(np.array([e[1] for e in experiences])).to(self.device)
|
actions = torch.LongTensor(np.array([e[1] for e in experiences])).to(self.device)
|
||||||
rewards = torch.FloatTensor(np.array([e[2] for e in experiences])).to(self.device)
|
rewards = torch.FloatTensor(np.array([e[2] for e in experiences])).to(self.device)
|
||||||
next_states = torch.FloatTensor(np.array([e[3] for e in experiences])).to(self.device)
|
|
||||||
dones = torch.FloatTensor(np.array([e[4] for e in experiences])).to(self.device)
|
dones = torch.FloatTensor(np.array([e[4] for e in experiences])).to(self.device)
|
||||||
|
|
||||||
# Use mixed precision replay
|
# Use mixed precision replay
|
||||||
@ -829,29 +851,32 @@ class DQNAgent:
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _replay_standard(self, experiences=None):
|
def _replay_standard(self, *args):
|
||||||
"""Standard training step without mixed precision"""
|
"""Standard training step without mixed precision"""
|
||||||
try:
|
try:
|
||||||
# Use experiences if provided, otherwise sample from memory
|
# Support both (experiences,) and (states, actions, rewards, next_states, dones)
|
||||||
if experiences is None:
|
if len(args) == 1:
|
||||||
# If memory is too small, skip training
|
experiences = args[0]
|
||||||
if len(self.memory) < self.batch_size:
|
# Use experiences if provided, otherwise sample from memory
|
||||||
return 0.0
|
if experiences is None:
|
||||||
|
# If memory is too small, skip training
|
||||||
# Sample random mini-batch from memory
|
if len(self.memory) < self.batch_size:
|
||||||
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
return 0.0
|
||||||
batch = [self.memory[i] for i in indices]
|
# Sample random mini-batch from memory
|
||||||
experiences = batch
|
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
||||||
|
batch = [self.memory[i] for i in indices]
|
||||||
# Unpack experiences
|
experiences = batch
|
||||||
states, actions, rewards, next_states, dones = zip(*experiences)
|
# Unpack experiences
|
||||||
|
states, actions, rewards, next_states, dones = zip(*experiences)
|
||||||
# Convert to PyTorch tensors
|
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
actions = torch.LongTensor(np.array(actions)).to(self.device)
|
||||||
actions = torch.LongTensor(np.array(actions)).to(self.device)
|
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
|
||||||
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
|
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
||||||
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
elif len(args) == 5:
|
||||||
|
states, actions, rewards, next_states, dones = args
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid arguments to _replay_standard")
|
||||||
|
|
||||||
# Get current Q values
|
# Get current Q values
|
||||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
||||||
|
@ -437,13 +437,34 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
def predict(self, data=None):
|
def predict(self, data=None):
|
||||||
try:
|
try:
|
||||||
# ExtremaTrainer provides context features, not a direct prediction
|
# Handle different data types that might be passed to ExtremaTrainer
|
||||||
# We assume 'data' here is the 'symbol' string to pass to get_context_features_for_model
|
symbol = None
|
||||||
if not isinstance(data, str):
|
|
||||||
logger.warning(f"ExtremaTrainerInterface.predict received non-string data: {type(data)}. Cannot get context features.")
|
if isinstance(data, str):
|
||||||
|
# Direct symbol string
|
||||||
|
symbol = data
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
# Dictionary with symbol information
|
||||||
|
symbol = data.get('symbol')
|
||||||
|
elif isinstance(data, np.ndarray):
|
||||||
|
# Numpy array - extract symbol from metadata or use default
|
||||||
|
# For now, use the first symbol from the model's symbols list
|
||||||
|
if hasattr(self.model, 'symbols') and self.model.symbols:
|
||||||
|
symbol = self.model.symbols[0]
|
||||||
|
else:
|
||||||
|
symbol = 'ETH/USDT' # Default fallback
|
||||||
|
else:
|
||||||
|
# Unknown data type - use default symbol
|
||||||
|
if hasattr(self.model, 'symbols') and self.model.symbols:
|
||||||
|
symbol = self.model.symbols[0]
|
||||||
|
else:
|
||||||
|
symbol = 'ETH/USDT' # Default fallback
|
||||||
|
|
||||||
|
if not symbol:
|
||||||
|
logger.warning(f"ExtremaTrainerInterface.predict could not determine symbol from data: {type(data)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
features = self.model.get_context_features_for_model(symbol=data)
|
features = self.model.get_context_features_for_model(symbol=symbol)
|
||||||
if features is not None and features.size > 0:
|
if features is not None and features.size > 0:
|
||||||
# The presence of features indicates a signal. We'll return a generic HOLD
|
# The presence of features indicates a signal. We'll return a generic HOLD
|
||||||
# with a neutral confidence. This can be refined if ExtremaTrainer provides
|
# with a neutral confidence. This can be refined if ExtremaTrainer provides
|
||||||
|
@ -134,8 +134,8 @@ class TrainingIntegration:
|
|||||||
|
|
||||||
# Store experience in DQN memory
|
# Store experience in DQN memory
|
||||||
dqn_agent = self.orchestrator.dqn_agent
|
dqn_agent = self.orchestrator.dqn_agent
|
||||||
if hasattr(dqn_agent, 'store_experience'):
|
if hasattr(dqn_agent, 'remember'):
|
||||||
dqn_agent.store_experience(
|
dqn_agent.remember(
|
||||||
state=np.array(dqn_state),
|
state=np.array(dqn_state),
|
||||||
action=action_idx,
|
action=action_idx,
|
||||||
reward=reward,
|
reward=reward,
|
||||||
@ -145,7 +145,7 @@ class TrainingIntegration:
|
|||||||
|
|
||||||
# Trigger training if enough experiences
|
# Trigger training if enough experiences
|
||||||
if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32:
|
if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32:
|
||||||
dqn_agent.replay(batch_size=32)
|
dqn_agent.replay()
|
||||||
logger.info("DQN training step completed")
|
logger.info("DQN training step completed")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -345,7 +345,7 @@ class TrainingIntegration:
|
|||||||
# Perform training step if agent has replay method
|
# Perform training step if agent has replay method
|
||||||
if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'):
|
if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'):
|
||||||
if len(cob_rl_agent.memory) > 32: # Enough samples to train
|
if len(cob_rl_agent.memory) > 32: # Enough samples to train
|
||||||
loss = cob_rl_agent.replay(batch_size=min(32, len(cob_rl_agent.memory)))
|
loss = cob_rl_agent.replay()
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
logger.info(f"COB RL trained on trade outcome: P&L=${pnl:.2f}, loss={loss:.4f}")
|
logger.info(f"COB RL trained on trade outcome: P&L=${pnl:.2f}, loss={loss:.4f}")
|
||||||
return True
|
return True
|
||||||
|
@ -1060,8 +1060,8 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
total_loss += loss
|
total_loss += loss
|
||||||
training_iterations += 1
|
training_iterations += 1
|
||||||
elif hasattr(rl_agent, 'replay'):
|
elif hasattr(rl_agent, 'replay'):
|
||||||
# Fallback to replay method
|
# Fallback to replay method - DQNAgent.replay() doesn't accept batch_size parameter
|
||||||
loss = rl_agent.replay(batch_size=len(batch))
|
loss = rl_agent.replay()
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
training_iterations += 1
|
training_iterations += 1
|
||||||
@ -1129,25 +1129,10 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
state = combined_features # 2000-dimensional state
|
state = combined_features # 2000-dimensional state
|
||||||
|
|
||||||
# Store experience in COB RL agent
|
# Store experience in COB RL agent
|
||||||
if hasattr(cob_rl_agent, 'store_experience'):
|
if hasattr(cob_rl_agent, 'remember'):
|
||||||
experience = {
|
# Use tuple format for DQN agent compatibility
|
||||||
'state': state,
|
experience_tuple = (state, action, reward, state, False) # next_state = current state for now
|
||||||
'action': action,
|
cob_rl_agent.remember(state, action, reward, state, False)
|
||||||
'reward': reward,
|
|
||||||
'next_state': state, # Will be updated with next observation
|
|
||||||
'done': False,
|
|
||||||
'symbol': symbol,
|
|
||||||
'timestamp': datetime.now(),
|
|
||||||
'price': current_price,
|
|
||||||
'cob_features': {
|
|
||||||
'raw_tick_available': raw_tick_matrix is not None,
|
|
||||||
'aggregated_available': aggregated_matrix is not None,
|
|
||||||
'imbalance': combined_features[0] if len(combined_features) > 0 else 0,
|
|
||||||
'spread': combined_features[1] if len(combined_features) > 1 else 0,
|
|
||||||
'liquidity': combined_features[4] if len(combined_features) > 4 else 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cob_rl_agent.store_experience(experience)
|
|
||||||
training_updates += 1
|
training_updates += 1
|
||||||
|
|
||||||
# Perform COB RL training if enough experiences
|
# Perform COB RL training if enough experiences
|
||||||
|
@ -3123,9 +3123,13 @@ class CleanTradingDashboard:
|
|||||||
if len(self.recent_decisions) > 200:
|
if len(self.recent_decisions) > 200:
|
||||||
self.recent_decisions = self.recent_decisions[-200:]
|
self.recent_decisions = self.recent_decisions[-200:]
|
||||||
|
|
||||||
# Train ALL models on the signal (if executed)
|
# Train ALL models on EVERY prediction result (not just executed ones)
|
||||||
|
# This ensures models learn from all predictions, not just successful trades
|
||||||
|
self._train_all_models_on_prediction(signal)
|
||||||
|
|
||||||
|
# Additional training weight for executed signals
|
||||||
if signal['executed']:
|
if signal['executed']:
|
||||||
self._train_all_models_on_signal(signal)
|
self._train_all_models_on_executed_signal(signal)
|
||||||
|
|
||||||
# Log signal processing
|
# Log signal processing
|
||||||
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
||||||
@ -3135,33 +3139,118 @@ class CleanTradingDashboard:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing dashboard signal: {e}")
|
logger.error(f"Error processing dashboard signal: {e}")
|
||||||
|
|
||||||
def _train_all_models_on_signal(self, signal: Dict):
|
def _train_all_models_on_prediction(self, signal: Dict):
|
||||||
"""Train ALL models on executed trade signal - Comprehensive training system"""
|
"""Train ALL models on EVERY prediction result - Comprehensive learning system"""
|
||||||
|
try:
|
||||||
|
# Get prediction outcome based on immediate price movement
|
||||||
|
prediction_outcome = self._get_prediction_outcome_for_training(signal)
|
||||||
|
if not prediction_outcome:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1. Train DQN model on prediction outcome
|
||||||
|
self._train_dqn_on_prediction(signal, prediction_outcome)
|
||||||
|
|
||||||
|
# 2. Train CNN model on prediction outcome
|
||||||
|
self._train_cnn_on_prediction(signal, prediction_outcome)
|
||||||
|
|
||||||
|
# 3. Train Transformer model on prediction outcome
|
||||||
|
self._train_transformer_on_prediction(signal, prediction_outcome)
|
||||||
|
|
||||||
|
# 4. Train COB RL model on prediction outcome
|
||||||
|
self._train_cob_rl_on_prediction(signal, prediction_outcome)
|
||||||
|
|
||||||
|
# 5. Train Decision Fusion model on prediction outcome
|
||||||
|
self._train_decision_fusion_on_prediction(signal, prediction_outcome)
|
||||||
|
|
||||||
|
logger.debug(f"Trained all models on {signal['action']} prediction with outcome: {prediction_outcome['accuracy']:.2f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error training models on prediction: {e}")
|
||||||
|
|
||||||
|
def _train_all_models_on_executed_signal(self, signal: Dict):
|
||||||
|
"""Train ALL models on executed trade signal with enhanced weight - Comprehensive training system"""
|
||||||
try:
|
try:
|
||||||
# Get trade outcome for training
|
# Get trade outcome for training
|
||||||
trade_outcome = self._get_trade_outcome_for_training(signal)
|
trade_outcome = self._get_trade_outcome_for_training(signal)
|
||||||
if not trade_outcome:
|
if not trade_outcome:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 1. Train DQN model
|
# Enhanced training weight for executed signals (10x more important)
|
||||||
self._train_dqn_on_signal(signal, trade_outcome)
|
enhanced_outcome = trade_outcome.copy()
|
||||||
|
enhanced_outcome['training_weight'] = 10.0 # 10x weight for executed trades
|
||||||
|
|
||||||
# 2. Train CNN model
|
# 1. Train DQN model with enhanced weight
|
||||||
self._train_cnn_on_signal(signal, trade_outcome)
|
self._train_dqn_on_executed_signal(signal, enhanced_outcome)
|
||||||
|
|
||||||
# 3. Train Transformer model
|
# 2. Train CNN model with enhanced weight
|
||||||
self._train_transformer_on_signal(signal, trade_outcome)
|
self._train_cnn_on_executed_signal(signal, enhanced_outcome)
|
||||||
|
|
||||||
# 4. Train COB RL model
|
# 3. Train Transformer model with enhanced weight
|
||||||
self._train_cob_rl_on_signal(signal, trade_outcome)
|
self._train_transformer_on_executed_signal(signal, enhanced_outcome)
|
||||||
|
|
||||||
# 5. Train Decision Fusion model
|
# 4. Train COB RL model with enhanced weight
|
||||||
self._train_decision_fusion_on_signal(signal, trade_outcome)
|
self._train_cob_rl_on_executed_signal(signal, enhanced_outcome)
|
||||||
|
|
||||||
logger.debug(f"Trained all models on {signal['action']} signal with outcome: {trade_outcome['pnl']:.2f}")
|
# 5. Train Decision Fusion model with enhanced weight
|
||||||
|
self._train_decision_fusion_on_executed_signal(signal, enhanced_outcome)
|
||||||
|
|
||||||
|
logger.info(f"Enhanced training completed on {signal['action']} executed signal with outcome: {trade_outcome['pnl']:.2f}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error training models on signal: {e}")
|
logger.debug(f"Error training models on executed signal: {e}")
|
||||||
|
|
||||||
|
def _train_all_models_on_signal(self, signal: Dict):
|
||||||
|
"""Legacy method - now redirects to new training system"""
|
||||||
|
self._train_all_models_on_prediction(signal)
|
||||||
|
|
||||||
|
def _get_prediction_outcome_for_training(self, signal: Dict) -> Optional[Dict]:
|
||||||
|
"""Get prediction outcome based on immediate price movement validation"""
|
||||||
|
try:
|
||||||
|
symbol = signal.get('symbol', 'ETH/USDT')
|
||||||
|
action = signal.get('action', 'HOLD')
|
||||||
|
confidence = signal.get('confidence', 0.0)
|
||||||
|
prediction_time = signal.get('timestamp', datetime.now())
|
||||||
|
|
||||||
|
# Get current price to validate prediction
|
||||||
|
current_price = self._get_current_price(symbol)
|
||||||
|
if not current_price:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get price at prediction time (or recent price if not available)
|
||||||
|
prediction_price = signal.get('price', current_price)
|
||||||
|
|
||||||
|
# Calculate immediate price movement (within 1-5 minutes)
|
||||||
|
price_change = ((current_price - prediction_price) / prediction_price) * 100
|
||||||
|
|
||||||
|
# Determine if prediction was accurate based on action and price movement
|
||||||
|
prediction_accurate = False
|
||||||
|
if action == 'BUY' and price_change > 0.1: # 0.1% positive movement
|
||||||
|
prediction_accurate = True
|
||||||
|
elif action == 'SELL' and price_change < -0.1: # 0.1% negative movement
|
||||||
|
prediction_accurate = True
|
||||||
|
elif action == 'HOLD' and abs(price_change) < 0.2: # Stable price
|
||||||
|
prediction_accurate = True
|
||||||
|
|
||||||
|
# Calculate accuracy score (0.0 to 1.0)
|
||||||
|
accuracy_score = 0.5 # Base neutral score
|
||||||
|
if prediction_accurate:
|
||||||
|
accuracy_score = min(1.0, 0.5 + (confidence * 0.5)) # Higher confidence = higher score
|
||||||
|
else:
|
||||||
|
accuracy_score = max(0.0, 0.5 - (confidence * 0.5)) # Higher confidence = lower score for wrong predictions
|
||||||
|
|
||||||
|
return {
|
||||||
|
'accuracy': accuracy_score,
|
||||||
|
'price_change': price_change,
|
||||||
|
'prediction_accurate': prediction_accurate,
|
||||||
|
'confidence': confidence,
|
||||||
|
'action': action,
|
||||||
|
'prediction_time': prediction_time,
|
||||||
|
'validation_time': datetime.now()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error getting prediction outcome: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def _get_trade_outcome_for_training(self, signal: Dict) -> Optional[Dict]:
|
def _get_trade_outcome_for_training(self, signal: Dict) -> Optional[Dict]:
|
||||||
"""Get trade outcome for training - either from completed trade or position change"""
|
"""Get trade outcome for training - either from completed trade or position change"""
|
||||||
@ -3213,8 +3302,8 @@ class CleanTradingDashboard:
|
|||||||
logger.debug(f"Error getting trade outcome: {e}")
|
logger.debug(f"Error getting trade outcome: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _train_dqn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
def _train_dqn_on_prediction(self, signal: Dict, prediction_outcome: Dict):
|
||||||
"""Train DQN agent on executed signal with trade outcome"""
|
"""Train DQN agent on prediction outcome (every prediction, not just executed trades)"""
|
||||||
try:
|
try:
|
||||||
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||||
return
|
return
|
||||||
@ -3223,31 +3312,66 @@ class CleanTradingDashboard:
|
|||||||
state_features = self._get_dqn_state_features(signal.get('symbol', 'ETH/USDT'), signal.get('price', 0))
|
state_features = self._get_dqn_state_features(signal.get('symbol', 'ETH/USDT'), signal.get('price', 0))
|
||||||
action = 0 if signal['action'] == 'BUY' else 1 # 0=BUY, 1=SELL
|
action = 0 if signal['action'] == 'BUY' else 1 # 0=BUY, 1=SELL
|
||||||
|
|
||||||
# Calculate reward based on trade outcome
|
# Calculate reward based on prediction accuracy
|
||||||
pnl = trade_outcome.get('pnl', 0)
|
accuracy = prediction_outcome.get('accuracy', 0.5)
|
||||||
reward = pnl * 100 # Scale reward for better learning
|
confidence = signal.get('confidence', 0.5)
|
||||||
|
reward = (accuracy - 0.5) * 2.0 # Convert to [-1, 1] range
|
||||||
# Create next state (simplified)
|
|
||||||
next_state_features = state_features.copy() # In real implementation, this would be the next market state
|
|
||||||
|
|
||||||
# Store experience in DQN memory
|
# Store experience in DQN memory
|
||||||
if hasattr(self.orchestrator.rl_agent, 'remember'):
|
if hasattr(self.orchestrator.rl_agent, 'remember'):
|
||||||
self.orchestrator.rl_agent.remember(
|
self.orchestrator.rl_agent.remember(
|
||||||
state_features, action, reward, next_state_features, done=True
|
state_features, action, reward, state_features, done=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trigger training if enough samples
|
# Trigger training if enough samples
|
||||||
if hasattr(self.orchestrator.rl_agent, 'memory') and len(self.orchestrator.rl_agent.memory) > 32:
|
if hasattr(self.orchestrator.rl_agent, 'memory') and len(self.orchestrator.rl_agent.memory) > 32:
|
||||||
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
||||||
loss = self.orchestrator.rl_agent.replay(batch_size=32)
|
loss = self.orchestrator.rl_agent.replay()
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
logger.debug(f"DQN trained on signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
logger.debug(f"DQN trained on prediction - loss: {loss:.4f}, accuracy: {accuracy:.2f}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error training DQN on signal: {e}")
|
logger.debug(f"Error training DQN on prediction: {e}")
|
||||||
|
|
||||||
def _train_cnn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
def _train_dqn_on_executed_signal(self, signal: Dict, trade_outcome: Dict):
|
||||||
"""Train CNN model on executed signal with trade outcome"""
|
"""Train DQN agent on executed signal with enhanced weight"""
|
||||||
|
try:
|
||||||
|
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create training data for DQN
|
||||||
|
state_features = self._get_dqn_state_features(signal.get('symbol', 'ETH/USDT'), signal.get('price', 0))
|
||||||
|
action = 0 if signal['action'] == 'BUY' else 1 # 0=BUY, 1=SELL
|
||||||
|
|
||||||
|
# Calculate enhanced reward based on trade outcome
|
||||||
|
pnl = trade_outcome.get('pnl', 0)
|
||||||
|
training_weight = trade_outcome.get('training_weight', 1.0)
|
||||||
|
reward = pnl * 100 * training_weight # Enhanced reward for executed trades
|
||||||
|
|
||||||
|
# Store experience in DQN memory with multiple entries for enhanced learning
|
||||||
|
if hasattr(self.orchestrator.rl_agent, 'remember'):
|
||||||
|
# Store multiple copies for enhanced learning
|
||||||
|
for _ in range(int(training_weight)):
|
||||||
|
self.orchestrator.rl_agent.remember(
|
||||||
|
state_features, action, reward, state_features, done=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger training if enough samples
|
||||||
|
if hasattr(self.orchestrator.rl_agent, 'memory') and len(self.orchestrator.rl_agent.memory) > 32:
|
||||||
|
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
||||||
|
loss = self.orchestrator.rl_agent.replay()
|
||||||
|
if loss is not None:
|
||||||
|
logger.info(f"DQN enhanced training on executed signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error training DQN on executed signal: {e}")
|
||||||
|
|
||||||
|
def _train_dqn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||||
|
"""Legacy method - redirects to new training system"""
|
||||||
|
self._train_dqn_on_prediction(signal, trade_outcome)
|
||||||
|
|
||||||
|
def _train_cnn_on_prediction(self, signal: Dict, prediction_outcome: Dict):
|
||||||
|
"""Train CNN model on prediction outcome (every prediction, not just executed trades)"""
|
||||||
try:
|
try:
|
||||||
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||||
return
|
return
|
||||||
@ -3261,25 +3385,64 @@ class CleanTradingDashboard:
|
|||||||
if not market_features:
|
if not market_features:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create target based on trade outcome
|
# Create target based on prediction accuracy
|
||||||
pnl = trade_outcome.get('pnl', 0)
|
accuracy = prediction_outcome.get('accuracy', 0.5)
|
||||||
target = 1.0 if pnl > 0 else 0.0 # Binary classification: profitable vs not
|
target = accuracy # Use accuracy as target (0.0 to 1.0)
|
||||||
|
|
||||||
# Prepare training data
|
# Prepare training data
|
||||||
features = market_features.get('features', [])
|
features = market_features.get('features', [])
|
||||||
if features:
|
if features:
|
||||||
# Convert to tensor format (simplified)
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
feature_tensor = np.array(features, dtype=np.float32)
|
feature_tensor = np.array(features, dtype=np.float32)
|
||||||
target_tensor = np.array([target], dtype=np.float32)
|
target_tensor = np.array([target], dtype=np.float32)
|
||||||
|
|
||||||
# Train CNN model (if it has training method)
|
# Train CNN model
|
||||||
if hasattr(self.orchestrator.cnn_model, 'train_on_batch'):
|
if hasattr(self.orchestrator.cnn_model, 'train_on_batch'):
|
||||||
loss = self.orchestrator.cnn_model.train_on_batch(feature_tensor, target_tensor)
|
loss = self.orchestrator.cnn_model.train_on_batch(feature_tensor, target_tensor)
|
||||||
logger.debug(f"CNN trained on signal - loss: {loss:.4f}, target: {target}")
|
logger.debug(f"CNN trained on prediction - loss: {loss:.4f}, accuracy: {accuracy:.2f}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error training CNN on signal: {e}")
|
logger.debug(f"Error training CNN on prediction: {e}")
|
||||||
|
|
||||||
|
def _train_cnn_on_executed_signal(self, signal: Dict, trade_outcome: Dict):
|
||||||
|
"""Train CNN model on executed signal with enhanced weight"""
|
||||||
|
try:
|
||||||
|
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create training data for CNN
|
||||||
|
symbol = signal.get('symbol', 'ETH/USDT')
|
||||||
|
current_price = signal.get('price', 0)
|
||||||
|
|
||||||
|
# Get market features
|
||||||
|
market_features = self._get_cnn_features_and_predictions(symbol)
|
||||||
|
if not market_features:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create target based on trade outcome with enhanced weight
|
||||||
|
pnl = trade_outcome.get('pnl', 0)
|
||||||
|
training_weight = trade_outcome.get('training_weight', 1.0)
|
||||||
|
target = 1.0 if pnl > 0 else 0.0
|
||||||
|
|
||||||
|
# Prepare training data
|
||||||
|
features = market_features.get('features', [])
|
||||||
|
if features:
|
||||||
|
import numpy as np
|
||||||
|
feature_tensor = np.array(features, dtype=np.float32)
|
||||||
|
target_tensor = np.array([target], dtype=np.float32)
|
||||||
|
|
||||||
|
# Train CNN model with multiple passes for enhanced learning
|
||||||
|
if hasattr(self.orchestrator.cnn_model, 'train_on_batch'):
|
||||||
|
for _ in range(int(training_weight)):
|
||||||
|
loss = self.orchestrator.cnn_model.train_on_batch(feature_tensor, target_tensor)
|
||||||
|
logger.info(f"CNN enhanced training on executed signal - loss: {loss:.4f}, pnl: {pnl:.2f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error training CNN on executed signal: {e}")
|
||||||
|
|
||||||
|
def _train_cnn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||||
|
"""Legacy method - redirects to new training system"""
|
||||||
|
self._train_cnn_on_prediction(signal, trade_outcome)
|
||||||
|
|
||||||
def _train_transformer_on_signal(self, signal: Dict, trade_outcome: Dict):
|
def _train_transformer_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||||
"""Train Transformer model on executed signal with trade outcome"""
|
"""Train Transformer model on executed signal with trade outcome"""
|
||||||
@ -3342,7 +3505,7 @@ class CleanTradingDashboard:
|
|||||||
# Trigger training if enough samples
|
# Trigger training if enough samples
|
||||||
if hasattr(self.orchestrator.cob_rl_agent, 'memory') and len(self.orchestrator.cob_rl_agent.memory) > 32:
|
if hasattr(self.orchestrator.cob_rl_agent, 'memory') and len(self.orchestrator.cob_rl_agent.memory) > 32:
|
||||||
if hasattr(self.orchestrator.cob_rl_agent, 'replay'):
|
if hasattr(self.orchestrator.cob_rl_agent, 'replay'):
|
||||||
loss = self.orchestrator.cob_rl_agent.replay(batch_size=32)
|
loss = self.orchestrator.cob_rl_agent.replay()
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
logger.debug(f"COB RL trained on signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
logger.debug(f"COB RL trained on signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||||
|
|
||||||
@ -3999,7 +4162,7 @@ class CleanTradingDashboard:
|
|||||||
# Cold start training moved to core.training_integration.TrainingIntegration
|
# Cold start training moved to core.training_integration.TrainingIntegration
|
||||||
|
|
||||||
def _clear_session(self):
|
def _clear_session(self):
|
||||||
"""Clear session data"""
|
"""Clear session data and persistent files"""
|
||||||
try:
|
try:
|
||||||
# Reset session metrics
|
# Reset session metrics
|
||||||
self.session_pnl = 0.0
|
self.session_pnl = 0.0
|
||||||
@ -4016,11 +4179,96 @@ class CleanTradingDashboard:
|
|||||||
self.current_position = None
|
self.current_position = None
|
||||||
self.pending_trade_case_id = None # Clear pending trade tracking
|
self.pending_trade_case_id = None # Clear pending trade tracking
|
||||||
|
|
||||||
logger.info("Session data cleared")
|
# Clear persistent trade log files
|
||||||
|
self._clear_trade_logs()
|
||||||
|
|
||||||
|
# Clear orchestrator state if available
|
||||||
|
if hasattr(self, 'orchestrator') and self.orchestrator:
|
||||||
|
self._clear_orchestrator_state()
|
||||||
|
|
||||||
|
logger.info("Session data and trade logs cleared")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error clearing session: {e}")
|
logger.error(f"Error clearing session: {e}")
|
||||||
|
|
||||||
|
def _clear_trade_logs(self):
|
||||||
|
"""Clear all trade log files"""
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
# Clear trade_logs directory
|
||||||
|
trade_logs_dir = "trade_logs"
|
||||||
|
if os.path.exists(trade_logs_dir):
|
||||||
|
# Remove all CSV files in trade_logs
|
||||||
|
csv_files = glob.glob(os.path.join(trade_logs_dir, "*.csv"))
|
||||||
|
for file in csv_files:
|
||||||
|
try:
|
||||||
|
os.remove(file)
|
||||||
|
logger.info(f"Deleted trade log: {file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete {file}: {e}")
|
||||||
|
|
||||||
|
# Remove any .log files in trade_logs
|
||||||
|
log_files = glob.glob(os.path.join(trade_logs_dir, "*.log"))
|
||||||
|
for file in log_files:
|
||||||
|
try:
|
||||||
|
os.remove(file)
|
||||||
|
logger.info(f"Deleted trade log: {file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete {file}: {e}")
|
||||||
|
|
||||||
|
# Clear recent log files in logs directory
|
||||||
|
logs_dir = "logs"
|
||||||
|
if os.path.exists(logs_dir):
|
||||||
|
# Remove recent trading logs (keep older system logs)
|
||||||
|
recent_logs = [
|
||||||
|
"enhanced_trading.log",
|
||||||
|
"realtime_rl_cob_trader.log",
|
||||||
|
"simple_cob_dashboard.log",
|
||||||
|
"integrated_rl_cob_system.log",
|
||||||
|
"optimized_cob_system.log"
|
||||||
|
]
|
||||||
|
|
||||||
|
for log_file in recent_logs:
|
||||||
|
log_path = os.path.join(logs_dir, log_file)
|
||||||
|
if os.path.exists(log_path):
|
||||||
|
try:
|
||||||
|
# Truncate the file instead of deleting to preserve file handles
|
||||||
|
with open(log_path, 'w') as f:
|
||||||
|
f.write("") # Clear file content
|
||||||
|
logger.info(f"Cleared log file: {log_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clear {log_path}: {e}")
|
||||||
|
|
||||||
|
logger.info("Trade logs cleared successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error clearing trade logs: {e}")
|
||||||
|
|
||||||
|
def _clear_orchestrator_state(self):
|
||||||
|
"""Clear orchestrator state and recent predictions"""
|
||||||
|
try:
|
||||||
|
if hasattr(self.orchestrator, 'recent_decisions'):
|
||||||
|
self.orchestrator.recent_decisions = {}
|
||||||
|
|
||||||
|
if hasattr(self.orchestrator, 'recent_dqn_predictions'):
|
||||||
|
for symbol in self.orchestrator.recent_dqn_predictions:
|
||||||
|
self.orchestrator.recent_dqn_predictions[symbol].clear()
|
||||||
|
|
||||||
|
if hasattr(self.orchestrator, 'recent_cnn_predictions'):
|
||||||
|
for symbol in self.orchestrator.recent_cnn_predictions:
|
||||||
|
self.orchestrator.recent_cnn_predictions[symbol].clear()
|
||||||
|
|
||||||
|
if hasattr(self.orchestrator, 'prediction_accuracy_history'):
|
||||||
|
for symbol in self.orchestrator.prediction_accuracy_history:
|
||||||
|
self.orchestrator.prediction_accuracy_history[symbol].clear()
|
||||||
|
|
||||||
|
logger.info("Orchestrator state cleared")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error clearing orchestrator state: {e}")
|
||||||
|
|
||||||
def _store_all_models(self) -> bool:
|
def _store_all_models(self) -> bool:
|
||||||
"""Store all current models to persistent storage"""
|
"""Store all current models to persistent storage"""
|
||||||
try:
|
try:
|
||||||
@ -6112,7 +6360,7 @@ class CleanTradingDashboard:
|
|||||||
# Perform training step if agent has replay method
|
# Perform training step if agent has replay method
|
||||||
if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'):
|
if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'):
|
||||||
if len(cob_rl_agent.memory) > 32: # Enough samples to train
|
if len(cob_rl_agent.memory) > 32: # Enough samples to train
|
||||||
loss = cob_rl_agent.replay(batch_size=min(32, len(cob_rl_agent.memory)))
|
loss = cob_rl_agent.replay()
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
loss_count += 1
|
loss_count += 1
|
||||||
|
Reference in New Issue
Block a user