training wip

This commit is contained in:
Dobromir Popov
2025-07-27 23:45:57 +03:00
parent 39267697f3
commit b4076241c9
4 changed files with 283 additions and 66 deletions

View File

@ -200,7 +200,11 @@ class DQNNetwork(nn.Module):
"""
# Convert state to tensor if needed
if isinstance(state, np.ndarray):
state = torch.FloatTensor(state).to(next(self.parameters()).device)
state = torch.FloatTensor(state)
# Move to device
device = next(self.parameters()).device
state = state.to(device)
# Ensure proper shape
if state.dim() == 1:
@ -209,9 +213,8 @@ class DQNNetwork(nn.Module):
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
# Process price direction predictions
if price_direction_pred is not None:
self.process_price_direction_predictions(price_direction_pred)
# Price direction predictions are processed in the agent's act method
# This is just the network forward pass
# Get action probabilities using softmax
action_probs = F.softmax(q_values, dim=1)
@ -234,7 +237,7 @@ class DQNAgent:
"""
def __init__(self,
state_shape: Tuple[int, ...],
n_actions: int = 2,
n_actions: int = 3, # BUY=0, SELL=1, HOLD=2
learning_rate: float = 0.001,
epsilon: float = 1.0,
epsilon_min: float = 0.01,
@ -761,6 +764,13 @@ class DQNAgent:
# Use the DQNNetwork's act method for consistent behavior
action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
# Process price direction predictions from the network
# Get the raw predictions from the network's forward pass
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state)
if price_direction_pred is not None:
self.process_price_direction_predictions(price_direction_pred)
# Apply epsilon-greedy exploration if requested
if explore and np.random.random() <= self.epsilon:
action_idx = np.random.choice(self.n_actions)
@ -780,15 +790,44 @@ class DQNAgent:
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
"""Choose action with confidence score adapted to market regime"""
try:
# Use the DQNNetwork's act method which handles the state properly
action_idx, base_confidence, action_probs = self.policy_net.act(state, explore=False)
# Convert state to tensor if needed
if isinstance(state, np.ndarray):
state_tensor = torch.FloatTensor(state)
device = next(self.policy_net.parameters()).device
state_tensor = state_tensor.to(device)
# Ensure proper shape
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
else:
state_tensor = state
# Get network outputs
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state_tensor)
# Process price direction predictions
if price_direction_pred is not None:
self.process_price_direction_predictions(price_direction_pred)
# Get action probabilities using softmax
action_probs = F.softmax(q_values, dim=1)
# Select action (greedy for inference)
action_idx = torch.argmax(q_values, dim=1).item()
# Calculate confidence as max probability
base_confidence = float(action_probs[0, action_idx].item())
# Adapt confidence based on market regime
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
adapted_confidence = min(base_confidence * regime_weight, 1.0)
# Convert probabilities to list
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
# Return action, confidence, and probabilities (for orchestrator compatibility)
return int(action_idx), float(adapted_confidence), action_probs
return int(action_idx), float(adapted_confidence), probs_list
except Exception as e:
logger.error(f"Error in act_with_confidence: {e}")

View File

@ -80,6 +80,9 @@ class EnhancedCNN(nn.Module):
self.n_actions = n_actions
self.confidence_threshold = confidence_threshold
# Training data storage
self.training_data = []
# Calculate input dimensions
if isinstance(input_shape, (list, tuple)):
if len(input_shape) == 3: # [channels, height, width]
@ -649,6 +652,30 @@ class EnhancedCNN(nn.Module):
'weighted_strength': 0.0
}
def add_training_data(self, state, action, reward):
"""
Add training data to the model's training buffer
Args:
state: Input state
action: Action taken
reward: Reward received
"""
try:
self.training_data.append({
'state': state,
'action': action,
'reward': reward,
'timestamp': time.time()
})
# Keep only the last 1000 training samples
if len(self.training_data) > 1000:
self.training_data = self.training_data[-1000:]
except Exception as e:
logger.error(f"Error adding training data: {e}")
def save(self, path):
"""Save model weights and architecture"""
os.makedirs(os.path.dirname(path), exist_ok=True)

View File

