training wip
This commit is contained in:
@ -1458,6 +1458,14 @@ class EnhancedRealtimeTrainingSystem:
|
||||
features_tensor = torch.from_numpy(features).float()
|
||||
targets_tensor = torch.from_numpy(targets).long()
|
||||
|
||||
# FIXED: Move tensors to same device as model
|
||||
device = next(model.parameters()).device
|
||||
features_tensor = features_tensor.to(device)
|
||||
targets_tensor = targets_tensor.to(device)
|
||||
|
||||
# Move criterion to same device as well
|
||||
criterion = criterion.to(device)
|
||||
|
||||
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
|
||||
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
|
||||
# This depends on the actual CNN model architecture. Assuming a simple CNN that expects (batch, channels, height, width)
|
||||
@ -1474,7 +1482,18 @@ class EnhancedRealtimeTrainingSystem:
|
||||
# Let's assume the CNN expects 2D input (batch_size, flattened_features)
|
||||
outputs = model(features_tensor)
|
||||
|
||||
loss = criterion(outputs, targets_tensor)
|
||||
# FIXED: Handle case where model returns tuple (extract the logits)
|
||||
if isinstance(outputs, tuple):
|
||||
# Assume the first element is the main output (logits)
|
||||
logits = outputs[0]
|
||||
elif isinstance(outputs, dict):
|
||||
# Handle dictionary output (get main prediction)
|
||||
logits = outputs.get('logits', outputs.get('predictions', outputs.get('output', list(outputs.values())[0])))
|
||||
else:
|
||||
# Single tensor output
|
||||
logits = outputs
|
||||
|
||||
loss = criterion(logits, targets_tensor)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@ -1482,8 +1501,122 @@ class EnhancedRealtimeTrainingSystem:
|
||||
return loss.item()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN training: {e}")
|
||||
logger.error(f"RT TRAINING: Error in CNN training: {e}")
|
||||
return 1.0 # Return default loss value in case of error
|
||||
|
||||
def _sample_prioritized_experiences(self) -> List[Dict]:
|
||||
"""Sample prioritized experiences for training"""
|
||||
try:
|
||||
experiences = []
|
||||
|
||||
# Sample from priority buffer first (high-priority experiences)
|
||||
if self.priority_buffer:
|
||||
priority_samples = min(len(self.priority_buffer), self.training_config['batch_size'] // 2)
|
||||
priority_experiences = random.sample(list(self.priority_buffer), priority_samples)
|
||||
experiences.extend(priority_experiences)
|
||||
|
||||
# Sample from regular experience buffer
|
||||
if self.experience_buffer:
|
||||
remaining_samples = self.training_config['batch_size'] - len(experiences)
|
||||
if remaining_samples > 0:
|
||||
regular_samples = min(len(self.experience_buffer), remaining_samples)
|
||||
regular_experiences = random.sample(list(self.experience_buffer), regular_samples)
|
||||
experiences.extend(regular_experiences)
|
||||
|
||||
# Convert experiences to DQN format
|
||||
dqn_experiences = []
|
||||
for exp in experiences:
|
||||
# Create next state by shifting current state (simple approximation)
|
||||
next_state = exp['state'].copy() if hasattr(exp['state'], 'copy') else exp['state']
|
||||
|
||||
# Simple reward based on recent market movement
|
||||
reward = self._calculate_experience_reward(exp)
|
||||
|
||||
# Action mapping: 0=BUY, 1=SELL, 2=HOLD
|
||||
action = self._determine_action_from_experience(exp)
|
||||
|
||||
dqn_exp = {
|
||||
'state': exp['state'],
|
||||
'action': action,
|
||||
'reward': reward,
|
||||
'next_state': next_state,
|
||||
'done': False # Episodes don't really "end" in continuous trading
|
||||
}
|
||||
|
||||
dqn_experiences.append(dqn_exp)
|
||||
|
||||
return dqn_experiences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sampling prioritized experiences: {e}")
|
||||
return []
|
||||
|
||||
def _calculate_experience_reward(self, experience: Dict) -> float:
|
||||
"""Calculate reward for an experience"""
|
||||
try:
|
||||
# Simple reward based on technical indicators and market events
|
||||
reward = 0.0
|
||||
|
||||
# Reward based on market events
|
||||
if experience.get('market_events', 0) > 0:
|
||||
reward += 0.1 # Bonus for learning from market events
|
||||
|
||||
# Reward based on technical indicators
|
||||
tech_indicators = experience.get('technical_indicators', {})
|
||||
if tech_indicators:
|
||||
# Reward for strong momentum
|
||||
momentum = tech_indicators.get('price_momentum', 0)
|
||||
reward += np.tanh(momentum * 10) # Bounded reward
|
||||
|
||||
# Penalize high volatility
|
||||
volatility = tech_indicators.get('volatility', 0)
|
||||
reward -= min(volatility * 5, 0.2) # Penalty for high volatility
|
||||
|
||||
# Reward based on COB features
|
||||
cob_features = experience.get('cob_features', [])
|
||||
if cob_features and len(cob_features) > 0:
|
||||
# Reward for strong order book imbalance
|
||||
imbalance = cob_features[0] if len(cob_features) > 0 else 0
|
||||
reward += abs(imbalance) * 0.1 # Reward for any imbalance signal
|
||||
|
||||
return max(-1.0, min(1.0, reward)) # Clamp to [-1, 1]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error calculating experience reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def _determine_action_from_experience(self, experience: Dict) -> int:
|
||||
"""Determine action from experience data"""
|
||||
try:
|
||||
# Use technical indicators to determine action
|
||||
tech_indicators = experience.get('technical_indicators', {})
|
||||
|
||||
if tech_indicators:
|
||||
momentum = tech_indicators.get('price_momentum', 0)
|
||||
rsi = tech_indicators.get('rsi', 50)
|
||||
|
||||
# Simple logic based on momentum and RSI
|
||||
if momentum > 0.005 and rsi < 70: # Upward momentum, not overbought
|
||||
return 0 # BUY
|
||||
elif momentum < -0.005 and rsi > 30: # Downward momentum, not oversold
|
||||
return 1 # SELL
|
||||
else:
|
||||
return 2 # HOLD
|
||||
|
||||
# Fallback to COB-based action
|
||||
cob_features = experience.get('cob_features', [])
|
||||
if cob_features and len(cob_features) > 0:
|
||||
imbalance = cob_features[0]
|
||||
if imbalance > 0.1:
|
||||
return 0 # BUY (bid imbalance)
|
||||
elif imbalance < -0.1:
|
||||
return 1 # SELL (ask imbalance)
|
||||
|
||||
return 2 # Default to HOLD
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error determining action from experience: {e}")
|
||||
return 2 # Default to HOLD
|
||||
|
||||
def _perform_validation(self):
|
||||
"""Perform validation to track model performance"""
|
||||
@ -1849,23 +1982,34 @@ class EnhancedRealtimeTrainingSystem:
|
||||
current_state = self._build_comprehensive_state()
|
||||
current_price = self._get_current_price_from_data(symbol)
|
||||
|
||||
if current_price is None:
|
||||
# SKIP prediction if price is invalid
|
||||
if current_price is None or current_price <= 0:
|
||||
logger.debug(f"Skipping DQN prediction for {symbol}: invalid price {current_price}")
|
||||
return
|
||||
|
||||
# Use DQN model to predict action (if available)
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
||||
and self.orchestrator.rl_agent):
|
||||
|
||||
# Get Q-values from model
|
||||
q_values = self.orchestrator.rl_agent.act(current_state, return_q_values=True)
|
||||
if isinstance(q_values, tuple):
|
||||
action, q_vals = q_values
|
||||
q_values = q_vals.tolist() if hasattr(q_vals, 'tolist') else [0, 0, 0]
|
||||
# Get action from DQN agent
|
||||
action = self.orchestrator.rl_agent.act(current_state, explore=False)
|
||||
|
||||
# Get Q-values by manually calling the model
|
||||
q_values = self._get_dqn_q_values(current_state)
|
||||
|
||||
# Calculate confidence from Q-values
|
||||
if q_values is not None and len(q_values) > 0:
|
||||
# Convert to probabilities and get confidence
|
||||
probs = torch.softmax(torch.tensor(q_values), dim=0).numpy()
|
||||
confidence = float(max(probs))
|
||||
q_values = q_values.tolist() if hasattr(q_values, 'tolist') else list(q_values)
|
||||
else:
|
||||
action = q_values
|
||||
confidence = 0.33
|
||||
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
||||
|
||||
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
|
||||
# Handle case where action is None (HOLD)
|
||||
if action is None:
|
||||
action = 2 # Map None to HOLD action
|
||||
|
||||
else:
|
||||
# Fallback to technical analysis-based prediction
|
||||
@ -1893,8 +2037,8 @@ class EnhancedRealtimeTrainingSystem:
|
||||
if symbol in self.pending_predictions:
|
||||
self.pending_predictions[symbol].append(prediction)
|
||||
|
||||
# Add to recent predictions for display (only if confident enough)
|
||||
if confidence > 0.4:
|
||||
# Add to recent predictions for display (only if confident enough AND valid price)
|
||||
if confidence > 0.4 and current_price > 0:
|
||||
display_prediction = {
|
||||
'timestamp': prediction_time,
|
||||
'price': current_price,
|
||||
@ -1907,11 +2051,44 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
self.last_prediction_time[symbol] = int(current_time)
|
||||
|
||||
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} price=${current_price:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating forward DQN prediction: {e}")
|
||||
|
||||
def _get_dqn_q_values(self, state: np.ndarray) -> Optional[np.ndarray]:
|
||||
"""Get Q-values from DQN agent without performing action selection"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
return None
|
||||
|
||||
rl_agent = self.orchestrator.rl_agent
|
||||
|
||||
# Convert state to tensor
|
||||
if isinstance(state, np.ndarray):
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(rl_agent.device)
|
||||
else:
|
||||
state_tensor = state.unsqueeze(0).to(rl_agent.device)
|
||||
|
||||
# Get Q-values directly from policy network
|
||||
with torch.no_grad():
|
||||
policy_output = rl_agent.policy_net(state_tensor)
|
||||
|
||||
# Handle different output formats
|
||||
if isinstance(policy_output, dict):
|
||||
q_values = policy_output.get('q_values', policy_output.get('Q_values', list(policy_output.values())[0]))
|
||||
elif isinstance(policy_output, tuple):
|
||||
q_values = policy_output[0] # Assume first element is Q-values
|
||||
else:
|
||||
q_values = policy_output
|
||||
|
||||
# Convert to numpy
|
||||
return q_values.cpu().data.numpy()[0]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting DQN Q-values: {e}")
|
||||
return None
|
||||
|
||||
def _generate_forward_cnn_prediction(self, symbol: str, current_time: float):
|
||||
"""Generate a CNN prediction for future price direction"""
|
||||
try:
|
||||
@ -1919,9 +2096,15 @@ class EnhancedRealtimeTrainingSystem:
|
||||
current_price = self._get_current_price_from_data(symbol)
|
||||
price_sequence = self._get_historical_price_sequence(symbol, periods=15)
|
||||
|
||||
if current_price is None or len(price_sequence) < 15:
|
||||
# SKIP prediction if price is invalid
|
||||
if current_price is None or current_price <= 0:
|
||||
logger.debug(f"Skipping CNN prediction for {symbol}: invalid price {current_price}")
|
||||
return
|
||||
|
||||
|
||||
if len(price_sequence) < 15:
|
||||
logger.debug(f"Skipping CNN prediction for {symbol}: insufficient data")
|
||||
return
|
||||
|
||||
# Use CNN model to predict direction (if available)
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'cnn_model')
|
||||
and self.orchestrator.cnn_model):
|
||||
@ -1974,8 +2157,8 @@ class EnhancedRealtimeTrainingSystem:
|
||||
if symbol in self.pending_predictions:
|
||||
self.pending_predictions[symbol].append(prediction)
|
||||
|
||||
# Add to recent predictions for display (only if confident enough)
|
||||
if confidence > 0.5:
|
||||
# Add to recent predictions for display (only if confident enough AND valid prices)
|
||||
if confidence > 0.5 and current_price > 0 and predicted_price > 0:
|
||||
display_prediction = {
|
||||
'timestamp': prediction_time,
|
||||
'current_price': current_price,
|
||||
@ -1986,7 +2169,7 @@ class EnhancedRealtimeTrainingSystem:
|
||||
if symbol in self.recent_cnn_predictions:
|
||||
self.recent_cnn_predictions[symbol].append(display_prediction)
|
||||
|
||||
logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} price=${current_price:.2f} -> ${predicted_price:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating forward CNN prediction: {e}")
|
||||
@ -2077,8 +2260,24 @@ class EnhancedRealtimeTrainingSystem:
|
||||
def _get_current_price_from_data(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price from real-time data streams"""
|
||||
try:
|
||||
# First, try to get from data provider (most reliable)
|
||||
if self.data_provider:
|
||||
price = self.data_provider.get_current_price(symbol)
|
||||
if price and price > 0:
|
||||
return price
|
||||
|
||||
# Fallback to internal buffer
|
||||
if len(self.real_time_data['ohlcv_1m']) > 0:
|
||||
return self.real_time_data['ohlcv_1m'][-1]['close']
|
||||
price = self.real_time_data['ohlcv_1m'][-1]['close']
|
||||
if price and price > 0:
|
||||
return price
|
||||
|
||||
# Fallback to orchestrator price
|
||||
if self.orchestrator:
|
||||
price = self.orchestrator._get_current_price(symbol)
|
||||
if price and price > 0:
|
||||
return price
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price: {e}")
|
||||
|
Reference in New Issue
Block a user