ACTUAL TRAINING WORKING (WIP)
This commit is contained in:
@ -1107,7 +1107,7 @@ class TradingOrchestrator:
|
||||
if model_input is not None and cnn_predictions:
|
||||
# Store inference data for each CNN prediction
|
||||
for cnn_pred in cnn_predictions:
|
||||
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time)
|
||||
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time, symbol)
|
||||
|
||||
elif isinstance(model, RLAgentInterface):
|
||||
# Get RL prediction
|
||||
@ -1118,7 +1118,7 @@ class TradingOrchestrator:
|
||||
# Store input data for RL
|
||||
model_input = input_data.get('rl_input')
|
||||
if model_input is not None:
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time)
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
|
||||
else:
|
||||
# Generic model interface
|
||||
@ -1129,7 +1129,7 @@ class TradingOrchestrator:
|
||||
# Store input data for generic model
|
||||
model_input = input_data.get('generic_input')
|
||||
if model_input is not None:
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time)
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||
@ -1206,15 +1206,21 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error collecting standardized model input data: {e}")
|
||||
return {}
|
||||
|
||||
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
|
||||
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime, symbol: str = None):
|
||||
"""Store inference data per-model with async file operations and memory optimization"""
|
||||
try:
|
||||
# Only log first few inference records to avoid spam
|
||||
if len(self.inference_history.get(model_name, [])) < 3:
|
||||
logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
|
||||
|
||||
# Extract symbol from prediction if not provided
|
||||
if symbol is None:
|
||||
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
|
||||
|
||||
# Create comprehensive inference record
|
||||
inference_record = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': symbol,
|
||||
'model_name': model_name,
|
||||
'model_input': model_input,
|
||||
'prediction': {
|
||||
|
Reference in New Issue
Block a user