wip training

This commit is contained in:
Dobromir Popov
2025-07-24 15:27:32 +03:00
parent b3edd21f1b
commit fa07265a16
4 changed files with 554 additions and 5 deletions

View File

@ -27,6 +27,7 @@ import shutil
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from .config import get_config
from .data_provider import DataProvider
@ -202,9 +203,18 @@ class TradingOrchestrator:
# Training tracking
self.last_trained_symbols: Dict[str, datetime] = {}
# INFERENCE DATA STORAGE - Store model inputs and outputs for training
self.inference_history: Dict[str, deque] = {} # {symbol: deque of inference records}
self.max_inference_history = 1000 # Keep last 1000 inference records per symbol
# Initialize inference history for each symbol
for symbol in self.symbols:
self.inference_history[symbol] = deque(maxlen=self.max_inference_history)
# ENHANCED: Real-time Training System Integration
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
# Enable training by default - don't depend on external training system
self.training_enabled: bool = enhanced_rl_training
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
@ -1023,34 +1033,409 @@ class TradingOrchestrator:
return None
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
"""Get predictions from all registered models"""
"""Get predictions from all registered models with input data storage"""
predictions = []
current_time = datetime.now()
# Collect input data for all models
input_data = await self._collect_model_input_data(symbol)
for model_name, model in self.model_registry.models.items():
try:
prediction = None
model_input = None
if isinstance(model, CNNModelInterface):
# Get CNN predictions for each timeframe
cnn_predictions = await self._get_cnn_predictions(model, symbol)
predictions.extend(cnn_predictions)
# Store input data for CNN
model_input = input_data.get('cnn_input')
elif isinstance(model, RLAgentInterface):
# Get RL prediction
rl_prediction = await self._get_rl_prediction(model, symbol)
if rl_prediction:
predictions.append(rl_prediction)
prediction = rl_prediction
# Store input data for RL
model_input = input_data.get('rl_input')
else:
# Generic model interface
generic_prediction = await self._get_generic_prediction(model, symbol)
if generic_prediction:
predictions.append(generic_prediction)
prediction = generic_prediction
# Store input data for generic model
model_input = input_data.get('generic_input')
# Store inference data for training
if prediction and model_input is not None:
self._store_inference_data(symbol, model_name, model_input, prediction, current_time)
except Exception as e:
logger.error(f"Error getting prediction from {model_name}: {e}")
continue
# Trigger training based on previous inference data
await self._trigger_model_training(symbol)
return predictions
async def _collect_model_input_data(self, symbol: str) -> Dict[str, Any]:
"""Collect comprehensive input data for all models"""
try:
input_data = {}
# Get current market data from data provider
current_price = self.data_provider.get_current_price(symbol)
# Collect OHLCV data for multiple timeframes
ohlcv_data = {}
timeframes = ['1s', '1m', '1h', '1d']
for tf in timeframes:
df = self.data_provider.get_historical_data(symbol, tf, limit=300)
if df is not None and not df.empty:
ohlcv_data[tf] = df
# Collect COB data if available
cob_data = self.get_cob_snapshot(symbol)
# Collect technical indicators
technical_indicators = {}
if '1h' in ohlcv_data:
df = ohlcv_data['1h']
if len(df) > 20:
technical_indicators['sma_20'] = df['close'].rolling(20).mean().iloc[-1]
technical_indicators['rsi'] = self._calculate_rsi(df['close'])
# Prepare CNN input
cnn_input = self._prepare_cnn_input_data(ohlcv_data, cob_data, technical_indicators)
# Prepare RL input
rl_input = self._prepare_rl_input_data(ohlcv_data, cob_data, technical_indicators)
# Prepare generic input
generic_input = {
'symbol': symbol,
'current_price': current_price,
'ohlcv_data': ohlcv_data,
'cob_data': cob_data,
'technical_indicators': technical_indicators
}
input_data = {
'cnn_input': cnn_input,
'rl_input': rl_input,
'generic_input': generic_input,
'timestamp': datetime.now(),
'symbol': symbol
}
return input_data
except Exception as e:
logger.error(f"Error collecting model input data for {symbol}: {e}")
return {}
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
"""Prepare standardized input data for CNN models"""
try:
# Create feature matrix from OHLCV data
features = []
# Add OHLCV features for each timeframe
for tf in ['1s', '1m', '1h', '1d']:
if tf in ohlcv_data and not ohlcv_data[tf].empty:
df = ohlcv_data[tf].tail(50) # Last 50 bars
features.extend([
df['close'].pct_change().fillna(0).values,
df['volume'].values / df['volume'].max() if df['volume'].max() > 0 else np.zeros(len(df))
])
# Add technical indicators
for key, value in technical_indicators.items():
if not np.isnan(value):
features.append([value])
# Flatten and pad/truncate to standard size
if features:
feature_array = np.concatenate([np.array(f).flatten() for f in features])
# Pad or truncate to 300 features
if len(feature_array) < 300:
feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant')
else:
feature_array = feature_array[:300]
return feature_array.reshape(1, -1)
else:
return np.zeros((1, 300))
except Exception as e:
logger.error(f"Error preparing CNN input data: {e}")
return np.zeros((1, 300))
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
"""Prepare standardized input data for RL models"""
try:
# Create state representation
state_features = []
# Add price and volume features
if '1m' in ohlcv_data and not ohlcv_data['1m'].empty:
df = ohlcv_data['1m'].tail(20)
state_features.extend([
df['close'].pct_change().fillna(0).values,
df['volume'].pct_change().fillna(0).values,
(df['high'] - df['low']) / df['close'] # Volatility proxy
])
# Add technical indicators
for key, value in technical_indicators.items():
if not np.isnan(value):
state_features.append(value)
# Flatten and standardize size
if state_features:
state_array = np.concatenate([np.array(f).flatten() for f in state_features])
# Pad or truncate to expected RL state size
expected_size = 100 # Adjust based on your RL model
if len(state_array) < expected_size:
state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant')
else:
state_array = state_array[:expected_size]
return state_array
else:
return np.zeros(100)
except Exception as e:
logger.error(f"Error preparing RL input data: {e}")
return np.zeros(100)
def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
"""Store comprehensive inference data for future training with persistent storage"""
try:
# Get current market context for complete replay capability
current_price = self.data_provider.get_current_price(symbol)
# Create comprehensive inference record with ALL data needed for model replay
inference_record = {
'timestamp': timestamp,
'symbol': symbol,
'model_name': model_name,
'current_price': current_price,
# Complete model input data
'model_input': {
'raw_input': model_input,
'input_shape': model_input.shape if hasattr(model_input, 'shape') else None,
'input_type': str(type(model_input))
},
# Complete prediction data
'prediction': {
'action': prediction.action,
'confidence': prediction.confidence,
'probabilities': prediction.probabilities,
'timeframe': prediction.timeframe
},
# Market context at prediction time
'market_context': {
'price': current_price,
'timestamp': timestamp.isoformat(),
'symbol': symbol
},
# Model metadata
'metadata': {
'model_metadata': prediction.metadata or {},
'orchestrator_state': {
'confidence_threshold': self.confidence_threshold,
'training_enabled': self.training_enabled
}
},
# Training outcome (will be filled later)
'training_outcome': None,
'outcome_evaluated': False
}
# Store in memory (inference history)
if symbol in self.inference_history:
self.inference_history[symbol].append(inference_record)
logger.debug(f"Stored inference data for {model_name} on {symbol}")
# Persistent storage to disk (for long-term training data)
self._save_inference_to_disk(inference_record)
except Exception as e:
logger.error(f"Error storing inference data: {e}")
def _save_inference_to_disk(self, inference_record: Dict):
"""Save inference record to persistent storage"""
try:
# Create inference data directory
inference_dir = Path("training_data/inference_history")
inference_dir.mkdir(parents=True, exist_ok=True)
# Create filename with timestamp and model name
timestamp_str = inference_record['timestamp'].strftime('%Y%m%d_%H%M%S')
filename = f"{inference_record['symbol']}_{inference_record['model_name']}_{timestamp_str}.json"
filepath = inference_dir / filename
# Convert numpy arrays to lists for JSON serialization
serializable_record = self._make_json_serializable(inference_record)
# Save to JSON file
with open(filepath, 'w') as f:
json.dump(serializable_record, f, indent=2)
logger.debug(f"Saved inference record to disk: {filepath}")
except Exception as e:
logger.error(f"Error saving inference to disk: {e}")
def _make_json_serializable(self, obj):
"""Convert object to JSON-serializable format"""
if isinstance(obj, dict):
return {k: self._make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._make_json_serializable(item) for item in obj]
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (np.integer, np.floating)):
return obj.item()
elif isinstance(obj, datetime):
return obj.isoformat()
else:
return obj
async def _trigger_model_training(self, symbol: str):
"""Trigger training for models based on previous inference data"""
try:
if not self.training_enabled or symbol not in self.inference_history:
return
# Get recent inference records
recent_records = list(self.inference_history[symbol])
if len(recent_records) < 2:
return # Need at least 2 records to compare
# Get current price for outcome evaluation
current_price = self.data_provider.get_current_price(symbol)
if current_price is None:
return
# Process records that are old enough to evaluate outcomes
cutoff_time = datetime.now() - timedelta(minutes=5) # 5 minutes ago
for record in recent_records:
if record['timestamp'] < cutoff_time:
await self._evaluate_and_train_on_record(record, current_price)
except Exception as e:
logger.error(f"Error triggering model training for {symbol}: {e}")
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
"""Evaluate prediction outcome and train model"""
try:
model_name = record['model_name']
prediction = record['prediction']
timestamp = record['timestamp']
# Calculate price change since prediction
# This is a simplified outcome evaluation - you might want to make it more sophisticated
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
# Get historical price at prediction time (simplified)
symbol = record['symbol']
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
if historical_data is None or historical_data.empty:
return
# Find closest price to prediction timestamp
prediction_price = historical_data['close'].iloc[-1] # Simplified
price_change_pct = (current_price - prediction_price) / prediction_price * 100
# Determine if prediction was correct
predicted_action = prediction['action']
was_correct = False
if predicted_action == 'BUY' and price_change_pct > 0.1: # Price went up
was_correct = True
elif predicted_action == 'SELL' and price_change_pct < -0.1: # Price went down
was_correct = True
elif predicted_action == 'HOLD' and abs(price_change_pct) < 0.1: # Price stayed stable
was_correct = True
# Update model performance tracking
if model_name not in self.model_performance:
self.model_performance[model_name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
self.model_performance[model_name]['total'] += 1
if was_correct:
self.model_performance[model_name]['correct'] += 1
self.model_performance[model_name]['accuracy'] = (
self.model_performance[model_name]['correct'] /
self.model_performance[model_name]['total']
)
# Train the specific model based on outcome
await self._train_model_on_outcome(record, was_correct, price_change_pct)
logger.debug(f"Evaluated {model_name} prediction: {'' if was_correct else ''} "
f"({prediction['action']}, {price_change_pct:.2f}% change)")
except Exception as e:
logger.error(f"Error evaluating and training on record: {e}")
async def _train_model_on_outcome(self, record: Dict, was_correct: bool, price_change_pct: float):
"""Train specific model based on prediction outcome"""
try:
model_name = record['model_name']
model_input = record['model_input']
prediction = record['prediction']
# Create training signal based on outcome
reward = 1.0 if was_correct else -0.5
# Train RL models
if 'dqn' in model_name.lower() and self.rl_agent:
if hasattr(self.rl_agent, 'add_experience'):
action_idx = ['SELL', 'HOLD', 'BUY'].index(prediction['action'])
self.rl_agent.add_experience(
state=model_input,
action=action_idx,
reward=reward,
next_state=model_input, # Simplified
done=True
)
logger.debug(f"Added RL training experience: reward={reward}")
# Train CNN models
elif 'cnn' in model_name.lower() and self.cnn_model:
if hasattr(self.cnn_model, 'train_on_outcome'):
target = 1 if was_correct else 0
self.cnn_model.train_on_outcome(model_input, target)
logger.debug(f"Trained CNN on outcome: target={target}")
except Exception as e:
logger.error(f"Error training model on outcome: {e}")
def _calculate_rsi(self, prices: pd.Series, period: int = 14) -> float:
"""Calculate RSI indicator"""
try:
delta = prices.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
return rsi.iloc[-1] if not rsi.empty else 50.0
except:
return 50.0
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
"""Get predictions from CNN model for all timeframes with enhanced COB features"""
predictions = []