233 lines
8.7 KiB
Python
233 lines
8.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Enhanced COB + ML Training Pipeline
|
|
|
|
Runs the complete pipeline:
|
|
Data -> COB Integration -> CNN Features -> RL States -> Model Training -> Trading Decisions
|
|
|
|
Real-time training with COB market microstructure integration.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
import time
|
|
from datetime import datetime
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from core.config import setup_logging, get_config
|
|
from core.data_provider import DataProvider
|
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
from core.trading_executor import TradingExecutor
|
|
|
|
# Setup logging
|
|
setup_logging()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class EnhancedCOBTrainer:
|
|
"""Enhanced COB + ML Training Pipeline"""
|
|
|
|
def __init__(self):
|
|
self.config = get_config()
|
|
self.symbols = ['BTC/USDT', 'ETH/USDT']
|
|
self.data_provider = DataProvider()
|
|
self.orchestrator = None
|
|
self.trading_executor = None
|
|
self.running = False
|
|
|
|
async def start_training(self):
|
|
"""Start the enhanced training pipeline"""
|
|
logger.info("=" * 80)
|
|
logger.info("ENHANCED COB + ML TRAINING PIPELINE")
|
|
logger.info("=" * 80)
|
|
logger.info("Pipeline: Data -> COB -> CNN Features -> RL States -> Model Training")
|
|
logger.info(f"Symbols: {self.symbols}")
|
|
logger.info(f"Start time: {datetime.now()}")
|
|
logger.info("=" * 80)
|
|
|
|
try:
|
|
# Initialize components
|
|
await self._initialize_components()
|
|
|
|
# Start training loop
|
|
await self._run_training_loop()
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Training interrupted by user")
|
|
except Exception as e:
|
|
logger.error(f"Training error: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
finally:
|
|
await self._cleanup()
|
|
|
|
async def _initialize_components(self):
|
|
"""Initialize all training components"""
|
|
logger.info("1. Initializing Enhanced Trading Orchestrator...")
|
|
|
|
self.orchestrator = EnhancedTradingOrchestrator(
|
|
data_provider=self.data_provider,
|
|
symbols=self.symbols,
|
|
enhanced_rl_training=True,
|
|
model_registry={}
|
|
)
|
|
|
|
logger.info("2. Starting COB Integration...")
|
|
await self.orchestrator.start_cob_integration()
|
|
|
|
logger.info("3. Starting Real-time Processing...")
|
|
await self.orchestrator.start_realtime_processing()
|
|
|
|
logger.info("4. Initializing Trading Executor...")
|
|
self.trading_executor = TradingExecutor()
|
|
|
|
logger.info("✅ All components initialized successfully")
|
|
|
|
# Wait for initial data collection
|
|
logger.info("Collecting initial data...")
|
|
await asyncio.sleep(10)
|
|
|
|
async def _run_training_loop(self):
|
|
"""Main training loop with monitoring"""
|
|
logger.info("Starting main training loop...")
|
|
self.running = True
|
|
iteration = 0
|
|
|
|
while self.running:
|
|
iteration += 1
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Make coordinated decisions (triggers CNN and RL training)
|
|
decisions = await self.orchestrator.make_coordinated_decisions()
|
|
|
|
# Process decisions
|
|
active_decisions = 0
|
|
for symbol, decision in decisions.items():
|
|
if decision and decision.action != 'HOLD':
|
|
active_decisions += 1
|
|
logger.info(f"🎯 {symbol}: {decision.action} "
|
|
f"(confidence: {decision.confidence:.3f})")
|
|
|
|
# Monitor every 5 iterations
|
|
if iteration % 5 == 0:
|
|
await self._log_training_status(iteration, active_decisions)
|
|
|
|
# Detailed monitoring every 20 iterations
|
|
if iteration % 20 == 0:
|
|
await self._detailed_monitoring(iteration)
|
|
|
|
# Sleep to maintain 5-second intervals
|
|
elapsed = time.time() - start_time
|
|
sleep_time = max(0, 5.0 - elapsed)
|
|
await asyncio.sleep(sleep_time)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training iteration {iteration}: {e}")
|
|
await asyncio.sleep(5)
|
|
|
|
async def _log_training_status(self, iteration, active_decisions):
|
|
"""Log current training status"""
|
|
logger.info(f"📊 Iteration {iteration} - Active decisions: {active_decisions}")
|
|
|
|
# Log COB integration status
|
|
for symbol in self.symbols:
|
|
cob_features = self.orchestrator.latest_cob_features.get(symbol)
|
|
cob_state = self.orchestrator.latest_cob_state.get(symbol)
|
|
|
|
if cob_features is not None:
|
|
logger.info(f" {symbol}: COB CNN features: {cob_features.shape}")
|
|
if cob_state is not None:
|
|
logger.info(f" {symbol}: COB RL state: {cob_state.shape}")
|
|
|
|
async def _detailed_monitoring(self, iteration):
|
|
"""Detailed monitoring and metrics"""
|
|
logger.info("=" * 60)
|
|
logger.info(f"DETAILED MONITORING - Iteration {iteration}")
|
|
logger.info("=" * 60)
|
|
|
|
# Performance metrics
|
|
try:
|
|
metrics = self.orchestrator.get_performance_metrics()
|
|
logger.info(f"📈 Performance Metrics:")
|
|
for key, value in metrics.items():
|
|
logger.info(f" {key}: {value}")
|
|
except Exception as e:
|
|
logger.warning(f"Could not get performance metrics: {e}")
|
|
|
|
# COB integration status
|
|
logger.info("🔄 COB Integration Status:")
|
|
for symbol in self.symbols:
|
|
try:
|
|
# Check COB features
|
|
cob_features = self.orchestrator.latest_cob_features.get(symbol)
|
|
cob_state = self.orchestrator.latest_cob_state.get(symbol)
|
|
history_len = len(self.orchestrator.cob_feature_history[symbol])
|
|
|
|
logger.info(f" {symbol}:")
|
|
logger.info(f" CNN Features: {cob_features.shape if cob_features is not None else 'None'}")
|
|
logger.info(f" RL State: {cob_state.shape if cob_state is not None else 'None'}")
|
|
logger.info(f" History Length: {history_len}")
|
|
|
|
# Get COB snapshot if available
|
|
if self.orchestrator.cob_integration:
|
|
snapshot = self.orchestrator.cob_integration.get_cob_snapshot(symbol)
|
|
if snapshot:
|
|
logger.info(f" Order Book: {len(snapshot.consolidated_bids)} bids, "
|
|
f"{len(snapshot.consolidated_asks)} asks")
|
|
logger.info(f" Mid Price: ${snapshot.volume_weighted_mid:.2f}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error checking {symbol} status: {e}")
|
|
|
|
# Model training status
|
|
logger.info("🧠 Model Training Status:")
|
|
# Add model-specific status here when available
|
|
|
|
# Position status
|
|
try:
|
|
positions = self.orchestrator.get_position_status()
|
|
logger.info(f"💼 Positions: {positions}")
|
|
except Exception as e:
|
|
logger.warning(f"Could not get position status: {e}")
|
|
|
|
logger.info("=" * 60)
|
|
|
|
async def _cleanup(self):
|
|
"""Cleanup resources"""
|
|
logger.info("Cleaning up resources...")
|
|
|
|
if self.orchestrator:
|
|
try:
|
|
await self.orchestrator.stop_realtime_processing()
|
|
logger.info("✅ Real-time processing stopped")
|
|
except Exception as e:
|
|
logger.warning(f"Error stopping real-time processing: {e}")
|
|
|
|
try:
|
|
await self.orchestrator.stop_cob_integration()
|
|
logger.info("✅ COB integration stopped")
|
|
except Exception as e:
|
|
logger.warning(f"Error stopping COB integration: {e}")
|
|
|
|
self.running = False
|
|
logger.info("🏁 Training pipeline stopped")
|
|
|
|
async def main():
|
|
"""Main entry point"""
|
|
trainer = EnhancedCOBTrainer()
|
|
await trainer.start_training()
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
print("\nTraining interrupted by user")
|
|
except Exception as e:
|
|
print(f"Training failed: {e}")
|
|
import traceback
|
|
traceback.print_exc() |