775 lines
32 KiB
Python
775 lines
32 KiB
Python
"""
|
|
Enhanced Training Integration Module
|
|
|
|
This module provides comprehensive integration between the training data collection system,
|
|
CNN training pipeline, RL training pipeline, and your existing infrastructure.
|
|
|
|
Key Features:
|
|
- Real-time integration with existing DataProvider
|
|
- Coordinated training across CNN and RL models
|
|
- Automatic outcome validation and profitability tracking
|
|
- Integration with existing COB RL model
|
|
- Performance monitoring and optimization
|
|
- Seamless connection to existing orchestrator and trading executor
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple, Any, Callable
|
|
from dataclasses import dataclass
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Import existing components
|
|
from .data_provider import DataProvider
|
|
from .orchestrator import Orchestrator
|
|
from .trading_executor import TradingExecutor
|
|
|
|
# Import our training system components
|
|
from .training_data_collector import (
|
|
TrainingDataCollector,
|
|
get_training_data_collector
|
|
)
|
|
from .cnn_training_pipeline import (
|
|
CNNPivotPredictor,
|
|
CNNTrainer,
|
|
get_cnn_trainer
|
|
)
|
|
from .rl_training_pipeline import (
|
|
RLTradingAgent,
|
|
RLTrainer,
|
|
get_rl_trainer
|
|
)
|
|
from .training_integration import TrainingIntegration
|
|
|
|
# Import existing RL model
|
|
try:
|
|
from NN.models.cob_rl_model import COBRLModelInterface
|
|
except ImportError:
|
|
logger.warning("Could not import COBRLModelInterface - using fallback")
|
|
COBRLModelInterface = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class EnhancedTrainingConfig:
|
|
"""Enhanced configuration for comprehensive training integration"""
|
|
# Data collection
|
|
collection_interval: float = 1.0
|
|
min_data_completeness: float = 0.8
|
|
|
|
# Training triggers
|
|
min_episodes_for_cnn_training: int = 100
|
|
min_experiences_for_rl_training: int = 200
|
|
training_frequency_minutes: int = 30
|
|
|
|
# Profitability thresholds
|
|
min_profitability_for_replay: float = 0.1
|
|
high_profitability_threshold: float = 0.5
|
|
|
|
# Model integration
|
|
use_existing_cob_rl_model: bool = True
|
|
enable_cross_model_learning: bool = True
|
|
|
|
# Performance optimization
|
|
max_concurrent_training_sessions: int = 2
|
|
enable_background_validation: bool = True
|
|
|
|
class EnhancedTrainingIntegration:
|
|
"""Enhanced training integration with existing infrastructure"""
|
|
|
|
def __init__(self,
|
|
data_provider: DataProvider,
|
|
orchestrator: Orchestrator = None,
|
|
trading_executor: TradingExecutor = None,
|
|
config: EnhancedTrainingConfig = None):
|
|
|
|
self.data_provider = data_provider
|
|
self.orchestrator = orchestrator
|
|
self.trading_executor = trading_executor
|
|
self.config = config or EnhancedTrainingConfig()
|
|
|
|
# Initialize training components
|
|
self.data_collector = get_training_data_collector()
|
|
|
|
# Initialize CNN components
|
|
self.cnn_model = CNNPivotPredictor()
|
|
self.cnn_trainer = get_cnn_trainer(self.cnn_model)
|
|
|
|
# Initialize RL components
|
|
if self.config.use_existing_cob_rl_model and COBRLModelInterface:
|
|
self.existing_rl_model = COBRLModelInterface()
|
|
logger.info("Using existing COB RL model")
|
|
else:
|
|
self.existing_rl_model = None
|
|
|
|
self.rl_agent = RLTradingAgent()
|
|
self.rl_trainer = get_rl_trainer(self.rl_agent)
|
|
|
|
# Integration state
|
|
self.is_running = False
|
|
self.training_threads = {}
|
|
self.validation_thread = None
|
|
|
|
# Performance tracking
|
|
self.integration_stats = {
|
|
'total_data_packages': 0,
|
|
'cnn_training_sessions': 0,
|
|
'rl_training_sessions': 0,
|
|
'profitable_predictions': 0,
|
|
'total_predictions': 0,
|
|
'cross_model_improvements': 0,
|
|
'last_update': datetime.now()
|
|
}
|
|
|
|
# Model prediction tracking
|
|
self.recent_predictions = {}
|
|
self.prediction_outcomes = {}
|
|
|
|
# Cross-model learning
|
|
self.model_performance_history = {
|
|
'cnn': [],
|
|
'rl': [],
|
|
'orchestrator': []
|
|
}
|
|
|
|
logger.info("Enhanced Training Integration initialized")
|
|
logger.info(f"CNN model parameters: {sum(p.numel() for p in self.cnn_model.parameters()):,}")
|
|
logger.info(f"RL agent parameters: {sum(p.numel() for p in self.rl_agent.parameters()):,}")
|
|
logger.info(f"Using existing COB RL model: {self.existing_rl_model is not None}")
|
|
|
|
def start_enhanced_integration(self):
|
|
"""Start the enhanced training integration system"""
|
|
if self.is_running:
|
|
logger.warning("Enhanced training integration already running")
|
|
return
|
|
|
|
self.is_running = True
|
|
|
|
# Start data collection
|
|
self.data_collector.start_collection()
|
|
|
|
# Start CNN training
|
|
if self.config.min_episodes_for_cnn_training > 0:
|
|
for symbol in self.data_provider.symbols:
|
|
self.cnn_trainer.start_real_time_training(symbol)
|
|
|
|
# Start coordinated training thread
|
|
self.training_threads['coordinator'] = threading.Thread(
|
|
target=self._training_coordinator_worker,
|
|
daemon=True
|
|
)
|
|
self.training_threads['coordinator'].start()
|
|
|
|
# Start data collection and validation
|
|
self.training_threads['data_collector'] = threading.Thread(
|
|
target=self._enhanced_data_collection_worker,
|
|
daemon=True
|
|
)
|
|
self.training_threads['data_collector'].start()
|
|
|
|
# Start outcome validation if enabled
|
|
if self.config.enable_background_validation:
|
|
self.validation_thread = threading.Thread(
|
|
target=self._outcome_validation_worker,
|
|
daemon=True
|
|
)
|
|
self.validation_thread.start()
|
|
|
|
logger.info("Enhanced training integration started")
|
|
|
|
def stop_enhanced_integration(self):
|
|
"""Stop the enhanced training integration system"""
|
|
self.is_running = False
|
|
|
|
# Stop data collection
|
|
self.data_collector.stop_collection()
|
|
|
|
# Stop CNN training
|
|
self.cnn_trainer.stop_training()
|
|
|
|
# Wait for threads to finish
|
|
for thread_name, thread in self.training_threads.items():
|
|
thread.join(timeout=10)
|
|
logger.info(f"Stopped {thread_name} thread")
|
|
|
|
if self.validation_thread:
|
|
self.validation_thread.join(timeout=5)
|
|
|
|
logger.info("Enhanced training integration stopped")
|
|
|
|
def _enhanced_data_collection_worker(self):
|
|
"""Enhanced data collection with real-time model integration"""
|
|
logger.info("Enhanced data collection worker started")
|
|
|
|
while self.is_running:
|
|
try:
|
|
for symbol in self.data_provider.symbols:
|
|
self._collect_enhanced_training_data(symbol)
|
|
|
|
time.sleep(self.config.collection_interval)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in enhanced data collection: {e}")
|
|
time.sleep(5)
|
|
|
|
logger.info("Enhanced data collection worker stopped")
|
|
|
|
def _collect_enhanced_training_data(self, symbol: str):
|
|
"""Collect enhanced training data with model predictions"""
|
|
try:
|
|
# Get comprehensive market data
|
|
market_data = self._get_comprehensive_market_data(symbol)
|
|
|
|
if not market_data or not self._validate_market_data(market_data):
|
|
return
|
|
|
|
# Get current model predictions
|
|
model_predictions = self._get_all_model_predictions(symbol, market_data)
|
|
|
|
# Create enhanced features
|
|
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
|
|
rl_state = self._create_enhanced_rl_state(symbol, market_data, model_predictions)
|
|
|
|
# Collect training data with predictions
|
|
episode_id = self.data_collector.collect_training_data(
|
|
symbol=symbol,
|
|
ohlcv_data=market_data['ohlcv'],
|
|
tick_data=market_data['ticks'],
|
|
cob_data=market_data['cob'],
|
|
technical_indicators=market_data['indicators'],
|
|
pivot_points=market_data['pivots'],
|
|
cnn_features=cnn_features,
|
|
rl_state=rl_state,
|
|
orchestrator_context=market_data['context'],
|
|
model_predictions=model_predictions
|
|
)
|
|
|
|
if episode_id:
|
|
# Store predictions for outcome validation
|
|
self.recent_predictions[episode_id] = {
|
|
'timestamp': datetime.now(),
|
|
'symbol': symbol,
|
|
'predictions': model_predictions,
|
|
'market_data': market_data
|
|
}
|
|
|
|
# Add RL experience if we have action
|
|
if 'rl_action' in model_predictions:
|
|
self._add_rl_experience(symbol, market_data, model_predictions, episode_id)
|
|
|
|
self.integration_stats['total_data_packages'] += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error collecting enhanced training data for {symbol}: {e}")
|
|
|
|
def _get_comprehensive_market_data(self, symbol: str) -> Dict[str, Any]:
|
|
"""Get comprehensive market data from all sources"""
|
|
try:
|
|
market_data = {}
|
|
|
|
# OHLCV data
|
|
ohlcv_data = {}
|
|
for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']:
|
|
df = self.data_provider.get_historical_data(symbol, timeframe, limit=300, refresh=True)
|
|
if df is not None and not df.empty:
|
|
ohlcv_data[timeframe] = df
|
|
market_data['ohlcv'] = ohlcv_data
|
|
|
|
# Tick data
|
|
market_data['ticks'] = self._get_recent_tick_data(symbol)
|
|
|
|
# COB data
|
|
market_data['cob'] = self._get_cob_data(symbol)
|
|
|
|
# Technical indicators
|
|
market_data['indicators'] = self._get_technical_indicators(symbol)
|
|
|
|
# Pivot points
|
|
market_data['pivots'] = self._get_pivot_points(symbol)
|
|
|
|
# Market context
|
|
market_data['context'] = self._get_market_context(symbol)
|
|
|
|
return market_data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting comprehensive market data: {e}")
|
|
return {}
|
|
|
|
def _get_all_model_predictions(self, symbol: str, market_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Get predictions from all available models"""
|
|
predictions = {}
|
|
|
|
try:
|
|
# CNN predictions
|
|
if self.cnn_model and market_data.get('ohlcv'):
|
|
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
|
|
if cnn_features is not None:
|
|
cnn_input = torch.from_numpy(cnn_features).float().unsqueeze(0)
|
|
|
|
# Reshape for CNN (add channel dimension)
|
|
cnn_input = cnn_input.view(1, 10, -1) # Assuming 10 channels
|
|
|
|
with torch.no_grad():
|
|
cnn_outputs = self.cnn_model(cnn_input)
|
|
predictions['cnn'] = {
|
|
'pivot_logits': cnn_outputs['pivot_logits'].cpu().numpy(),
|
|
'pivot_price': cnn_outputs['pivot_price'].cpu().numpy(),
|
|
'confidence': cnn_outputs['confidence'].cpu().numpy(),
|
|
'timestamp': datetime.now()
|
|
}
|
|
|
|
# RL predictions
|
|
if self.rl_agent and market_data.get('cob'):
|
|
rl_state = self._create_enhanced_rl_state(symbol, market_data, predictions)
|
|
if rl_state is not None:
|
|
action, confidence = self.rl_agent.select_action(rl_state, epsilon=0.1)
|
|
predictions['rl'] = {
|
|
'action': action,
|
|
'confidence': confidence,
|
|
'timestamp': datetime.now()
|
|
}
|
|
predictions['rl_action'] = action
|
|
|
|
# Existing COB RL model predictions
|
|
if self.existing_rl_model and market_data.get('cob'):
|
|
cob_features = market_data['cob'].get('cob_features', [])
|
|
if cob_features and len(cob_features) >= 2000:
|
|
cob_array = np.array(cob_features[:2000], dtype=np.float32)
|
|
cob_prediction = self.existing_rl_model.predict(cob_array)
|
|
predictions['cob_rl'] = {
|
|
'predicted_direction': cob_prediction.get('predicted_direction', 1),
|
|
'confidence': cob_prediction.get('confidence', 0.5),
|
|
'value': cob_prediction.get('value', 0.0),
|
|
'timestamp': datetime.now()
|
|
}
|
|
|
|
# Orchestrator predictions (if available)
|
|
if self.orchestrator:
|
|
try:
|
|
# This would integrate with your orchestrator's prediction method
|
|
orchestrator_prediction = self._get_orchestrator_prediction(symbol, market_data, predictions)
|
|
if orchestrator_prediction:
|
|
predictions['orchestrator'] = orchestrator_prediction
|
|
except Exception as e:
|
|
logger.debug(f"Could not get orchestrator prediction: {e}")
|
|
|
|
return predictions
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting model predictions: {e}")
|
|
return {}
|
|
|
|
def _add_rl_experience(self, symbol: str, market_data: Dict[str, Any],
|
|
predictions: Dict[str, Any], episode_id: str):
|
|
"""Add RL experience to the training buffer"""
|
|
try:
|
|
# Create RL state
|
|
state = self._create_enhanced_rl_state(symbol, market_data, predictions)
|
|
if state is None:
|
|
return
|
|
|
|
# Get action from predictions
|
|
action = predictions.get('rl_action', 1) # Default to HOLD
|
|
|
|
# Calculate immediate reward (placeholder - would be updated with actual outcome)
|
|
reward = 0.0
|
|
|
|
# Create next state (same as current for now - would be updated)
|
|
next_state = state.copy()
|
|
|
|
# Market context
|
|
market_context = {
|
|
'symbol': symbol,
|
|
'episode_id': episode_id,
|
|
'timestamp': datetime.now(),
|
|
'market_session': market_data['context'].get('market_session', 'unknown'),
|
|
'volatility_regime': market_data['context'].get('volatility_regime', 'unknown')
|
|
}
|
|
|
|
# Add experience
|
|
experience_id = self.rl_trainer.add_experience(
|
|
state=state,
|
|
action=action,
|
|
reward=reward,
|
|
next_state=next_state,
|
|
done=False,
|
|
market_context=market_context,
|
|
cnn_predictions=predictions.get('cnn'),
|
|
confidence_score=predictions.get('rl', {}).get('confidence', 0.0)
|
|
)
|
|
|
|
if experience_id:
|
|
logger.debug(f"Added RL experience: {experience_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding RL experience: {e}")
|
|
|
|
def _training_coordinator_worker(self):
|
|
"""Coordinate training across all models"""
|
|
logger.info("Training coordinator worker started")
|
|
|
|
while self.is_running:
|
|
try:
|
|
# Check if we should trigger training
|
|
for symbol in self.data_provider.symbols:
|
|
self._check_and_trigger_training(symbol)
|
|
|
|
# Wait before next check
|
|
time.sleep(self.config.training_frequency_minutes * 60)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training coordinator: {e}")
|
|
time.sleep(60)
|
|
|
|
logger.info("Training coordinator worker stopped")
|
|
|
|
def _check_and_trigger_training(self, symbol: str):
|
|
"""Check conditions and trigger training if needed"""
|
|
try:
|
|
# Get training episodes and experiences
|
|
episodes = self.data_collector.get_high_priority_episodes(symbol, limit=1000)
|
|
|
|
# Check CNN training conditions
|
|
if len(episodes) >= self.config.min_episodes_for_cnn_training:
|
|
profitable_episodes = [ep for ep in episodes if ep.actual_outcome.is_profitable]
|
|
|
|
if len(profitable_episodes) >= 20: # Minimum profitable episodes
|
|
logger.info(f"Triggering CNN training for {symbol} with {len(profitable_episodes)} profitable episodes")
|
|
|
|
results = self.cnn_trainer.train_on_profitable_episodes(
|
|
symbol=symbol,
|
|
min_profitability=self.config.min_profitability_for_replay,
|
|
max_episodes=len(profitable_episodes)
|
|
)
|
|
|
|
if results.get('status') == 'success':
|
|
self.integration_stats['cnn_training_sessions'] += 1
|
|
logger.info(f"CNN training completed for {symbol}")
|
|
|
|
# Check RL training conditions
|
|
buffer_stats = self.rl_trainer.experience_buffer.get_buffer_statistics()
|
|
total_experiences = buffer_stats.get('total_experiences', 0)
|
|
|
|
if total_experiences >= self.config.min_experiences_for_rl_training:
|
|
profitable_experiences = buffer_stats.get('profitable_experiences', 0)
|
|
|
|
if profitable_experiences >= 50: # Minimum profitable experiences
|
|
logger.info(f"Triggering RL training with {profitable_experiences} profitable experiences")
|
|
|
|
results = self.rl_trainer.train_on_profitable_experiences(
|
|
min_profitability=self.config.min_profitability_for_replay,
|
|
max_experiences=min(profitable_experiences, 500),
|
|
batch_size=32
|
|
)
|
|
|
|
if results.get('status') == 'success':
|
|
self.integration_stats['rl_training_sessions'] += 1
|
|
logger.info("RL training completed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking training conditions for {symbol}: {e}")
|
|
|
|
def _outcome_validation_worker(self):
|
|
"""Background worker for validating prediction outcomes"""
|
|
logger.info("Outcome validation worker started")
|
|
|
|
while self.is_running:
|
|
try:
|
|
self._validate_recent_predictions()
|
|
time.sleep(300) # Check every 5 minutes
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in outcome validation: {e}")
|
|
time.sleep(60)
|
|
|
|
logger.info("Outcome validation worker stopped")
|
|
|
|
def _validate_recent_predictions(self):
|
|
"""Validate recent predictions against actual outcomes"""
|
|
try:
|
|
current_time = datetime.now()
|
|
validation_delay = timedelta(hours=1) # Wait 1 hour to validate
|
|
|
|
validated_predictions = []
|
|
|
|
for episode_id, prediction_data in self.recent_predictions.items():
|
|
prediction_time = prediction_data['timestamp']
|
|
|
|
if current_time - prediction_time >= validation_delay:
|
|
# Validate this prediction
|
|
outcome = self._calculate_prediction_outcome(prediction_data)
|
|
|
|
if outcome:
|
|
self.prediction_outcomes[episode_id] = outcome
|
|
|
|
# Update RL experience if exists
|
|
if 'rl_action' in prediction_data['predictions']:
|
|
self._update_rl_experience_outcome(episode_id, outcome)
|
|
|
|
# Update statistics
|
|
if outcome['is_profitable']:
|
|
self.integration_stats['profitable_predictions'] += 1
|
|
self.integration_stats['total_predictions'] += 1
|
|
|
|
validated_predictions.append(episode_id)
|
|
|
|
# Remove validated predictions
|
|
for episode_id in validated_predictions:
|
|
del self.recent_predictions[episode_id]
|
|
|
|
if validated_predictions:
|
|
logger.info(f"Validated {len(validated_predictions)} predictions")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating predictions: {e}")
|
|
|
|
def _calculate_prediction_outcome(self, prediction_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
"""Calculate actual outcome for a prediction"""
|
|
try:
|
|
symbol = prediction_data['symbol']
|
|
prediction_time = prediction_data['timestamp']
|
|
|
|
# Get price data after prediction
|
|
current_df = self.data_provider.get_historical_data(symbol, '1m', limit=100, refresh=True)
|
|
|
|
if current_df is None or current_df.empty:
|
|
return None
|
|
|
|
# Find price at prediction time and current price
|
|
prediction_price = prediction_data['market_data']['ohlcv'].get('1m', pd.DataFrame())
|
|
if prediction_price.empty:
|
|
return None
|
|
|
|
base_price = float(prediction_price['close'].iloc[-1])
|
|
current_price = float(current_df['close'].iloc[-1])
|
|
|
|
# Calculate outcome
|
|
price_change = (current_price - base_price) / base_price
|
|
is_profitable = abs(price_change) > 0.005 # 0.5% threshold
|
|
|
|
return {
|
|
'episode_id': prediction_data.get('episode_id'),
|
|
'base_price': base_price,
|
|
'current_price': current_price,
|
|
'price_change': price_change,
|
|
'is_profitable': is_profitable,
|
|
'profitability_score': abs(price_change) * 10, # Scale to 0-1 range
|
|
'validation_time': datetime.now()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating prediction outcome: {e}")
|
|
return None
|
|
|
|
def _update_rl_experience_outcome(self, episode_id: str, outcome: Dict[str, Any]):
|
|
"""Update RL experience with actual outcome"""
|
|
try:
|
|
# Find the experience ID associated with this episode
|
|
# This is a simplified approach - in practice you'd maintain better mapping
|
|
actual_profit = outcome['price_change']
|
|
|
|
# Determine optimal action based on outcome
|
|
if outcome['price_change'] > 0.01:
|
|
optimal_action = 2 # BUY
|
|
elif outcome['price_change'] < -0.01:
|
|
optimal_action = 0 # SELL
|
|
else:
|
|
optimal_action = 1 # HOLD
|
|
|
|
# Update experience (this would need proper experience ID mapping)
|
|
# For now, we'll update the most recent experience
|
|
# In practice, you'd maintain a mapping between episodes and experiences
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating RL experience outcome: {e}")
|
|
|
|
def get_integration_statistics(self) -> Dict[str, Any]:
|
|
"""Get comprehensive integration statistics"""
|
|
stats = self.integration_stats.copy()
|
|
|
|
# Add component statistics
|
|
stats['data_collector'] = self.data_collector.get_collection_statistics()
|
|
stats['cnn_trainer'] = self.cnn_trainer.get_training_statistics()
|
|
stats['rl_trainer'] = self.rl_trainer.get_training_statistics()
|
|
|
|
# Add performance metrics
|
|
stats['is_running'] = self.is_running
|
|
stats['active_symbols'] = len(self.data_provider.symbols)
|
|
stats['recent_predictions_count'] = len(self.recent_predictions)
|
|
stats['validated_outcomes_count'] = len(self.prediction_outcomes)
|
|
|
|
# Calculate profitability rate
|
|
if stats['total_predictions'] > 0:
|
|
stats['overall_profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
|
|
else:
|
|
stats['overall_profitability_rate'] = 0.0
|
|
|
|
return stats
|
|
|
|
def trigger_manual_training(self, training_type: str = 'all', symbol: str = None) -> Dict[str, Any]:
|
|
"""Manually trigger training"""
|
|
results = {}
|
|
|
|
try:
|
|
if training_type in ['all', 'cnn']:
|
|
symbols = [symbol] if symbol else self.data_provider.symbols
|
|
for sym in symbols:
|
|
cnn_results = self.cnn_trainer.train_on_profitable_episodes(
|
|
symbol=sym,
|
|
min_profitability=0.1,
|
|
max_episodes=200
|
|
)
|
|
results[f'cnn_{sym}'] = cnn_results
|
|
|
|
if training_type in ['all', 'rl']:
|
|
rl_results = self.rl_trainer.train_on_profitable_experiences(
|
|
min_profitability=0.1,
|
|
max_experiences=500,
|
|
batch_size=32
|
|
)
|
|
results['rl'] = rl_results
|
|
|
|
return {'status': 'success', 'results': results}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in manual training trigger: {e}")
|
|
return {'status': 'error', 'error': str(e)}
|
|
|
|
# Helper methods (simplified implementations)
|
|
def _get_recent_tick_data(self, symbol: str) -> List[Dict[str, Any]]:
|
|
"""Get recent tick data"""
|
|
# Implementation would get tick data from data provider
|
|
return []
|
|
|
|
def _get_cob_data(self, symbol: str) -> Dict[str, Any]:
|
|
"""Get COB data"""
|
|
# Implementation would get COB data from data provider
|
|
return {}
|
|
|
|
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
|
|
"""Get technical indicators"""
|
|
# Implementation would get indicators from data provider
|
|
return {}
|
|
|
|
def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]:
|
|
"""Get pivot points"""
|
|
# Implementation would get pivot points from data provider
|
|
return []
|
|
|
|
def _get_market_context(self, symbol: str) -> Dict[str, Any]:
|
|
"""Get market context"""
|
|
return {
|
|
'symbol': symbol,
|
|
'timestamp': datetime.now(),
|
|
'market_session': 'unknown',
|
|
'volatility_regime': 'unknown'
|
|
}
|
|
|
|
def _validate_market_data(self, market_data: Dict[str, Any]) -> bool:
|
|
"""Validate market data completeness"""
|
|
required_fields = ['ohlcv', 'indicators']
|
|
return all(field in market_data for field in required_fields)
|
|
|
|
def _create_enhanced_cnn_features(self, symbol: str, market_data: Dict[str, Any]) -> Optional[np.ndarray]:
|
|
"""Create enhanced CNN features"""
|
|
try:
|
|
# Simplified feature creation
|
|
features = []
|
|
|
|
# Add OHLCV features
|
|
for timeframe in ['1m', '5m', '15m', '1h']:
|
|
if timeframe in market_data.get('ohlcv', {}):
|
|
df = market_data['ohlcv'][timeframe]
|
|
if not df.empty:
|
|
ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values
|
|
if len(ohlcv_values) > 0:
|
|
recent_values = ohlcv_values[-60:].flatten()
|
|
features.extend(recent_values)
|
|
|
|
# Pad to target size
|
|
target_size = 3000 # 10 channels * 300 sequence length
|
|
if len(features) < target_size:
|
|
features.extend([0.0] * (target_size - len(features)))
|
|
else:
|
|
features = features[:target_size]
|
|
|
|
return np.array(features, dtype=np.float32)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error creating CNN features: {e}")
|
|
return None
|
|
|
|
def _create_enhanced_rl_state(self, symbol: str, market_data: Dict[str, Any],
|
|
predictions: Dict[str, Any] = None) -> Optional[np.ndarray]:
|
|
"""Create enhanced RL state"""
|
|
try:
|
|
state_features = []
|
|
|
|
# Add market features
|
|
if '1m' in market_data.get('ohlcv', {}):
|
|
df = market_data['ohlcv']['1m']
|
|
if not df.empty:
|
|
latest = df.iloc[-1]
|
|
state_features.extend([
|
|
latest['open'], latest['high'],
|
|
latest['low'], latest['close'], latest['volume']
|
|
])
|
|
|
|
# Add technical indicators
|
|
indicators = market_data.get('indicators', {})
|
|
for value in indicators.values():
|
|
state_features.append(value)
|
|
|
|
# Add model predictions as features
|
|
if predictions:
|
|
if 'cnn' in predictions:
|
|
cnn_pred = predictions['cnn']
|
|
state_features.extend(cnn_pred.get('pivot_logits', [0, 0, 0]))
|
|
state_features.append(cnn_pred.get('confidence', [0.0])[0])
|
|
|
|
if 'cob_rl' in predictions:
|
|
cob_pred = predictions['cob_rl']
|
|
state_features.append(cob_pred.get('predicted_direction', 1))
|
|
state_features.append(cob_pred.get('confidence', 0.5))
|
|
|
|
# Pad to target size
|
|
target_size = 2000
|
|
if len(state_features) < target_size:
|
|
state_features.extend([0.0] * (target_size - len(state_features)))
|
|
else:
|
|
state_features = state_features[:target_size]
|
|
|
|
return np.array(state_features, dtype=np.float32)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error creating RL state: {e}")
|
|
return None
|
|
|
|
def _get_orchestrator_prediction(self, symbol: str, market_data: Dict[str, Any],
|
|
predictions: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
"""Get orchestrator prediction"""
|
|
# This would integrate with your orchestrator
|
|
return None
|
|
|
|
# Global instance
|
|
enhanced_training_integration = None
|
|
|
|
def get_enhanced_training_integration(data_provider: DataProvider = None,
|
|
orchestrator: Orchestrator = None,
|
|
trading_executor: TradingExecutor = None) -> EnhancedTrainingIntegration:
|
|
"""Get global enhanced training integration instance"""
|
|
global enhanced_training_integration
|
|
if enhanced_training_integration is None:
|
|
if data_provider is None:
|
|
raise ValueError("DataProvider required for first initialization")
|
|
enhanced_training_integration = EnhancedTrainingIntegration(
|
|
data_provider, orchestrator, trading_executor
|
|
)
|
|
return enhanced_training_integration |