replay system
This commit is contained in:
775
core/enhanced_training_integration.py
Normal file
775
core/enhanced_training_integration.py
Normal file
@ -0,0 +1,775 @@
|
||||
"""
|
||||
Enhanced Training Integration Module
|
||||
|
||||
This module provides comprehensive integration between the training data collection system,
|
||||
CNN training pipeline, RL training pipeline, and your existing infrastructure.
|
||||
|
||||
Key Features:
|
||||
- Real-time integration with existing DataProvider
|
||||
- Coordinated training across CNN and RL models
|
||||
- Automatic outcome validation and profitability tracking
|
||||
- Integration with existing COB RL model
|
||||
- Performance monitoring and optimization
|
||||
- Seamless connection to existing orchestrator and trading executor
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Import existing components
|
||||
from .data_provider import DataProvider
|
||||
from .orchestrator import Orchestrator
|
||||
from .trading_executor import TradingExecutor
|
||||
|
||||
# Import our training system components
|
||||
from .training_data_collector import (
|
||||
TrainingDataCollector,
|
||||
get_training_data_collector
|
||||
)
|
||||
from .cnn_training_pipeline import (
|
||||
CNNPivotPredictor,
|
||||
CNNTrainer,
|
||||
get_cnn_trainer
|
||||
)
|
||||
from .rl_training_pipeline import (
|
||||
RLTradingAgent,
|
||||
RLTrainer,
|
||||
get_rl_trainer
|
||||
)
|
||||
from .training_integration import TrainingIntegration
|
||||
|
||||
# Import existing RL model
|
||||
try:
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
except ImportError:
|
||||
logger.warning("Could not import COBRLModelInterface - using fallback")
|
||||
COBRLModelInterface = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class EnhancedTrainingConfig:
|
||||
"""Enhanced configuration for comprehensive training integration"""
|
||||
# Data collection
|
||||
collection_interval: float = 1.0
|
||||
min_data_completeness: float = 0.8
|
||||
|
||||
# Training triggers
|
||||
min_episodes_for_cnn_training: int = 100
|
||||
min_experiences_for_rl_training: int = 200
|
||||
training_frequency_minutes: int = 30
|
||||
|
||||
# Profitability thresholds
|
||||
min_profitability_for_replay: float = 0.1
|
||||
high_profitability_threshold: float = 0.5
|
||||
|
||||
# Model integration
|
||||
use_existing_cob_rl_model: bool = True
|
||||
enable_cross_model_learning: bool = True
|
||||
|
||||
# Performance optimization
|
||||
max_concurrent_training_sessions: int = 2
|
||||
enable_background_validation: bool = True
|
||||
|
||||
class EnhancedTrainingIntegration:
|
||||
"""Enhanced training integration with existing infrastructure"""
|
||||
|
||||
def __init__(self,
|
||||
data_provider: DataProvider,
|
||||
orchestrator: Orchestrator = None,
|
||||
trading_executor: TradingExecutor = None,
|
||||
config: EnhancedTrainingConfig = None):
|
||||
|
||||
self.data_provider = data_provider
|
||||
self.orchestrator = orchestrator
|
||||
self.trading_executor = trading_executor
|
||||
self.config = config or EnhancedTrainingConfig()
|
||||
|
||||
# Initialize training components
|
||||
self.data_collector = get_training_data_collector()
|
||||
|
||||
# Initialize CNN components
|
||||
self.cnn_model = CNNPivotPredictor()
|
||||
self.cnn_trainer = get_cnn_trainer(self.cnn_model)
|
||||
|
||||
# Initialize RL components
|
||||
if self.config.use_existing_cob_rl_model and COBRLModelInterface:
|
||||
self.existing_rl_model = COBRLModelInterface()
|
||||
logger.info("Using existing COB RL model")
|
||||
else:
|
||||
self.existing_rl_model = None
|
||||
|
||||
self.rl_agent = RLTradingAgent()
|
||||
self.rl_trainer = get_rl_trainer(self.rl_agent)
|
||||
|
||||
# Integration state
|
||||
self.is_running = False
|
||||
self.training_threads = {}
|
||||
self.validation_thread = None
|
||||
|
||||
# Performance tracking
|
||||
self.integration_stats = {
|
||||
'total_data_packages': 0,
|
||||
'cnn_training_sessions': 0,
|
||||
'rl_training_sessions': 0,
|
||||
'profitable_predictions': 0,
|
||||
'total_predictions': 0,
|
||||
'cross_model_improvements': 0,
|
||||
'last_update': datetime.now()
|
||||
}
|
||||
|
||||
# Model prediction tracking
|
||||
self.recent_predictions = {}
|
||||
self.prediction_outcomes = {}
|
||||
|
||||
# Cross-model learning
|
||||
self.model_performance_history = {
|
||||
'cnn': [],
|
||||
'rl': [],
|
||||
'orchestrator': []
|
||||
}
|
||||
|
||||
logger.info("Enhanced Training Integration initialized")
|
||||
logger.info(f"CNN model parameters: {sum(p.numel() for p in self.cnn_model.parameters()):,}")
|
||||
logger.info(f"RL agent parameters: {sum(p.numel() for p in self.rl_agent.parameters()):,}")
|
||||
logger.info(f"Using existing COB RL model: {self.existing_rl_model is not None}")
|
||||
|
||||
def start_enhanced_integration(self):
|
||||
"""Start the enhanced training integration system"""
|
||||
if self.is_running:
|
||||
logger.warning("Enhanced training integration already running")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start data collection
|
||||
self.data_collector.start_collection()
|
||||
|
||||
# Start CNN training
|
||||
if self.config.min_episodes_for_cnn_training > 0:
|
||||
for symbol in self.data_provider.symbols:
|
||||
self.cnn_trainer.start_real_time_training(symbol)
|
||||
|
||||
# Start coordinated training thread
|
||||
self.training_threads['coordinator'] = threading.Thread(
|
||||
target=self._training_coordinator_worker,
|
||||
daemon=True
|
||||
)
|
||||
self.training_threads['coordinator'].start()
|
||||
|
||||
# Start data collection and validation
|
||||
self.training_threads['data_collector'] = threading.Thread(
|
||||
target=self._enhanced_data_collection_worker,
|
||||
daemon=True
|
||||
)
|
||||
self.training_threads['data_collector'].start()
|
||||
|
||||
# Start outcome validation if enabled
|
||||
if self.config.enable_background_validation:
|
||||
self.validation_thread = threading.Thread(
|
||||
target=self._outcome_validation_worker,
|
||||
daemon=True
|
||||
)
|
||||
self.validation_thread.start()
|
||||
|
||||
logger.info("Enhanced training integration started")
|
||||
|
||||
def stop_enhanced_integration(self):
|
||||
"""Stop the enhanced training integration system"""
|
||||
self.is_running = False
|
||||
|
||||
# Stop data collection
|
||||
self.data_collector.stop_collection()
|
||||
|
||||
# Stop CNN training
|
||||
self.cnn_trainer.stop_training()
|
||||
|
||||
# Wait for threads to finish
|
||||
for thread_name, thread in self.training_threads.items():
|
||||
thread.join(timeout=10)
|
||||
logger.info(f"Stopped {thread_name} thread")
|
||||
|
||||
if self.validation_thread:
|
||||
self.validation_thread.join(timeout=5)
|
||||
|
||||
logger.info("Enhanced training integration stopped")
|
||||
|
||||
def _enhanced_data_collection_worker(self):
|
||||
"""Enhanced data collection with real-time model integration"""
|
||||
logger.info("Enhanced data collection worker started")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
for symbol in self.data_provider.symbols:
|
||||
self._collect_enhanced_training_data(symbol)
|
||||
|
||||
time.sleep(self.config.collection_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in enhanced data collection: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
logger.info("Enhanced data collection worker stopped")
|
||||
|
||||
def _collect_enhanced_training_data(self, symbol: str):
|
||||
"""Collect enhanced training data with model predictions"""
|
||||
try:
|
||||
# Get comprehensive market data
|
||||
market_data = self._get_comprehensive_market_data(symbol)
|
||||
|
||||
if not market_data or not self._validate_market_data(market_data):
|
||||
return
|
||||
|
||||
# Get current model predictions
|
||||
model_predictions = self._get_all_model_predictions(symbol, market_data)
|
||||
|
||||
# Create enhanced features
|
||||
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
|
||||
rl_state = self._create_enhanced_rl_state(symbol, market_data, model_predictions)
|
||||
|
||||
# Collect training data with predictions
|
||||
episode_id = self.data_collector.collect_training_data(
|
||||
symbol=symbol,
|
||||
ohlcv_data=market_data['ohlcv'],
|
||||
tick_data=market_data['ticks'],
|
||||
cob_data=market_data['cob'],
|
||||
technical_indicators=market_data['indicators'],
|
||||
pivot_points=market_data['pivots'],
|
||||
cnn_features=cnn_features,
|
||||
rl_state=rl_state,
|
||||
orchestrator_context=market_data['context'],
|
||||
model_predictions=model_predictions
|
||||
)
|
||||
|
||||
if episode_id:
|
||||
# Store predictions for outcome validation
|
||||
self.recent_predictions[episode_id] = {
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': symbol,
|
||||
'predictions': model_predictions,
|
||||
'market_data': market_data
|
||||
}
|
||||
|
||||
# Add RL experience if we have action
|
||||
if 'rl_action' in model_predictions:
|
||||
self._add_rl_experience(symbol, market_data, model_predictions, episode_id)
|
||||
|
||||
self.integration_stats['total_data_packages'] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting enhanced training data for {symbol}: {e}")
|
||||
|
||||
def _get_comprehensive_market_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get comprehensive market data from all sources"""
|
||||
try:
|
||||
market_data = {}
|
||||
|
||||
# OHLCV data
|
||||
ohlcv_data = {}
|
||||
for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']:
|
||||
df = self.data_provider.get_historical_data(symbol, timeframe, limit=300, refresh=True)
|
||||
if df is not None and not df.empty:
|
||||
ohlcv_data[timeframe] = df
|
||||
market_data['ohlcv'] = ohlcv_data
|
||||
|
||||
# Tick data
|
||||
market_data['ticks'] = self._get_recent_tick_data(symbol)
|
||||
|
||||
# COB data
|
||||
market_data['cob'] = self._get_cob_data(symbol)
|
||||
|
||||
# Technical indicators
|
||||
market_data['indicators'] = self._get_technical_indicators(symbol)
|
||||
|
||||
# Pivot points
|
||||
market_data['pivots'] = self._get_pivot_points(symbol)
|
||||
|
||||
# Market context
|
||||
market_data['context'] = self._get_market_context(symbol)
|
||||
|
||||
return market_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting comprehensive market data: {e}")
|
||||
return {}
|
||||
|
||||
def _get_all_model_predictions(self, symbol: str, market_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get predictions from all available models"""
|
||||
predictions = {}
|
||||
|
||||
try:
|
||||
# CNN predictions
|
||||
if self.cnn_model and market_data.get('ohlcv'):
|
||||
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
|
||||
if cnn_features is not None:
|
||||
cnn_input = torch.from_numpy(cnn_features).float().unsqueeze(0)
|
||||
|
||||
# Reshape for CNN (add channel dimension)
|
||||
cnn_input = cnn_input.view(1, 10, -1) # Assuming 10 channels
|
||||
|
||||
with torch.no_grad():
|
||||
cnn_outputs = self.cnn_model(cnn_input)
|
||||
predictions['cnn'] = {
|
||||
'pivot_logits': cnn_outputs['pivot_logits'].cpu().numpy(),
|
||||
'pivot_price': cnn_outputs['pivot_price'].cpu().numpy(),
|
||||
'confidence': cnn_outputs['confidence'].cpu().numpy(),
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
# RL predictions
|
||||
if self.rl_agent and market_data.get('cob'):
|
||||
rl_state = self._create_enhanced_rl_state(symbol, market_data, predictions)
|
||||
if rl_state is not None:
|
||||
action, confidence = self.rl_agent.select_action(rl_state, epsilon=0.1)
|
||||
predictions['rl'] = {
|
||||
'action': action,
|
||||
'confidence': confidence,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
predictions['rl_action'] = action
|
||||
|
||||
# Existing COB RL model predictions
|
||||
if self.existing_rl_model and market_data.get('cob'):
|
||||
cob_features = market_data['cob'].get('cob_features', [])
|
||||
if cob_features and len(cob_features) >= 2000:
|
||||
cob_array = np.array(cob_features[:2000], dtype=np.float32)
|
||||
cob_prediction = self.existing_rl_model.predict(cob_array)
|
||||
predictions['cob_rl'] = {
|
||||
'predicted_direction': cob_prediction.get('predicted_direction', 1),
|
||||
'confidence': cob_prediction.get('confidence', 0.5),
|
||||
'value': cob_prediction.get('value', 0.0),
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
# Orchestrator predictions (if available)
|
||||
if self.orchestrator:
|
||||
try:
|
||||
# This would integrate with your orchestrator's prediction method
|
||||
orchestrator_prediction = self._get_orchestrator_prediction(symbol, market_data, predictions)
|
||||
if orchestrator_prediction:
|
||||
predictions['orchestrator'] = orchestrator_prediction
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get orchestrator prediction: {e}")
|
||||
|
||||
return predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model predictions: {e}")
|
||||
return {}
|
||||
|
||||
def _add_rl_experience(self, symbol: str, market_data: Dict[str, Any],
|
||||
predictions: Dict[str, Any], episode_id: str):
|
||||
"""Add RL experience to the training buffer"""
|
||||
try:
|
||||
# Create RL state
|
||||
state = self._create_enhanced_rl_state(symbol, market_data, predictions)
|
||||
if state is None:
|
||||
return
|
||||
|
||||
# Get action from predictions
|
||||
action = predictions.get('rl_action', 1) # Default to HOLD
|
||||
|
||||
# Calculate immediate reward (placeholder - would be updated with actual outcome)
|
||||
reward = 0.0
|
||||
|
||||
# Create next state (same as current for now - would be updated)
|
||||
next_state = state.copy()
|
||||
|
||||
# Market context
|
||||
market_context = {
|
||||
'symbol': symbol,
|
||||
'episode_id': episode_id,
|
||||
'timestamp': datetime.now(),
|
||||
'market_session': market_data['context'].get('market_session', 'unknown'),
|
||||
'volatility_regime': market_data['context'].get('volatility_regime', 'unknown')
|
||||
}
|
||||
|
||||
# Add experience
|
||||
experience_id = self.rl_trainer.add_experience(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=False,
|
||||
market_context=market_context,
|
||||
cnn_predictions=predictions.get('cnn'),
|
||||
confidence_score=predictions.get('rl', {}).get('confidence', 0.0)
|
||||
)
|
||||
|
||||
if experience_id:
|
||||
logger.debug(f"Added RL experience: {experience_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding RL experience: {e}")
|
||||
|
||||
def _training_coordinator_worker(self):
|
||||
"""Coordinate training across all models"""
|
||||
logger.info("Training coordinator worker started")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# Check if we should trigger training
|
||||
for symbol in self.data_provider.symbols:
|
||||
self._check_and_trigger_training(symbol)
|
||||
|
||||
# Wait before next check
|
||||
time.sleep(self.config.training_frequency_minutes * 60)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training coordinator: {e}")
|
||||
time.sleep(60)
|
||||
|
||||
logger.info("Training coordinator worker stopped")
|
||||
|
||||
def _check_and_trigger_training(self, symbol: str):
|
||||
"""Check conditions and trigger training if needed"""
|
||||
try:
|
||||
# Get training episodes and experiences
|
||||
episodes = self.data_collector.get_high_priority_episodes(symbol, limit=1000)
|
||||
|
||||
# Check CNN training conditions
|
||||
if len(episodes) >= self.config.min_episodes_for_cnn_training:
|
||||
profitable_episodes = [ep for ep in episodes if ep.actual_outcome.is_profitable]
|
||||
|
||||
if len(profitable_episodes) >= 20: # Minimum profitable episodes
|
||||
logger.info(f"Triggering CNN training for {symbol} with {len(profitable_episodes)} profitable episodes")
|
||||
|
||||
results = self.cnn_trainer.train_on_profitable_episodes(
|
||||
symbol=symbol,
|
||||
min_profitability=self.config.min_profitability_for_replay,
|
||||
max_episodes=len(profitable_episodes)
|
||||
)
|
||||
|
||||
if results.get('status') == 'success':
|
||||
self.integration_stats['cnn_training_sessions'] += 1
|
||||
logger.info(f"CNN training completed for {symbol}")
|
||||
|
||||
# Check RL training conditions
|
||||
buffer_stats = self.rl_trainer.experience_buffer.get_buffer_statistics()
|
||||
total_experiences = buffer_stats.get('total_experiences', 0)
|
||||
|
||||
if total_experiences >= self.config.min_experiences_for_rl_training:
|
||||
profitable_experiences = buffer_stats.get('profitable_experiences', 0)
|
||||
|
||||
if profitable_experiences >= 50: # Minimum profitable experiences
|
||||
logger.info(f"Triggering RL training with {profitable_experiences} profitable experiences")
|
||||
|
||||
results = self.rl_trainer.train_on_profitable_experiences(
|
||||
min_profitability=self.config.min_profitability_for_replay,
|
||||
max_experiences=min(profitable_experiences, 500),
|
||||
batch_size=32
|
||||
)
|
||||
|
||||
if results.get('status') == 'success':
|
||||
self.integration_stats['rl_training_sessions'] += 1
|
||||
logger.info("RL training completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking training conditions for {symbol}: {e}")
|
||||
|
||||
def _outcome_validation_worker(self):
|
||||
"""Background worker for validating prediction outcomes"""
|
||||
logger.info("Outcome validation worker started")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
self._validate_recent_predictions()
|
||||
time.sleep(300) # Check every 5 minutes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in outcome validation: {e}")
|
||||
time.sleep(60)
|
||||
|
||||
logger.info("Outcome validation worker stopped")
|
||||
|
||||
def _validate_recent_predictions(self):
|
||||
"""Validate recent predictions against actual outcomes"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
validation_delay = timedelta(hours=1) # Wait 1 hour to validate
|
||||
|
||||
validated_predictions = []
|
||||
|
||||
for episode_id, prediction_data in self.recent_predictions.items():
|
||||
prediction_time = prediction_data['timestamp']
|
||||
|
||||
if current_time - prediction_time >= validation_delay:
|
||||
# Validate this prediction
|
||||
outcome = self._calculate_prediction_outcome(prediction_data)
|
||||
|
||||
if outcome:
|
||||
self.prediction_outcomes[episode_id] = outcome
|
||||
|
||||
# Update RL experience if exists
|
||||
if 'rl_action' in prediction_data['predictions']:
|
||||
self._update_rl_experience_outcome(episode_id, outcome)
|
||||
|
||||
# Update statistics
|
||||
if outcome['is_profitable']:
|
||||
self.integration_stats['profitable_predictions'] += 1
|
||||
self.integration_stats['total_predictions'] += 1
|
||||
|
||||
validated_predictions.append(episode_id)
|
||||
|
||||
# Remove validated predictions
|
||||
for episode_id in validated_predictions:
|
||||
del self.recent_predictions[episode_id]
|
||||
|
||||
if validated_predictions:
|
||||
logger.info(f"Validated {len(validated_predictions)} predictions")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating predictions: {e}")
|
||||
|
||||
def _calculate_prediction_outcome(self, prediction_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Calculate actual outcome for a prediction"""
|
||||
try:
|
||||
symbol = prediction_data['symbol']
|
||||
prediction_time = prediction_data['timestamp']
|
||||
|
||||
# Get price data after prediction
|
||||
current_df = self.data_provider.get_historical_data(symbol, '1m', limit=100, refresh=True)
|
||||
|
||||
if current_df is None or current_df.empty:
|
||||
return None
|
||||
|
||||
# Find price at prediction time and current price
|
||||
prediction_price = prediction_data['market_data']['ohlcv'].get('1m', pd.DataFrame())
|
||||
if prediction_price.empty:
|
||||
return None
|
||||
|
||||
base_price = float(prediction_price['close'].iloc[-1])
|
||||
current_price = float(current_df['close'].iloc[-1])
|
||||
|
||||
# Calculate outcome
|
||||
price_change = (current_price - base_price) / base_price
|
||||
is_profitable = abs(price_change) > 0.005 # 0.5% threshold
|
||||
|
||||
return {
|
||||
'episode_id': prediction_data.get('episode_id'),
|
||||
'base_price': base_price,
|
||||
'current_price': current_price,
|
||||
'price_change': price_change,
|
||||
'is_profitable': is_profitable,
|
||||
'profitability_score': abs(price_change) * 10, # Scale to 0-1 range
|
||||
'validation_time': datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating prediction outcome: {e}")
|
||||
return None
|
||||
|
||||
def _update_rl_experience_outcome(self, episode_id: str, outcome: Dict[str, Any]):
|
||||
"""Update RL experience with actual outcome"""
|
||||
try:
|
||||
# Find the experience ID associated with this episode
|
||||
# This is a simplified approach - in practice you'd maintain better mapping
|
||||
actual_profit = outcome['price_change']
|
||||
|
||||
# Determine optimal action based on outcome
|
||||
if outcome['price_change'] > 0.01:
|
||||
optimal_action = 2 # BUY
|
||||
elif outcome['price_change'] < -0.01:
|
||||
optimal_action = 0 # SELL
|
||||
else:
|
||||
optimal_action = 1 # HOLD
|
||||
|
||||
# Update experience (this would need proper experience ID mapping)
|
||||
# For now, we'll update the most recent experience
|
||||
# In practice, you'd maintain a mapping between episodes and experiences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating RL experience outcome: {e}")
|
||||
|
||||
def get_integration_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive integration statistics"""
|
||||
stats = self.integration_stats.copy()
|
||||
|
||||
# Add component statistics
|
||||
stats['data_collector'] = self.data_collector.get_collection_statistics()
|
||||
stats['cnn_trainer'] = self.cnn_trainer.get_training_statistics()
|
||||
stats['rl_trainer'] = self.rl_trainer.get_training_statistics()
|
||||
|
||||
# Add performance metrics
|
||||
stats['is_running'] = self.is_running
|
||||
stats['active_symbols'] = len(self.data_provider.symbols)
|
||||
stats['recent_predictions_count'] = len(self.recent_predictions)
|
||||
stats['validated_outcomes_count'] = len(self.prediction_outcomes)
|
||||
|
||||
# Calculate profitability rate
|
||||
if stats['total_predictions'] > 0:
|
||||
stats['overall_profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
|
||||
else:
|
||||
stats['overall_profitability_rate'] = 0.0
|
||||
|
||||
return stats
|
||||
|
||||
def trigger_manual_training(self, training_type: str = 'all', symbol: str = None) -> Dict[str, Any]:
|
||||
"""Manually trigger training"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
if training_type in ['all', 'cnn']:
|
||||
symbols = [symbol] if symbol else self.data_provider.symbols
|
||||
for sym in symbols:
|
||||
cnn_results = self.cnn_trainer.train_on_profitable_episodes(
|
||||
symbol=sym,
|
||||
min_profitability=0.1,
|
||||
max_episodes=200
|
||||
)
|
||||
results[f'cnn_{sym}'] = cnn_results
|
||||
|
||||
if training_type in ['all', 'rl']:
|
||||
rl_results = self.rl_trainer.train_on_profitable_experiences(
|
||||
min_profitability=0.1,
|
||||
max_experiences=500,
|
||||
batch_size=32
|
||||
)
|
||||
results['rl'] = rl_results
|
||||
|
||||
return {'status': 'success', 'results': results}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in manual training trigger: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
# Helper methods (simplified implementations)
|
||||
def _get_recent_tick_data(self, symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Get recent tick data"""
|
||||
# Implementation would get tick data from data provider
|
||||
return []
|
||||
|
||||
def _get_cob_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get COB data"""
|
||||
# Implementation would get COB data from data provider
|
||||
return {}
|
||||
|
||||
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
|
||||
"""Get technical indicators"""
|
||||
# Implementation would get indicators from data provider
|
||||
return {}
|
||||
|
||||
def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Get pivot points"""
|
||||
# Implementation would get pivot points from data provider
|
||||
return []
|
||||
|
||||
def _get_market_context(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get market context"""
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now(),
|
||||
'market_session': 'unknown',
|
||||
'volatility_regime': 'unknown'
|
||||
}
|
||||
|
||||
def _validate_market_data(self, market_data: Dict[str, Any]) -> bool:
|
||||
"""Validate market data completeness"""
|
||||
required_fields = ['ohlcv', 'indicators']
|
||||
return all(field in market_data for field in required_fields)
|
||||
|
||||
def _create_enhanced_cnn_features(self, symbol: str, market_data: Dict[str, Any]) -> Optional[np.ndarray]:
|
||||
"""Create enhanced CNN features"""
|
||||
try:
|
||||
# Simplified feature creation
|
||||
features = []
|
||||
|
||||
# Add OHLCV features
|
||||
for timeframe in ['1m', '5m', '15m', '1h']:
|
||||
if timeframe in market_data.get('ohlcv', {}):
|
||||
df = market_data['ohlcv'][timeframe]
|
||||
if not df.empty:
|
||||
ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
if len(ohlcv_values) > 0:
|
||||
recent_values = ohlcv_values[-60:].flatten()
|
||||
features.extend(recent_values)
|
||||
|
||||
# Pad to target size
|
||||
target_size = 3000 # 10 channels * 300 sequence length
|
||||
if len(features) < target_size:
|
||||
features.extend([0.0] * (target_size - len(features)))
|
||||
else:
|
||||
features = features[:target_size]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating CNN features: {e}")
|
||||
return None
|
||||
|
||||
def _create_enhanced_rl_state(self, symbol: str, market_data: Dict[str, Any],
|
||||
predictions: Dict[str, Any] = None) -> Optional[np.ndarray]:
|
||||
"""Create enhanced RL state"""
|
||||
try:
|
||||
state_features = []
|
||||
|
||||
# Add market features
|
||||
if '1m' in market_data.get('ohlcv', {}):
|
||||
df = market_data['ohlcv']['1m']
|
||||
if not df.empty:
|
||||
latest = df.iloc[-1]
|
||||
state_features.extend([
|
||||
latest['open'], latest['high'],
|
||||
latest['low'], latest['close'], latest['volume']
|
||||
])
|
||||
|
||||
# Add technical indicators
|
||||
indicators = market_data.get('indicators', {})
|
||||
for value in indicators.values():
|
||||
state_features.append(value)
|
||||
|
||||
# Add model predictions as features
|
||||
if predictions:
|
||||
if 'cnn' in predictions:
|
||||
cnn_pred = predictions['cnn']
|
||||
state_features.extend(cnn_pred.get('pivot_logits', [0, 0, 0]))
|
||||
state_features.append(cnn_pred.get('confidence', [0.0])[0])
|
||||
|
||||
if 'cob_rl' in predictions:
|
||||
cob_pred = predictions['cob_rl']
|
||||
state_features.append(cob_pred.get('predicted_direction', 1))
|
||||
state_features.append(cob_pred.get('confidence', 0.5))
|
||||
|
||||
# Pad to target size
|
||||
target_size = 2000
|
||||
if len(state_features) < target_size:
|
||||
state_features.extend([0.0] * (target_size - len(state_features)))
|
||||
else:
|
||||
state_features = state_features[:target_size]
|
||||
|
||||
return np.array(state_features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating RL state: {e}")
|
||||
return None
|
||||
|
||||
def _get_orchestrator_prediction(self, symbol: str, market_data: Dict[str, Any],
|
||||
predictions: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Get orchestrator prediction"""
|
||||
# This would integrate with your orchestrator
|
||||
return None
|
||||
|
||||
# Global instance
|
||||
enhanced_training_integration = None
|
||||
|
||||
def get_enhanced_training_integration(data_provider: DataProvider = None,
|
||||
orchestrator: Orchestrator = None,
|
||||
trading_executor: TradingExecutor = None) -> EnhancedTrainingIntegration:
|
||||
"""Get global enhanced training integration instance"""
|
||||
global enhanced_training_integration
|
||||
if enhanced_training_integration is None:
|
||||
if data_provider is None:
|
||||
raise ValueError("DataProvider required for first initialization")
|
||||
enhanced_training_integration = EnhancedTrainingIntegration(
|
||||
data_provider, orchestrator, trading_executor
|
||||
)
|
||||
return enhanced_training_integration
|
Reference in New Issue
Block a user