cleanup, cob ladder still broken
This commit is contained in:
@ -34,7 +34,7 @@ class COBIntegration:
|
||||
Integration layer for Multi-Exchange COB data with gogo2 trading system
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider = None, symbols: List[str] = None):
|
||||
def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None):
|
||||
"""
|
||||
Initialize COB Integration
|
||||
|
||||
@ -45,15 +45,8 @@ class COBIntegration:
|
||||
self.data_provider = data_provider
|
||||
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
|
||||
|
||||
# Initialize COB provider
|
||||
self.cob_provider = MultiExchangeCOBProvider(
|
||||
symbols=self.symbols,
|
||||
bucket_size_bps=1.0 # 1 basis point granularity
|
||||
)
|
||||
|
||||
# Register callbacks
|
||||
self.cob_provider.subscribe_to_cob_updates(self._on_cob_update)
|
||||
self.cob_provider.subscribe_to_bucket_updates(self._on_bucket_update)
|
||||
# Initialize COB provider to None, will be set in start()
|
||||
self.cob_provider = None
|
||||
|
||||
# CNN/DQN integration
|
||||
self.cnn_callbacks: List[Callable] = []
|
||||
@ -75,13 +68,23 @@ class COBIntegration:
|
||||
self.liquidity_alerts[symbol] = []
|
||||
self.arbitrage_opportunities[symbol] = []
|
||||
|
||||
logger.info("COB Integration initialized")
|
||||
logger.info("COB Integration initialized (provider will be started in async)")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
|
||||
async def start(self):
|
||||
"""Start COB integration"""
|
||||
logger.info("Starting COB Integration")
|
||||
|
||||
# Initialize COB provider here, within the async context
|
||||
self.cob_provider = MultiExchangeCOBProvider(
|
||||
symbols=self.symbols,
|
||||
bucket_size_bps=1.0 # 1 basis point granularity
|
||||
)
|
||||
|
||||
# Register callbacks
|
||||
self.cob_provider.subscribe_to_cob_updates(self._on_cob_update)
|
||||
self.cob_provider.subscribe_to_bucket_updates(self._on_bucket_update)
|
||||
|
||||
# Start COB provider
|
||||
await self.cob_provider.start_streaming()
|
||||
|
||||
@ -94,7 +97,8 @@ class COBIntegration:
|
||||
async def stop(self):
|
||||
"""Stop COB integration"""
|
||||
logger.info("Stopping COB Integration")
|
||||
await self.cob_provider.stop_streaming()
|
||||
if self.cob_provider:
|
||||
await self.cob_provider.stop_streaming()
|
||||
logger.info("COB Integration stopped")
|
||||
|
||||
def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
|
||||
@ -293,7 +297,9 @@ class COBIntegration:
|
||||
"""Generate formatted data for dashboard visualization"""
|
||||
try:
|
||||
# Get fixed bucket size for the symbol
|
||||
bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0)
|
||||
bucket_size = 1.0 # Default bucket size
|
||||
if self.cob_provider:
|
||||
bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0)
|
||||
|
||||
# Calculate price range for buckets
|
||||
mid_price = cob_snapshot.volume_weighted_mid
|
||||
@ -338,15 +344,16 @@ class COBIntegration:
|
||||
|
||||
# Get actual Session Volume Profile (SVP) from trade data
|
||||
svp_data = []
|
||||
try:
|
||||
svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size)
|
||||
if svp_result and 'data' in svp_result:
|
||||
svp_data = svp_result['data']
|
||||
logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels")
|
||||
else:
|
||||
logger.warning(f"No SVP data available for {symbol}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting SVP data for {symbol}: {e}")
|
||||
if self.cob_provider:
|
||||
try:
|
||||
svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size)
|
||||
if svp_result and 'data' in svp_result:
|
||||
svp_data = svp_result['data']
|
||||
logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels")
|
||||
else:
|
||||
logger.warning(f"No SVP data available for {symbol}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting SVP data for {symbol}: {e}")
|
||||
|
||||
# Generate market stats
|
||||
stats = {
|
||||
@ -381,19 +388,21 @@ class COBIntegration:
|
||||
stats['svp_price_levels'] = 0
|
||||
stats['session_start'] = ''
|
||||
|
||||
# Add real-time statistics for NN models
|
||||
try:
|
||||
realtime_stats = self.cob_provider.get_realtime_stats(symbol)
|
||||
if realtime_stats:
|
||||
stats['realtime_1s'] = realtime_stats.get('1s_stats', {})
|
||||
stats['realtime_5s'] = realtime_stats.get('5s_stats', {})
|
||||
else:
|
||||
# Get additional real-time stats
|
||||
realtime_stats = {}
|
||||
if self.cob_provider:
|
||||
try:
|
||||
realtime_stats = self.cob_provider.get_realtime_stats(symbol)
|
||||
if realtime_stats:
|
||||
stats['realtime_1s'] = realtime_stats.get('1s_stats', {})
|
||||
stats['realtime_5s'] = realtime_stats.get('5s_stats', {})
|
||||
else:
|
||||
stats['realtime_1s'] = {}
|
||||
stats['realtime_5s'] = {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting real-time stats for {symbol}: {e}")
|
||||
stats['realtime_1s'] = {}
|
||||
stats['realtime_5s'] = {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting real-time stats for {symbol}: {e}")
|
||||
stats['realtime_1s'] = {}
|
||||
stats['realtime_5s'] = {}
|
||||
|
||||
return {
|
||||
'type': 'cob_update',
|
||||
@ -463,9 +472,10 @@ class COBIntegration:
|
||||
while True:
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol)
|
||||
if cob_snapshot:
|
||||
await self._analyze_cob_patterns(symbol, cob_snapshot)
|
||||
if self.cob_provider:
|
||||
cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol)
|
||||
if cob_snapshot:
|
||||
await self._analyze_cob_patterns(symbol, cob_snapshot)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@ -540,18 +550,26 @@ class COBIntegration:
|
||||
|
||||
def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]:
|
||||
"""Get latest COB snapshot for a symbol"""
|
||||
if not self.cob_provider:
|
||||
return None
|
||||
return self.cob_provider.get_consolidated_orderbook(symbol)
|
||||
|
||||
def get_market_depth_analysis(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get detailed market depth analysis"""
|
||||
if not self.cob_provider:
|
||||
return None
|
||||
return self.cob_provider.get_market_depth_analysis(symbol)
|
||||
|
||||
def get_exchange_breakdown(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get liquidity breakdown by exchange"""
|
||||
if not self.cob_provider:
|
||||
return None
|
||||
return self.cob_provider.get_exchange_breakdown(symbol)
|
||||
|
||||
def get_price_buckets(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get fine-grain price buckets"""
|
||||
if not self.cob_provider:
|
||||
return None
|
||||
return self.cob_provider.get_price_buckets(symbol)
|
||||
|
||||
def get_recent_signals(self, symbol: str, count: int = 20) -> List[Dict]:
|
||||
@ -560,6 +578,16 @@ class COBIntegration:
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get COB integration statistics"""
|
||||
if not self.cob_provider:
|
||||
return {
|
||||
'cnn_callbacks': len(self.cnn_callbacks),
|
||||
'dqn_callbacks': len(self.dqn_callbacks),
|
||||
'dashboard_callbacks': len(self.dashboard_callbacks),
|
||||
'cached_features': list(self.cob_feature_cache.keys()),
|
||||
'total_signals': {symbol: len(signals) for symbol, signals in self.cob_signals.items()},
|
||||
'provider_status': 'Not initialized'
|
||||
}
|
||||
|
||||
provider_stats = self.cob_provider.get_statistics()
|
||||
|
||||
return {
|
||||
@ -574,6 +602,11 @@ class COBIntegration:
|
||||
def get_realtime_stats_for_nn(self, symbol: str) -> Dict:
|
||||
"""Get real-time statistics formatted for NN models"""
|
||||
try:
|
||||
# Check if COB provider is initialized
|
||||
if not self.cob_provider:
|
||||
logger.debug(f"COB provider not initialized yet for {symbol}")
|
||||
return {}
|
||||
|
||||
realtime_stats = self.cob_provider.get_realtime_stats(symbol)
|
||||
if not realtime_stats:
|
||||
return {}
|
||||
@ -608,4 +641,66 @@ class COBIntegration:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting NN stats for {symbol}: {e}")
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def get_realtime_stats(self):
|
||||
# Added null check to ensure the COB provider is initialized
|
||||
if self.cob_provider is None:
|
||||
logger.warning("COB provider is uninitialized; attempting initialization.")
|
||||
self.initialize_provider()
|
||||
if self.cob_provider is None:
|
||||
logger.error("COB provider failed to initialize; returning default empty snapshot.")
|
||||
return COBSnapshot(
|
||||
symbol="",
|
||||
timestamp=0,
|
||||
exchanges_active=0,
|
||||
total_bid_liquidity=0,
|
||||
total_ask_liquidity=0,
|
||||
price_buckets=[],
|
||||
volume_weighted_mid=0,
|
||||
spread_bps=0,
|
||||
liquidity_imbalance=0,
|
||||
consolidated_bids=[],
|
||||
consolidated_asks=[]
|
||||
)
|
||||
try:
|
||||
snapshot = self.cob_provider.get_realtime_stats()
|
||||
return snapshot
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving COB snapshot: {e}")
|
||||
return COBSnapshot(
|
||||
symbol="",
|
||||
timestamp=0,
|
||||
exchanges_active=0,
|
||||
total_bid_liquidity=0,
|
||||
total_ask_liquidity=0,
|
||||
price_buckets=[],
|
||||
volume_weighted_mid=0,
|
||||
spread_bps=0,
|
||||
liquidity_imbalance=0,
|
||||
consolidated_bids=[],
|
||||
consolidated_asks=[]
|
||||
)
|
||||
|
||||
def stop_streaming(self):
|
||||
pass
|
||||
|
||||
def _initialize_cob_integration(self):
|
||||
"""Initialize COB integration with high-frequency data handling"""
|
||||
logger.info("Initializing COB integration...")
|
||||
if not COB_INTEGRATION_AVAILABLE:
|
||||
logger.warning("COB integration not available - skipping initialization")
|
||||
return
|
||||
|
||||
try:
|
||||
if not hasattr(self.orchestrator, 'cob_integration') or self.orchestrator.cob_integration is None:
|
||||
logger.info("Creating new COB integration instance")
|
||||
self.orchestrator.cob_integration = COBIntegration(self.data_provider)
|
||||
else:
|
||||
logger.info("Using existing COB integration from orchestrator")
|
||||
|
||||
# Start simple COB data collection for both symbols
|
||||
self._start_simple_cob_collection()
|
||||
logger.info("COB integration initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing COB integration: {e}")
|
@ -33,7 +33,7 @@ except ImportError:
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable, Union
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable, Union, Awaitable
|
||||
from collections import deque, defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread, Lock
|
||||
@ -194,6 +194,11 @@ class MultiExchangeCOBProvider:
|
||||
# Thread safety
|
||||
self.data_lock = asyncio.Lock()
|
||||
|
||||
# Initialize aiohttp session and connector to None, will be set up in start_streaming
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
self.connector: Optional[aiohttp.TCPConnector] = None
|
||||
self.rest_session: Optional[aiohttp.ClientSession] = None # Added for explicit None initialization
|
||||
|
||||
# Create REST API session
|
||||
# Fix for Windows aiodns issue - use ThreadedResolver instead
|
||||
connector = aiohttp.TCPConnector(
|
||||
@ -286,64 +291,62 @@ class MultiExchangeCOBProvider:
|
||||
return configs
|
||||
|
||||
async def start_streaming(self):
|
||||
"""Start streaming from all configured exchanges"""
|
||||
if self.is_streaming:
|
||||
logger.warning("COB streaming already active")
|
||||
return
|
||||
|
||||
logger.info("Starting Multi-Exchange COB streaming")
|
||||
"""Start real-time order book streaming from all configured exchanges"""
|
||||
logger.info(f"Starting COB streaming for symbols: {self.symbols}")
|
||||
self.is_streaming = True
|
||||
|
||||
# Start streaming tasks for each exchange and symbol
|
||||
# Setup aiohttp session here, within the async context
|
||||
await self._setup_http_session()
|
||||
|
||||
# Start WebSocket connections for each active exchange and symbol
|
||||
tasks = []
|
||||
|
||||
for exchange_name in self.active_exchanges:
|
||||
for symbol in self.symbols:
|
||||
# WebSocket task for real-time top 20 levels
|
||||
task = asyncio.create_task(
|
||||
self._stream_exchange_orderbook(exchange_name, symbol)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# REST API task for deep order book snapshots
|
||||
deep_task = asyncio.create_task(
|
||||
self._stream_deep_orderbook(exchange_name, symbol)
|
||||
)
|
||||
tasks.append(deep_task)
|
||||
|
||||
# Trade stream task for SVP
|
||||
if exchange_name == 'binance':
|
||||
trade_task = asyncio.create_task(
|
||||
self._stream_binance_trades(symbol)
|
||||
)
|
||||
tasks.append(trade_task)
|
||||
|
||||
# Start consolidation and analysis tasks
|
||||
tasks.extend([
|
||||
asyncio.create_task(self._continuous_consolidation()),
|
||||
asyncio.create_task(self._continuous_bucket_updates())
|
||||
])
|
||||
|
||||
# Wait for all tasks
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming tasks: {e}")
|
||||
finally:
|
||||
self.is_streaming = False
|
||||
for symbol in self.symbols:
|
||||
for exchange_name, config in self.exchange_configs.items():
|
||||
if config.enabled and exchange_name in self.active_exchanges:
|
||||
# Start WebSocket stream
|
||||
tasks.append(self._stream_exchange_orderbook(exchange_name, symbol))
|
||||
|
||||
# Start deep order book (REST API) stream
|
||||
tasks.append(self._stream_deep_orderbook(exchange_name, symbol))
|
||||
|
||||
# Start trade stream (for SVP)
|
||||
if exchange_name == 'binance': # Only Binance for now
|
||||
tasks.append(self._stream_binance_trades(symbol))
|
||||
|
||||
# Start continuous consolidation and bucket updates
|
||||
tasks.append(self._continuous_consolidation())
|
||||
tasks.append(self._continuous_bucket_updates())
|
||||
|
||||
logger.info(f"Starting {len(tasks)} COB streaming tasks")
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _setup_http_session(self):
|
||||
"""Setup aiohttp session and connector"""
|
||||
self.connector = aiohttp.TCPConnector(
|
||||
resolver=aiohttp.ThreadedResolver() # This is now created inside async function
|
||||
)
|
||||
self.session = aiohttp.ClientSession(connector=self.connector)
|
||||
self.rest_session = aiohttp.ClientSession(connector=self.connector) # Moved here from __init__
|
||||
logger.info("aiohttp session and connector setup completed")
|
||||
|
||||
async def stop_streaming(self):
|
||||
"""Stop streaming from all exchanges"""
|
||||
logger.info("Stopping Multi-Exchange COB streaming")
|
||||
"""Stop real-time order book streaming and close sessions"""
|
||||
logger.info("Stopping COB Integration")
|
||||
self.is_streaming = False
|
||||
|
||||
# Close REST API session
|
||||
if self.rest_session:
|
||||
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
logger.info("aiohttp session closed")
|
||||
|
||||
if self.rest_session and not self.rest_session.closed:
|
||||
await self.rest_session.close()
|
||||
self.rest_session = None
|
||||
|
||||
# Wait a bit for tasks to stop gracefully
|
||||
await asyncio.sleep(1)
|
||||
logger.info("aiohttp REST session closed")
|
||||
|
||||
if self.connector and not self.connector.closed:
|
||||
await self.connector.close()
|
||||
logger.info("aiohttp connector closed")
|
||||
|
||||
logger.info("COB Integration stopped")
|
||||
|
||||
async def _stream_deep_orderbook(self, exchange_name: str, symbol: str):
|
||||
"""Fetch deep order book data via REST API periodically"""
|
||||
@ -1086,12 +1089,12 @@ class MultiExchangeCOBProvider:
|
||||
|
||||
# Public interface methods
|
||||
|
||||
def subscribe_to_cob_updates(self, callback: Callable[[str, COBSnapshot], None]):
|
||||
def subscribe_to_cob_updates(self, callback: Callable[[str, COBSnapshot], Awaitable[None]]):
|
||||
"""Subscribe to consolidated order book updates"""
|
||||
self.cob_update_callbacks.append(callback)
|
||||
logger.info(f"Added COB update callback: {len(self.cob_update_callbacks)} total")
|
||||
|
||||
def subscribe_to_bucket_updates(self, callback: Callable[[str, Dict], None]):
|
||||
def subscribe_to_bucket_updates(self, callback: Callable[[str, Dict], Awaitable[None]]):
|
||||
"""Subscribe to price bucket updates"""
|
||||
self.bucket_update_callbacks.append(callback)
|
||||
logger.info(f"Added bucket update callback: {len(self.bucket_update_callbacks)} total")
|
||||
|
@ -386,7 +386,7 @@ class TradingOrchestrator:
|
||||
# Import COB integration directly (same as working dashboard)
|
||||
from core.cob_integration import COBIntegration
|
||||
|
||||
# Initialize COB integration with our symbols
|
||||
# Initialize COB integration with our symbols (but don't start it yet)
|
||||
self.cob_integration = COBIntegration(symbols=self.symbols)
|
||||
|
||||
# Register callbacks to receive real-time COB data
|
||||
@ -440,13 +440,21 @@ class TradingOrchestrator:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing COB integration: {e}")
|
||||
self.cob_integration = None
|
||||
self.cob_integration = None # Ensure it's None if init fails
|
||||
logger.info("COB integration will be disabled - models will use basic price data")
|
||||
|
||||
async def start_cob_integration(self):
|
||||
"""Start COB integration with matrix data collection"""
|
||||
try:
|
||||
if self.cob_integration:
|
||||
if not self.cob_integration:
|
||||
logger.info("COB integration not initialized yet, creating instance.")
|
||||
from core.cob_integration import COBIntegration
|
||||
self.cob_integration = COBIntegration(symbols=self.symbols)
|
||||
# Re-register callbacks if COBIntegration was just created
|
||||
self.cob_integration.add_cnn_callback(self._on_cob_cnn_features)
|
||||
self.cob_integration.add_dqn_callback(self._on_cob_dqn_features)
|
||||
self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_data)
|
||||
|
||||
logger.info("Starting COB integration with 5-minute matrix collection...")
|
||||
|
||||
# Start COB integration in background thread
|
||||
@ -480,12 +488,11 @@ class TradingOrchestrator:
|
||||
self._start_cob_matrix_worker()
|
||||
|
||||
logger.info("COB Integration started - 5-minute data matrix streaming active")
|
||||
else:
|
||||
logger.warning("COB integration is None - cannot start")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting COB integration: {e}")
|
||||
self.cob_integration = None
|
||||
logger.info("COB integration will be disabled - models will use basic price data")
|
||||
|
||||
def _start_cob_matrix_worker(self):
|
||||
"""Start background worker for COB matrix updates"""
|
||||
@ -760,7 +767,18 @@ class TradingOrchestrator:
|
||||
|
||||
def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]:
|
||||
"""Get latest COB snapshot for a symbol"""
|
||||
return self.latest_cob_data.get(symbol)
|
||||
try:
|
||||
# First try to get from COB integration (live data)
|
||||
if self.cob_integration:
|
||||
snapshot = self.cob_integration.get_cob_snapshot(symbol)
|
||||
if snapshot:
|
||||
return snapshot
|
||||
|
||||
# Fallback to cached data if COB integration not available
|
||||
return self.latest_cob_data.get(symbol)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting COB snapshot for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_cob_feature_matrix(self, symbol: str, sequence_length: int = 60) -> Optional[np.ndarray]:
|
||||
"""
|
||||
@ -1325,12 +1343,25 @@ class TradingOrchestrator:
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
# Get RL agent's action and confidence - use the actual underlying model
|
||||
# Get RL agent's action, confidence, and q_values from the underlying model
|
||||
if hasattr(model.model, 'act_with_confidence'):
|
||||
action_idx, confidence = model.model.act_with_confidence(state)
|
||||
# Call act_with_confidence and handle different return formats
|
||||
result = model.model.act_with_confidence(state)
|
||||
|
||||
if len(result) == 3:
|
||||
# EnhancedCNN format: (action, confidence, q_values)
|
||||
action_idx, confidence, raw_q_values = result
|
||||
elif len(result) == 2:
|
||||
# DQN format: (action, confidence)
|
||||
action_idx, confidence = result
|
||||
raw_q_values = None
|
||||
else:
|
||||
logger.error(f"Unexpected return format from act_with_confidence: {len(result)} values")
|
||||
return None
|
||||
elif hasattr(model.model, 'act'):
|
||||
action_idx = model.model.act(state, explore=False)
|
||||
confidence = 0.7 # Default confidence for basic act method
|
||||
raw_q_values = None # No raw q_values from simple act
|
||||
else:
|
||||
logger.error(f"RL model {model.name} has no act method")
|
||||
return None
|
||||
@ -1338,11 +1369,19 @@ class TradingOrchestrator:
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
action = action_names[action_idx]
|
||||
|
||||
# Convert raw_q_values to list if they are a tensor
|
||||
q_values_for_capture = None
|
||||
if raw_q_values is not None and hasattr(raw_q_values, 'tolist'):
|
||||
q_values_for_capture = raw_q_values.tolist()
|
||||
elif raw_q_values is not None and isinstance(raw_q_values, list):
|
||||
q_values_for_capture = raw_q_values
|
||||
|
||||
# Create prediction object
|
||||
prediction = Prediction(
|
||||
action=action,
|
||||
confidence=float(confidence),
|
||||
probabilities={action: float(confidence), 'HOLD': 1.0 - float(confidence)},
|
||||
# Use actual q_values if available, otherwise default probabilities
|
||||
probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))},
|
||||
timeframe='mixed', # RL uses mixed timeframes
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name,
|
||||
@ -1352,17 +1391,9 @@ class TradingOrchestrator:
|
||||
# Capture DQN prediction for dashboard visualization
|
||||
current_price = self._get_current_price(symbol)
|
||||
if current_price:
|
||||
# Get Q-values if available
|
||||
q_values = [0.33, 0.33, 0.34] # Default
|
||||
if hasattr(model, 'get_q_values'):
|
||||
try:
|
||||
q_values = model.get_q_values(state)
|
||||
if hasattr(q_values, 'tolist'):
|
||||
q_values = q_values.tolist()
|
||||
except:
|
||||
pass
|
||||
|
||||
self.capture_dqn_prediction(symbol, action_idx, float(confidence), current_price, q_values)
|
||||
# Only pass q_values if they exist, otherwise pass empty list
|
||||
q_values_to_pass = q_values_for_capture if q_values_for_capture is not None else []
|
||||
self.capture_dqn_prediction(symbol, action_idx, float(confidence), current_price, q_values_to_pass)
|
||||
|
||||
return prediction
|
||||
|
||||
@ -2434,11 +2465,11 @@ class TradingOrchestrator:
|
||||
|
||||
self.decision_fusion_network = DecisionFusionNet()
|
||||
logger.info("Decision fusion network initialized")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Decision fusion initialization failed: {e}")
|
||||
self.decision_fusion_enabled = False
|
||||
|
||||
|
||||
def _initialize_enhanced_training_system(self):
|
||||
"""Initialize the enhanced real-time training system"""
|
||||
try:
|
||||
@ -2599,7 +2630,7 @@ class TradingOrchestrator:
|
||||
if self.enhanced_training_system:
|
||||
self.enhanced_training_system.dashboard = dashboard
|
||||
logger.info("Dashboard reference set for enhanced training system")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting training dashboard: {e}")
|
||||
|
||||
|
@ -13,6 +13,9 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional
|
||||
import numpy as np
|
||||
from utils.reward_calculator import RewardCalculator
|
||||
import threading
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -21,8 +24,16 @@ class TrainingIntegration:
|
||||
|
||||
def __init__(self, orchestrator=None):
|
||||
self.orchestrator = orchestrator
|
||||
self.reward_calculator = RewardCalculator()
|
||||
self.training_sessions = {}
|
||||
self.min_confidence_threshold = 0.15 # Lowered from 0.3 for more aggressive training
|
||||
self.training_active = False
|
||||
self.trainer_thread = None
|
||||
self.stop_event = threading.Event()
|
||||
self.training_lock = threading.Lock()
|
||||
self.last_training_time = 0.0 if orchestrator is None else time.time()
|
||||
self.training_interval = 300 # 5 minutes between training sessions
|
||||
self.min_data_points = 100 # Minimum data points required to trigger training
|
||||
|
||||
logger.info("TrainingIntegration initialized")
|
||||
|
||||
@ -347,46 +358,32 @@ class TrainingIntegration:
|
||||
return False
|
||||
|
||||
def get_training_status(self) -> Dict[str, Any]:
|
||||
"""Get current training integration status"""
|
||||
"""Get current training status"""
|
||||
try:
|
||||
status = {
|
||||
'orchestrator_available': self.orchestrator is not None,
|
||||
'training_sessions': len(self.training_sessions),
|
||||
'last_update': datetime.now().isoformat()
|
||||
'active': self.training_active,
|
||||
'last_training_time': self.last_training_time,
|
||||
'training_sessions': self.training_sessions if self.training_sessions else {}
|
||||
}
|
||||
|
||||
if self.orchestrator:
|
||||
status['dqn_available'] = hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent is not None
|
||||
status['cnn_available'] = hasattr(self.orchestrator, 'williams_cnn') and self.orchestrator.williams_cnn is not None
|
||||
status['cob_available'] = hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration is not None
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training status: {e}")
|
||||
return {'error': str(e)}
|
||||
return {}
|
||||
|
||||
def start_training_session(self, session_name: str, config: Dict[str, Any] = None) -> str:
|
||||
"""Start a new training session"""
|
||||
try:
|
||||
session_id = f"{session_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
session_data = {
|
||||
'session_id': session_id,
|
||||
'session_name': session_name,
|
||||
'start_time': datetime.now().isoformat(),
|
||||
'config': config or {},
|
||||
self.training_sessions[session_id] = {
|
||||
'name': session_name,
|
||||
'start_time': datetime.now(),
|
||||
'config': config if config else {},
|
||||
'trades_processed': 0,
|
||||
'successful_trainings': 0,
|
||||
'failed_trainings': 0
|
||||
'training_attempts': 0,
|
||||
'successful_trainings': 0
|
||||
}
|
||||
|
||||
self.training_sessions[session_id] = session_data
|
||||
|
||||
logger.info(f"Started training session: {session_id}")
|
||||
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting training session: {e}")
|
||||
return ""
|
||||
|
Reference in New Issue
Block a user