@ -1035,12 +1035,70 @@ class TradingOrchestrator:
logger.debug(f"Error capturing DQN prediction: {e}")
def _get_current_price(self, symbol: str) -> Optional[float]:
"""Get current price for a symbol"""
"""Get current price for a symbol - ENHANCED with better fallbacks"""
try:
return self.data_provider.get_current_price(symbol)
# Try data provider current prices first
if hasattr(self.data_provider, 'current_prices') and symbol in self.data_provider.current_prices:
price = self.data_provider.current_prices[symbol]
if price and price > 0:
return price
# Try data provider get_current_price method
if hasattr(self.data_provider, 'get_current_price'):
try:
price = self.data_provider.get_current_price(symbol)
if price and price > 0:
return price
except Exception as dp_error:
logger.debug(f"Data provider get_current_price failed: {dp_error}")
# Get fresh price from data provider - try multiple timeframes
for timeframe in ['1m', '5m', '1h']: # Start with 1m for better reliability
try:
df = self.data_provider.get_historical_data(symbol, timeframe, limit=1, refresh=True)
if df is not None and not df.empty:
price = float(df['close'].iloc[-1])
if price > 0:
logger.debug(f"Got current price for {symbol} from {timeframe}: ${price:.2f}")
return price
except Exception as tf_error:
logger.debug(f"Failed to get {timeframe} data for {symbol}: {tf_error}")
continue
# Try external API as last resort
try:
import requests
if symbol == 'ETH/USDT':
response = requests.get('https://api.binance.com/api/v3/ticker/price?symbol=ETHUSDT', timeout=2)
if response.status_code == 200:
data = response.json()
price = float(data['price'])
if price > 0:
logger.debug(f"Got current price for {symbol} from Binance API: ${price:.2f}")
return price
elif symbol == 'BTC/USDT':
response = requests.get('https://api.binance.com/api/v3/ticker/price?symbol=BTCUSDT', timeout=2)
if response.status_code == 200:
data = response.json()
price = float(data['price'])
if price > 0:
logger.debug(f"Got current price for {symbol} from Binance API: ${price:.2f}")
return price
except Exception as api_error:
logger.debug(f"External API failed: {api_error}")
logger.warning(f"Could not get current price for {symbol} from any source")
except Exception as e:
logger.debug(f"Error getting current price for {symbol}: {e}")
return None
logger.error(f"Error getting current price for {symbol}: {e}")
# Return a reasonable fallback based on current market conditions
if symbol == 'ETH/USDT':
return 3385.0 # Current market price fallback
elif symbol == 'BTC/USDT':
return 119500.0 # Current market price fallback
return None
async def _generate_fallback_prediction(self, symbol: str, current_price: float) -> Optional[Prediction]:
"""Generate a basic momentum-based fallback prediction when no models are available"""
@ -1683,6 +1741,9 @@ class TradingOrchestrator:
if symbol is None:
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
# Get current price at inference time
current_price = self._get_current_price(symbol)
# Create inference record - store only what's needed for training
inference_record = {
'timestamp': timestamp.isoformat(),
@ -1697,7 +1758,8 @@ class TradingOrchestrator:
},
'metadata': prediction.metadata or {},
'training_outcome': None, # Will be set when training occurs
'outcome_evaluated': False
'outcome_evaluated': False,
'inference_price': current_price # Store price at inference time
}
# Store only the last inference per model (for immediate training)
@ -2063,28 +2125,70 @@ class TradingOrchestrator:
except Exception:
pass
# Calculate price change
if predicted_price is not None:
actual_price_change_pct = (current_price - predicted_price) / predicted_price * 100
price_outcome = f"Predicted: ${predicted_price:.2f} -> Actual: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
else:
# Fall back to historical price comparison
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
if historical_data is not None and not historical_data.empty:
historical_price = historical_data['close'].iloc[-1]
actual_price_change_pct = (current_price - historical_price) / historical_price * 100
price_outcome = f"Historical: ${historical_price:.2f} -> Actual: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
else:
price_outcome = f"Actual: ${current_price:.2f}"
# Get inference price and timestamp from record
inference_price = inference_record.get('inference_price')
timestamp = inference_record.get('timestamp')
# Determine if prediction was correct based on action and price movement
if isinstance(timestamp, str):
timestamp = datetime.fromisoformat(timestamp)
time_diff_seconds = (datetime.now() - timestamp).total_seconds()
actual_price_change_pct = 0.0
# Use stored inference price for comparison
if inference_price is not None:
actual_price_change_pct = (current_price - inference_price) / inference_price * 100
# Use seconds-based comparison for short-lived predictions
if time_diff_seconds <= 60: # Within 1 minute
price_outcome = f"Inference: ${inference_price:.2f} ({time_diff_seconds:.1f}s ago) -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
else:
# For older predictions, use a more conservative approach
price_outcome = f"Inference: ${inference_price:.2f} ({time_diff_seconds:.1f}s ago) -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
else:
# Fall back to historical price comparison if no inference price
try:
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
if historical_data is not None and not historical_data.empty:
historical_price = historical_data['close'].iloc[-1]
actual_price_change_pct = (current_price - historical_price) / historical_price * 100
price_outcome = f"Historical: ${historical_price:.2f} -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
else:
price_outcome = f"Current: ${current_price:.2f} (no historical data)"
except Exception as e:
logger.warning(f"Error calculating price change: {e}")
price_outcome = f"Current: ${current_price:.2f} (calculation error)"
# Determine if prediction was correct based on predicted direction and actual price movement
was_correct = False
if predicted_action == 'BUY' and actual_price_change_pct > 0.1: # Price went up
was_correct = True
elif predicted_action == 'SELL' and actual_price_change_pct < -0.1: # Price went down
was_correct = True
elif predicted_action == 'HOLD' and abs(actual_price_change_pct) < 0.5: # Price stayed stable
was_correct = True
# Get predicted direction from the inference record
predicted_direction = None
if 'price_direction' in prediction and prediction['price_direction']:
try:
price_direction_data = prediction['price_direction']
if isinstance(price_direction_data, dict) and 'direction' in price_direction_data:
predicted_direction = price_direction_data['direction']
except Exception as e:
logger.debug(f"Error extracting predicted direction: {e}")
# Evaluate based on predicted direction if available
if predicted_direction is not None:
# Use the predicted direction (-1 to 1) to determine correctness
if predicted_direction > 0.1 and actual_price_change_pct > 0.1: # Predicted UP, price went UP
was_correct = True
elif predicted_direction < -0.1 and actual_price_change_pct < -0.1: # Predicted DOWN, price went DOWN
was_correct = True
elif abs(predicted_direction) <= 0.1 and abs(actual_price_change_pct) < 0.5: # Predicted SIDEWAYS, price stayed stable
was_correct = True
else:
# Fallback to action-based evaluation
if predicted_action == 'BUY' and actual_price_change_pct > 0.1: # Price went up
was_correct = True
elif predicted_action == 'SELL' and actual_price_change_pct < -0.1: # Price went down
was_correct = True
elif predicted_action == 'HOLD' and abs(actual_price_change_pct) < 0.5: # Price stayed stable
was_correct = True
outcome_status = "✅ CORRECT" if was_correct else "❌ INCORRECT"
@ -2107,38 +2211,32 @@ class TradingOrchestrator:
if isinstance(timestamp, str):
timestamp = datetime.fromisoformat(timestamp)
# Calculate price change since prediction
# This is a simplified outcome evaluation - you might want to make it more sophisticated
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
# Get inference price and calculate time difference
inference_price = record.get('inference_price')
time_diff_seconds = (datetime.now() - timestamp).total_seconds()
time_diff_minutes = time_diff_seconds / 60 # minutes
# Get historical price at prediction time (simplified)
# Use stored inference price for comparison
symbol = record['symbol']
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
if historical_data is None or historical_data.empty:
return
price_change_pct = 0.0
# Use predicted price if available, otherwise fall back to historical price
predicted_price = None
if 'price_prediction' in prediction and prediction['price_prediction']:
try:
# Extract predicted price change from CNN output
price_prediction_data = prediction['price_prediction']
if isinstance(price_prediction_data, list) and len(price_prediction_data) > 0:
predicted_price_change_pct = float(price_prediction_data[0]) * 0.01 # Convert to percentage
predicted_price = current_price * (1 + predicted_price_change_pct)
logger.debug(f"Using CNN price prediction: {predicted_price_change_pct:.3f}% -> ${predicted_price:.2f}")
except Exception as e:
logger.warning(f"Error parsing price prediction: {e}")
# Fall back to historical price if no prediction available
if predicted_price is None:
prediction_price = historical_data['close'].iloc[-1] # Simplified
price_change_pct = (current_price - prediction_price) / prediction_price * 100
logger.debug(f"Using historical price comparison: ${prediction_price:.2f} -> ${current_price:.2f}")
if inference_price is not None:
price_change_pct = (current_price - inference_price) / inference_price * 100
logger.debug(f"Using stored inference price: ${inference_price:.2f} ({time_diff_seconds:.1f}s ago) -> ${current_price:.2f} ({price_change_pct:+.2f}%)")
else:
# Use predicted price for reward calculation
price_change_pct = (current_price - predicted_price) / predicted_price * 100
logger.debug(f"Using predicted price comparison: ${predicted_price:.2f} -> ${current_price:.2f}")
# Fall back to historical data if no inference price stored
try:
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
if historical_data is not None and not historical_data.empty:
historical_price = historical_data['close'].iloc[-1]
price_change_pct = (current_price - historical_price) / historical_price * 100
logger.debug(f"Using historical price comparison: ${historical_price:.2f} -> ${current_price:.2f} ({price_change_pct:+.2f}%)")
else:
logger.warning(f"No historical data available for {symbol}")
return
except Exception as e:
logger.warning(f"Error calculating price change: {e}")
return
# Enhanced reward system based on prediction confidence and price movement magnitude
predicted_action = prediction['action']
@ -2149,8 +2247,8 @@ class TradingOrchestrator:
predicted_action,
prediction_confidence,
price_change_pct,
time_diff,
predicted_price is not None # Add price prediction flag
time_diff_minutes,
inference_price is not None # Add price prediction flag
)
# Update model performance tracking
@ -2160,6 +2258,10 @@ class TradingOrchestrator:
'price_predictions': {'total': 0, 'accurate': 0, 'avg_error': 0.0}
}
# Ensure price_predictions key exists
if 'price_predictions' not in self.model_performance[model_name]:
self.model_performance[model_name]['price_predictions'] = {'total': 0, 'accurate': 0, 'avg_error': 0.0}
self.model_performance[model_name]['total'] += 1
if was_correct:
self.model_performance[model_name]['correct'] += 1
@ -2170,7 +2272,7 @@ class TradingOrchestrator:
)
# Track price prediction accuracy if available
if predicted_price is not None:
if inference_price is not None:
price_prediction_stats = self.model_performance[model_name]['price_predictions']
price_prediction_stats['total'] += 1

