wip train

This commit is contained in:
Dobromir Popov
2025-07-27 20:34:51 +03:00
parent ff66cb8b79
commit d333681447
4 changed files with 160 additions and 49 deletions

View File

@ -146,6 +146,38 @@ class ModelStatistics:
self.average_loss = sum(self.losses) / len(self.losses)
self.best_loss = min(self.losses) if self.best_loss is None else min(self.best_loss, loss)
self.worst_loss = max(self.losses) if self.worst_loss is None else max(self.worst_loss, loss)
def update_training_stats(self, loss: Optional[float] = None, training_duration_ms: Optional[float] = None):
"""Update training statistics"""
current_time = datetime.now()
# Update training timing
self.last_training_time = current_time
self.total_trainings += 1
self.training_times.append(current_time)
# Update training duration
if training_duration_ms is not None:
self.training_durations_ms.append(training_duration_ms)
if self.training_durations_ms:
self.average_training_time_ms = sum(self.training_durations_ms) / len(self.training_durations_ms)
# Calculate training rates
if len(self.training_times) > 1:
time_window = (self.training_times[-1] - self.training_times[0]).total_seconds()
if time_window > 0:
self.training_rate_per_second = len(self.training_times) / time_window
self.training_rate_per_minute = self.training_rate_per_second * 60
# Update loss stats
if loss is not None:
self.current_loss = loss
self.losses.append(loss)
if self.losses:
self.average_loss = sum(self.losses) / len(self.losses)
self.best_loss = min(self.losses) if self.best_loss is None else min(self.best_loss, loss)
self.worst_loss = max(self.losses) if self.worst_loss is None else max(self.worst_loss, loss)
@dataclass
class TradingDecision:
@ -1374,40 +1406,52 @@ class TradingOrchestrator:
if isinstance(model, CNNModelInterface):
# Get CNN predictions using the pre-built base data
cnn_predictions = await self._get_cnn_predictions(model, symbol, base_data)
inference_duration_ms = (time.time() - inference_start_time) * 1000
predictions.extend(cnn_predictions)
# Update statistics for CNN predictions
if cnn_predictions:
for cnn_pred in cnn_predictions:
self._update_model_statistics(model_name, cnn_pred)
self._update_model_statistics(model_name, cnn_pred, inference_duration_ms=inference_duration_ms)
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time, symbol)
else:
# Still update statistics even if no predictions (for timing)
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
elif isinstance(model, RLAgentInterface):
# Get RL prediction using the pre-built base data
rl_prediction = await self._get_rl_prediction(model, symbol, base_data)
inference_duration_ms = (time.time() - inference_start_time) * 1000
if rl_prediction:
predictions.append(rl_prediction)
prediction = rl_prediction
# Update statistics for RL prediction
self._update_model_statistics(model_name, prediction)
self._update_model_statistics(model_name, prediction, inference_duration_ms=inference_duration_ms)
# Store input data for RL
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
else:
# Still update statistics even if no prediction (for timing)
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
else:
# Generic model interface using the pre-built base data
generic_prediction = await self._get_generic_prediction(model, symbol, base_data)
inference_duration_ms = (time.time() - inference_start_time) * 1000
if generic_prediction:
predictions.append(generic_prediction)
prediction = generic_prediction
# Update statistics for generic prediction
self._update_model_statistics(model_name, prediction)
self._update_model_statistics(model_name, prediction, inference_duration_ms=inference_duration_ms)
# Store input data for generic model
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
else:
# Still update statistics even if no prediction (for timing)
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
except Exception as e:
inference_duration_ms = (time.time() - inference_start_time) * 1000
logger.error(f"Error getting prediction from {model_name}: {e}")
# Still update statistics for failed inference
if model_name in self.model_statistics:
self.model_statistics[model_name].update_inference_stats()
# Still update statistics for failed inference (for timing)
self._update_model_statistics(model_name, inference_duration_ms=inference_duration_ms)
continue
@ -1417,25 +1461,48 @@ class TradingOrchestrator:
return predictions
def _update_model_statistics(self, model_name: str, prediction: Optional[Prediction] = None, loss: Optional[float] = None):
def _update_model_statistics(self, model_name: str, prediction: Optional[Prediction] = None, loss: Optional[float] = None,
inference_duration_ms: Optional[float] = None):
"""Update statistics for a specific model"""
try:
if model_name not in self.model_statistics:
self.model_statistics[model_name] = ModelStatistics(model_name=model_name)
# Update the statistics
self.model_statistics[model_name].update_inference_stats(prediction, loss)
self.model_statistics[model_name].update_inference_stats(prediction, loss, inference_duration_ms)
# Log statistics periodically (every 10 inferences)
stats = self.model_statistics[model_name]
if stats.total_inferences % 10 == 0:
logger.debug(f"Model {model_name} stats: {stats.total_inferences} inferences, "
f"{stats.inference_rate_per_minute:.1f}/min, "
f"avg: {stats.average_inference_time_ms:.1f}ms, "
f"last: {stats.last_prediction} ({stats.last_confidence:.3f})")
except Exception as e:
logger.error(f"Error updating statistics for {model_name}: {e}")
def _update_model_training_statistics(self, model_name: str, loss: Optional[float] = None,
training_duration_ms: Optional[float] = None):
"""Update training statistics for a specific model"""
try:
if model_name not in self.model_statistics:
self.model_statistics[model_name] = ModelStatistics(model_name=model_name)
# Update the training statistics
self.model_statistics[model_name].update_training_stats(loss, training_duration_ms)
# Log training statistics periodically (every 5 trainings)
stats = self.model_statistics[model_name]
if stats.total_trainings % 5 == 0:
logger.debug(f"Model {model_name} training stats: {stats.total_trainings} trainings, "
f"{stats.training_rate_per_minute:.1f}/min, "
f"avg: {stats.average_training_time_ms:.1f}ms, "
f"loss: {stats.current_loss:.4f}" if stats.current_loss else "loss: N/A")
except Exception as e:
logger.error(f"Error updating training statistics for {model_name}: {e}")
def get_model_statistics(self, model_name: Optional[str] = None) -> Union[Dict[str, ModelStatistics], ModelStatistics, None]:
"""Get statistics for a specific model or all models"""
try:
@ -1454,9 +1521,15 @@ class TradingOrchestrator:
for model_name, stats in self.model_statistics.items():
summary[model_name] = {
'last_inference_time': stats.last_inference_time.isoformat() if stats.last_inference_time else None,
'last_training_time': stats.last_training_time.isoformat() if stats.last_training_time else None,
'total_inferences': stats.total_inferences,
'total_trainings': stats.total_trainings,
'inference_rate_per_minute': round(stats.inference_rate_per_minute, 2),
'inference_rate_per_second': round(stats.inference_rate_per_second, 4),
'training_rate_per_minute': round(stats.training_rate_per_minute, 2),
'training_rate_per_second': round(stats.training_rate_per_second, 4),
'average_inference_time_ms': round(stats.average_inference_time_ms, 2),
'average_training_time_ms': round(stats.average_training_time_ms, 2),
'current_loss': round(stats.current_loss, 6) if stats.current_loss is not None else None,
'average_loss': round(stats.average_loss, 6) if stats.average_loss is not None else None,
'best_loss': round(stats.best_loss, 6) if stats.best_loss is not None else None,
@ -1483,18 +1556,26 @@ class TradingOrchestrator:
for model_name, stats in self.model_statistics.items():
if detailed:
logger.info(f"{model_name}:")
logger.info(f" Total inferences: {stats.total_inferences}")
logger.info(f" Total inferences: {stats.total_inferences} (avg: {stats.average_inference_time_ms:.1f}ms)")
logger.info(f" Total trainings: {stats.total_trainings} (avg: {stats.average_training_time_ms:.1f}ms)")
logger.info(f" Inference rate: {stats.inference_rate_per_minute:.1f}/min ({stats.inference_rate_per_second:.3f}/sec)")
logger.info(f" Training rate: {stats.training_rate_per_minute:.1f}/min ({stats.training_rate_per_second:.3f}/sec)")
logger.info(f" Last inference: {stats.last_inference_time}")
logger.info(f" Last training: {stats.last_training_time}")
logger.info(f" Current loss: {stats.current_loss:.6f}" if stats.current_loss else " Current loss: N/A")
logger.info(f" Average loss: {stats.average_loss:.6f}" if stats.average_loss else " Average loss: N/A")
logger.info(f" Best loss: {stats.best_loss:.6f}" if stats.best_loss else " Best loss: N/A")
logger.info(f" Last prediction: {stats.last_prediction} ({stats.last_confidence:.3f})" if stats.last_prediction else " Last prediction: N/A")
else:
rate_str = f"{stats.inference_rate_per_minute:.1f}/min"
inf_rate_str = f"{stats.inference_rate_per_minute:.1f}/min"
train_rate_str = f"{stats.training_rate_per_minute:.1f}/min" if stats.total_trainings > 0 else "0/min"
inf_time_str = f"{stats.average_inference_time_ms:.1f}ms" if stats.average_inference_time_ms > 0 else "N/A"
train_time_str = f"{stats.average_training_time_ms:.1f}ms" if stats.average_training_time_ms > 0 else "N/A"
loss_str = f"{stats.current_loss:.4f}" if stats.current_loss else "N/A"
pred_str = f"{stats.last_prediction}({stats.last_confidence:.2f})" if stats.last_prediction else "N/A"
logger.info(f"{model_name}: {stats.total_inferences} inferences, {rate_str}, loss={loss_str}, last={pred_str}")
logger.info(f"{model_name}: Inf: {stats.total_inferences}@{inf_time_str} ({inf_rate_str}) | "
f"Train: {stats.total_trainings}@{train_time_str} ({train_rate_str}) | "
f"Loss: {loss_str} | Last: {pred_str}")
except Exception as e:
logger.error(f"Error logging model statistics: {e}")
@ -2143,11 +2224,18 @@ class TradingOrchestrator:
batch_size = getattr(model, 'batch_size', 32)
if memory_size >= batch_size:
logger.debug(f"Training {model_name} with {memory_size} experiences")
training_start_time = time.time()
training_loss = model.replay()
training_duration_ms = (time.time() - training_start_time) * 1000
if training_loss is not None and training_loss > 0:
self.update_model_loss(model_name, training_loss)
logger.debug(f"RL training completed for {model_name}: loss={training_loss:.4f}")
self._update_model_training_statistics(model_name, training_loss, training_duration_ms)
logger.debug(f"RL training completed for {model_name}: loss={training_loss:.4f}, time={training_duration_ms:.1f}ms")
return True
else:
# Still update training statistics even if no loss returned
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
else:
logger.debug(f"Not enough experiences for {model_name}: {memory_size}/{batch_size}")
return True # Experience added successfully, training will happen later
@ -2242,12 +2330,19 @@ class TradingOrchestrator:
if hasattr(self.cnn_adapter, 'training_data') and hasattr(self.cnn_adapter, 'batch_size'):
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
logger.debug(f"Training CNN with {len(self.cnn_adapter.training_data)} samples")
training_start_time = time.time()
training_results = self.cnn_adapter.train(epochs=1)
training_duration_ms = (time.time() - training_start_time) * 1000
if training_results and 'loss' in training_results:
current_loss = training_results['loss']
self.update_model_loss(model_name, current_loss)
logger.debug(f"CNN training completed: loss={current_loss:.4f}")
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
logger.debug(f"CNN training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms")
return True
else:
# Still update training statistics even if no loss returned
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
else:
logger.debug(f"Not enough samples for CNN training: {len(self.cnn_adapter.training_data)}/{self.cnn_adapter.batch_size}")
return True # Sample added successfully
@ -2264,12 +2359,19 @@ class TradingOrchestrator:
# Trigger training if batch size is met
if hasattr(model, 'train') and hasattr(model, 'training_data') and hasattr(model, 'batch_size'):
if len(model.training_data) >= model.batch_size:
training_start_time = time.time()
training_results = model.train(epochs=1)
training_duration_ms = (time.time() - training_start_time) * 1000
if training_results and 'loss' in training_results:
current_loss = training_results['loss']
self.update_model_loss(model_name, current_loss)
logger.debug(f"CNN training completed: loss={current_loss:.4f}")
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
logger.debug(f"CNN training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms")
return True
else:
# Still update training statistics even if no loss returned
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
return True # Sample added successfully
# Try basic training method for EnhancedCNN
@ -2412,7 +2514,9 @@ class TradingOrchestrator:
# Use CNN adapter if available
if hasattr(self, 'cnn_adapter') and self.cnn_adapter:
try:
cnn_start_time = time.time()
result = self.cnn_adapter.predict(base_data)
cnn_duration_ms = (time.time() - cnn_start_time) * 1000
if result:
# Extract action and probabilities from ModelOutput
action = result.predictions.get('action', 'HOLD')
@ -2428,7 +2532,7 @@ class TradingOrchestrator:
probabilities=probabilities,
timeframe="multi", # Multi-timeframe prediction
timestamp=datetime.now(),
model_name="enhanced_cnn",
model_name=model.name, # Use the actual model name, not hardcoded "enhanced_cnn"
metadata={
'feature_size': len(base_data.get_feature_vector()),
'data_sources': ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d', 'btc', 'cob', 'indicators'],