wip train
This commit is contained in:
@ -147,6 +147,38 @@ class ModelStatistics:
|
||||
self.best_loss = min(self.losses) if self.best_loss is None else min(self.best_loss, loss)
|
||||
self.worst_loss = max(self.losses) if self.worst_loss is None else max(self.worst_loss, loss)
|
||||
|
||||
def update_training_stats(self, loss: Optional[float] = None, training_duration_ms: Optional[float] = None):
|
||||
"""Update training statistics"""
|
||||
current_time = datetime.now()
|
||||
|
||||
# Update training timing
|
||||
self.last_training_time = current_time
|
||||
self.total_trainings += 1
|
||||
self.training_times.append(current_time)
|
||||
|
||||
# Update training duration
|
||||
if training_duration_ms is not None:
|
||||
self.training_durations_ms.append(training_duration_ms)
|
||||
if self.training_durations_ms:
|
||||
self.average_training_time_ms = sum(self.training_durations_ms) / len(self.training_durations_ms)
|
||||
|
||||
# Calculate training rates
|
||||
if len(self.training_times) > 1:
|
||||
time_window = (self.training_times[-1] - self.training_times[0]).total_seconds()
|
||||
if time_window > 0:
|
||||
self.training_rate_per_second = len(self.training_times) / time_window
|
||||
self.training_rate_per_minute = self.training_rate_per_second * 60
|
||||
|
||||
# Update loss stats
|
||||
if loss is not None:
|
||||
self.current_loss = loss
|
||||
self.losses.append(loss)
|
||||
|
||||
if self.losses:
|
||||
self.average_loss = sum(self.losses) / len(self.losses)
|
||||
self.best_loss = min(self.losses) if self.best_loss is None else min(self.best_loss, loss)
|
||||
self.worst_loss = max(self.losses) if self.worst_loss is None else max(self.worst_loss, loss)
|
||||
|
||||
@dataclass
|
||||
class TradingDecision:
|
||||
"""Final trading decision from the orchestrator"""
|
||||
@ -1374,40 +1406,52 @@ class TradingOrchestrator:
|
||||
if isinstance(model, CNNModelInterface):
|
||||
# Get CNN predictions using the pre-built base data
|
||||
cnn_predictions = await self._get_cnn_predictions(model, symbol, base_data)
|
||||
inference_duration_ms = (time.time() - inference_start_time) * 1000
|
||||
predictions.extend(cnn_predictions)
|
||||
# Update statistics for CNN predictions
|
||||
if cnn_predictions:
|
||||
for cnn_pred in cnn_predictions:
|
||||
self._update_model_statistics(model_name, cnn_pred)
|
||||
self._update_model_statistics(model_name, cnn_pred, inference_duration_ms=inference_duration_ms)
|
||||
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time, symbol)
|
||||
else:
|
||||
# Still update statistics even if no predictions (for timing)
|
||||
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
|
||||
|
||||
elif isinstance(model, RLAgentInterface):
|
||||
# Get RL prediction using the pre-built base data
|
||||
rl_prediction = await self._get_rl_prediction(model, symbol, base_data)
|
||||
inference_duration_ms = (time.time() - inference_start_time) * 1000
|
||||
if rl_prediction:
|
||||
predictions.append(rl_prediction)
|
||||
prediction = rl_prediction
|
||||
# Update statistics for RL prediction
|
||||
self._update_model_statistics(model_name, prediction)
|
||||
self._update_model_statistics(model_name, prediction, inference_duration_ms=inference_duration_ms)
|
||||
# Store input data for RL
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
else:
|
||||
# Still update statistics even if no prediction (for timing)
|
||||
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
|
||||
|
||||
else:
|
||||
# Generic model interface using the pre-built base data
|
||||
generic_prediction = await self._get_generic_prediction(model, symbol, base_data)
|
||||
inference_duration_ms = (time.time() - inference_start_time) * 1000
|
||||
if generic_prediction:
|
||||
predictions.append(generic_prediction)
|
||||
prediction = generic_prediction
|
||||
# Update statistics for generic prediction
|
||||
self._update_model_statistics(model_name, prediction)
|
||||
self._update_model_statistics(model_name, prediction, inference_duration_ms=inference_duration_ms)
|
||||
# Store input data for generic model
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
else:
|
||||
# Still update statistics even if no prediction (for timing)
|
||||
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
|
||||
|
||||
except Exception as e:
|
||||
inference_duration_ms = (time.time() - inference_start_time) * 1000
|
||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||
# Still update statistics for failed inference
|
||||
if model_name in self.model_statistics:
|
||||
self.model_statistics[model_name].update_inference_stats()
|
||||
# Still update statistics for failed inference (for timing)
|
||||
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
|
||||
continue
|
||||
|
||||
|
||||
@ -1417,25 +1461,48 @@ class TradingOrchestrator:
|
||||
|
||||
return predictions
|
||||
|
||||
def _update_model_statistics(self, model_name: str, prediction: Optional[Prediction] = None, loss: Optional[float] = None):
|
||||
def _update_model_statistics(self, model_name: str, prediction: Optional[Prediction] = None, loss: Optional[float] = None,
|
||||
inference_duration_ms: Optional[float] = None):
|
||||
"""Update statistics for a specific model"""
|
||||
try:
|
||||
if model_name not in self.model_statistics:
|
||||
self.model_statistics[model_name] = ModelStatistics(model_name=model_name)
|
||||
|
||||
# Update the statistics
|
||||
self.model_statistics[model_name].update_inference_stats(prediction, loss)
|
||||
self.model_statistics[model_name].update_inference_stats(prediction, loss, inference_duration_ms)
|
||||
|
||||
# Log statistics periodically (every 10 inferences)
|
||||
stats = self.model_statistics[model_name]
|
||||
if stats.total_inferences % 10 == 0:
|
||||
logger.debug(f"Model {model_name} stats: {stats.total_inferences} inferences, "
|
||||
f"{stats.inference_rate_per_minute:.1f}/min, "
|
||||
f"avg: {stats.average_inference_time_ms:.1f}ms, "
|
||||
f"last: {stats.last_prediction} ({stats.last_confidence:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating statistics for {model_name}: {e}")
|
||||
|
||||
def _update_model_training_statistics(self, model_name: str, loss: Optional[float] = None,
|
||||
training_duration_ms: Optional[float] = None):
|
||||
"""Update training statistics for a specific model"""
|
||||
try:
|
||||
if model_name not in self.model_statistics:
|
||||
self.model_statistics[model_name] = ModelStatistics(model_name=model_name)
|
||||
|
||||
# Update the training statistics
|
||||
self.model_statistics[model_name].update_training_stats(loss, training_duration_ms)
|
||||
|
||||
# Log training statistics periodically (every 5 trainings)
|
||||
stats = self.model_statistics[model_name]
|
||||
if stats.total_trainings % 5 == 0:
|
||||
logger.debug(f"Model {model_name} training stats: {stats.total_trainings} trainings, "
|
||||
f"{stats.training_rate_per_minute:.1f}/min, "
|
||||
f"avg: {stats.average_training_time_ms:.1f}ms, "
|
||||
f"loss: {stats.current_loss:.4f}" if stats.current_loss else "loss: N/A")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating training statistics for {model_name}: {e}")
|
||||
|
||||
def get_model_statistics(self, model_name: Optional[str] = None) -> Union[Dict[str, ModelStatistics], ModelStatistics, None]:
|
||||
"""Get statistics for a specific model or all models"""
|
||||
try:
|
||||
@ -1454,9 +1521,15 @@ class TradingOrchestrator:
|
||||
for model_name, stats in self.model_statistics.items():
|
||||
summary[model_name] = {
|
||||
'last_inference_time': stats.last_inference_time.isoformat() if stats.last_inference_time else None,
|
||||
'last_training_time': stats.last_training_time.isoformat() if stats.last_training_time else None,
|
||||
'total_inferences': stats.total_inferences,
|
||||
'total_trainings': stats.total_trainings,
|
||||
'inference_rate_per_minute': round(stats.inference_rate_per_minute, 2),
|
||||
'inference_rate_per_second': round(stats.inference_rate_per_second, 4),
|
||||
'training_rate_per_minute': round(stats.training_rate_per_minute, 2),
|
||||
'training_rate_per_second': round(stats.training_rate_per_second, 4),
|
||||
'average_inference_time_ms': round(stats.average_inference_time_ms, 2),
|
||||
'average_training_time_ms': round(stats.average_training_time_ms, 2),
|
||||
'current_loss': round(stats.current_loss, 6) if stats.current_loss is not None else None,
|
||||
'average_loss': round(stats.average_loss, 6) if stats.average_loss is not None else None,
|
||||
'best_loss': round(stats.best_loss, 6) if stats.best_loss is not None else None,
|
||||
@ -1483,18 +1556,26 @@ class TradingOrchestrator:
|
||||
for model_name, stats in self.model_statistics.items():
|
||||
if detailed:
|
||||
logger.info(f"{model_name}:")
|
||||
logger.info(f" Total inferences: {stats.total_inferences}")
|
||||
logger.info(f" Total inferences: {stats.total_inferences} (avg: {stats.average_inference_time_ms:.1f}ms)")
|
||||
logger.info(f" Total trainings: {stats.total_trainings} (avg: {stats.average_training_time_ms:.1f}ms)")
|
||||
logger.info(f" Inference rate: {stats.inference_rate_per_minute:.1f}/min ({stats.inference_rate_per_second:.3f}/sec)")
|
||||
logger.info(f" Training rate: {stats.training_rate_per_minute:.1f}/min ({stats.training_rate_per_second:.3f}/sec)")
|
||||
logger.info(f" Last inference: {stats.last_inference_time}")
|
||||
logger.info(f" Last training: {stats.last_training_time}")
|
||||
logger.info(f" Current loss: {stats.current_loss:.6f}" if stats.current_loss else " Current loss: N/A")
|
||||
logger.info(f" Average loss: {stats.average_loss:.6f}" if stats.average_loss else " Average loss: N/A")
|
||||
logger.info(f" Best loss: {stats.best_loss:.6f}" if stats.best_loss else " Best loss: N/A")
|
||||
logger.info(f" Last prediction: {stats.last_prediction} ({stats.last_confidence:.3f})" if stats.last_prediction else " Last prediction: N/A")
|
||||
else:
|
||||
rate_str = f"{stats.inference_rate_per_minute:.1f}/min"
|
||||
inf_rate_str = f"{stats.inference_rate_per_minute:.1f}/min"
|
||||
train_rate_str = f"{stats.training_rate_per_minute:.1f}/min" if stats.total_trainings > 0 else "0/min"
|
||||
inf_time_str = f"{stats.average_inference_time_ms:.1f}ms" if stats.average_inference_time_ms > 0 else "N/A"
|
||||
train_time_str = f"{stats.average_training_time_ms:.1f}ms" if stats.average_training_time_ms > 0 else "N/A"
|
||||
loss_str = f"{stats.current_loss:.4f}" if stats.current_loss else "N/A"
|
||||
pred_str = f"{stats.last_prediction}({stats.last_confidence:.2f})" if stats.last_prediction else "N/A"
|
||||
logger.info(f"{model_name}: {stats.total_inferences} inferences, {rate_str}, loss={loss_str}, last={pred_str}")
|
||||
logger.info(f"{model_name}: Inf: {stats.total_inferences}@{inf_time_str} ({inf_rate_str}) | "
|
||||
f"Train: {stats.total_trainings}@{train_time_str} ({train_rate_str}) | "
|
||||
f"Loss: {loss_str} | Last: {pred_str}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging model statistics: {e}")
|
||||
@ -2143,11 +2224,18 @@ class TradingOrchestrator:
|
||||
batch_size = getattr(model, 'batch_size', 32)
|
||||
if memory_size >= batch_size:
|
||||
logger.debug(f"Training {model_name} with {memory_size} experiences")
|
||||
training_start_time = time.time()
|
||||
training_loss = model.replay()
|
||||
training_duration_ms = (time.time() - training_start_time) * 1000
|
||||
|
||||
if training_loss is not None and training_loss > 0:
|
||||
self.update_model_loss(model_name, training_loss)
|
||||
logger.debug(f"RL training completed for {model_name}: loss={training_loss:.4f}")
|
||||
self._update_model_training_statistics(model_name, training_loss, training_duration_ms)
|
||||
logger.debug(f"RL training completed for {model_name}: loss={training_loss:.4f}, time={training_duration_ms:.1f}ms")
|
||||
return True
|
||||
else:
|
||||
# Still update training statistics even if no loss returned
|
||||
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
|
||||
else:
|
||||
logger.debug(f"Not enough experiences for {model_name}: {memory_size}/{batch_size}")
|
||||
return True # Experience added successfully, training will happen later
|
||||
@ -2242,12 +2330,19 @@ class TradingOrchestrator:
|
||||
if hasattr(self.cnn_adapter, 'training_data') and hasattr(self.cnn_adapter, 'batch_size'):
|
||||
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
|
||||
logger.debug(f"Training CNN with {len(self.cnn_adapter.training_data)} samples")
|
||||
training_start_time = time.time()
|
||||
training_results = self.cnn_adapter.train(epochs=1)
|
||||
training_duration_ms = (time.time() - training_start_time) * 1000
|
||||
|
||||
if training_results and 'loss' in training_results:
|
||||
current_loss = training_results['loss']
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
logger.debug(f"CNN training completed: loss={current_loss:.4f}")
|
||||
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
|
||||
logger.debug(f"CNN training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms")
|
||||
return True
|
||||
else:
|
||||
# Still update training statistics even if no loss returned
|
||||
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
|
||||
else:
|
||||
logger.debug(f"Not enough samples for CNN training: {len(self.cnn_adapter.training_data)}/{self.cnn_adapter.batch_size}")
|
||||
return True # Sample added successfully
|
||||
@ -2264,12 +2359,19 @@ class TradingOrchestrator:
|
||||
# Trigger training if batch size is met
|
||||
if hasattr(model, 'train') and hasattr(model, 'training_data') and hasattr(model, 'batch_size'):
|
||||
if len(model.training_data) >= model.batch_size:
|
||||
training_start_time = time.time()
|
||||
training_results = model.train(epochs=1)
|
||||
training_duration_ms = (time.time() - training_start_time) * 1000
|
||||
|
||||
if training_results and 'loss' in training_results:
|
||||
current_loss = training_results['loss']
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
logger.debug(f"CNN training completed: loss={current_loss:.4f}")
|
||||
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
|
||||
logger.debug(f"CNN training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms")
|
||||
return True
|
||||
else:
|
||||
# Still update training statistics even if no loss returned
|
||||
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
|
||||
return True # Sample added successfully
|
||||
|
||||
# Try basic training method for EnhancedCNN
|
||||
@ -2412,7 +2514,9 @@ class TradingOrchestrator:
|
||||
# Use CNN adapter if available
|
||||
if hasattr(self, 'cnn_adapter') and self.cnn_adapter:
|
||||
try:
|
||||
cnn_start_time = time.time()
|
||||
result = self.cnn_adapter.predict(base_data)
|
||||
cnn_duration_ms = (time.time() - cnn_start_time) * 1000
|
||||
if result:
|
||||
# Extract action and probabilities from ModelOutput
|
||||
action = result.predictions.get('action', 'HOLD')
|
||||
@ -2428,7 +2532,7 @@ class TradingOrchestrator:
|
||||
probabilities=probabilities,
|
||||
timeframe="multi", # Multi-timeframe prediction
|
||||
timestamp=datetime.now(),
|
||||
model_name="enhanced_cnn",
|
||||
model_name=model.name, # Use the actual model name, not hardcoded "enhanced_cnn"
|
||||
metadata={
|
||||
'feature_size': len(base_data.get_feature_vector()),
|
||||
'data_sources': ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d', 'btc', 'cob', 'indicators'],
|
||||
|
Binary file not shown.
@ -2951,35 +2951,34 @@ class CleanTradingDashboard:
|
||||
'last_training': None,
|
||||
'inferences_per_second': 0.0,
|
||||
'trainings_per_second': 0.0,
|
||||
'prediction_count_24h': 0
|
||||
'prediction_count_24h': 0,
|
||||
'average_inference_time_ms': 0.0,
|
||||
'average_training_time_ms': 0.0
|
||||
}
|
||||
|
||||
try:
|
||||
if self.orchestrator:
|
||||
# Get recent predictions for timing analysis
|
||||
recent_predictions = self.orchestrator.get_recent_model_predictions('ETH/USDT', model_name.lower())
|
||||
|
||||
if model_name.lower() in recent_predictions:
|
||||
predictions = recent_predictions[model_name.lower()]
|
||||
if predictions:
|
||||
# Use the new model statistics system
|
||||
model_stats = self.orchestrator.get_model_statistics(model_name.lower())
|
||||
if model_stats:
|
||||
# Last inference time
|
||||
last_pred = predictions[-1]
|
||||
timing['last_inference'] = last_pred.get('timestamp', datetime.now())
|
||||
timing['last_inference'] = model_stats.last_inference_time
|
||||
|
||||
# Calculate predictions per second (last 60 seconds)
|
||||
now = datetime.now()
|
||||
recent_preds = [p for p in predictions
|
||||
if (now - p.get('timestamp', now)).total_seconds() <= 60]
|
||||
timing['inferences_per_second'] = len(recent_preds) / 60.0
|
||||
# Last training time
|
||||
timing['last_training'] = model_stats.last_training_time
|
||||
|
||||
# 24h prediction count
|
||||
preds_24h = [p for p in predictions
|
||||
if (now - p.get('timestamp', now)).total_seconds() <= 86400]
|
||||
timing['prediction_count_24h'] = len(preds_24h)
|
||||
# Inference rate per second
|
||||
timing['inferences_per_second'] = model_stats.inference_rate_per_second
|
||||
|
||||
# For training timing, check model-specific training status
|
||||
if hasattr(self.orchestrator, f'{model_name.lower()}_last_training'):
|
||||
timing['last_training'] = getattr(self.orchestrator, f'{model_name.lower()}_last_training')
|
||||
# Training rate per second
|
||||
timing['trainings_per_second'] = model_stats.training_rate_per_second
|
||||
|
||||
# 24h prediction count (approximate from total inferences)
|
||||
timing['prediction_count_24h'] = model_stats.total_inferences
|
||||
|
||||
# Average timing data
|
||||
timing['average_inference_time_ms'] = model_stats.average_inference_time_ms
|
||||
timing['average_training_time_ms'] = model_stats.average_training_time_ms
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting timing info for {model_name}: {e}")
|
||||
@ -3068,7 +3067,9 @@ class CleanTradingDashboard:
|
||||
'last_inference': dqn_timing['last_inference'].strftime('%H:%M:%S') if dqn_timing['last_inference'] else 'None',
|
||||
'last_training': dqn_timing['last_training'].strftime('%H:%M:%S') if dqn_timing['last_training'] else 'None',
|
||||
'inferences_per_second': f"{dqn_timing['inferences_per_second']:.2f}",
|
||||
'predictions_24h': dqn_timing['prediction_count_24h']
|
||||
'predictions_24h': dqn_timing['prediction_count_24h'],
|
||||
'average_inference_time_ms': f"{dqn_timing.get('average_inference_time_ms', 0):.1f}",
|
||||
'average_training_time_ms': f"{dqn_timing.get('average_training_time_ms', 0):.1f}"
|
||||
},
|
||||
# NEW: Performance metrics for split-second decisions
|
||||
'performance': self.get_model_performance_metrics().get('dqn', {})
|
||||
@ -3143,7 +3144,9 @@ class CleanTradingDashboard:
|
||||
'last_inference': cnn_timing['last_inference'].strftime('%H:%M:%S') if cnn_timing['last_inference'] else 'None',
|
||||
'last_training': cnn_timing['last_training'].strftime('%H:%M:%S') if cnn_timing['last_training'] else 'None',
|
||||
'inferences_per_second': f"{cnn_timing['inferences_per_second']:.2f}",
|
||||
'predictions_24h': cnn_timing['prediction_count_24h']
|
||||
'predictions_24h': cnn_timing['prediction_count_24h'],
|
||||
'average_inference_time_ms': f"{cnn_timing.get('average_inference_time_ms', 0):.1f}",
|
||||
'average_training_time_ms': f"{cnn_timing.get('average_training_time_ms', 0):.1f}"
|
||||
},
|
||||
# NEW: Performance metrics for split-second decisions
|
||||
'performance': self.get_model_performance_metrics().get('cnn', {})
|
||||
|
@ -823,7 +823,11 @@ class DashboardComponentManager:
|
||||
html.Br(),
|
||||
html.Span(f"Rate: {model_info.get('timing', {}).get('inferences_per_second', '0.00')}/s", className="text-success small"),
|
||||
html.Span(" | ", className="text-muted small"),
|
||||
html.Span(f"24h: {model_info.get('timing', {}).get('predictions_24h', 0)}", className="text-primary small")
|
||||
html.Span(f"24h: {model_info.get('timing', {}).get('predictions_24h', 0)}", className="text-primary small"),
|
||||
html.Br(),
|
||||
html.Span(f"Avg Inf: {model_info.get('timing', {}).get('average_inference_time_ms', 'N/A')}ms", className="text-info small"),
|
||||
html.Span(" | ", className="text-muted small"),
|
||||
html.Span(f"Avg Train: {model_info.get('timing', {}).get('average_training_time_ms', 'N/A')}ms", className="text-warning small")
|
||||
], className="mb-1"),
|
||||
|
||||
# Loss metrics with improvement tracking
|
||||
|
Reference in New Issue
Block a user