training wip
This commit is contained in:
@ -200,7 +200,11 @@ class DQNNetwork(nn.Module):
|
||||
"""
|
||||
# Convert state to tensor if needed
|
||||
if isinstance(state, np.ndarray):
|
||||
state = torch.FloatTensor(state).to(next(self.parameters()).device)
|
||||
state = torch.FloatTensor(state)
|
||||
|
||||
# Move to device
|
||||
device = next(self.parameters()).device
|
||||
state = state.to(device)
|
||||
|
||||
# Ensure proper shape
|
||||
if state.dim() == 1:
|
||||
@ -209,9 +213,8 @@ class DQNNetwork(nn.Module):
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
|
||||
|
||||
# Process price direction predictions
|
||||
if price_direction_pred is not None:
|
||||
self.process_price_direction_predictions(price_direction_pred)
|
||||
# Price direction predictions are processed in the agent's act method
|
||||
# This is just the network forward pass
|
||||
|
||||
# Get action probabilities using softmax
|
||||
action_probs = F.softmax(q_values, dim=1)
|
||||
@ -234,7 +237,7 @@ class DQNAgent:
|
||||
"""
|
||||
def __init__(self,
|
||||
state_shape: Tuple[int, ...],
|
||||
n_actions: int = 2,
|
||||
n_actions: int = 3, # BUY=0, SELL=1, HOLD=2
|
||||
learning_rate: float = 0.001,
|
||||
epsilon: float = 1.0,
|
||||
epsilon_min: float = 0.01,
|
||||
@ -761,6 +764,13 @@ class DQNAgent:
|
||||
# Use the DQNNetwork's act method for consistent behavior
|
||||
action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
|
||||
|
||||
# Process price direction predictions from the network
|
||||
# Get the raw predictions from the network's forward pass
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state)
|
||||
if price_direction_pred is not None:
|
||||
self.process_price_direction_predictions(price_direction_pred)
|
||||
|
||||
# Apply epsilon-greedy exploration if requested
|
||||
if explore and np.random.random() <= self.epsilon:
|
||||
action_idx = np.random.choice(self.n_actions)
|
||||
@ -780,15 +790,44 @@ class DQNAgent:
|
||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
|
||||
"""Choose action with confidence score adapted to market regime"""
|
||||
try:
|
||||
# Use the DQNNetwork's act method which handles the state properly
|
||||
action_idx, base_confidence, action_probs = self.policy_net.act(state, explore=False)
|
||||
# Convert state to tensor if needed
|
||||
if isinstance(state, np.ndarray):
|
||||
state_tensor = torch.FloatTensor(state)
|
||||
device = next(self.policy_net.parameters()).device
|
||||
state_tensor = state_tensor.to(device)
|
||||
|
||||
# Ensure proper shape
|
||||
if state_tensor.dim() == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
else:
|
||||
state_tensor = state
|
||||
|
||||
# Get network outputs
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state_tensor)
|
||||
|
||||
# Process price direction predictions
|
||||
if price_direction_pred is not None:
|
||||
self.process_price_direction_predictions(price_direction_pred)
|
||||
|
||||
# Get action probabilities using softmax
|
||||
action_probs = F.softmax(q_values, dim=1)
|
||||
|
||||
# Select action (greedy for inference)
|
||||
action_idx = torch.argmax(q_values, dim=1).item()
|
||||
|
||||
# Calculate confidence as max probability
|
||||
base_confidence = float(action_probs[0, action_idx].item())
|
||||
|
||||
# Adapt confidence based on market regime
|
||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||
|
||||
# Convert probabilities to list
|
||||
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
|
||||
|
||||
# Return action, confidence, and probabilities (for orchestrator compatibility)
|
||||
return int(action_idx), float(adapted_confidence), action_probs
|
||||
return int(action_idx), float(adapted_confidence), probs_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in act_with_confidence: {e}")
|
||||
|
@ -80,6 +80,9 @@ class EnhancedCNN(nn.Module):
|
||||
self.n_actions = n_actions
|
||||
self.confidence_threshold = confidence_threshold
|
||||
|
||||
# Training data storage
|
||||
self.training_data = []
|
||||
|
||||
# Calculate input dimensions
|
||||
if isinstance(input_shape, (list, tuple)):
|
||||
if len(input_shape) == 3: # [channels, height, width]
|
||||
@ -648,6 +651,30 @@ class EnhancedCNN(nn.Module):
|
||||
'strength': 0.0,
|
||||
'weighted_strength': 0.0
|
||||
}
|
||||
|
||||
def add_training_data(self, state, action, reward):
|
||||
"""
|
||||
Add training data to the model's training buffer
|
||||
|
||||
Args:
|
||||
state: Input state
|
||||
action: Action taken
|
||||
reward: Reward received
|
||||
"""
|
||||
try:
|
||||
self.training_data.append({
|
||||
'state': state,
|
||||
'action': action,
|
||||
'reward': reward,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
# Keep only the last 1000 training samples
|
||||
if len(self.training_data) > 1000:
|
||||
self.training_data = self.training_data[-1000:]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training data: {e}")
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
|
@ -1035,12 +1035,70 @@ class TradingOrchestrator:
|
||||
logger.debug(f"Error capturing DQN prediction: {e}")
|
||||
|
||||
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price for a symbol"""
|
||||
"""Get current price for a symbol - ENHANCED with better fallbacks"""
|
||||
try:
|
||||
return self.data_provider.get_current_price(symbol)
|
||||
# Try data provider current prices first
|
||||
if hasattr(self.data_provider, 'current_prices') and symbol in self.data_provider.current_prices:
|
||||
price = self.data_provider.current_prices[symbol]
|
||||
if price and price > 0:
|
||||
return price
|
||||
|
||||
# Try data provider get_current_price method
|
||||
if hasattr(self.data_provider, 'get_current_price'):
|
||||
try:
|
||||
price = self.data_provider.get_current_price(symbol)
|
||||
if price and price > 0:
|
||||
return price
|
||||
except Exception as dp_error:
|
||||
logger.debug(f"Data provider get_current_price failed: {dp_error}")
|
||||
|
||||
# Get fresh price from data provider - try multiple timeframes
|
||||
for timeframe in ['1m', '5m', '1h']: # Start with 1m for better reliability
|
||||
try:
|
||||
df = self.data_provider.get_historical_data(symbol, timeframe, limit=1, refresh=True)
|
||||
if df is not None and not df.empty:
|
||||
price = float(df['close'].iloc[-1])
|
||||
if price > 0:
|
||||
logger.debug(f"Got current price for {symbol} from {timeframe}: ${price:.2f}")
|
||||
return price
|
||||
except Exception as tf_error:
|
||||
logger.debug(f"Failed to get {timeframe} data for {symbol}: {tf_error}")
|
||||
continue
|
||||
|
||||
# Try external API as last resort
|
||||
try:
|
||||
import requests
|
||||
if symbol == 'ETH/USDT':
|
||||
response = requests.get('https://api.binance.com/api/v3/ticker/price?symbol=ETHUSDT', timeout=2)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
price = float(data['price'])
|
||||
if price > 0:
|
||||
logger.debug(f"Got current price for {symbol} from Binance API: ${price:.2f}")
|
||||
return price
|
||||
elif symbol == 'BTC/USDT':
|
||||
response = requests.get('https://api.binance.com/api/v3/ticker/price?symbol=BTCUSDT', timeout=2)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
price = float(data['price'])
|
||||
if price > 0:
|
||||
logger.debug(f"Got current price for {symbol} from Binance API: ${price:.2f}")
|
||||
return price
|
||||
except Exception as api_error:
|
||||
logger.debug(f"External API failed: {api_error}")
|
||||
|
||||
logger.warning(f"Could not get current price for {symbol} from any source")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
return None
|
||||
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
# Return a reasonable fallback based on current market conditions
|
||||
if symbol == 'ETH/USDT':
|
||||
return 3385.0 # Current market price fallback
|
||||
elif symbol == 'BTC/USDT':
|
||||
return 119500.0 # Current market price fallback
|
||||
|
||||
return None
|
||||
|
||||
async def _generate_fallback_prediction(self, symbol: str, current_price: float) -> Optional[Prediction]:
|
||||
"""Generate a basic momentum-based fallback prediction when no models are available"""
|
||||
@ -1683,6 +1741,9 @@ class TradingOrchestrator:
|
||||
if symbol is None:
|
||||
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
|
||||
|
||||
# Get current price at inference time
|
||||
current_price = self._get_current_price(symbol)
|
||||
|
||||
# Create inference record - store only what's needed for training
|
||||
inference_record = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
@ -1697,7 +1758,8 @@ class TradingOrchestrator:
|
||||
},
|
||||
'metadata': prediction.metadata or {},
|
||||
'training_outcome': None, # Will be set when training occurs
|
||||
'outcome_evaluated': False
|
||||
'outcome_evaluated': False,
|
||||
'inference_price': current_price # Store price at inference time
|
||||
}
|
||||
|
||||
# Store only the last inference per model (for immediate training)
|
||||
@ -2063,28 +2125,70 @@ class TradingOrchestrator:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Calculate price change
|
||||
if predicted_price is not None:
|
||||
actual_price_change_pct = (current_price - predicted_price) / predicted_price * 100
|
||||
price_outcome = f"Predicted: ${predicted_price:.2f} -> Actual: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||
else:
|
||||
# Fall back to historical price comparison
|
||||
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if historical_data is not None and not historical_data.empty:
|
||||
historical_price = historical_data['close'].iloc[-1]
|
||||
actual_price_change_pct = (current_price - historical_price) / historical_price * 100
|
||||
price_outcome = f"Historical: ${historical_price:.2f} -> Actual: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||
else:
|
||||
price_outcome = f"Actual: ${current_price:.2f}"
|
||||
# Get inference price and timestamp from record
|
||||
inference_price = inference_record.get('inference_price')
|
||||
timestamp = inference_record.get('timestamp')
|
||||
|
||||
# Determine if prediction was correct based on action and price movement
|
||||
if isinstance(timestamp, str):
|
||||
timestamp = datetime.fromisoformat(timestamp)
|
||||
|
||||
time_diff_seconds = (datetime.now() - timestamp).total_seconds()
|
||||
actual_price_change_pct = 0.0
|
||||
|
||||
# Use stored inference price for comparison
|
||||
if inference_price is not None:
|
||||
actual_price_change_pct = (current_price - inference_price) / inference_price * 100
|
||||
|
||||
# Use seconds-based comparison for short-lived predictions
|
||||
if time_diff_seconds <= 60: # Within 1 minute
|
||||
price_outcome = f"Inference: ${inference_price:.2f} ({time_diff_seconds:.1f}s ago) -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||
else:
|
||||
# For older predictions, use a more conservative approach
|
||||
price_outcome = f"Inference: ${inference_price:.2f} ({time_diff_seconds:.1f}s ago) -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||
else:
|
||||
# Fall back to historical price comparison if no inference price
|
||||
try:
|
||||
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if historical_data is not None and not historical_data.empty:
|
||||
historical_price = historical_data['close'].iloc[-1]
|
||||
actual_price_change_pct = (current_price - historical_price) / historical_price * 100
|
||||
price_outcome = f"Historical: ${historical_price:.2f} -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||
else:
|
||||
price_outcome = f"Current: ${current_price:.2f} (no historical data)"
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating price change: {e}")
|
||||
price_outcome = f"Current: ${current_price:.2f} (calculation error)"
|
||||
|
||||
# Determine if prediction was correct based on predicted direction and actual price movement
|
||||
was_correct = False
|
||||
if predicted_action == 'BUY' and actual_price_change_pct > 0.1: # Price went up
|
||||
was_correct = True
|
||||
elif predicted_action == 'SELL' and actual_price_change_pct < -0.1: # Price went down
|
||||
was_correct = True
|
||||
elif predicted_action == 'HOLD' and abs(actual_price_change_pct) < 0.5: # Price stayed stable
|
||||
was_correct = True
|
||||
|
||||
# Get predicted direction from the inference record
|
||||
predicted_direction = None
|
||||
if 'price_direction' in prediction and prediction['price_direction']:
|
||||
try:
|
||||
price_direction_data = prediction['price_direction']
|
||||
if isinstance(price_direction_data, dict) and 'direction' in price_direction_data:
|
||||
predicted_direction = price_direction_data['direction']
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting predicted direction: {e}")
|
||||
|
||||
# Evaluate based on predicted direction if available
|
||||
if predicted_direction is not None:
|
||||
# Use the predicted direction (-1 to 1) to determine correctness
|
||||
if predicted_direction > 0.1 and actual_price_change_pct > 0.1: # Predicted UP, price went UP
|
||||
was_correct = True
|
||||
elif predicted_direction < -0.1 and actual_price_change_pct < -0.1: # Predicted DOWN, price went DOWN
|
||||
was_correct = True
|
||||
elif abs(predicted_direction) <= 0.1 and abs(actual_price_change_pct) < 0.5: # Predicted SIDEWAYS, price stayed stable
|
||||
was_correct = True
|
||||
else:
|
||||
# Fallback to action-based evaluation
|
||||
if predicted_action == 'BUY' and actual_price_change_pct > 0.1: # Price went up
|
||||
was_correct = True
|
||||
elif predicted_action == 'SELL' and actual_price_change_pct < -0.1: # Price went down
|
||||
was_correct = True
|
||||
elif predicted_action == 'HOLD' and abs(actual_price_change_pct) < 0.5: # Price stayed stable
|
||||
was_correct = True
|
||||
|
||||
outcome_status = "✅ CORRECT" if was_correct else "❌ INCORRECT"
|
||||
|
||||
@ -2107,38 +2211,32 @@ class TradingOrchestrator:
|
||||
if isinstance(timestamp, str):
|
||||
timestamp = datetime.fromisoformat(timestamp)
|
||||
|
||||
# Calculate price change since prediction
|
||||
# This is a simplified outcome evaluation - you might want to make it more sophisticated
|
||||
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
|
||||
# Get inference price and calculate time difference
|
||||
inference_price = record.get('inference_price')
|
||||
time_diff_seconds = (datetime.now() - timestamp).total_seconds()
|
||||
time_diff_minutes = time_diff_seconds / 60 # minutes
|
||||
|
||||
# Get historical price at prediction time (simplified)
|
||||
# Use stored inference price for comparison
|
||||
symbol = record['symbol']
|
||||
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if historical_data is None or historical_data.empty:
|
||||
return
|
||||
price_change_pct = 0.0
|
||||
|
||||
# Use predicted price if available, otherwise fall back to historical price
|
||||
predicted_price = None
|
||||
if 'price_prediction' in prediction and prediction['price_prediction']:
|
||||
try:
|
||||
# Extract predicted price change from CNN output
|
||||
price_prediction_data = prediction['price_prediction']
|
||||
if isinstance(price_prediction_data, list) and len(price_prediction_data) > 0:
|
||||
predicted_price_change_pct = float(price_prediction_data[0]) * 0.01 # Convert to percentage
|
||||
predicted_price = current_price * (1 + predicted_price_change_pct)
|
||||
logger.debug(f"Using CNN price prediction: {predicted_price_change_pct:.3f}% -> ${predicted_price:.2f}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing price prediction: {e}")
|
||||
|
||||
# Fall back to historical price if no prediction available
|
||||
if predicted_price is None:
|
||||
prediction_price = historical_data['close'].iloc[-1] # Simplified
|
||||
price_change_pct = (current_price - prediction_price) / prediction_price * 100
|
||||
logger.debug(f"Using historical price comparison: ${prediction_price:.2f} -> ${current_price:.2f}")
|
||||
if inference_price is not None:
|
||||
price_change_pct = (current_price - inference_price) / inference_price * 100
|
||||
logger.debug(f"Using stored inference price: ${inference_price:.2f} ({time_diff_seconds:.1f}s ago) -> ${current_price:.2f} ({price_change_pct:+.2f}%)")
|
||||
else:
|
||||
# Use predicted price for reward calculation
|
||||
price_change_pct = (current_price - predicted_price) / predicted_price * 100
|
||||
logger.debug(f"Using predicted price comparison: ${predicted_price:.2f} -> ${current_price:.2f}")
|
||||
# Fall back to historical data if no inference price stored
|
||||
try:
|
||||
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if historical_data is not None and not historical_data.empty:
|
||||
historical_price = historical_data['close'].iloc[-1]
|
||||
price_change_pct = (current_price - historical_price) / historical_price * 100
|
||||
logger.debug(f"Using historical price comparison: ${historical_price:.2f} -> ${current_price:.2f} ({price_change_pct:+.2f}%)")
|
||||
else:
|
||||
logger.warning(f"No historical data available for {symbol}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating price change: {e}")
|
||||
return
|
||||
|
||||
# Enhanced reward system based on prediction confidence and price movement magnitude
|
||||
predicted_action = prediction['action']
|
||||
@ -2149,8 +2247,8 @@ class TradingOrchestrator:
|
||||
predicted_action,
|
||||
prediction_confidence,
|
||||
price_change_pct,
|
||||
time_diff,
|
||||
predicted_price is not None # Add price prediction flag
|
||||
time_diff_minutes,
|
||||
inference_price is not None # Add price prediction flag
|
||||
)
|
||||
|
||||
# Update model performance tracking
|
||||
@ -2160,6 +2258,10 @@ class TradingOrchestrator:
|
||||
'price_predictions': {'total': 0, 'accurate': 0, 'avg_error': 0.0}
|
||||
}
|
||||
|
||||
# Ensure price_predictions key exists
|
||||
if 'price_predictions' not in self.model_performance[model_name]:
|
||||
self.model_performance[model_name]['price_predictions'] = {'total': 0, 'accurate': 0, 'avg_error': 0.0}
|
||||
|
||||
self.model_performance[model_name]['total'] += 1
|
||||
if was_correct:
|
||||
self.model_performance[model_name]['correct'] += 1
|
||||
@ -2170,7 +2272,7 @@ class TradingOrchestrator:
|
||||
)
|
||||
|
||||
# Track price prediction accuracy if available
|
||||
if predicted_price is not None:
|
||||
if inference_price is not None:
|
||||
price_prediction_stats = self.model_performance[model_name]['price_predictions']
|
||||
price_prediction_stats['total'] += 1
|
||||
|
||||
|
@ -5328,8 +5328,11 @@ class CleanTradingDashboard:
|
||||
# Cold start training moved to core.training_integration.TrainingIntegration
|
||||
|
||||
def _clear_session(self):
|
||||
"""Clear session data and persistent files"""
|
||||
"""Clear session data, close all positions, and reset PnL"""
|
||||
try:
|
||||
# Close all held positions first
|
||||
self._close_all_positions()
|
||||
|
||||
# Reset session metrics
|
||||
self.session_pnl = 0.0
|
||||
self.total_fees = 0.0
|
||||
@ -5393,12 +5396,58 @@ class CleanTradingDashboard:
|
||||
|
||||
logger.info("✅ Session data and trade logs cleared successfully")
|
||||
logger.info("📊 Session P&L reset to $0.00")
|
||||
logger.info("📈 Position cleared")
|
||||
logger.info("📈 All positions closed")
|
||||
logger.info("📋 Trade history cleared")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error clearing session: {e}")
|
||||
|
||||
def _close_all_positions(self):
|
||||
"""Close all held positions"""
|
||||
try:
|
||||
# Close positions via trading executor if available
|
||||
if hasattr(self, 'trading_executor') and self.trading_executor:
|
||||
try:
|
||||
# Close ETH/USDT position
|
||||
self.trading_executor.close_position('ETH/USDT')
|
||||
logger.info("🔒 Closed ETH/USDT position")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close ETH/USDT position: {e}")
|
||||
|
||||
try:
|
||||
# Close BTC/USDT position
|
||||
self.trading_executor.close_position('BTC/USDT')
|
||||
logger.info("🔒 Closed BTC/USDT position")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close BTC/USDT position: {e}")
|
||||
|
||||
# Also try to close via orchestrator if available
|
||||
if hasattr(self, 'orchestrator') and self.orchestrator:
|
||||
try:
|
||||
if hasattr(self.orchestrator, '_close_all_positions'):
|
||||
self.orchestrator._close_all_positions()
|
||||
logger.info("🔒 Closed all positions via orchestrator")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close positions via orchestrator: {e}")
|
||||
|
||||
# Reset position tracking
|
||||
self.current_position = None
|
||||
if hasattr(self, 'position_size'):
|
||||
self.position_size = 0.0
|
||||
if hasattr(self, 'position_entry_price'):
|
||||
self.position_entry_price = None
|
||||
if hasattr(self, 'position_pnl'):
|
||||
self.position_pnl = 0.0
|
||||
if hasattr(self, 'unrealized_pnl'):
|
||||
self.unrealized_pnl = 0.0
|
||||
if hasattr(self, 'realized_pnl'):
|
||||
self.realized_pnl = 0.0
|
||||
|
||||
logger.info("✅ All positions closed and PnL reset")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error closing positions: {e}")
|
||||
|
||||
def _clear_trade_logs(self):
|
||||
"""Clear all trade log files"""
|
||||
try:
|
||||
|
Reference in New Issue
Block a user