fixed CNN training
This commit is contained in:
@ -2184,7 +2184,7 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Clean up memory periodically
|
# Clean up memory periodically
|
||||||
if len(self.recent_decisions[symbol]) % 200 == 0: # Reduced from 50 to 200
|
if len(self.recent_decisions[symbol]) % 20 == 0: # Reduced from 50 to 20
|
||||||
self.model_registry.cleanup_all_models()
|
self.model_registry.cleanup_all_models()
|
||||||
|
|
||||||
return decision
|
return decision
|
||||||
@ -2198,55 +2198,108 @@ class TradingOrchestrator:
|
|||||||
):
|
):
|
||||||
"""Add training samples to models based on current predictions and market conditions"""
|
"""Add training samples to models based on current predictions and market conditions"""
|
||||||
try:
|
try:
|
||||||
if not hasattr(self, "cnn_adapter") or not self.cnn_adapter:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get recent price data to evaluate if predictions would be correct
|
# Get recent price data to evaluate if predictions would be correct
|
||||||
recent_prices = self.data_provider.get_recent_prices(symbol, limit=10)
|
# Use available methods from data provider
|
||||||
if not recent_prices or len(recent_prices) < 2:
|
try:
|
||||||
return
|
# Try to get recent prices using get_price_at_index
|
||||||
|
recent_prices = []
|
||||||
|
for i in range(10):
|
||||||
|
price = self.data_provider.get_price_at_index(symbol, i, '1m')
|
||||||
|
if price is not None:
|
||||||
|
recent_prices.append(price)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
# Calculate recent price change
|
if len(recent_prices) < 2:
|
||||||
price_change_pct = (
|
# Fallback: use current price and a small assumed change
|
||||||
(current_price - recent_prices[-2]) / recent_prices[-2] * 100
|
price_change_pct = 0.1 # Assume small positive change
|
||||||
)
|
else:
|
||||||
|
# Calculate recent price change
|
||||||
|
price_change_pct = (
|
||||||
|
(current_price - recent_prices[-2]) / recent_prices[-2] * 100
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not get recent prices for {symbol}: {e}")
|
||||||
|
# Fallback: use current price and a small assumed change
|
||||||
|
price_change_pct = 0.1 # Assume small positive change
|
||||||
|
|
||||||
# Add training samples for CNN predictions
|
# Get current position P&L for sophisticated reward calculation
|
||||||
|
current_position_pnl = self._get_current_position_pnl(symbol)
|
||||||
|
has_position = self._has_open_position(symbol)
|
||||||
|
|
||||||
|
# Add training samples for CNN predictions using sophisticated reward system
|
||||||
for prediction in predictions:
|
for prediction in predictions:
|
||||||
if "cnn" in prediction.model_name.lower():
|
if "cnn" in prediction.model_name.lower():
|
||||||
# Determine reward based on prediction accuracy
|
# Calculate sophisticated reward using the new PnL penalty/reward system
|
||||||
reward = 0.0
|
sophisticated_reward, was_correct = self._calculate_sophisticated_reward(
|
||||||
|
predicted_action=prediction.action,
|
||||||
if prediction.action == "BUY" and price_change_pct > 0.1:
|
prediction_confidence=prediction.confidence,
|
||||||
reward = min(
|
price_change_pct=price_change_pct,
|
||||||
price_change_pct * 0.1, 1.0
|
time_diff_minutes=1.0, # Assume 1 minute for now
|
||||||
) # Positive reward for correct BUY
|
has_price_prediction=False,
|
||||||
elif prediction.action == "SELL" and price_change_pct < -0.1:
|
symbol=symbol,
|
||||||
reward = min(
|
has_position=has_position,
|
||||||
abs(price_change_pct) * 0.1, 1.0
|
current_position_pnl=current_position_pnl
|
||||||
) # Positive reward for correct SELL
|
|
||||||
elif prediction.action == "HOLD" and abs(price_change_pct) < 0.1:
|
|
||||||
reward = 0.1 # Small positive reward for correct HOLD
|
|
||||||
else:
|
|
||||||
reward = -0.05 # Small negative reward for incorrect prediction
|
|
||||||
|
|
||||||
# Add training sample
|
|
||||||
self.cnn_adapter.add_training_sample(
|
|
||||||
symbol, prediction.action, reward
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Added CNN training sample: {prediction.action}, reward={reward:.3f}, price_change={price_change_pct:.2f}%"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trigger training if we have enough samples
|
# Create training record for the new training system
|
||||||
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
|
training_record = {
|
||||||
training_results = self.cnn_adapter.train(epochs=1)
|
"symbol": symbol,
|
||||||
logger.info(
|
"model_name": prediction.model_name,
|
||||||
f"CNN training completed: loss={training_results.get('loss', 0):.4f}, accuracy={training_results.get('accuracy', 0):.4f}"
|
"action": prediction.action,
|
||||||
)
|
"confidence": prediction.confidence,
|
||||||
|
"timestamp": prediction.timestamp,
|
||||||
|
"current_price": current_price,
|
||||||
|
"price_change_pct": price_change_pct,
|
||||||
|
"was_correct": was_correct,
|
||||||
|
"sophisticated_reward": sophisticated_reward,
|
||||||
|
"current_position_pnl": current_position_pnl,
|
||||||
|
"has_position": has_position
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use the new training system instead of old cnn_adapter
|
||||||
|
if hasattr(self, "cnn_model") and self.cnn_model:
|
||||||
|
# Train CNN model directly using the new system
|
||||||
|
training_success = await self._train_cnn_model(
|
||||||
|
model=self.cnn_model,
|
||||||
|
model_name=prediction.model_name,
|
||||||
|
record=training_record,
|
||||||
|
prediction={"action": prediction.action, "confidence": prediction.confidence},
|
||||||
|
reward=sophisticated_reward
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_success:
|
||||||
|
logger.debug(
|
||||||
|
f"CNN training completed: action={prediction.action}, reward={sophisticated_reward:.3f}, "
|
||||||
|
f"price_change={price_change_pct:.2f}%, was_correct={was_correct}, "
|
||||||
|
f"position_pnl={current_position_pnl:.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"CNN training failed for {prediction.model_name}")
|
||||||
|
|
||||||
|
# Also try training through model registry if available
|
||||||
|
elif self.model_registry and prediction.model_name in self.model_registry.models:
|
||||||
|
model = self.model_registry.models[prediction.model_name]
|
||||||
|
training_success = await self._train_cnn_model(
|
||||||
|
model=model,
|
||||||
|
model_name=prediction.model_name,
|
||||||
|
record=training_record,
|
||||||
|
prediction={"action": prediction.action, "confidence": prediction.confidence},
|
||||||
|
reward=sophisticated_reward
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_success:
|
||||||
|
logger.debug(
|
||||||
|
f"CNN training via registry completed: {prediction.model_name}, "
|
||||||
|
f"reward={sophisticated_reward:.3f}, was_correct={was_correct}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"CNN training via registry failed for {prediction.model_name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding training samples from predictions: {e}")
|
logger.error(f"Error adding training samples from predictions: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||||||
"""Get predictions from all registered models with input data storage"""
|
"""Get predictions from all registered models with input data storage"""
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Model Configurations
|
# Model Configurations
|
||||||
# This file contains all model-specific configurations to keep the main config.yaml clean
|
# This file contains all model-specific configurations to keep the main config.yaml clean
|
||||||
|
|
||||||
# Enhanced CNN Configuration ( does not use yml file now)
|
# Enhanced CNN Configuration (cnn model do not use yml config. do not change this)
|
||||||
# cnn:
|
# cnn:
|
||||||
# window_size: 20
|
# window_size: 20
|
||||||
# features: ["open", "high", "low", "close", "volume"]
|
# features: ["open", "high", "low", "close", "volume"]
|
||||||
|
@ -1367,71 +1367,73 @@ class CleanTradingDashboard:
|
|||||||
# Original training metrics callback - temporarily disabled for testing
|
# Original training metrics callback - temporarily disabled for testing
|
||||||
# @self.app.callback(
|
# @self.app.callback(
|
||||||
# Output('training-metrics', 'children'),
|
# Output('training-metrics', 'children'),
|
||||||
# [Input('slow-interval-component', 'n_intervals'),
|
|
||||||
# Input('fast-interval-component', 'n_intervals'), # Add fast interval for testing
|
|
||||||
# Input('refresh-training-metrics-btn', 'n_clicks')] # Add manual refresh button
|
|
||||||
# )
|
|
||||||
# def update_training_metrics(slow_intervals, fast_intervals, n_clicks):
|
|
||||||
# """Update training metrics"""
|
|
||||||
# logger.info(f"update_training_metrics callback triggered with slow_intervals={slow_intervals}, fast_intervals={fast_intervals}, n_clicks={n_clicks}")
|
|
||||||
# try:
|
|
||||||
# # Get toggle states from orchestrator
|
|
||||||
# toggle_states = {}
|
|
||||||
# if self.orchestrator:
|
|
||||||
# # Get all available models dynamically
|
|
||||||
# available_models = self._get_available_models()
|
|
||||||
# logger.info(f"Available models: {list(available_models.keys())}")
|
|
||||||
# for model_name in available_models.keys():
|
|
||||||
# toggle_states[model_name] = self.orchestrator.get_model_toggle_state(model_name)
|
|
||||||
# else:
|
|
||||||
# # Fallback to dashboard dynamic state
|
|
||||||
# toggle_states = {}
|
|
||||||
# for model_name, state in self.model_toggle_states.items():
|
|
||||||
# toggle_states[model_name] = state
|
|
||||||
# # Now using slow-interval-component (10s) - no batching needed
|
|
||||||
#
|
|
||||||
# logger.info(f"Getting training metrics with toggle_states: {toggle_states}")
|
|
||||||
# metrics_data = self._get_training_metrics(toggle_states)
|
|
||||||
# logger.info(f"update_training_metrics callback: got metrics_data type={type(metrics_data)}")
|
|
||||||
# if metrics_data and isinstance(metrics_data, dict):
|
|
||||||
# logger.info(f"Metrics data keys: {list(metrics_data.keys())}")
|
|
||||||
# if 'loaded_models' in metrics_data:
|
|
||||||
# logger.info(f"Loaded models count: {len(metrics_data['loaded_models'])}")
|
|
||||||
# logger.info(f"Loaded model names: {list(metrics_data['loaded_models'].keys())}")
|
|
||||||
# else:
|
|
||||||
# logger.warning("No 'loaded_models' key in metrics_data!")
|
|
||||||
# else:
|
|
||||||
# logger.warning(f"Invalid metrics_data: {metrics_data}")
|
|
||||||
#
|
|
||||||
# logger.info("Formatting training metrics...")
|
|
||||||
# formatted_metrics = self.component_manager.format_training_metrics(metrics_data)
|
|
||||||
# logger.info(f"Formatted metrics type: {type(formatted_metrics)}, length: {len(formatted_metrics) if isinstance(formatted_metrics, list) else 'N/A'}")
|
|
||||||
# return formatted_metrics
|
|
||||||
# except PreventUpdate:
|
|
||||||
# logger.info("PreventUpdate raised in training metrics callback")
|
|
||||||
# raise
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error updating training metrics: {e}")
|
|
||||||
# import traceback
|
|
||||||
# logger.error(f"Traceback: {traceback.format_exc()}")
|
|
||||||
# return [html.P(f"Error: {str(e)}", className="text-danger")]
|
|
||||||
|
|
||||||
# Test callback for training metrics
|
|
||||||
@self.app.callback(
|
@self.app.callback(
|
||||||
Output('training-metrics', 'children'),
|
Output('training-metrics', 'children'),
|
||||||
[Input('refresh-training-metrics-btn', 'n_clicks')],
|
[Input('slow-interval-component', 'n_intervals'),
|
||||||
prevent_initial_call=False
|
Input('fast-interval-component', 'n_intervals'), # Add fast interval for testing
|
||||||
|
Input('refresh-training-metrics-btn', 'n_clicks')] # Add manual refresh button
|
||||||
)
|
)
|
||||||
def test_training_metrics_callback(n_clicks):
|
def update_training_metrics(slow_intervals, fast_intervals, n_clicks):
|
||||||
"""Test callback for training metrics"""
|
"""Update training metrics"""
|
||||||
logger.info(f"test_training_metrics_callback triggered with n_clicks={n_clicks}")
|
logger.info(f"update_training_metrics callback triggered with slow_intervals={slow_intervals}, fast_intervals={fast_intervals}, n_clicks={n_clicks}")
|
||||||
try:
|
try:
|
||||||
# Return a simple test message
|
# Get toggle states from orchestrator
|
||||||
return [html.P("Training metrics test - callback is working!", className="text-success")]
|
toggle_states = {}
|
||||||
|
if self.orchestrator:
|
||||||
|
# Get all available models dynamically
|
||||||
|
available_models = self._get_available_models()
|
||||||
|
logger.info(f"Available models: {list(available_models.keys())}")
|
||||||
|
for model_name in available_models.keys():
|
||||||
|
toggle_states[model_name] = self.orchestrator.get_model_toggle_state(model_name)
|
||||||
|
else:
|
||||||
|
# Fallback to dashboard dynamic state
|
||||||
|
toggle_states = {}
|
||||||
|
for model_name, state in self.model_toggle_states.items():
|
||||||
|
toggle_states[model_name] = state
|
||||||
|
# Now using slow-interval-component (10s) - no batching needed
|
||||||
|
|
||||||
|
logger.info(f"Getting training metrics with toggle_states: {toggle_states}")
|
||||||
|
metrics_data = self._get_training_metrics(toggle_states)
|
||||||
|
logger.info(f"update_training_metrics callback: got metrics_data type={type(metrics_data)}")
|
||||||
|
if metrics_data and isinstance(metrics_data, dict):
|
||||||
|
logger.info(f"Metrics data keys: {list(metrics_data.keys())}")
|
||||||
|
if 'loaded_models' in metrics_data:
|
||||||
|
logger.info(f"Loaded models count: {len(metrics_data['loaded_models'])}")
|
||||||
|
logger.info(f"Loaded model names: {list(metrics_data['loaded_models'].keys())}")
|
||||||
|
else:
|
||||||
|
logger.warning("No 'loaded_models' key in metrics_data!")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Invalid metrics_data: {metrics_data}")
|
||||||
|
|
||||||
|
logger.info("Formatting training metrics...")
|
||||||
|
formatted_metrics = self.component_manager.format_training_metrics(metrics_data)
|
||||||
|
logger.info(f"Formatted metrics type: {type(formatted_metrics)}, length: {len(formatted_metrics) if isinstance(formatted_metrics, list) else 'N/A'}")
|
||||||
|
return formatted_metrics
|
||||||
|
except PreventUpdate:
|
||||||
|
logger.info("PreventUpdate raised in training metrics callback")
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in test callback: {e}")
|
logger.error(f"Error updating training metrics: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
return [html.P(f"Error: {str(e)}", className="text-danger")]
|
return [html.P(f"Error: {str(e)}", className="text-danger")]
|
||||||
|
|
||||||
|
# Test callback for training metrics (commented out - using real callback now)
|
||||||
|
# @self.app.callback(
|
||||||
|
# Output('training-metrics', 'children'),
|
||||||
|
# [Input('refresh-training-metrics-btn', 'n_clicks')],
|
||||||
|
# prevent_initial_call=False
|
||||||
|
# )
|
||||||
|
# def test_training_metrics_callback(n_clicks):
|
||||||
|
# """Test callback for training metrics"""
|
||||||
|
# logger.info(f"test_training_metrics_callback triggered with n_clicks={n_clicks}")
|
||||||
|
# try:
|
||||||
|
# # Return a simple test message
|
||||||
|
# return [html.P("Training metrics test - callback is working!", className="text-success")]
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Error in test callback: {e}")
|
||||||
|
# return [html.P(f"Error: {str(e)}", className="text-danger")]
|
||||||
|
|
||||||
# Manual trading buttons
|
# Manual trading buttons
|
||||||
@self.app.callback(
|
@self.app.callback(
|
||||||
Output('manual-buy-btn', 'children'),
|
Output('manual-buy-btn', 'children'),
|
||||||
|
Reference in New Issue
Block a user