ACTUAL TRAINING WORKING (WIP)
This commit is contained in:
@ -1107,7 +1107,7 @@ class TradingOrchestrator:
|
|||||||
if model_input is not None and cnn_predictions:
|
if model_input is not None and cnn_predictions:
|
||||||
# Store inference data for each CNN prediction
|
# Store inference data for each CNN prediction
|
||||||
for cnn_pred in cnn_predictions:
|
for cnn_pred in cnn_predictions:
|
||||||
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time)
|
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time, symbol)
|
||||||
|
|
||||||
elif isinstance(model, RLAgentInterface):
|
elif isinstance(model, RLAgentInterface):
|
||||||
# Get RL prediction
|
# Get RL prediction
|
||||||
@ -1118,7 +1118,7 @@ class TradingOrchestrator:
|
|||||||
# Store input data for RL
|
# Store input data for RL
|
||||||
model_input = input_data.get('rl_input')
|
model_input = input_data.get('rl_input')
|
||||||
if model_input is not None:
|
if model_input is not None:
|
||||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time)
|
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Generic model interface
|
# Generic model interface
|
||||||
@ -1129,7 +1129,7 @@ class TradingOrchestrator:
|
|||||||
# Store input data for generic model
|
# Store input data for generic model
|
||||||
model_input = input_data.get('generic_input')
|
model_input = input_data.get('generic_input')
|
||||||
if model_input is not None:
|
if model_input is not None:
|
||||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time)
|
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||||
@ -1206,15 +1206,21 @@ class TradingOrchestrator:
|
|||||||
logger.error(f"Error collecting standardized model input data: {e}")
|
logger.error(f"Error collecting standardized model input data: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
|
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime, symbol: str = None):
|
||||||
"""Store inference data per-model with async file operations and memory optimization"""
|
"""Store inference data per-model with async file operations and memory optimization"""
|
||||||
try:
|
try:
|
||||||
# Only log first few inference records to avoid spam
|
# Only log first few inference records to avoid spam
|
||||||
if len(self.inference_history.get(model_name, [])) < 3:
|
if len(self.inference_history.get(model_name, [])) < 3:
|
||||||
logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
|
logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
|
||||||
|
|
||||||
|
# Extract symbol from prediction if not provided
|
||||||
|
if symbol is None:
|
||||||
|
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
|
||||||
|
|
||||||
# Create comprehensive inference record
|
# Create comprehensive inference record
|
||||||
inference_record = {
|
inference_record = {
|
||||||
'timestamp': timestamp.isoformat(),
|
'timestamp': timestamp.isoformat(),
|
||||||
|
'symbol': symbol,
|
||||||
'model_name': model_name,
|
'model_name': model_name,
|
||||||
'model_input': model_input,
|
'model_input': model_input,
|
||||||
'prediction': {
|
'prediction': {
|
||||||
|
Reference in New Issue
Block a user