wip training
This commit is contained in:
@ -1270,8 +1270,11 @@ class TradingOrchestrator:
|
||||
predictions = []
|
||||
current_time = datetime.now()
|
||||
|
||||
# Collect input data for all models
|
||||
input_data = await self._collect_model_input_data(symbol)
|
||||
# Get the standard model input data once for all models
|
||||
base_data = self.data_provider.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for predictions: {symbol}")
|
||||
return predictions
|
||||
|
||||
# log all registered models
|
||||
logger.debug(f"inferencing registered models: {self.model_registry.models}")
|
||||
@ -1279,40 +1282,35 @@ class TradingOrchestrator:
|
||||
for model_name, model in self.model_registry.models.items():
|
||||
try:
|
||||
prediction = None
|
||||
model_input = None
|
||||
model_input = base_data # Use the same base data for all models
|
||||
|
||||
if isinstance(model, CNNModelInterface):
|
||||
# Get CNN predictions for each timeframe
|
||||
cnn_predictions = await self._get_cnn_predictions(model, symbol)
|
||||
# Get CNN predictions using the pre-built base data
|
||||
cnn_predictions = await self._get_cnn_predictions(model, symbol, base_data)
|
||||
predictions.extend(cnn_predictions)
|
||||
# Store input data for CNN - store for each prediction
|
||||
model_input = input_data.get('cnn_input')
|
||||
if model_input is not None and cnn_predictions:
|
||||
if cnn_predictions:
|
||||
# Store inference data for each CNN prediction
|
||||
for cnn_pred in cnn_predictions:
|
||||
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time, symbol)
|
||||
|
||||
elif isinstance(model, RLAgentInterface):
|
||||
# Get RL prediction
|
||||
rl_prediction = await self._get_rl_prediction(model, symbol)
|
||||
# Get RL prediction using the pre-built base data
|
||||
rl_prediction = await self._get_rl_prediction(model, symbol, base_data)
|
||||
if rl_prediction:
|
||||
predictions.append(rl_prediction)
|
||||
prediction = rl_prediction
|
||||
# Store input data for RL
|
||||
model_input = input_data.get('rl_input')
|
||||
if model_input is not None:
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
|
||||
else:
|
||||
# Generic model interface
|
||||
generic_prediction = await self._get_generic_prediction(model, symbol)
|
||||
# Generic model interface using the pre-built base data
|
||||
generic_prediction = await self._get_generic_prediction(model, symbol, base_data)
|
||||
if generic_prediction:
|
||||
predictions.append(generic_prediction)
|
||||
prediction = generic_prediction
|
||||
# Store input data for generic model
|
||||
model_input = input_data.get('generic_input')
|
||||
if model_input is not None:
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||
@ -1320,69 +1318,12 @@ class TradingOrchestrator:
|
||||
|
||||
|
||||
|
||||
# Trigger training based on previous inference data
|
||||
await self._trigger_model_training(symbol)
|
||||
# Note: Training is now triggered immediately within each prediction method
|
||||
# when previous inference data exists, rather than after all predictions
|
||||
|
||||
return predictions
|
||||
|
||||
async def _collect_model_input_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Collect standardized input data for all models - ETH primary + BTC reference"""
|
||||
try:
|
||||
# Only collect data for ETH (primary symbol) - we inference only for ETH
|
||||
if symbol != 'ETH/USDT':
|
||||
return {}
|
||||
|
||||
# Standardized input: 4 ETH timeframes + 1s BTC reference
|
||||
eth_data = {}
|
||||
eth_timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
# Collect ETH data for all timeframes
|
||||
for tf in eth_timeframes:
|
||||
df = self.data_provider.get_historical_data('ETH/USDT', tf, limit=300)
|
||||
if df is not None and not df.empty:
|
||||
eth_data[f'ETH_{tf}'] = df
|
||||
|
||||
# Collect BTC 1s reference data
|
||||
btc_1s = self.data_provider.get_historical_data('BTC/USDT', '1s', limit=300)
|
||||
btc_data = {}
|
||||
if btc_1s is not None and not btc_1s.empty:
|
||||
btc_data['BTC_1s'] = btc_1s
|
||||
|
||||
# Get current prices
|
||||
eth_price = self.data_provider.get_current_price('ETH/USDT')
|
||||
btc_price = self.data_provider.get_current_price('BTC/USDT')
|
||||
|
||||
# Create standardized input package
|
||||
standardized_input = {
|
||||
'timestamp': datetime.now(),
|
||||
'primary_symbol': 'ETH/USDT',
|
||||
'reference_symbol': 'BTC/USDT',
|
||||
'eth_data': eth_data,
|
||||
'btc_data': btc_data,
|
||||
'current_prices': {
|
||||
'ETH': eth_price,
|
||||
'BTC': btc_price
|
||||
},
|
||||
'data_completeness': {
|
||||
'eth_timeframes': len(eth_data),
|
||||
'btc_reference': len(btc_data),
|
||||
'total_expected': 5 # 4 ETH + 1 BTC
|
||||
}
|
||||
}
|
||||
|
||||
# Create model-specific input data
|
||||
model_inputs = {
|
||||
'cnn_input': standardized_input,
|
||||
'rl_input': standardized_input,
|
||||
'generic_input': standardized_input,
|
||||
'standardized_input': standardized_input
|
||||
}
|
||||
|
||||
return model_inputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting standardized model input data: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime, symbol: str = None):
|
||||
"""Store last inference in memory and all inferences to database for future training"""
|
||||
@ -1711,30 +1652,36 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error getting model training data: {e}")
|
||||
return []
|
||||
|
||||
async def _trigger_model_training(self, symbol: str):
|
||||
"""Trigger training for models based on their last inference"""
|
||||
|
||||
async def _trigger_immediate_training_for_model(self, model_name: str, symbol: str):
|
||||
"""Trigger immediate training for a specific model with previous inference data"""
|
||||
try:
|
||||
if not self.training_enabled:
|
||||
logger.debug("Training disabled, skipping model training")
|
||||
if model_name not in self.last_inference:
|
||||
logger.debug(f"No previous inference data for {model_name}")
|
||||
return
|
||||
|
||||
# Check if we have any last inferences for any model
|
||||
if not self.last_inference:
|
||||
logger.debug("No inference data available for training")
|
||||
inference_record = self.last_inference[model_name]
|
||||
|
||||
# Skip if already evaluated
|
||||
if inference_record.get('outcome_evaluated', False):
|
||||
logger.debug(f"Skipping {model_name} - already evaluated")
|
||||
return
|
||||
|
||||
# Get current price for outcome evaluation
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
current_price = self._get_current_price(symbol)
|
||||
if current_price is None:
|
||||
logger.warning(f"Cannot get current price for {symbol}, skipping immediate training for {model_name}")
|
||||
return
|
||||
|
||||
# Train each model based on its last inference
|
||||
for model_name, last_inference_record in self.last_inference.items():
|
||||
if last_inference_record and not last_inference_record.get('outcome_evaluated', False):
|
||||
await self._evaluate_and_train_on_record(last_inference_record, current_price)
|
||||
logger.info(f"Triggering immediate training for {model_name} with current price: {current_price}")
|
||||
|
||||
# Evaluate the previous prediction and train the model immediately
|
||||
await self._evaluate_and_train_on_record(inference_record, current_price)
|
||||
|
||||
logger.info(f"Completed immediate training for {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model training for {symbol}: {e}")
|
||||
logger.error(f"Error in immediate training for {model_name}: {e}")
|
||||
|
||||
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
|
||||
"""Evaluate prediction outcome and train model"""
|
||||
@ -1963,15 +1910,16 @@ class TradingOrchestrator:
|
||||
except:
|
||||
return 50.0
|
||||
|
||||
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
|
||||
"""Get predictions from CNN model using FIFO queue data"""
|
||||
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str, base_data=None) -> List[Prediction]:
|
||||
"""Get predictions from CNN model using pre-built base data"""
|
||||
predictions = []
|
||||
try:
|
||||
# Use FIFO queue data instead of direct data provider calls
|
||||
base_data = self.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for CNN prediction: {symbol}")
|
||||
return predictions
|
||||
# Use pre-built base data if provided, otherwise build it
|
||||
if base_data is None:
|
||||
base_data = self.data_provider.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for CNN prediction: {symbol}")
|
||||
return predictions
|
||||
|
||||
# Use CNN adapter if available
|
||||
if hasattr(self, 'cnn_adapter') and self.cnn_adapter:
|
||||
@ -2016,10 +1964,9 @@ class TradingOrchestrator:
|
||||
logger.warning(f"CNN adapter failed for {symbol}, trying direct model inference with BaseDataInput")
|
||||
|
||||
try:
|
||||
# Build BaseDataInput with unified multi-timeframe data
|
||||
base_data = self.build_base_data_input(symbol)
|
||||
# Use the already available base_data (no need to rebuild)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for CNN fallback: {symbol}")
|
||||
logger.warning(f"No BaseDataInput available for CNN fallback: {symbol}")
|
||||
return predictions
|
||||
|
||||
# Convert to unified feature vector (7850 features)
|
||||
@ -2080,6 +2027,12 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"CNN fallback inference failed for {symbol}: {e}")
|
||||
# Don't continue with old timeframe-by-timeframe approach
|
||||
|
||||
# Trigger immediate training if previous inference data exists for this model
|
||||
if predictions and model.name in self.last_inference:
|
||||
logger.debug(f"Triggering immediate training for CNN model {model.name} with previous inference data")
|
||||
await self._trigger_immediate_training_for_model(model.name, symbol)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Orch: Error getting CNN predictions: {e}")
|
||||
return predictions
|
||||
@ -2087,20 +2040,21 @@ class TradingOrchestrator:
|
||||
# Note: Removed obsolete _augment_with_cob and _prepare_cnn_input methods
|
||||
# The unified CNN model now handles all timeframes and COB data internally through BaseDataInput
|
||||
|
||||
async def _get_rl_prediction(self, model: RLAgentInterface, symbol: str) -> Optional[Prediction]:
|
||||
"""Get prediction from RL agent using FIFO queue data"""
|
||||
async def _get_rl_prediction(self, model: RLAgentInterface, symbol: str, base_data=None) -> Optional[Prediction]:
|
||||
"""Get prediction from RL agent using pre-built base data"""
|
||||
try:
|
||||
# Use FIFO queue data to build consistent state
|
||||
base_data = self.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for RL prediction: {symbol}")
|
||||
return None
|
||||
# Use pre-built base data if provided, otherwise build it
|
||||
if base_data is None:
|
||||
base_data = self.data_provider.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for RL prediction: {symbol}")
|
||||
return None
|
||||
|
||||
# Convert BaseDataInput to RL state format
|
||||
state_features = base_data.get_feature_vector()
|
||||
|
||||
# Get current state for RL agent
|
||||
state = self._get_rl_state(symbol)
|
||||
# Get current state for RL agent using the pre-built base data
|
||||
state = self._get_rl_state(symbol, base_data)
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
@ -2166,20 +2120,26 @@ class TradingOrchestrator:
|
||||
q_values_to_pass = q_values_for_capture if q_values_for_capture is not None else []
|
||||
self.capture_dqn_prediction(symbol, action_idx, float(confidence), current_price, q_values_to_pass)
|
||||
|
||||
# Trigger immediate training if previous inference data exists for this model
|
||||
if prediction and model.name in self.last_inference:
|
||||
logger.debug(f"Triggering immediate training for RL model {model.name} with previous inference data")
|
||||
await self._trigger_immediate_training_for_model(model.name, symbol)
|
||||
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting RL prediction: {e}")
|
||||
return None
|
||||
|
||||
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
|
||||
"""Get prediction from generic model using unified BaseDataInput"""
|
||||
async def _get_generic_prediction(self, model: ModelInterface, symbol: str, base_data=None) -> Optional[Prediction]:
|
||||
"""Get prediction from generic model using pre-built base data"""
|
||||
try:
|
||||
# Use unified BaseDataInput approach instead of old timeframe-specific method
|
||||
base_data = self.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for generic prediction: {symbol}")
|
||||
return None
|
||||
# Use pre-built base data if provided, otherwise build it
|
||||
if base_data is None:
|
||||
base_data = self.data_provider.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for generic prediction: {symbol}")
|
||||
return None
|
||||
|
||||
# Convert to feature vector for generic models
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
@ -2237,14 +2197,15 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error getting generic prediction: {e}")
|
||||
return None
|
||||
|
||||
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get current state for RL agent using unified BaseDataInput"""
|
||||
def _get_rl_state(self, symbol: str, base_data=None) -> Optional[np.ndarray]:
|
||||
"""Get current state for RL agent using pre-built base data"""
|
||||
try:
|
||||
# Use unified BaseDataInput approach
|
||||
base_data = self.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}")
|
||||
return None
|
||||
# Use pre-built base data if provided, otherwise build it
|
||||
if base_data is None:
|
||||
base_data = self.data_provider.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}")
|
||||
return None
|
||||
|
||||
# Get unified feature vector (7850 features including all timeframes and COB data)
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
|
Reference in New Issue
Block a user