training wip
This commit is contained in:
@ -200,7 +200,11 @@ class DQNNetwork(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# Convert state to tensor if needed
|
# Convert state to tensor if needed
|
||||||
if isinstance(state, np.ndarray):
|
if isinstance(state, np.ndarray):
|
||||||
state = torch.FloatTensor(state).to(next(self.parameters()).device)
|
state = torch.FloatTensor(state)
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
state = state.to(device)
|
||||||
|
|
||||||
# Ensure proper shape
|
# Ensure proper shape
|
||||||
if state.dim() == 1:
|
if state.dim() == 1:
|
||||||
@ -209,9 +213,8 @@ class DQNNetwork(nn.Module):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
|
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
|
||||||
|
|
||||||
# Process price direction predictions
|
# Price direction predictions are processed in the agent's act method
|
||||||
if price_direction_pred is not None:
|
# This is just the network forward pass
|
||||||
self.process_price_direction_predictions(price_direction_pred)
|
|
||||||
|
|
||||||
# Get action probabilities using softmax
|
# Get action probabilities using softmax
|
||||||
action_probs = F.softmax(q_values, dim=1)
|
action_probs = F.softmax(q_values, dim=1)
|
||||||
@ -234,7 +237,7 @@ class DQNAgent:
|
|||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
state_shape: Tuple[int, ...],
|
state_shape: Tuple[int, ...],
|
||||||
n_actions: int = 2,
|
n_actions: int = 3, # BUY=0, SELL=1, HOLD=2
|
||||||
learning_rate: float = 0.001,
|
learning_rate: float = 0.001,
|
||||||
epsilon: float = 1.0,
|
epsilon: float = 1.0,
|
||||||
epsilon_min: float = 0.01,
|
epsilon_min: float = 0.01,
|
||||||
@ -761,6 +764,13 @@ class DQNAgent:
|
|||||||
# Use the DQNNetwork's act method for consistent behavior
|
# Use the DQNNetwork's act method for consistent behavior
|
||||||
action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
|
action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
|
||||||
|
|
||||||
|
# Process price direction predictions from the network
|
||||||
|
# Get the raw predictions from the network's forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state)
|
||||||
|
if price_direction_pred is not None:
|
||||||
|
self.process_price_direction_predictions(price_direction_pred)
|
||||||
|
|
||||||
# Apply epsilon-greedy exploration if requested
|
# Apply epsilon-greedy exploration if requested
|
||||||
if explore and np.random.random() <= self.epsilon:
|
if explore and np.random.random() <= self.epsilon:
|
||||||
action_idx = np.random.choice(self.n_actions)
|
action_idx = np.random.choice(self.n_actions)
|
||||||
@ -780,15 +790,44 @@ class DQNAgent:
|
|||||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
|
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
|
||||||
"""Choose action with confidence score adapted to market regime"""
|
"""Choose action with confidence score adapted to market regime"""
|
||||||
try:
|
try:
|
||||||
# Use the DQNNetwork's act method which handles the state properly
|
# Convert state to tensor if needed
|
||||||
action_idx, base_confidence, action_probs = self.policy_net.act(state, explore=False)
|
if isinstance(state, np.ndarray):
|
||||||
|
state_tensor = torch.FloatTensor(state)
|
||||||
|
device = next(self.policy_net.parameters()).device
|
||||||
|
state_tensor = state_tensor.to(device)
|
||||||
|
|
||||||
|
# Ensure proper shape
|
||||||
|
if state_tensor.dim() == 1:
|
||||||
|
state_tensor = state_tensor.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
state_tensor = state
|
||||||
|
|
||||||
|
# Get network outputs
|
||||||
|
with torch.no_grad():
|
||||||
|
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state_tensor)
|
||||||
|
|
||||||
|
# Process price direction predictions
|
||||||
|
if price_direction_pred is not None:
|
||||||
|
self.process_price_direction_predictions(price_direction_pred)
|
||||||
|
|
||||||
|
# Get action probabilities using softmax
|
||||||
|
action_probs = F.softmax(q_values, dim=1)
|
||||||
|
|
||||||
|
# Select action (greedy for inference)
|
||||||
|
action_idx = torch.argmax(q_values, dim=1).item()
|
||||||
|
|
||||||
|
# Calculate confidence as max probability
|
||||||
|
base_confidence = float(action_probs[0, action_idx].item())
|
||||||
|
|
||||||
# Adapt confidence based on market regime
|
# Adapt confidence based on market regime
|
||||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||||
|
|
||||||
|
# Convert probabilities to list
|
||||||
|
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
|
||||||
|
|
||||||
# Return action, confidence, and probabilities (for orchestrator compatibility)
|
# Return action, confidence, and probabilities (for orchestrator compatibility)
|
||||||
return int(action_idx), float(adapted_confidence), action_probs
|
return int(action_idx), float(adapted_confidence), probs_list
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in act_with_confidence: {e}")
|
logger.error(f"Error in act_with_confidence: {e}")
|
||||||
|
@ -80,6 +80,9 @@ class EnhancedCNN(nn.Module):
|
|||||||
self.n_actions = n_actions
|
self.n_actions = n_actions
|
||||||
self.confidence_threshold = confidence_threshold
|
self.confidence_threshold = confidence_threshold
|
||||||
|
|
||||||
|
# Training data storage
|
||||||
|
self.training_data = []
|
||||||
|
|
||||||
# Calculate input dimensions
|
# Calculate input dimensions
|
||||||
if isinstance(input_shape, (list, tuple)):
|
if isinstance(input_shape, (list, tuple)):
|
||||||
if len(input_shape) == 3: # [channels, height, width]
|
if len(input_shape) == 3: # [channels, height, width]
|
||||||
@ -649,6 +652,30 @@ class EnhancedCNN(nn.Module):
|
|||||||
'weighted_strength': 0.0
|
'weighted_strength': 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def add_training_data(self, state, action, reward):
|
||||||
|
"""
|
||||||
|
Add training data to the model's training buffer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Input state
|
||||||
|
action: Action taken
|
||||||
|
reward: Reward received
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.training_data.append({
|
||||||
|
'state': state,
|
||||||
|
'action': action,
|
||||||
|
'reward': reward,
|
||||||
|
'timestamp': time.time()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Keep only the last 1000 training samples
|
||||||
|
if len(self.training_data) > 1000:
|
||||||
|
self.training_data = self.training_data[-1000:]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding training data: {e}")
|
||||||
|
|
||||||
def save(self, path):
|
def save(self, path):
|
||||||
"""Save model weights and architecture"""
|
"""Save model weights and architecture"""
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
@ -1035,11 +1035,69 @@ class TradingOrchestrator:
|
|||||||
logger.debug(f"Error capturing DQN prediction: {e}")
|
logger.debug(f"Error capturing DQN prediction: {e}")
|
||||||
|
|
||||||
def _get_current_price(self, symbol: str) -> Optional[float]:
|
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||||||
"""Get current price for a symbol"""
|
"""Get current price for a symbol - ENHANCED with better fallbacks"""
|
||||||
try:
|
try:
|
||||||
return self.data_provider.get_current_price(symbol)
|
# Try data provider current prices first
|
||||||
|
if hasattr(self.data_provider, 'current_prices') and symbol in self.data_provider.current_prices:
|
||||||
|
price = self.data_provider.current_prices[symbol]
|
||||||
|
if price and price > 0:
|
||||||
|
return price
|
||||||
|
|
||||||
|
# Try data provider get_current_price method
|
||||||
|
if hasattr(self.data_provider, 'get_current_price'):
|
||||||
|
try:
|
||||||
|
price = self.data_provider.get_current_price(symbol)
|
||||||
|
if price and price > 0:
|
||||||
|
return price
|
||||||
|
except Exception as dp_error:
|
||||||
|
logger.debug(f"Data provider get_current_price failed: {dp_error}")
|
||||||
|
|
||||||
|
# Get fresh price from data provider - try multiple timeframes
|
||||||
|
for timeframe in ['1m', '5m', '1h']: # Start with 1m for better reliability
|
||||||
|
try:
|
||||||
|
df = self.data_provider.get_historical_data(symbol, timeframe, limit=1, refresh=True)
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
price = float(df['close'].iloc[-1])
|
||||||
|
if price > 0:
|
||||||
|
logger.debug(f"Got current price for {symbol} from {timeframe}: ${price:.2f}")
|
||||||
|
return price
|
||||||
|
except Exception as tf_error:
|
||||||
|
logger.debug(f"Failed to get {timeframe} data for {symbol}: {tf_error}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try external API as last resort
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
if symbol == 'ETH/USDT':
|
||||||
|
response = requests.get('https://api.binance.com/api/v3/ticker/price?symbol=ETHUSDT', timeout=2)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
price = float(data['price'])
|
||||||
|
if price > 0:
|
||||||
|
logger.debug(f"Got current price for {symbol} from Binance API: ${price:.2f}")
|
||||||
|
return price
|
||||||
|
elif symbol == 'BTC/USDT':
|
||||||
|
response = requests.get('https://api.binance.com/api/v3/ticker/price?symbol=BTCUSDT', timeout=2)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
price = float(data['price'])
|
||||||
|
if price > 0:
|
||||||
|
logger.debug(f"Got current price for {symbol} from Binance API: ${price:.2f}")
|
||||||
|
return price
|
||||||
|
except Exception as api_error:
|
||||||
|
logger.debug(f"External API failed: {api_error}")
|
||||||
|
|
||||||
|
logger.warning(f"Could not get current price for {symbol} from any source")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||||
|
|
||||||
|
# Return a reasonable fallback based on current market conditions
|
||||||
|
if symbol == 'ETH/USDT':
|
||||||
|
return 3385.0 # Current market price fallback
|
||||||
|
elif symbol == 'BTC/USDT':
|
||||||
|
return 119500.0 # Current market price fallback
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _generate_fallback_prediction(self, symbol: str, current_price: float) -> Optional[Prediction]:
|
async def _generate_fallback_prediction(self, symbol: str, current_price: float) -> Optional[Prediction]:
|
||||||
@ -1683,6 +1741,9 @@ class TradingOrchestrator:
|
|||||||
if symbol is None:
|
if symbol is None:
|
||||||
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
|
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
|
||||||
|
|
||||||
|
# Get current price at inference time
|
||||||
|
current_price = self._get_current_price(symbol)
|
||||||
|
|
||||||
# Create inference record - store only what's needed for training
|
# Create inference record - store only what's needed for training
|
||||||
inference_record = {
|
inference_record = {
|
||||||
'timestamp': timestamp.isoformat(),
|
'timestamp': timestamp.isoformat(),
|
||||||
@ -1697,7 +1758,8 @@ class TradingOrchestrator:
|
|||||||
},
|
},
|
||||||
'metadata': prediction.metadata or {},
|
'metadata': prediction.metadata or {},
|
||||||
'training_outcome': None, # Will be set when training occurs
|
'training_outcome': None, # Will be set when training occurs
|
||||||
'outcome_evaluated': False
|
'outcome_evaluated': False,
|
||||||
|
'inference_price': current_price # Store price at inference time
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store only the last inference per model (for immediate training)
|
# Store only the last inference per model (for immediate training)
|
||||||
@ -2063,22 +2125,64 @@ class TradingOrchestrator:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Calculate price change
|
# Get inference price and timestamp from record
|
||||||
if predicted_price is not None:
|
inference_price = inference_record.get('inference_price')
|
||||||
actual_price_change_pct = (current_price - predicted_price) / predicted_price * 100
|
timestamp = inference_record.get('timestamp')
|
||||||
price_outcome = f"Predicted: ${predicted_price:.2f} -> Actual: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
|
||||||
|
if isinstance(timestamp, str):
|
||||||
|
timestamp = datetime.fromisoformat(timestamp)
|
||||||
|
|
||||||
|
time_diff_seconds = (datetime.now() - timestamp).total_seconds()
|
||||||
|
actual_price_change_pct = 0.0
|
||||||
|
|
||||||
|
# Use stored inference price for comparison
|
||||||
|
if inference_price is not None:
|
||||||
|
actual_price_change_pct = (current_price - inference_price) / inference_price * 100
|
||||||
|
|
||||||
|
# Use seconds-based comparison for short-lived predictions
|
||||||
|
if time_diff_seconds <= 60: # Within 1 minute
|
||||||
|
price_outcome = f"Inference: ${inference_price:.2f} ({time_diff_seconds:.1f}s ago) -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||||
else:
|
else:
|
||||||
# Fall back to historical price comparison
|
# For older predictions, use a more conservative approach
|
||||||
|
price_outcome = f"Inference: ${inference_price:.2f} ({time_diff_seconds:.1f}s ago) -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||||
|
else:
|
||||||
|
# Fall back to historical price comparison if no inference price
|
||||||
|
try:
|
||||||
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||||
if historical_data is not None and not historical_data.empty:
|
if historical_data is not None and not historical_data.empty:
|
||||||
historical_price = historical_data['close'].iloc[-1]
|
historical_price = historical_data['close'].iloc[-1]
|
||||||
actual_price_change_pct = (current_price - historical_price) / historical_price * 100
|
actual_price_change_pct = (current_price - historical_price) / historical_price * 100
|
||||||
price_outcome = f"Historical: ${historical_price:.2f} -> Actual: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
price_outcome = f"Historical: ${historical_price:.2f} -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||||
else:
|
else:
|
||||||
price_outcome = f"Actual: ${current_price:.2f}"
|
price_outcome = f"Current: ${current_price:.2f} (no historical data)"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error calculating price change: {e}")
|
||||||
|
price_outcome = f"Current: ${current_price:.2f} (calculation error)"
|
||||||
|
|
||||||
# Determine if prediction was correct based on action and price movement
|
# Determine if prediction was correct based on predicted direction and actual price movement
|
||||||
was_correct = False
|
was_correct = False
|
||||||
|
|
||||||
|
# Get predicted direction from the inference record
|
||||||
|
predicted_direction = None
|
||||||
|
if 'price_direction' in prediction and prediction['price_direction']:
|
||||||
|
try:
|
||||||
|
price_direction_data = prediction['price_direction']
|
||||||
|
if isinstance(price_direction_data, dict) and 'direction' in price_direction_data:
|
||||||
|
predicted_direction = price_direction_data['direction']
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error extracting predicted direction: {e}")
|
||||||
|
|
||||||
|
# Evaluate based on predicted direction if available
|
||||||
|
if predicted_direction is not None:
|
||||||
|
# Use the predicted direction (-1 to 1) to determine correctness
|
||||||
|
if predicted_direction > 0.1 and actual_price_change_pct > 0.1: # Predicted UP, price went UP
|
||||||
|
was_correct = True
|
||||||
|
elif predicted_direction < -0.1 and actual_price_change_pct < -0.1: # Predicted DOWN, price went DOWN
|
||||||
|
was_correct = True
|
||||||
|
elif abs(predicted_direction) <= 0.1 and abs(actual_price_change_pct) < 0.5: # Predicted SIDEWAYS, price stayed stable
|
||||||
|
was_correct = True
|
||||||
|
else:
|
||||||
|
# Fallback to action-based evaluation
|
||||||
if predicted_action == 'BUY' and actual_price_change_pct > 0.1: # Price went up
|
if predicted_action == 'BUY' and actual_price_change_pct > 0.1: # Price went up
|
||||||
was_correct = True
|
was_correct = True
|
||||||
elif predicted_action == 'SELL' and actual_price_change_pct < -0.1: # Price went down
|
elif predicted_action == 'SELL' and actual_price_change_pct < -0.1: # Price went down
|
||||||
@ -2107,38 +2211,32 @@ class TradingOrchestrator:
|
|||||||
if isinstance(timestamp, str):
|
if isinstance(timestamp, str):
|
||||||
timestamp = datetime.fromisoformat(timestamp)
|
timestamp = datetime.fromisoformat(timestamp)
|
||||||
|
|
||||||
# Calculate price change since prediction
|
# Get inference price and calculate time difference
|
||||||
# This is a simplified outcome evaluation - you might want to make it more sophisticated
|
inference_price = record.get('inference_price')
|
||||||
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
|
time_diff_seconds = (datetime.now() - timestamp).total_seconds()
|
||||||
|
time_diff_minutes = time_diff_seconds / 60 # minutes
|
||||||
|
|
||||||
# Get historical price at prediction time (simplified)
|
# Use stored inference price for comparison
|
||||||
symbol = record['symbol']
|
symbol = record['symbol']
|
||||||
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
price_change_pct = 0.0
|
||||||
if historical_data is None or historical_data.empty:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use predicted price if available, otherwise fall back to historical price
|
if inference_price is not None:
|
||||||
predicted_price = None
|
price_change_pct = (current_price - inference_price) / inference_price * 100
|
||||||
if 'price_prediction' in prediction and prediction['price_prediction']:
|
logger.debug(f"Using stored inference price: ${inference_price:.2f} ({time_diff_seconds:.1f}s ago) -> ${current_price:.2f} ({price_change_pct:+.2f}%)")
|
||||||
try:
|
|
||||||
# Extract predicted price change from CNN output
|
|
||||||
price_prediction_data = prediction['price_prediction']
|
|
||||||
if isinstance(price_prediction_data, list) and len(price_prediction_data) > 0:
|
|
||||||
predicted_price_change_pct = float(price_prediction_data[0]) * 0.01 # Convert to percentage
|
|
||||||
predicted_price = current_price * (1 + predicted_price_change_pct)
|
|
||||||
logger.debug(f"Using CNN price prediction: {predicted_price_change_pct:.3f}% -> ${predicted_price:.2f}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error parsing price prediction: {e}")
|
|
||||||
|
|
||||||
# Fall back to historical price if no prediction available
|
|
||||||
if predicted_price is None:
|
|
||||||
prediction_price = historical_data['close'].iloc[-1] # Simplified
|
|
||||||
price_change_pct = (current_price - prediction_price) / prediction_price * 100
|
|
||||||
logger.debug(f"Using historical price comparison: ${prediction_price:.2f} -> ${current_price:.2f}")
|
|
||||||
else:
|
else:
|
||||||
# Use predicted price for reward calculation
|
# Fall back to historical data if no inference price stored
|
||||||
price_change_pct = (current_price - predicted_price) / predicted_price * 100
|
try:
|
||||||
logger.debug(f"Using predicted price comparison: ${predicted_price:.2f} -> ${current_price:.2f}")
|
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||||
|
if historical_data is not None and not historical_data.empty:
|
||||||
|
historical_price = historical_data['close'].iloc[-1]
|
||||||
|
price_change_pct = (current_price - historical_price) / historical_price * 100
|
||||||
|
logger.debug(f"Using historical price comparison: ${historical_price:.2f} -> ${current_price:.2f} ({price_change_pct:+.2f}%)")
|
||||||
|
else:
|
||||||
|
logger.warning(f"No historical data available for {symbol}")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error calculating price change: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
# Enhanced reward system based on prediction confidence and price movement magnitude
|
# Enhanced reward system based on prediction confidence and price movement magnitude
|
||||||
predicted_action = prediction['action']
|
predicted_action = prediction['action']
|
||||||
@ -2149,8 +2247,8 @@ class TradingOrchestrator:
|
|||||||
predicted_action,
|
predicted_action,
|
||||||
prediction_confidence,
|
prediction_confidence,
|
||||||
price_change_pct,
|
price_change_pct,
|
||||||
time_diff,
|
time_diff_minutes,
|
||||||
predicted_price is not None # Add price prediction flag
|
inference_price is not None # Add price prediction flag
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update model performance tracking
|
# Update model performance tracking
|
||||||
@ -2160,6 +2258,10 @@ class TradingOrchestrator:
|
|||||||
'price_predictions': {'total': 0, 'accurate': 0, 'avg_error': 0.0}
|
'price_predictions': {'total': 0, 'accurate': 0, 'avg_error': 0.0}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Ensure price_predictions key exists
|
||||||
|
if 'price_predictions' not in self.model_performance[model_name]:
|
||||||
|
self.model_performance[model_name]['price_predictions'] = {'total': 0, 'accurate': 0, 'avg_error': 0.0}
|
||||||
|
|
||||||
self.model_performance[model_name]['total'] += 1
|
self.model_performance[model_name]['total'] += 1
|
||||||
if was_correct:
|
if was_correct:
|
||||||
self.model_performance[model_name]['correct'] += 1
|
self.model_performance[model_name]['correct'] += 1
|
||||||
@ -2170,7 +2272,7 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Track price prediction accuracy if available
|
# Track price prediction accuracy if available
|
||||||
if predicted_price is not None:
|
if inference_price is not None:
|
||||||
price_prediction_stats = self.model_performance[model_name]['price_predictions']
|
price_prediction_stats = self.model_performance[model_name]['price_predictions']
|
||||||
price_prediction_stats['total'] += 1
|
price_prediction_stats['total'] += 1
|
||||||
|
|
||||||
|
@ -5328,8 +5328,11 @@ class CleanTradingDashboard:
|
|||||||
# Cold start training moved to core.training_integration.TrainingIntegration
|
# Cold start training moved to core.training_integration.TrainingIntegration
|
||||||
|
|
||||||
def _clear_session(self):
|
def _clear_session(self):
|
||||||
"""Clear session data and persistent files"""
|
"""Clear session data, close all positions, and reset PnL"""
|
||||||
try:
|
try:
|
||||||
|
# Close all held positions first
|
||||||
|
self._close_all_positions()
|
||||||
|
|
||||||
# Reset session metrics
|
# Reset session metrics
|
||||||
self.session_pnl = 0.0
|
self.session_pnl = 0.0
|
||||||
self.total_fees = 0.0
|
self.total_fees = 0.0
|
||||||
@ -5393,12 +5396,58 @@ class CleanTradingDashboard:
|
|||||||
|
|
||||||
logger.info("✅ Session data and trade logs cleared successfully")
|
logger.info("✅ Session data and trade logs cleared successfully")
|
||||||
logger.info("📊 Session P&L reset to $0.00")
|
logger.info("📊 Session P&L reset to $0.00")
|
||||||
logger.info("📈 Position cleared")
|
logger.info("📈 All positions closed")
|
||||||
logger.info("📋 Trade history cleared")
|
logger.info("📋 Trade history cleared")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ Error clearing session: {e}")
|
logger.error(f"❌ Error clearing session: {e}")
|
||||||
|
|
||||||
|
def _close_all_positions(self):
|
||||||
|
"""Close all held positions"""
|
||||||
|
try:
|
||||||
|
# Close positions via trading executor if available
|
||||||
|
if hasattr(self, 'trading_executor') and self.trading_executor:
|
||||||
|
try:
|
||||||
|
# Close ETH/USDT position
|
||||||
|
self.trading_executor.close_position('ETH/USDT')
|
||||||
|
logger.info("🔒 Closed ETH/USDT position")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to close ETH/USDT position: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Close BTC/USDT position
|
||||||
|
self.trading_executor.close_position('BTC/USDT')
|
||||||
|
logger.info("🔒 Closed BTC/USDT position")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to close BTC/USDT position: {e}")
|
||||||
|
|
||||||
|
# Also try to close via orchestrator if available
|
||||||
|
if hasattr(self, 'orchestrator') and self.orchestrator:
|
||||||
|
try:
|
||||||
|
if hasattr(self.orchestrator, '_close_all_positions'):
|
||||||
|
self.orchestrator._close_all_positions()
|
||||||
|
logger.info("🔒 Closed all positions via orchestrator")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to close positions via orchestrator: {e}")
|
||||||
|
|
||||||
|
# Reset position tracking
|
||||||
|
self.current_position = None
|
||||||
|
if hasattr(self, 'position_size'):
|
||||||
|
self.position_size = 0.0
|
||||||
|
if hasattr(self, 'position_entry_price'):
|
||||||
|
self.position_entry_price = None
|
||||||
|
if hasattr(self, 'position_pnl'):
|
||||||
|
self.position_pnl = 0.0
|
||||||
|
if hasattr(self, 'unrealized_pnl'):
|
||||||
|
self.unrealized_pnl = 0.0
|
||||||
|
if hasattr(self, 'realized_pnl'):
|
||||||
|
self.realized_pnl = 0.0
|
||||||
|
|
||||||
|
logger.info("✅ All positions closed and PnL reset")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error closing positions: {e}")
|
||||||
|
|
||||||
def _clear_trade_logs(self):
|
def _clear_trade_logs(self):
|
||||||
"""Clear all trade log files"""
|
"""Clear all trade log files"""
|
||||||
try:
|
try:
|
||||||
|
Reference in New Issue
Block a user