wip train

This commit is contained in:
Dobromir Popov
2025-07-27 20:34:51 +03:00
parent ff66cb8b79
commit d333681447
4 changed files with 160 additions and 49 deletions

View File

@ -147,6 +147,38 @@ class ModelStatistics:
self.best_loss = min(self.losses) if self.best_loss is None else min(self.best_loss, loss)
self.worst_loss = max(self.losses) if self.worst_loss is None else max(self.worst_loss, loss)
def update_training_stats(self, loss: Optional[float] = None, training_duration_ms: Optional[float] = None):
"""Update training statistics"""
current_time = datetime.now()
# Update training timing
self.last_training_time = current_time
self.total_trainings += 1
self.training_times.append(current_time)
# Update training duration
if training_duration_ms is not None:
self.training_durations_ms.append(training_duration_ms)
if self.training_durations_ms:
self.average_training_time_ms = sum(self.training_durations_ms) / len(self.training_durations_ms)
# Calculate training rates
if len(self.training_times) > 1:
time_window = (self.training_times[-1] - self.training_times[0]).total_seconds()
if time_window > 0:
self.training_rate_per_second = len(self.training_times) / time_window
self.training_rate_per_minute = self.training_rate_per_second * 60
# Update loss stats
if loss is not None:
self.current_loss = loss
self.losses.append(loss)
if self.losses:
self.average_loss = sum(self.losses) / len(self.losses)
self.best_loss = min(self.losses) if self.best_loss is None else min(self.best_loss, loss)
self.worst_loss = max(self.losses) if self.worst_loss is None else max(self.worst_loss, loss)
@dataclass
class TradingDecision:
"""Final trading decision from the orchestrator"""
@ -1374,40 +1406,52 @@ class TradingOrchestrator:
if isinstance(model, CNNModelInterface):
# Get CNN predictions using the pre-built base data
cnn_predictions = await self._get_cnn_predictions(model, symbol, base_data)
inference_duration_ms = (time.time() - inference_start_time) * 1000
predictions.extend(cnn_predictions)
# Update statistics for CNN predictions
if cnn_predictions:
for cnn_pred in cnn_predictions:
self._update_model_statistics(model_name, cnn_pred)
self._update_model_statistics(model_name, cnn_pred, inference_duration_ms=inference_duration_ms)
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time, symbol)
else:
# Still update statistics even if no predictions (for timing)
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
elif isinstance(model, RLAgentInterface):
# Get RL prediction using the pre-built base data
rl_prediction = await self._get_rl_prediction(model, symbol, base_data)
inference_duration_ms = (time.time() - inference_start_time) * 1000
if rl_prediction:
predictions.append(rl_prediction)
prediction = rl_prediction
# Update statistics for RL prediction
self._update_model_statistics(model_name, prediction)
self._update_model_statistics(model_name, prediction, inference_duration_ms=inference_duration_ms)
# Store input data for RL
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
else:
# Still update statistics even if no prediction (for timing)
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
else:
# Generic model interface using the pre-built base data
generic_prediction = await self._get_generic_prediction(model, symbol, base_data)
inference_duration_ms = (time.time() - inference_start_time) * 1000
if generic_prediction:
predictions.append(generic_prediction)
prediction = generic_prediction
# Update statistics for generic prediction
self._update_model_statistics(model_name, prediction)
self._update_model_statistics(model_name, prediction, inference_duration_ms=inference_duration_ms)
# Store input data for generic model
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
else:
# Still update statistics even if no prediction (for timing)
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
except Exception as e:
inference_duration_ms = (time.time() - inference_start_time) * 1000
logger.error(f"Error getting prediction from {model_name}: {e}")
# Still update statistics for failed inference
if model_name in self.model_statistics:
self.model_statistics[model_name].update_inference_stats()
# Still update statistics for failed inference (for timing)
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
continue
@ -1417,25 +1461,48 @@ class TradingOrchestrator:
return predictions
def _update_model_statistics(self, model_name: str, prediction: Optional[Prediction] = None, loss: Optional[float] = None):
def _update_model_statistics(self, model_name: str, prediction: Optional[Prediction] = None, loss: Optional[float] = None,
inference_duration_ms: Optional[float] = None):
"""Update statistics for a specific model"""
try:
if model_name not in self.model_statistics:
self.model_statistics[model_name] = ModelStatistics(model_name=model_name)
# Update the statistics
self.model_statistics[model_name].update_inference_stats(prediction, loss)
self.model_statistics[model_name].update_inference_stats(prediction, loss, inference_duration_ms)
# Log statistics periodically (every 10 inferences)
stats = self.model_statistics[model_name]
if stats.total_inferences % 10 == 0:
logger.debug(f"Model {model_name} stats: {stats.total_inferences} inferences, "
f"{stats.inference_rate_per_minute:.1f}/min, "
f"avg: {stats.average_inference_time_ms:.1f}ms, "
f"last: {stats.last_prediction} ({stats.last_confidence:.3f})")
except Exception as e:
logger.error(f"Error updating statistics for {model_name}: {e}")
def _update_model_training_statistics(self, model_name: str, loss: Optional[float] = None,
training_duration_ms: Optional[float] = None):
"""Update training statistics for a specific model"""
try:
if model_name not in self.model_statistics:
self.model_statistics[model_name] = ModelStatistics(model_name=model_name)
# Update the training statistics
self.model_statistics[model_name].update_training_stats(loss, training_duration_ms)
# Log training statistics periodically (every 5 trainings)
stats = self.model_statistics[model_name]
if stats.total_trainings % 5 == 0:
logger.debug(f"Model {model_name} training stats: {stats.total_trainings} trainings, "
f"{stats.training_rate_per_minute:.1f}/min, "
f"avg: {stats.average_training_time_ms:.1f}ms, "
f"loss: {stats.current_loss:.4f}" if stats.current_loss else "loss: N/A")
except Exception as e:
logger.error(f"Error updating training statistics for {model_name}: {e}")
def get_model_statistics(self, model_name: Optional[str] = None) -> Union[Dict[str, ModelStatistics], ModelStatistics, None]:
"""Get statistics for a specific model or all models"""
try:
@ -1454,9 +1521,15 @@ class TradingOrchestrator:
for model_name, stats in self.model_statistics.items():
summary[model_name] = {
'last_inference_time': stats.last_inference_time.isoformat() if stats.last_inference_time else None,
'last_training_time': stats.last_training_time.isoformat() if stats.last_training_time else None,
'total_inferences': stats.total_inferences,
'total_trainings': stats.total_trainings,
'inference_rate_per_minute': round(stats.inference_rate_per_minute, 2),
'inference_rate_per_second': round(stats.inference_rate_per_second, 4),
'training_rate_per_minute': round(stats.training_rate_per_minute, 2),
'training_rate_per_second': round(stats.training_rate_per_second, 4),
'average_inference_time_ms': round(stats.average_inference_time_ms, 2),
'average_training_time_ms': round(stats.average_training_time_ms, 2),
'current_loss': round(stats.current_loss, 6) if stats.current_loss is not None else None,
'average_loss': round(stats.average_loss, 6) if stats.average_loss is not None else None,
'best_loss': round(stats.best_loss, 6) if stats.best_loss is not None else None,
@ -1483,18 +1556,26 @@ class TradingOrchestrator:
for model_name, stats in self.model_statistics.items():
if detailed:
logger.info(f"{model_name}:")
logger.info(f" Total inferences: {stats.total_inferences}")
logger.info(f" Total inferences: {stats.total_inferences} (avg: {stats.average_inference_time_ms:.1f}ms)")
logger.info(f" Total trainings: {stats.total_trainings} (avg: {stats.average_training_time_ms:.1f}ms)")
logger.info(f" Inference rate: {stats.inference_rate_per_minute:.1f}/min ({stats.inference_rate_per_second:.3f}/sec)")
logger.info(f" Training rate: {stats.training_rate_per_minute:.1f}/min ({stats.training_rate_per_second:.3f}/sec)")
logger.info(f" Last inference: {stats.last_inference_time}")
logger.info(f" Last training: {stats.last_training_time}")
logger.info(f" Current loss: {stats.current_loss:.6f}" if stats.current_loss else " Current loss: N/A")
logger.info(f" Average loss: {stats.average_loss:.6f}" if stats.average_loss else " Average loss: N/A")
logger.info(f" Best loss: {stats.best_loss:.6f}" if stats.best_loss else " Best loss: N/A")
logger.info(f" Last prediction: {stats.last_prediction} ({stats.last_confidence:.3f})" if stats.last_prediction else " Last prediction: N/A")
else:
rate_str = f"{stats.inference_rate_per_minute:.1f}/min"
inf_rate_str = f"{stats.inference_rate_per_minute:.1f}/min"
train_rate_str = f"{stats.training_rate_per_minute:.1f}/min" if stats.total_trainings > 0 else "0/min"
inf_time_str = f"{stats.average_inference_time_ms:.1f}ms" if stats.average_inference_time_ms > 0 else "N/A"
train_time_str = f"{stats.average_training_time_ms:.1f}ms" if stats.average_training_time_ms > 0 else "N/A"
loss_str = f"{stats.current_loss:.4f}" if stats.current_loss else "N/A"
pred_str = f"{stats.last_prediction}({stats.last_confidence:.2f})" if stats.last_prediction else "N/A"
logger.info(f"{model_name}: {stats.total_inferences} inferences, {rate_str}, loss={loss_str}, last={pred_str}")
logger.info(f"{model_name}: Inf: {stats.total_inferences}@{inf_time_str} ({inf_rate_str}) | "
f"Train: {stats.total_trainings}@{train_time_str} ({train_rate_str}) | "
f"Loss: {loss_str} | Last: {pred_str}")
except Exception as e:
logger.error(f"Error logging model statistics: {e}")
@ -2143,11 +2224,18 @@ class TradingOrchestrator:
batch_size = getattr(model, 'batch_size', 32)
if memory_size >= batch_size:
logger.debug(f"Training {model_name} with {memory_size} experiences")
training_start_time = time.time()
training_loss = model.replay()
training_duration_ms = (time.time() - training_start_time) * 1000
if training_loss is not None and training_loss > 0:
self.update_model_loss(model_name, training_loss)
logger.debug(f"RL training completed for {model_name}: loss={training_loss:.4f}")
self._update_model_training_statistics(model_name, training_loss, training_duration_ms)
logger.debug(f"RL training completed for {model_name}: loss={training_loss:.4f}, time={training_duration_ms:.1f}ms")
return True
else:
# Still update training statistics even if no loss returned
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
else:
logger.debug(f"Not enough experiences for {model_name}: {memory_size}/{batch_size}")
return True # Experience added successfully, training will happen later
@ -2242,12 +2330,19 @@ class TradingOrchestrator:
if hasattr(self.cnn_adapter, 'training_data') and hasattr(self.cnn_adapter, 'batch_size'):
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
logger.debug(f"Training CNN with {len(self.cnn_adapter.training_data)} samples")
training_start_time = time.time()
training_results = self.cnn_adapter.train(epochs=1)
training_duration_ms = (time.time() - training_start_time) * 1000
if training_results and 'loss' in training_results:
current_loss = training_results['loss']
self.update_model_loss(model_name, current_loss)
logger.debug(f"CNN training completed: loss={current_loss:.4f}")
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
logger.debug(f"CNN training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms")
return True
else:
# Still update training statistics even if no loss returned
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
else:
logger.debug(f"Not enough samples for CNN training: {len(self.cnn_adapter.training_data)}/{self.cnn_adapter.batch_size}")
return True # Sample added successfully
@ -2264,12 +2359,19 @@ class TradingOrchestrator:
# Trigger training if batch size is met
if hasattr(model, 'train') and hasattr(model, 'training_data') and hasattr(model, 'batch_size'):
if len(model.training_data) >= model.batch_size:
training_start_time = time.time()
training_results = model.train(epochs=1)
training_duration_ms = (time.time() - training_start_time) * 1000
if training_results and 'loss' in training_results:
current_loss = training_results['loss']
self.update_model_loss(model_name, current_loss)
logger.debug(f"CNN training completed: loss={current_loss:.4f}")
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
logger.debug(f"CNN training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms")
return True
else:
# Still update training statistics even if no loss returned
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
return True # Sample added successfully
# Try basic training method for EnhancedCNN
@ -2412,7 +2514,9 @@ class TradingOrchestrator:
# Use CNN adapter if available
if hasattr(self, 'cnn_adapter') and self.cnn_adapter:
try:
cnn_start_time = time.time()
result = self.cnn_adapter.predict(base_data)
cnn_duration_ms = (time.time() - cnn_start_time) * 1000
if result:
# Extract action and probabilities from ModelOutput
action = result.predictions.get('action', 'HOLD')
@ -2428,7 +2532,7 @@ class TradingOrchestrator:
probabilities=probabilities,
timeframe="multi", # Multi-timeframe prediction
timestamp=datetime.now(),
model_name="enhanced_cnn",
model_name=model.name, # Use the actual model name, not hardcoded "enhanced_cnn"
metadata={
'feature_size': len(base_data.get_feature_vector()),
'data_sources': ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d', 'btc', 'cob', 'indicators'],

Binary file not shown.

View File

@ -2951,35 +2951,34 @@ class CleanTradingDashboard:
'last_training': None,
'inferences_per_second': 0.0,
'trainings_per_second': 0.0,
'prediction_count_24h': 0
'prediction_count_24h': 0,
'average_inference_time_ms': 0.0,
'average_training_time_ms': 0.0
}
try:
if self.orchestrator:
# Get recent predictions for timing analysis
recent_predictions = self.orchestrator.get_recent_model_predictions('ETH/USDT', model_name.lower())
if model_name.lower() in recent_predictions:
predictions = recent_predictions[model_name.lower()]
if predictions:
# Use the new model statistics system
model_stats = self.orchestrator.get_model_statistics(model_name.lower())
if model_stats:
# Last inference time
last_pred = predictions[-1]
timing['last_inference'] = last_pred.get('timestamp', datetime.now())
timing['last_inference'] = model_stats.last_inference_time
# Calculate predictions per second (last 60 seconds)
now = datetime.now()
recent_preds = [p for p in predictions
if (now - p.get('timestamp', now)).total_seconds() <= 60]
timing['inferences_per_second'] = len(recent_preds) / 60.0
# Last training time
timing['last_training'] = model_stats.last_training_time
# 24h prediction count
preds_24h = [p for p in predictions
if (now - p.get('timestamp', now)).total_seconds() <= 86400]
timing['prediction_count_24h'] = len(preds_24h)
# Inference rate per second
timing['inferences_per_second'] = model_stats.inference_rate_per_second
# For training timing, check model-specific training status
if hasattr(self.orchestrator, f'{model_name.lower()}_last_training'):
timing['last_training'] = getattr(self.orchestrator, f'{model_name.lower()}_last_training')
# Training rate per second
timing['trainings_per_second'] = model_stats.training_rate_per_second
# 24h prediction count (approximate from total inferences)
timing['prediction_count_24h'] = model_stats.total_inferences
# Average timing data
timing['average_inference_time_ms'] = model_stats.average_inference_time_ms
timing['average_training_time_ms'] = model_stats.average_training_time_ms
except Exception as e:
logger.debug(f"Error getting timing info for {model_name}: {e}")
@ -3068,7 +3067,9 @@ class CleanTradingDashboard:
'last_inference': dqn_timing['last_inference'].strftime('%H:%M:%S') if dqn_timing['last_inference'] else 'None',
'last_training': dqn_timing['last_training'].strftime('%H:%M:%S') if dqn_timing['last_training'] else 'None',
'inferences_per_second': f"{dqn_timing['inferences_per_second']:.2f}",
'predictions_24h': dqn_timing['prediction_count_24h']
'predictions_24h': dqn_timing['prediction_count_24h'],
'average_inference_time_ms': f"{dqn_timing.get('average_inference_time_ms', 0):.1f}",
'average_training_time_ms': f"{dqn_timing.get('average_training_time_ms', 0):.1f}"
},
# NEW: Performance metrics for split-second decisions
'performance': self.get_model_performance_metrics().get('dqn', {})
@ -3143,7 +3144,9 @@ class CleanTradingDashboard:
'last_inference': cnn_timing['last_inference'].strftime('%H:%M:%S') if cnn_timing['last_inference'] else 'None',
'last_training': cnn_timing['last_training'].strftime('%H:%M:%S') if cnn_timing['last_training'] else 'None',
'inferences_per_second': f"{cnn_timing['inferences_per_second']:.2f}",
'predictions_24h': cnn_timing['prediction_count_24h']
'predictions_24h': cnn_timing['prediction_count_24h'],
'average_inference_time_ms': f"{cnn_timing.get('average_inference_time_ms', 0):.1f}",
'average_training_time_ms': f"{cnn_timing.get('average_training_time_ms', 0):.1f}"
},
# NEW: Performance metrics for split-second decisions
'performance': self.get_model_performance_metrics().get('cnn', {})

View File

@ -823,7 +823,11 @@ class DashboardComponentManager:
html.Br(),
html.Span(f"Rate: {model_info.get('timing', {}).get('inferences_per_second', '0.00')}/s", className="text-success small"),
html.Span(" | ", className="text-muted small"),
html.Span(f"24h: {model_info.get('timing', {}).get('predictions_24h', 0)}", className="text-primary small")
html.Span(f"24h: {model_info.get('timing', {}).get('predictions_24h', 0)}", className="text-primary small"),
html.Br(),
html.Span(f"Avg Inf: {model_info.get('timing', {}).get('average_inference_time_ms', 'N/A')}ms", className="text-info small"),
html.Span(" | ", className="text-muted small"),
html.Span(f"Avg Train: {model_info.get('timing', {}).get('average_training_time_ms', 'N/A')}ms", className="text-warning small")
], className="mb-1"),
# Loss metrics with improvement tracking