wip training
This commit is contained in:
@ -27,6 +27,7 @@ import shutil
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import pandas as pd
|
||||
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider
|
||||
@ -202,9 +203,18 @@ class TradingOrchestrator:
|
||||
# Training tracking
|
||||
self.last_trained_symbols: Dict[str, datetime] = {}
|
||||
|
||||
# INFERENCE DATA STORAGE - Store model inputs and outputs for training
|
||||
self.inference_history: Dict[str, deque] = {} # {symbol: deque of inference records}
|
||||
self.max_inference_history = 1000 # Keep last 1000 inference records per symbol
|
||||
|
||||
# Initialize inference history for each symbol
|
||||
for symbol in self.symbols:
|
||||
self.inference_history[symbol] = deque(maxlen=self.max_inference_history)
|
||||
|
||||
# ENHANCED: Real-time Training System Integration
|
||||
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
|
||||
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
|
||||
# Enable training by default - don't depend on external training system
|
||||
self.training_enabled: bool = enhanced_rl_training
|
||||
|
||||
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
|
||||
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
|
||||
@ -1023,34 +1033,409 @@ class TradingOrchestrator:
|
||||
return None
|
||||
|
||||
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||||
"""Get predictions from all registered models"""
|
||||
"""Get predictions from all registered models with input data storage"""
|
||||
predictions = []
|
||||
current_time = datetime.now()
|
||||
|
||||
# Collect input data for all models
|
||||
input_data = await self._collect_model_input_data(symbol)
|
||||
|
||||
for model_name, model in self.model_registry.models.items():
|
||||
try:
|
||||
prediction = None
|
||||
model_input = None
|
||||
|
||||
if isinstance(model, CNNModelInterface):
|
||||
# Get CNN predictions for each timeframe
|
||||
cnn_predictions = await self._get_cnn_predictions(model, symbol)
|
||||
predictions.extend(cnn_predictions)
|
||||
# Store input data for CNN
|
||||
model_input = input_data.get('cnn_input')
|
||||
|
||||
elif isinstance(model, RLAgentInterface):
|
||||
# Get RL prediction
|
||||
rl_prediction = await self._get_rl_prediction(model, symbol)
|
||||
if rl_prediction:
|
||||
predictions.append(rl_prediction)
|
||||
prediction = rl_prediction
|
||||
# Store input data for RL
|
||||
model_input = input_data.get('rl_input')
|
||||
|
||||
else:
|
||||
# Generic model interface
|
||||
generic_prediction = await self._get_generic_prediction(model, symbol)
|
||||
if generic_prediction:
|
||||
predictions.append(generic_prediction)
|
||||
|
||||
prediction = generic_prediction
|
||||
# Store input data for generic model
|
||||
model_input = input_data.get('generic_input')
|
||||
|
||||
# Store inference data for training
|
||||
if prediction and model_input is not None:
|
||||
self._store_inference_data(symbol, model_name, model_input, prediction, current_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||
continue
|
||||
|
||||
# Trigger training based on previous inference data
|
||||
await self._trigger_model_training(symbol)
|
||||
|
||||
return predictions
|
||||
|
||||
async def _collect_model_input_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Collect comprehensive input data for all models"""
|
||||
try:
|
||||
input_data = {}
|
||||
|
||||
# Get current market data from data provider
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
|
||||
# Collect OHLCV data for multiple timeframes
|
||||
ohlcv_data = {}
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
for tf in timeframes:
|
||||
df = self.data_provider.get_historical_data(symbol, tf, limit=300)
|
||||
if df is not None and not df.empty:
|
||||
ohlcv_data[tf] = df
|
||||
|
||||
# Collect COB data if available
|
||||
cob_data = self.get_cob_snapshot(symbol)
|
||||
|
||||
# Collect technical indicators
|
||||
technical_indicators = {}
|
||||
if '1h' in ohlcv_data:
|
||||
df = ohlcv_data['1h']
|
||||
if len(df) > 20:
|
||||
technical_indicators['sma_20'] = df['close'].rolling(20).mean().iloc[-1]
|
||||
technical_indicators['rsi'] = self._calculate_rsi(df['close'])
|
||||
|
||||
# Prepare CNN input
|
||||
cnn_input = self._prepare_cnn_input_data(ohlcv_data, cob_data, technical_indicators)
|
||||
|
||||
# Prepare RL input
|
||||
rl_input = self._prepare_rl_input_data(ohlcv_data, cob_data, technical_indicators)
|
||||
|
||||
# Prepare generic input
|
||||
generic_input = {
|
||||
'symbol': symbol,
|
||||
'current_price': current_price,
|
||||
'ohlcv_data': ohlcv_data,
|
||||
'cob_data': cob_data,
|
||||
'technical_indicators': technical_indicators
|
||||
}
|
||||
|
||||
input_data = {
|
||||
'cnn_input': cnn_input,
|
||||
'rl_input': rl_input,
|
||||
'generic_input': generic_input,
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': symbol
|
||||
}
|
||||
|
||||
return input_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting model input data for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
|
||||
"""Prepare standardized input data for CNN models"""
|
||||
try:
|
||||
# Create feature matrix from OHLCV data
|
||||
features = []
|
||||
|
||||
# Add OHLCV features for each timeframe
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
if tf in ohlcv_data and not ohlcv_data[tf].empty:
|
||||
df = ohlcv_data[tf].tail(50) # Last 50 bars
|
||||
features.extend([
|
||||
df['close'].pct_change().fillna(0).values,
|
||||
df['volume'].values / df['volume'].max() if df['volume'].max() > 0 else np.zeros(len(df))
|
||||
])
|
||||
|
||||
# Add technical indicators
|
||||
for key, value in technical_indicators.items():
|
||||
if not np.isnan(value):
|
||||
features.append([value])
|
||||
|
||||
# Flatten and pad/truncate to standard size
|
||||
if features:
|
||||
feature_array = np.concatenate([np.array(f).flatten() for f in features])
|
||||
# Pad or truncate to 300 features
|
||||
if len(feature_array) < 300:
|
||||
feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant')
|
||||
else:
|
||||
feature_array = feature_array[:300]
|
||||
return feature_array.reshape(1, -1)
|
||||
else:
|
||||
return np.zeros((1, 300))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing CNN input data: {e}")
|
||||
return np.zeros((1, 300))
|
||||
|
||||
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
|
||||
"""Prepare standardized input data for RL models"""
|
||||
try:
|
||||
# Create state representation
|
||||
state_features = []
|
||||
|
||||
# Add price and volume features
|
||||
if '1m' in ohlcv_data and not ohlcv_data['1m'].empty:
|
||||
df = ohlcv_data['1m'].tail(20)
|
||||
state_features.extend([
|
||||
df['close'].pct_change().fillna(0).values,
|
||||
df['volume'].pct_change().fillna(0).values,
|
||||
(df['high'] - df['low']) / df['close'] # Volatility proxy
|
||||
])
|
||||
|
||||
# Add technical indicators
|
||||
for key, value in technical_indicators.items():
|
||||
if not np.isnan(value):
|
||||
state_features.append(value)
|
||||
|
||||
# Flatten and standardize size
|
||||
if state_features:
|
||||
state_array = np.concatenate([np.array(f).flatten() for f in state_features])
|
||||
# Pad or truncate to expected RL state size
|
||||
expected_size = 100 # Adjust based on your RL model
|
||||
if len(state_array) < expected_size:
|
||||
state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant')
|
||||
else:
|
||||
state_array = state_array[:expected_size]
|
||||
return state_array
|
||||
else:
|
||||
return np.zeros(100)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing RL input data: {e}")
|
||||
return np.zeros(100)
|
||||
|
||||
def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
|
||||
"""Store comprehensive inference data for future training with persistent storage"""
|
||||
try:
|
||||
# Get current market context for complete replay capability
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
|
||||
# Create comprehensive inference record with ALL data needed for model replay
|
||||
inference_record = {
|
||||
'timestamp': timestamp,
|
||||
'symbol': symbol,
|
||||
'model_name': model_name,
|
||||
'current_price': current_price,
|
||||
|
||||
# Complete model input data
|
||||
'model_input': {
|
||||
'raw_input': model_input,
|
||||
'input_shape': model_input.shape if hasattr(model_input, 'shape') else None,
|
||||
'input_type': str(type(model_input))
|
||||
},
|
||||
|
||||
# Complete prediction data
|
||||
'prediction': {
|
||||
'action': prediction.action,
|
||||
'confidence': prediction.confidence,
|
||||
'probabilities': prediction.probabilities,
|
||||
'timeframe': prediction.timeframe
|
||||
},
|
||||
|
||||
# Market context at prediction time
|
||||
'market_context': {
|
||||
'price': current_price,
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': symbol
|
||||
},
|
||||
|
||||
# Model metadata
|
||||
'metadata': {
|
||||
'model_metadata': prediction.metadata or {},
|
||||
'orchestrator_state': {
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'training_enabled': self.training_enabled
|
||||
}
|
||||
},
|
||||
|
||||
# Training outcome (will be filled later)
|
||||
'training_outcome': None,
|
||||
'outcome_evaluated': False
|
||||
}
|
||||
|
||||
# Store in memory (inference history)
|
||||
if symbol in self.inference_history:
|
||||
self.inference_history[symbol].append(inference_record)
|
||||
logger.debug(f"Stored inference data for {model_name} on {symbol}")
|
||||
|
||||
# Persistent storage to disk (for long-term training data)
|
||||
self._save_inference_to_disk(inference_record)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing inference data: {e}")
|
||||
|
||||
def _save_inference_to_disk(self, inference_record: Dict):
|
||||
"""Save inference record to persistent storage"""
|
||||
try:
|
||||
# Create inference data directory
|
||||
inference_dir = Path("training_data/inference_history")
|
||||
inference_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create filename with timestamp and model name
|
||||
timestamp_str = inference_record['timestamp'].strftime('%Y%m%d_%H%M%S')
|
||||
filename = f"{inference_record['symbol']}_{inference_record['model_name']}_{timestamp_str}.json"
|
||||
filepath = inference_dir / filename
|
||||
|
||||
# Convert numpy arrays to lists for JSON serialization
|
||||
serializable_record = self._make_json_serializable(inference_record)
|
||||
|
||||
# Save to JSON file
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(serializable_record, f, indent=2)
|
||||
|
||||
logger.debug(f"Saved inference record to disk: {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving inference to disk: {e}")
|
||||
|
||||
def _make_json_serializable(self, obj):
|
||||
"""Convert object to JSON-serializable format"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self._make_json_serializable(item) for item in obj]
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, (np.integer, np.floating)):
|
||||
return obj.item()
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
else:
|
||||
return obj
|
||||
|
||||
async def _trigger_model_training(self, symbol: str):
|
||||
"""Trigger training for models based on previous inference data"""
|
||||
try:
|
||||
if not self.training_enabled or symbol not in self.inference_history:
|
||||
return
|
||||
|
||||
# Get recent inference records
|
||||
recent_records = list(self.inference_history[symbol])
|
||||
if len(recent_records) < 2:
|
||||
return # Need at least 2 records to compare
|
||||
|
||||
# Get current price for outcome evaluation
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
if current_price is None:
|
||||
return
|
||||
|
||||
# Process records that are old enough to evaluate outcomes
|
||||
cutoff_time = datetime.now() - timedelta(minutes=5) # 5 minutes ago
|
||||
|
||||
for record in recent_records:
|
||||
if record['timestamp'] < cutoff_time:
|
||||
await self._evaluate_and_train_on_record(record, current_price)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model training for {symbol}: {e}")
|
||||
|
||||
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
|
||||
"""Evaluate prediction outcome and train model"""
|
||||
try:
|
||||
model_name = record['model_name']
|
||||
prediction = record['prediction']
|
||||
timestamp = record['timestamp']
|
||||
|
||||
# Calculate price change since prediction
|
||||
# This is a simplified outcome evaluation - you might want to make it more sophisticated
|
||||
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
|
||||
|
||||
# Get historical price at prediction time (simplified)
|
||||
symbol = record['symbol']
|
||||
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if historical_data is None or historical_data.empty:
|
||||
return
|
||||
|
||||
# Find closest price to prediction timestamp
|
||||
prediction_price = historical_data['close'].iloc[-1] # Simplified
|
||||
price_change_pct = (current_price - prediction_price) / prediction_price * 100
|
||||
|
||||
# Determine if prediction was correct
|
||||
predicted_action = prediction['action']
|
||||
was_correct = False
|
||||
|
||||
if predicted_action == 'BUY' and price_change_pct > 0.1: # Price went up
|
||||
was_correct = True
|
||||
elif predicted_action == 'SELL' and price_change_pct < -0.1: # Price went down
|
||||
was_correct = True
|
||||
elif predicted_action == 'HOLD' and abs(price_change_pct) < 0.1: # Price stayed stable
|
||||
was_correct = True
|
||||
|
||||
# Update model performance tracking
|
||||
if model_name not in self.model_performance:
|
||||
self.model_performance[model_name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||||
|
||||
self.model_performance[model_name]['total'] += 1
|
||||
if was_correct:
|
||||
self.model_performance[model_name]['correct'] += 1
|
||||
|
||||
self.model_performance[model_name]['accuracy'] = (
|
||||
self.model_performance[model_name]['correct'] /
|
||||
self.model_performance[model_name]['total']
|
||||
)
|
||||
|
||||
# Train the specific model based on outcome
|
||||
await self._train_model_on_outcome(record, was_correct, price_change_pct)
|
||||
|
||||
logger.debug(f"Evaluated {model_name} prediction: {'✓' if was_correct else '✗'} "
|
||||
f"({prediction['action']}, {price_change_pct:.2f}% change)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating and training on record: {e}")
|
||||
|
||||
async def _train_model_on_outcome(self, record: Dict, was_correct: bool, price_change_pct: float):
|
||||
"""Train specific model based on prediction outcome"""
|
||||
try:
|
||||
model_name = record['model_name']
|
||||
model_input = record['model_input']
|
||||
prediction = record['prediction']
|
||||
|
||||
# Create training signal based on outcome
|
||||
reward = 1.0 if was_correct else -0.5
|
||||
|
||||
# Train RL models
|
||||
if 'dqn' in model_name.lower() and self.rl_agent:
|
||||
if hasattr(self.rl_agent, 'add_experience'):
|
||||
action_idx = ['SELL', 'HOLD', 'BUY'].index(prediction['action'])
|
||||
self.rl_agent.add_experience(
|
||||
state=model_input,
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
next_state=model_input, # Simplified
|
||||
done=True
|
||||
)
|
||||
logger.debug(f"Added RL training experience: reward={reward}")
|
||||
|
||||
# Train CNN models
|
||||
elif 'cnn' in model_name.lower() and self.cnn_model:
|
||||
if hasattr(self.cnn_model, 'train_on_outcome'):
|
||||
target = 1 if was_correct else 0
|
||||
self.cnn_model.train_on_outcome(model_input, target)
|
||||
logger.debug(f"Trained CNN on outcome: target={target}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training model on outcome: {e}")
|
||||
|
||||
def _calculate_rsi(self, prices: pd.Series, period: int = 14) -> float:
|
||||
"""Calculate RSI indicator"""
|
||||
try:
|
||||
delta = prices.diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||
rs = gain / loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi.iloc[-1] if not rsi.empty else 50.0
|
||||
except:
|
||||
return 50.0
|
||||
|
||||
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
|
||||
"""Get predictions from CNN model for all timeframes with enhanced COB features"""
|
||||
predictions = []
|
||||
|
Reference in New Issue
Block a user