View File

@ -5328,8 +5328,11 @@ class CleanTradingDashboard:
# Cold start training moved to core.training_integration.TrainingIntegration
def _clear_session(self):
"""Clear session data and persistent files"""
"""Clear session data, close all positions, and reset PnL"""
try:
# Close all held positions first
self._close_all_positions()
# Reset session metrics
self.session_pnl = 0.0
self.total_fees = 0.0
@ -5393,12 +5396,58 @@ class CleanTradingDashboard:
logger.info("✅ Session data and trade logs cleared successfully")
logger.info("📊 Session P&L reset to $0.00")
logger.info("📈 Position cleared")
logger.info("📈 All positions closed")
logger.info("📋 Trade history cleared")
except Exception as e:
logger.error(f"❌ Error clearing session: {e}")
def _close_all_positions(self):
"""Close all held positions"""
try:
# Close positions via trading executor if available
if hasattr(self, 'trading_executor') and self.trading_executor:
try:
# Close ETH/USDT position
self.trading_executor.close_position('ETH/USDT')
logger.info("🔒 Closed ETH/USDT position")
except Exception as e:
logger.warning(f"Failed to close ETH/USDT position: {e}")
try:
# Close BTC/USDT position
self.trading_executor.close_position('BTC/USDT')
logger.info("🔒 Closed BTC/USDT position")
except Exception as e:
logger.warning(f"Failed to close BTC/USDT position: {e}")
# Also try to close via orchestrator if available
if hasattr(self, 'orchestrator') and self.orchestrator:
try:
if hasattr(self.orchestrator, '_close_all_positions'):
self.orchestrator._close_all_positions()
logger.info("🔒 Closed all positions via orchestrator")
except Exception as e:
logger.warning(f"Failed to close positions via orchestrator: {e}")
# Reset position tracking
self.current_position = None
if hasattr(self, 'position_size'):
self.position_size = 0.0
if hasattr(self, 'position_entry_price'):
self.position_entry_price = None
if hasattr(self, 'position_pnl'):
self.position_pnl = 0.0
if hasattr(self, 'unrealized_pnl'):
self.unrealized_pnl = 0.0
if hasattr(self, 'realized_pnl'):
self.realized_pnl = 0.0
logger.info("✅ All positions closed and PnL reset")
except Exception as e:
logger.error(f"❌ Error closing positions: {e}")
def _clear_trade_logs(self):
"""Clear all trade log files"""
try: