wip training and inference stats
This commit is contained in:
@ -512,18 +512,23 @@ class DataProvider:
|
||||
# Get raw ticks for the target second
|
||||
target_ticks = []
|
||||
|
||||
for tick in self.cob_raw_ticks[symbol]:
|
||||
tick_timestamp = tick['timestamp']
|
||||
# FIXED: Create a copy of the deque to avoid mutation during iteration
|
||||
if symbol in self.cob_raw_ticks:
|
||||
# Create a safe copy of the deque to iterate over
|
||||
ticks_copy = list(self.cob_raw_ticks[symbol])
|
||||
|
||||
# Handle both datetime and float timestamps
|
||||
if isinstance(tick_timestamp, datetime):
|
||||
tick_time = tick_timestamp.timestamp()
|
||||
else:
|
||||
tick_time = float(tick_timestamp)
|
||||
|
||||
# Check if tick is in target second
|
||||
if target_second <= tick_time < target_second + 1:
|
||||
target_ticks.append(tick)
|
||||
for tick in ticks_copy:
|
||||
tick_timestamp = tick['timestamp']
|
||||
|
||||
# Handle both datetime and float timestamps
|
||||
if isinstance(tick_timestamp, datetime):
|
||||
tick_time = tick_timestamp.timestamp()
|
||||
else:
|
||||
tick_time = float(tick_timestamp)
|
||||
|
||||
# Check if tick is in target second
|
||||
if target_second <= tick_time < target_second + 1:
|
||||
target_ticks.append(tick)
|
||||
|
||||
if not target_ticks:
|
||||
return
|
||||
@ -563,7 +568,14 @@ class DataProvider:
|
||||
current_imbalance = self._calculate_cob_imbalance(latest_cob, price_range)
|
||||
|
||||
# Get historical COB data for timeframe calculations
|
||||
historical_cob_data = list(self.cob_raw_ticks[symbol]) if symbol in self.cob_raw_ticks else []
|
||||
# FIXED: Create a safe copy to avoid deque mutation during iteration
|
||||
historical_cob_data = []
|
||||
if symbol in self.cob_raw_ticks:
|
||||
try:
|
||||
historical_cob_data = list(self.cob_raw_ticks[symbol])
|
||||
except Exception as e:
|
||||
logger.debug(f"Error copying COB raw ticks for {symbol}: {e}")
|
||||
historical_cob_data = []
|
||||
|
||||
# Calculate imbalances for different timeframes using COB data
|
||||
imbalances = {
|
||||
@ -4112,12 +4124,18 @@ class DataProvider:
|
||||
target_ticks = []
|
||||
|
||||
# Filter ticks for the target second
|
||||
for tick in self.cob_raw_ticks[symbol]:
|
||||
tick_time = tick.get('timestamp', 0)
|
||||
if isinstance(tick_time, (int, float)):
|
||||
tick_second = int(tick_time)
|
||||
if tick_second == target_second:
|
||||
target_ticks.append(tick)
|
||||
# FIXED: Create a safe copy to avoid deque mutation during iteration
|
||||
if symbol in self.cob_raw_ticks:
|
||||
try:
|
||||
ticks_copy = list(self.cob_raw_ticks[symbol])
|
||||
for tick in ticks_copy:
|
||||
tick_time = tick.get('timestamp', 0)
|
||||
if isinstance(tick_time, (int, float)):
|
||||
tick_second = int(tick_time)
|
||||
if tick_second == target_second:
|
||||
target_ticks.append(tick)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error copying COB raw ticks for {symbol}: {e}")
|
||||
|
||||
if not target_ticks:
|
||||
return
|
||||
|
@ -1125,7 +1125,7 @@ class MultiExchangeCOBProvider:
|
||||
)
|
||||
|
||||
# Store consolidated order book
|
||||
self.consolidated_order_books[symbol] = cob_snapshot
|
||||
self.current_order_book[symbol] = cob_snapshot
|
||||
self.realtime_snapshots[symbol].append(cob_snapshot)
|
||||
|
||||
# Update real-time statistics
|
||||
@ -1294,8 +1294,8 @@ class MultiExchangeCOBProvider:
|
||||
while self.is_streaming:
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
if symbol in self.consolidated_order_books:
|
||||
cob = self.consolidated_order_books[symbol]
|
||||
if symbol in self.current_order_book:
|
||||
cob = self.current_order_book[symbol]
|
||||
|
||||
# Notify bucket update callbacks
|
||||
for callback in self.bucket_update_callbacks:
|
||||
@ -1327,22 +1327,22 @@ class MultiExchangeCOBProvider:
|
||||
|
||||
def get_consolidated_orderbook(self, symbol: str) -> Optional[COBSnapshot]:
|
||||
"""Get current consolidated order book snapshot"""
|
||||
return self.consolidated_order_books.get(symbol)
|
||||
return self.current_order_book.get(symbol)
|
||||
|
||||
def get_price_buckets(self, symbol: str, bucket_count: int = 100) -> Optional[Dict]:
|
||||
"""Get fine-grain price buckets for a symbol"""
|
||||
if symbol not in self.consolidated_order_books:
|
||||
if symbol not in self.current_order_book:
|
||||
return None
|
||||
|
||||
cob = self.consolidated_order_books[symbol]
|
||||
cob = self.current_order_book[symbol]
|
||||
return cob.price_buckets
|
||||
|
||||
def get_exchange_breakdown(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get breakdown of liquidity by exchange"""
|
||||
if symbol not in self.consolidated_order_books:
|
||||
if symbol not in self.current_order_book:
|
||||
return None
|
||||
|
||||
cob = self.consolidated_order_books[symbol]
|
||||
cob = self.current_order_book[symbol]
|
||||
breakdown = {}
|
||||
|
||||
for exchange in cob.exchanges_active:
|
||||
@ -1386,10 +1386,10 @@ class MultiExchangeCOBProvider:
|
||||
|
||||
def get_market_depth_analysis(self, symbol: str, depth_levels: int = 20) -> Optional[Dict]:
|
||||
"""Get detailed market depth analysis"""
|
||||
if symbol not in self.consolidated_order_books:
|
||||
if symbol not in self.current_order_book:
|
||||
return None
|
||||
|
||||
cob = self.consolidated_order_books[symbol]
|
||||
cob = self.current_order_book[symbol]
|
||||
|
||||
# Analyze depth distribution
|
||||
bid_levels = cob.consolidated_bids[:depth_levels]
|
||||
|
@ -76,6 +76,61 @@ class Prediction:
|
||||
model_name: str # Name of the model that made this prediction
|
||||
metadata: Optional[Dict[str, Any]] = None # Additional model-specific data
|
||||
|
||||
@dataclass
|
||||
class ModelStatistics:
|
||||
"""Statistics for tracking model performance and inference metrics"""
|
||||
model_name: str
|
||||
last_inference_time: Optional[datetime] = None
|
||||
total_inferences: int = 0
|
||||
inference_rate_per_minute: float = 0.0
|
||||
inference_rate_per_second: float = 0.0
|
||||
current_loss: Optional[float] = None
|
||||
average_loss: Optional[float] = None
|
||||
best_loss: Optional[float] = None
|
||||
worst_loss: Optional[float] = None
|
||||
accuracy: Optional[float] = None
|
||||
last_prediction: Optional[str] = None
|
||||
last_confidence: Optional[float] = None
|
||||
inference_times: deque = field(default_factory=lambda: deque(maxlen=100)) # Last 100 inference times
|
||||
losses: deque = field(default_factory=lambda: deque(maxlen=100)) # Last 100 losses
|
||||
predictions_history: deque = field(default_factory=lambda: deque(maxlen=50)) # Last 50 predictions
|
||||
|
||||
def update_inference_stats(self, prediction: Optional[Prediction] = None, loss: Optional[float] = None):
|
||||
"""Update inference statistics"""
|
||||
current_time = datetime.now()
|
||||
|
||||
# Update inference timing
|
||||
self.last_inference_time = current_time
|
||||
self.total_inferences += 1
|
||||
self.inference_times.append(current_time)
|
||||
|
||||
# Calculate inference rates
|
||||
if len(self.inference_times) > 1:
|
||||
time_window = (self.inference_times[-1] - self.inference_times[0]).total_seconds()
|
||||
if time_window > 0:
|
||||
self.inference_rate_per_second = len(self.inference_times) / time_window
|
||||
self.inference_rate_per_minute = self.inference_rate_per_second * 60
|
||||
|
||||
# Update prediction stats
|
||||
if prediction:
|
||||
self.last_prediction = prediction.action
|
||||
self.last_confidence = prediction.confidence
|
||||
self.predictions_history.append({
|
||||
'action': prediction.action,
|
||||
'confidence': prediction.confidence,
|
||||
'timestamp': prediction.timestamp
|
||||
})
|
||||
|
||||
# Update loss stats
|
||||
if loss is not None:
|
||||
self.current_loss = loss
|
||||
self.losses.append(loss)
|
||||
|
||||
if self.losses:
|
||||
self.average_loss = sum(self.losses) / len(self.losses)
|
||||
self.best_loss = min(self.losses) if self.best_loss is None else min(self.best_loss, loss)
|
||||
self.worst_loss = max(self.losses) if self.worst_loss is None else max(self.worst_loss, loss)
|
||||
|
||||
@dataclass
|
||||
class TradingDecision:
|
||||
"""Final trading decision from the orchestrator"""
|
||||
@ -146,6 +201,9 @@ class TradingOrchestrator:
|
||||
self.recent_decisions: Dict[str, List[TradingDecision]] = {} # {symbol: List[TradingDecision]}
|
||||
self.model_performance: Dict[str, Dict[str, Any]] = {} # {model_name: {'correct': int, 'total': int, 'accuracy': float}}
|
||||
|
||||
# Model statistics tracking
|
||||
self.model_statistics: Dict[str, ModelStatistics] = {} # {model_name: ModelStatistics}
|
||||
|
||||
# Signal rate limiting to prevent spam
|
||||
self.last_signal_time: Dict[str, Dict[str, datetime]] = {} # {symbol: {action: datetime}}
|
||||
self.min_signal_interval = timedelta(seconds=30) # Minimum 30 seconds between same signals
|
||||
@ -619,6 +677,9 @@ class TradingOrchestrator:
|
||||
elif self.model_states[model_name]['best_loss'] is None or current_loss < self.model_states[model_name]['best_loss']:
|
||||
self.model_states[model_name]['best_loss'] = current_loss
|
||||
logger.debug(f"Updated {model_name} loss: current={current_loss:.4f}, best={self.model_states[model_name]['best_loss']:.4f}")
|
||||
|
||||
# Also update model statistics
|
||||
self._update_model_statistics(model_name, loss=current_loss)
|
||||
|
||||
def get_model_training_stats(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get current model training statistics for dashboard display"""
|
||||
@ -1112,6 +1173,11 @@ class TradingOrchestrator:
|
||||
if model.name not in self.model_performance:
|
||||
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||||
|
||||
# Initialize model statistics tracking
|
||||
if model.name not in self.model_statistics:
|
||||
self.model_statistics[model.name] = ModelStatistics(model_name=model.name)
|
||||
logger.debug(f"Initialized statistics tracking for {model.name}")
|
||||
|
||||
# Initialize last inference storage for this model
|
||||
if model.name not in self.last_inference:
|
||||
self.last_inference[model.name] = None
|
||||
@ -1133,6 +1199,8 @@ class TradingOrchestrator:
|
||||
del self.model_weights[model_name]
|
||||
if model_name in self.model_performance:
|
||||
del self.model_performance[model_name]
|
||||
if model_name in self.model_statistics:
|
||||
del self.model_statistics[model_name]
|
||||
|
||||
self._normalize_weights()
|
||||
logger.info(f"Unregistered {model_name} model")
|
||||
@ -1284,14 +1352,17 @@ class TradingOrchestrator:
|
||||
prediction = None
|
||||
model_input = base_data # Use the same base data for all models
|
||||
|
||||
# Track inference start time for statistics
|
||||
inference_start_time = time.time()
|
||||
|
||||
if isinstance(model, CNNModelInterface):
|
||||
# Get CNN predictions using the pre-built base data
|
||||
cnn_predictions = await self._get_cnn_predictions(model, symbol, base_data)
|
||||
predictions.extend(cnn_predictions)
|
||||
# Store input data for CNN - store for each prediction
|
||||
# Update statistics for CNN predictions
|
||||
if cnn_predictions:
|
||||
# Store inference data for each CNN prediction
|
||||
for cnn_pred in cnn_predictions:
|
||||
self._update_model_statistics(model_name, cnn_pred)
|
||||
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time, symbol)
|
||||
|
||||
elif isinstance(model, RLAgentInterface):
|
||||
@ -1300,6 +1371,8 @@ class TradingOrchestrator:
|
||||
if rl_prediction:
|
||||
predictions.append(rl_prediction)
|
||||
prediction = rl_prediction
|
||||
# Update statistics for RL prediction
|
||||
self._update_model_statistics(model_name, prediction)
|
||||
# Store input data for RL
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
|
||||
@ -1309,11 +1382,16 @@ class TradingOrchestrator:
|
||||
if generic_prediction:
|
||||
predictions.append(generic_prediction)
|
||||
prediction = generic_prediction
|
||||
# Update statistics for generic prediction
|
||||
self._update_model_statistics(model_name, prediction)
|
||||
# Store input data for generic model
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||
# Still update statistics for failed inference
|
||||
if model_name in self.model_statistics:
|
||||
self.model_statistics[model_name].update_inference_stats()
|
||||
continue
|
||||
|
||||
|
||||
@ -1323,6 +1401,88 @@ class TradingOrchestrator:
|
||||
|
||||
return predictions
|
||||
|
||||
def _update_model_statistics(self, model_name: str, prediction: Optional[Prediction] = None, loss: Optional[float] = None):
|
||||
"""Update statistics for a specific model"""
|
||||
try:
|
||||
if model_name not in self.model_statistics:
|
||||
self.model_statistics[model_name] = ModelStatistics(model_name=model_name)
|
||||
|
||||
# Update the statistics
|
||||
self.model_statistics[model_name].update_inference_stats(prediction, loss)
|
||||
|
||||
# Log statistics periodically (every 10 inferences)
|
||||
stats = self.model_statistics[model_name]
|
||||
if stats.total_inferences % 10 == 0:
|
||||
logger.debug(f"Model {model_name} stats: {stats.total_inferences} inferences, "
|
||||
f"{stats.inference_rate_per_minute:.1f}/min, "
|
||||
f"last: {stats.last_prediction} ({stats.last_confidence:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating statistics for {model_name}: {e}")
|
||||
|
||||
def get_model_statistics(self, model_name: Optional[str] = None) -> Union[Dict[str, ModelStatistics], ModelStatistics, None]:
|
||||
"""Get statistics for a specific model or all models"""
|
||||
try:
|
||||
if model_name:
|
||||
return self.model_statistics.get(model_name)
|
||||
else:
|
||||
return self.model_statistics.copy()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model statistics: {e}")
|
||||
return None
|
||||
|
||||
def get_model_statistics_summary(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get a summary of all model statistics in a serializable format"""
|
||||
try:
|
||||
summary = {}
|
||||
for model_name, stats in self.model_statistics.items():
|
||||
summary[model_name] = {
|
||||
'last_inference_time': stats.last_inference_time.isoformat() if stats.last_inference_time else None,
|
||||
'total_inferences': stats.total_inferences,
|
||||
'inference_rate_per_minute': round(stats.inference_rate_per_minute, 2),
|
||||
'inference_rate_per_second': round(stats.inference_rate_per_second, 4),
|
||||
'current_loss': round(stats.current_loss, 6) if stats.current_loss is not None else None,
|
||||
'average_loss': round(stats.average_loss, 6) if stats.average_loss is not None else None,
|
||||
'best_loss': round(stats.best_loss, 6) if stats.best_loss is not None else None,
|
||||
'worst_loss': round(stats.worst_loss, 6) if stats.worst_loss is not None else None,
|
||||
'accuracy': round(stats.accuracy, 4) if stats.accuracy is not None else None,
|
||||
'last_prediction': stats.last_prediction,
|
||||
'last_confidence': round(stats.last_confidence, 4) if stats.last_confidence is not None else None,
|
||||
'recent_predictions_count': len(stats.predictions_history),
|
||||
'recent_losses_count': len(stats.losses)
|
||||
}
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model statistics summary: {e}")
|
||||
return {}
|
||||
|
||||
def log_model_statistics(self, detailed: bool = False):
|
||||
"""Log current model statistics for monitoring"""
|
||||
try:
|
||||
if not self.model_statistics:
|
||||
logger.info("No model statistics available")
|
||||
return
|
||||
|
||||
logger.info("=== Model Statistics Summary ===")
|
||||
for model_name, stats in self.model_statistics.items():
|
||||
if detailed:
|
||||
logger.info(f"{model_name}:")
|
||||
logger.info(f" Total inferences: {stats.total_inferences}")
|
||||
logger.info(f" Inference rate: {stats.inference_rate_per_minute:.1f}/min ({stats.inference_rate_per_second:.3f}/sec)")
|
||||
logger.info(f" Last inference: {stats.last_inference_time}")
|
||||
logger.info(f" Current loss: {stats.current_loss:.6f}" if stats.current_loss else " Current loss: N/A")
|
||||
logger.info(f" Average loss: {stats.average_loss:.6f}" if stats.average_loss else " Average loss: N/A")
|
||||
logger.info(f" Best loss: {stats.best_loss:.6f}" if stats.best_loss else " Best loss: N/A")
|
||||
logger.info(f" Last prediction: {stats.last_prediction} ({stats.last_confidence:.3f})" if stats.last_prediction else " Last prediction: N/A")
|
||||
else:
|
||||
rate_str = f"{stats.inference_rate_per_minute:.1f}/min"
|
||||
loss_str = f"{stats.current_loss:.4f}" if stats.current_loss else "N/A"
|
||||
pred_str = f"{stats.last_prediction}({stats.last_confidence:.2f})" if stats.last_prediction else "N/A"
|
||||
logger.info(f"{model_name}: {stats.total_inferences} inferences, {rate_str}, loss={loss_str}, last={pred_str}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging model statistics: {e}")
|
||||
|
||||
|
||||
|
||||
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime, symbol: str = None):
|
||||
@ -1831,7 +1991,7 @@ class TradingOrchestrator:
|
||||
return (1.0 if simple_correct else -0.5, simple_correct)
|
||||
|
||||
async def _train_model_on_outcome(self, record: Dict, was_correct: bool, price_change_pct: float, sophisticated_reward: float = None):
|
||||
"""Train specific model based on prediction outcome with sophisticated reward system"""
|
||||
"""Universal training for any model based on prediction outcome with sophisticated reward system"""
|
||||
try:
|
||||
model_name = record['model_name']
|
||||
model_input = record['model_input']
|
||||
@ -1840,63 +2000,269 @@ class TradingOrchestrator:
|
||||
# Use sophisticated reward if provided, otherwise fallback to simple reward
|
||||
reward = sophisticated_reward if sophisticated_reward is not None else (1.0 if was_correct else -0.5)
|
||||
|
||||
# Train RL models
|
||||
if 'dqn' in model_name.lower() and self.rl_agent:
|
||||
if hasattr(self.rl_agent, 'add_experience'):
|
||||
action_idx = ['SELL', 'HOLD', 'BUY'].index(prediction['action'])
|
||||
self.rl_agent.add_experience(
|
||||
state=model_input,
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
next_state=model_input, # Simplified
|
||||
done=True
|
||||
)
|
||||
logger.debug(f"Added RL training experience: reward={reward:.3f} (sophisticated)")
|
||||
|
||||
# Trigger training and update model state if loss is available
|
||||
if hasattr(self.rl_agent, 'train') and len(getattr(self.rl_agent, 'memory', [])) > 32:
|
||||
training_loss = self.rl_agent.train()
|
||||
if training_loss is not None:
|
||||
self.update_model_loss('dqn', training_loss)
|
||||
logger.debug(f"Updated DQN model state: loss={training_loss:.4f}")
|
||||
|
||||
# Also check for recent losses and update model state
|
||||
if hasattr(self.rl_agent, 'losses') and len(self.rl_agent.losses) > 0:
|
||||
recent_loss = self.rl_agent.losses[-1] # Most recent loss
|
||||
self.update_model_loss('dqn', recent_loss)
|
||||
logger.debug(f"Updated DQN model state from recent loss: {recent_loss:.4f}")
|
||||
# Get the actual model from registry
|
||||
model_interface = None
|
||||
if hasattr(self, 'model_registry') and self.model_registry:
|
||||
model_interface = self.model_registry.models.get(model_name)
|
||||
logger.debug(f"Found model interface {model_name} in registry: {type(model_interface).__name__}")
|
||||
else:
|
||||
logger.debug(f"No model registry available for {model_name}")
|
||||
|
||||
# Train CNN models using adapter
|
||||
elif 'cnn' in model_name.lower() and hasattr(self, 'cnn_adapter') and self.cnn_adapter:
|
||||
# Use the adapter's add_training_sample method
|
||||
actual_action = prediction['action']
|
||||
self.cnn_adapter.add_training_sample(record['symbol'], actual_action, reward)
|
||||
logger.debug(f"Added CNN training sample: action={actual_action}, reward={reward:.3f} (sophisticated)")
|
||||
|
||||
# Trigger training if we have enough samples
|
||||
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
|
||||
training_results = self.cnn_adapter.train(epochs=1)
|
||||
logger.debug(f"CNN training results: {training_results}")
|
||||
|
||||
# Update model state with training loss
|
||||
if training_results and 'loss' in training_results:
|
||||
current_loss = training_results['loss']
|
||||
self.update_model_loss('cnn', current_loss)
|
||||
logger.debug(f"Updated CNN model state: loss={current_loss:.4f}")
|
||||
if not model_interface:
|
||||
logger.warning(f"Model {model_name} not found in registry, skipping training")
|
||||
return
|
||||
|
||||
# Fallback for raw CNN model
|
||||
elif 'cnn' in model_name.lower() and self.cnn_model and hasattr(self.cnn_model, 'train_on_outcome'):
|
||||
target = 1 if was_correct else 0
|
||||
loss = self.cnn_model.train_on_outcome(model_input, target)
|
||||
logger.debug(f"Trained CNN on outcome: target={target}")
|
||||
# Get the underlying model from the interface
|
||||
underlying_model = getattr(model_interface, 'model', None)
|
||||
if not underlying_model:
|
||||
logger.warning(f"No underlying model found for {model_name}, skipping training")
|
||||
return
|
||||
|
||||
logger.debug(f"Training {model_name} with reward={reward:.3f} (was_correct={was_correct})")
|
||||
logger.debug(f"Model interface type: {type(model_interface).__name__}")
|
||||
logger.debug(f"Underlying model type: {type(underlying_model).__name__}")
|
||||
|
||||
# Debug: Log available training methods on both interface and underlying model
|
||||
interface_methods = []
|
||||
underlying_methods = []
|
||||
|
||||
for method in ['train_on_outcome', 'add_experience', 'remember', 'replay', 'add_training_sample', 'train', 'train_with_reward', 'update_loss']:
|
||||
if hasattr(model_interface, method):
|
||||
interface_methods.append(method)
|
||||
if hasattr(underlying_model, method):
|
||||
underlying_methods.append(method)
|
||||
|
||||
logger.debug(f"Available methods on interface: {interface_methods}")
|
||||
logger.debug(f"Available methods on underlying model: {underlying_methods}")
|
||||
|
||||
training_success = False
|
||||
|
||||
# Try training based on model type and available methods
|
||||
if isinstance(model_interface, RLAgentInterface):
|
||||
# RL Agent Training
|
||||
training_success = await self._train_rl_model(underlying_model, model_name, model_input, prediction, reward)
|
||||
|
||||
# Update model state if loss is returned
|
||||
if loss is not None:
|
||||
self.update_model_loss('cnn', loss)
|
||||
logger.debug(f"Updated CNN model state: loss={loss:.4f}")
|
||||
elif isinstance(model_interface, CNNModelInterface):
|
||||
# CNN Model Training
|
||||
training_success = await self._train_cnn_model(underlying_model, model_name, record, prediction, reward)
|
||||
|
||||
elif 'extrema' in model_name.lower():
|
||||
# Extrema Trainer - doesn't need traditional training
|
||||
logger.debug(f"Extrema trainer {model_name} doesn't require outcome-based training")
|
||||
training_success = True
|
||||
|
||||
elif 'cob_rl' in model_name.lower():
|
||||
# COB RL Model Training
|
||||
training_success = await self._train_cob_rl_model(underlying_model, model_name, model_input, prediction, reward)
|
||||
|
||||
else:
|
||||
# Generic model training
|
||||
training_success = await self._train_generic_model(underlying_model, model_name, model_input, prediction, reward)
|
||||
|
||||
if not training_success:
|
||||
logger.warning(f"Training failed for {model_name} - trying fallback methods")
|
||||
# Try fallback training methods
|
||||
training_success = await self._train_model_fallback(model_name, underlying_model, model_input, prediction, reward)
|
||||
|
||||
if training_success:
|
||||
logger.debug(f"Successfully trained {model_name}")
|
||||
else:
|
||||
logger.warning(f"All training methods failed for {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training model on outcome: {e}")
|
||||
logger.error(f"Error training model {model_name} on outcome: {e}")
|
||||
|
||||
async def _train_rl_model(self, model, model_name: str, model_input, prediction: Dict, reward: float) -> bool:
|
||||
"""Train RL model (DQN) with experience replay"""
|
||||
try:
|
||||
# Convert prediction action to action index
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
if prediction['action'] not in action_names:
|
||||
logger.warning(f"Invalid action {prediction['action']} for RL training")
|
||||
return False
|
||||
|
||||
action_idx = action_names.index(prediction['action'])
|
||||
|
||||
# Ensure model_input is numpy array
|
||||
if hasattr(model_input, 'get_feature_vector'):
|
||||
state = model_input.get_feature_vector()
|
||||
elif isinstance(model_input, np.ndarray):
|
||||
state = model_input
|
||||
else:
|
||||
logger.warning(f"Cannot convert model_input to state for RL training: {type(model_input)}")
|
||||
return False
|
||||
|
||||
# Add experience to memory
|
||||
if hasattr(model, 'remember'):
|
||||
model.remember(
|
||||
state=state,
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
next_state=state, # Simplified - using same state
|
||||
done=True
|
||||
)
|
||||
logger.debug(f"Added experience to {model_name}: action={prediction['action']}, reward={reward:.3f}")
|
||||
|
||||
# Trigger training if enough experiences
|
||||
memory_size = len(getattr(model, 'memory', []))
|
||||
if memory_size >= model.batch_size:
|
||||
logger.debug(f"Training {model_name} with {memory_size} experiences")
|
||||
training_loss = model.replay()
|
||||
if training_loss is not None and training_loss > 0:
|
||||
self.update_model_loss(model_name, training_loss)
|
||||
logger.debug(f"RL training completed for {model_name}: loss={training_loss:.4f}")
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"Not enough experiences for {model_name}: {memory_size}/{model.batch_size}")
|
||||
return True # Experience added successfully, training will happen later
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training RL model {model_name}: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cnn_model(self, model, model_name: str, record: Dict, prediction: Dict, reward: float) -> bool:
|
||||
"""Train CNN model with training samples"""
|
||||
try:
|
||||
# Check if we have CNN adapter (preferred method)
|
||||
if hasattr(self, 'cnn_adapter') and self.cnn_adapter and 'cnn' in model_name.lower():
|
||||
symbol = record.get('symbol', 'ETH/USDT')
|
||||
actual_action = prediction['action']
|
||||
|
||||
self.cnn_adapter.add_training_sample(symbol, actual_action, reward)
|
||||
logger.debug(f"Added training sample to CNN adapter: action={actual_action}, reward={reward:.3f}")
|
||||
|
||||
# Check if we have enough samples to train
|
||||
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
|
||||
logger.debug(f"Training CNN with {len(self.cnn_adapter.training_data)} samples")
|
||||
training_results = self.cnn_adapter.train(epochs=1)
|
||||
if training_results and 'loss' in training_results:
|
||||
current_loss = training_results['loss']
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
logger.debug(f"CNN training completed: loss={current_loss:.4f}")
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"Not enough samples for CNN training: {len(self.cnn_adapter.training_data)}/{self.cnn_adapter.batch_size}")
|
||||
return True # Sample added successfully
|
||||
|
||||
# Try direct model training methods
|
||||
elif hasattr(model, 'add_training_sample'):
|
||||
symbol = record.get('symbol', 'ETH/USDT')
|
||||
actual_action = prediction['action']
|
||||
model.add_training_sample(symbol, actual_action, reward)
|
||||
logger.debug(f"Added training sample to {model_name}: action={actual_action}, reward={reward:.3f}")
|
||||
|
||||
# Trigger training if batch size is met
|
||||
if hasattr(model, 'train') and hasattr(model, 'training_data') and hasattr(model, 'batch_size'):
|
||||
if len(model.training_data) >= model.batch_size:
|
||||
training_results = model.train(epochs=1)
|
||||
if training_results and 'loss' in training_results:
|
||||
current_loss = training_results['loss']
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
logger.debug(f"CNN training completed: loss={current_loss:.4f}")
|
||||
return True
|
||||
return True # Sample added successfully
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN model {model_name}: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cob_rl_model(self, model, model_name: str, model_input, prediction: Dict, reward: float) -> bool:
|
||||
"""Train COB RL model"""
|
||||
try:
|
||||
# COB RL models might have specific training methods
|
||||
if hasattr(model, 'add_experience'):
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
action_idx = action_names.index(prediction['action'])
|
||||
|
||||
# Ensure model_input is in correct format
|
||||
if hasattr(model_input, 'get_feature_vector'):
|
||||
state = model_input.get_feature_vector()
|
||||
elif isinstance(model_input, np.ndarray):
|
||||
state = model_input
|
||||
else:
|
||||
logger.warning(f"Cannot convert model_input for COB RL training: {type(model_input)}")
|
||||
return False
|
||||
|
||||
model.add_experience(
|
||||
state=state,
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
next_state=state,
|
||||
done=True
|
||||
)
|
||||
logger.debug(f"Added experience to COB RL model: action={prediction['action']}, reward={reward:.3f}")
|
||||
|
||||
# Trigger training if enough experiences
|
||||
if hasattr(model, 'train') and hasattr(model, 'memory'):
|
||||
memory_size = len(model.memory) if hasattr(model.memory, '__len__') else 0
|
||||
if memory_size >= getattr(model, 'batch_size', 32):
|
||||
training_loss = model.train()
|
||||
if training_loss is not None:
|
||||
self.update_model_loss(model_name, training_loss)
|
||||
logger.debug(f"COB RL training completed: loss={training_loss:.4f}")
|
||||
return True
|
||||
return True # Experience added successfully
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training COB RL model {model_name}: {e}")
|
||||
return False
|
||||
|
||||
async def _train_generic_model(self, model, model_name: str, model_input, prediction: Dict, reward: float) -> bool:
|
||||
"""Train generic model with available methods"""
|
||||
try:
|
||||
# Try various generic training methods
|
||||
if hasattr(model, 'train_with_reward'):
|
||||
loss = model.train_with_reward(model_input, reward)
|
||||
if loss is not None:
|
||||
self.update_model_loss(model_name, loss)
|
||||
logger.debug(f"Generic training completed for {model_name}: loss={loss:.4f}")
|
||||
return True
|
||||
|
||||
elif hasattr(model, 'update_loss'):
|
||||
model.update_loss(reward)
|
||||
logger.debug(f"Updated loss for {model_name}: reward={reward:.3f}")
|
||||
return True
|
||||
|
||||
elif hasattr(model, 'train_on_outcome'):
|
||||
target = 1 if reward > 0 else 0
|
||||
loss = model.train_on_outcome(model_input, target)
|
||||
if loss is not None:
|
||||
self.update_model_loss(model_name, loss)
|
||||
logger.debug(f"Outcome training completed for {model_name}: loss={loss:.4f}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training generic model {model_name}: {e}")
|
||||
return False
|
||||
|
||||
async def _train_model_fallback(self, model_name: str, model, model_input, prediction: Dict, reward: float) -> bool:
|
||||
"""Fallback training methods for models that don't fit standard patterns"""
|
||||
try:
|
||||
# Try to access direct model instances for legacy support
|
||||
if 'dqn' in model_name.lower() and hasattr(self, 'rl_agent') and self.rl_agent:
|
||||
return await self._train_rl_model(self.rl_agent, model_name, model_input, prediction, reward)
|
||||
|
||||
elif 'cnn' in model_name.lower() and hasattr(self, 'cnn_model') and self.cnn_model:
|
||||
# Create a fake record for CNN training
|
||||
fake_record = {'symbol': 'ETH/USDT', 'model_input': model_input}
|
||||
return await self._train_cnn_model(self.cnn_model, model_name, fake_record, prediction, reward)
|
||||
|
||||
elif 'cob' in model_name.lower() and hasattr(self, 'cob_rl_agent') and self.cob_rl_agent:
|
||||
return await self._train_cob_rl_model(self.cob_rl_agent, model_name, model_input, prediction, reward)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fallback training for {model_name}: {e}")
|
||||
return False
|
||||
|
||||
def _calculate_rsi(self, prices: pd.Series, period: int = 14) -> float:
|
||||
"""Calculate RSI indicator"""
|
||||
|
Reference in New Issue
Block a user