more models wireup

This commit is contained in:
Dobromir Popov
2025-06-25 21:10:53 +03:00
parent 2f712c9d6a
commit 3da454efb7
6 changed files with 792 additions and 134 deletions

View File

@ -58,20 +58,23 @@ class TradingDecision:
class TradingOrchestrator:
"""
Main orchestrator that coordinates multiple AI models for trading decisions
Enhanced Trading Orchestrator with full ML and COB integration
Coordinates CNN, DQN, and COB models for advanced trading decisions
Features real-time COB (Change of Bid) integration for market microstructure data
"""
def __init__(self, data_provider: DataProvider = None):
"""Initialize the orchestrator with COB integration"""
def __init__(self, data_provider: DataProvider = None, enhanced_rl_training: bool = True, model_registry: Dict = None):
"""Initialize the enhanced orchestrator with full ML capabilities"""
self.config = get_config()
self.data_provider = data_provider or DataProvider()
self.model_registry = get_model_registry()
self.model_registry = model_registry or get_model_registry()
self.enhanced_rl_training = enhanced_rl_training
# Configuration
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.5)
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 60)
self.symbols = self.config.get('symbols', ['ETH/USDT']) # Default symbols to trade
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.20)
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.10)
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
# Dynamic weights (will be adapted based on performance)
self.model_weights = {} # {model_name: weight}
@ -92,22 +95,85 @@ class TradingOrchestrator:
self.latest_cob_state: Dict[str, Any] = {} # {symbol: np.ndarray} - DQN state features
self.cob_feature_history: Dict[str, List] = {symbol: [] for symbol in self.symbols} # Rolling history for models
logger.info("TradingOrchestrator initialized with modular model system")
# Enhanced ML Models
self.rl_agent = None # DQN Agent
self.cnn_model = None # CNN Model for pattern recognition
self.extrema_trainer = None # Extrema/pivot trainer
self.latest_cnn_features: Dict[str, Any] = {} # CNN hidden features
self.latest_cnn_predictions: Dict[str, Any] = {} # CNN predictions
# Enhanced RL features
self.sensitivity_learning_queue = [] # For outcome-based learning
self.perfect_move_buffer = [] # Buffer for perfect move analysis
self.position_status = {} # Current positions
# Real-time processing
self.realtime_processing = False
self.realtime_tasks = []
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
logger.info(f"Confidence threshold: {self.confidence_threshold}")
logger.info(f"Decision frequency: {self.decision_frequency}s")
logger.info(f"Symbols: {self.symbols}")
# Initialize COB integration
# Initialize models and COB integration
self._initialize_ml_models()
self._initialize_cob_integration()
def _initialize_ml_models(self):
"""Initialize ML models for enhanced trading"""
try:
logger.info("Initializing ML models...")
# Initialize DQN Agent
try:
from NN.models.dqn_agent import DQNAgent
state_size = self.config.rl.get('state_size', 13800) # Enhanced with COB features
action_size = self.config.rl.get('action_space', 3)
self.rl_agent = DQNAgent(state_size=state_size, action_size=action_size)
logger.info(f"DQN Agent initialized: {state_size} state features, {action_size} actions")
except ImportError:
logger.warning("DQN Agent not available")
self.rl_agent = None
# Initialize CNN Model
try:
from NN.models.enhanced_cnn import EnhancedCNN
self.cnn_model = EnhancedCNN()
logger.info("Enhanced CNN model initialized")
except ImportError:
try:
from NN.models.cnn_model import CNNModel
self.cnn_model = CNNModel()
logger.info("Basic CNN model initialized")
except ImportError:
logger.warning("CNN model not available")
self.cnn_model = None
# Initialize Extrema Trainer
try:
from core.extrema_trainer import ExtremaTrainer
self.extrema_trainer = ExtremaTrainer(
data_provider=self.data_provider,
symbols=self.symbols
)
logger.info("Extrema trainer initialized")
except ImportError:
logger.warning("Extrema trainer not available")
self.extrema_trainer = None
logger.info("ML models initialization completed")
except Exception as e:
logger.error(f"Error initializing ML models: {e}")
def _initialize_cob_integration(self):
"""Initialize real-time COB integration for market microstructure data"""
try:
if COB_INTEGRATION_AVAILABLE:
# Initialize COB integration with our symbols
self.cob_integration = COBIntegration(
data_provider=self.data_provider,
symbols=self.symbols
)
self.cob_integration = COBIntegration(data_provider=self.data_provider, symbols=self.symbols )
# Register callbacks to receive real-time COB data
self.cob_integration.add_cnn_callback(self._on_cob_cnn_features)
@ -116,9 +182,8 @@ class TradingOrchestrator:
logger.info("COB Integration initialized - real-time market microstructure data available")
logger.info(f"COB symbols: {self.symbols}")
# Start COB integration in background
asyncio.create_task(self._start_cob_integration())
# COB integration will be started manually via start_cob_integration()
else:
logger.warning("COB Integration not available - models will use basic price data only")
@ -1177,4 +1242,165 @@ class TradingOrchestrator:
except Exception as e:
logger.warning(f"Error generating fallback prediction for {symbol}: {e}")
return None
# Enhanced Orchestrator Methods
async def start_cob_integration(self):
"""Start COB integration manually"""
try:
if self.cob_integration:
await self._start_cob_integration()
logger.info("COB Integration started successfully")
else:
logger.warning("COB Integration not available")
except Exception as e:
logger.error(f"Error starting COB integration: {e}")
async def stop_cob_integration(self):
"""Stop COB integration"""
try:
if self.cob_integration:
await self.cob_integration.stop()
logger.info("COB Integration stopped")
except Exception as e:
logger.error(f"Error stopping COB integration: {e}")
async def start_realtime_processing(self):
"""Start real-time processing"""
try:
self.realtime_processing = True
logger.info("Real-time processing started")
# Start background tasks for real-time processing
for symbol in self.symbols:
task = asyncio.create_task(self._realtime_processing_loop(symbol))
self.realtime_tasks.append(task)
except Exception as e:
logger.error(f"Error starting real-time processing: {e}")
async def stop_realtime_processing(self):
"""Stop real-time processing"""
try:
self.realtime_processing = False
# Cancel all background tasks
for task in self.realtime_tasks:
task.cancel()
self.realtime_tasks = []
logger.info("Real-time processing stopped")
except Exception as e:
logger.error(f"Error stopping real-time processing: {e}")
async def _realtime_processing_loop(self, symbol: str):
"""Real-time processing loop for a symbol"""
while self.realtime_processing:
try:
# Update CNN features
await self._update_cnn_features(symbol)
# Update RL state
await self._update_rl_state(symbol)
# Sleep between updates
await asyncio.sleep(1)
except asyncio.CancelledError:
break
except Exception as e:
logger.warning(f"Error in real-time processing for {symbol}: {e}")
await asyncio.sleep(5)
async def _update_cnn_features(self, symbol: str):
"""Update CNN features for a symbol"""
try:
if self.cnn_model and hasattr(self.cnn_model, 'extract_features'):
# Get current market data
df = self.data_provider.get_historical_data(symbol, '1m', limit=100)
if df is not None and not df.empty:
# Generate CNN features
features = self.cnn_model.extract_features(df)
if features is not None:
self.latest_cnn_features[symbol] = features
# Generate CNN predictions
if hasattr(self.cnn_model, 'predict'):
predictions = self.cnn_model.predict(df)
if predictions is not None:
self.latest_cnn_predictions[symbol] = predictions
except Exception as e:
logger.debug(f"Error updating CNN features for {symbol}: {e}")
async def _update_rl_state(self, symbol: str):
"""Update RL state for a symbol"""
try:
if self.rl_agent:
# Build comprehensive RL state
rl_state = self.build_comprehensive_rl_state(symbol)
if rl_state and hasattr(self.rl_agent, 'remember'):
# Store for training
pass
except Exception as e:
logger.debug(f"Error updating RL state for {symbol}: {e}")
async def make_coordinated_decisions(self) -> Dict[str, Any]:
"""Make coordinated trading decisions for all symbols"""
decisions = {}
try:
for symbol in self.symbols:
decision = await self.make_trading_decision(symbol)
decisions[symbol] = decision
return decisions
except Exception as e:
logger.error(f"Error making coordinated decisions: {e}")
return {}
def get_position_status(self) -> Dict[str, Any]:
"""Get current position status"""
return self.position_status.copy()
def cleanup_all_models(self):
"""Cleanup all models"""
try:
if hasattr(self.model_registry, 'cleanup_all_models'):
self.model_registry.cleanup_all_models()
else:
logger.debug("Model registry cleanup not available")
except Exception as e:
logger.error(f"Error cleaning up models: {e}")
def _get_cnn_hidden_features_for_rl_enhanced(self, symbol: str) -> Optional[List[float]]:
"""Get CNN hidden features for RL (enhanced version)"""
try:
cnn_features = self.latest_cnn_features.get(symbol)
if cnn_features is not None:
if hasattr(cnn_features, 'tolist'):
return cnn_features.tolist()[:1000] # First 1000 features
elif isinstance(cnn_features, list):
return cnn_features[:1000]
return None
except Exception as e:
logger.debug(f"Error getting CNN hidden features: {e}")
return None
def _get_pivot_analysis_features_for_rl_enhanced(self, symbol: str) -> Optional[List[float]]:
"""Get pivot analysis features for RL (enhanced version)"""
try:
if self.extrema_trainer and hasattr(self.extrema_trainer, 'get_context_features_for_model'):
pivot_features = self.extrema_trainer.get_context_features_for_model(symbol)
if pivot_features is not None:
if hasattr(pivot_features, 'tolist'):
return pivot_features.tolist()[:300] # First 300 features
elif isinstance(pivot_features, list):
return pivot_features[:300]
return None
except Exception as e:
logger.debug(f"Error getting pivot analysis features: {e}")
return None