stability

This commit is contained in:
Dobromir Popov
2025-07-28 12:10:52 +03:00
parent 9219b78241
commit fb72c93743
8 changed files with 207 additions and 53 deletions

View File

@ -1164,35 +1164,23 @@ class DQNAgent:
# Check if state is a dict or complex object # Check if state is a dict or complex object
if isinstance(state, dict): if isinstance(state, dict):
logger.error(f"State is a dict: {state}") logger.error(f"State is a dict: {state}")
# Handle empty dictionary case
if not state:
logger.error("No numerical values found in state dict, using default state")
expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
# Extract numerical values from dict if possible # Extract numerical values from dict if possible
if 'features' in state: if 'features' in state:
state = state['features'] state = state['features']
elif 'state' in state: elif 'state' in state:
state = state['state'] state = state['state']
else: else:
# Try to extract all numerical values # Try to extract all numerical values using the helper method
numerical_values = [] numerical_values = self._extract_numeric_from_dict(state)
for key, value in state.items():
if isinstance(value, (int, float)):
numerical_values.append(float(value))
elif isinstance(value, (list, np.ndarray)):
try:
# Handle nested structures safely
flattened = np.array(value).flatten()
for x in flattened:
if isinstance(x, (int, float)):
numerical_values.append(float(x))
elif hasattr(x, 'item'): # numpy scalar
numerical_values.append(float(x.item()))
except (ValueError, TypeError):
continue
elif isinstance(value, dict):
# Recursively extract from nested dicts
try:
nested_values = self._extract_numeric_from_dict(value)
numerical_values.extend(nested_values)
except Exception:
continue
if numerical_values: if numerical_values:
state = np.array(numerical_values, dtype=np.float32) state = np.array(numerical_values, dtype=np.float32)
else: else:
@ -1254,6 +1242,31 @@ class DQNAgent:
expected_size = np.prod(expected_size) expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32) return np.zeros(int(expected_size), dtype=np.float32)
def _extract_numeric_from_dict(self, data_dict):
"""Recursively extract numerical values from nested dictionaries"""
numerical_values = []
try:
for key, value in data_dict.items():
if isinstance(value, (int, float)):
numerical_values.append(float(value))
elif isinstance(value, (list, np.ndarray)):
try:
flattened = np.array(value).flatten()
for x in flattened:
if isinstance(x, (int, float)):
numerical_values.append(float(x))
elif hasattr(x, 'item'): # numpy scalar
numerical_values.append(float(x.item()))
except (ValueError, TypeError):
continue
elif isinstance(value, dict):
# Recursively extract from nested dicts
nested_values = self._extract_numeric_from_dict(value)
numerical_values.extend(nested_values)
except Exception as e:
logger.debug(f"Error extracting numeric values from dict: {e}")
return numerical_values
def _replay_standard(self, states, actions, rewards, next_states, dones): def _replay_standard(self, states, actions, rewards, next_states, dones):
"""Standard training step without mixed precision""" """Standard training step without mixed precision"""
try: try:

View File

@ -83,6 +83,13 @@ class PivotBounds:
distances = [abs(current_price - r) for r in self.pivot_resistance_levels] distances = [abs(current_price - r) for r in self.pivot_resistance_levels]
return min(distances) / self.get_price_range() return min(distances) / self.get_price_range()
@dataclass
class SimplePivotLevel:
"""Simple pivot level structure for fallback pivot detection"""
swing_points: List[Any] = field(default_factory=list)
support_levels: List[float] = field(default_factory=list)
resistance_levels: List[float] = field(default_factory=list)
@dataclass @dataclass
class MarketTick: class MarketTick:
"""Standardized market tick data structure""" """Standardized market tick data structure"""
@ -127,6 +134,10 @@ class DataProvider:
self.real_time_data = {} # {symbol: {timeframe: deque}} self.real_time_data = {} # {symbol: {timeframe: deque}}
self.current_prices = {} # {symbol: float} self.current_prices = {} # {symbol: float}
# Live price cache for low-latency price updates
self.live_price_cache: Dict[str, Tuple[float, datetime]] = {}
self.live_price_cache_ttl = timedelta(milliseconds=500)
# Initialize cached data structure # Initialize cached data structure
for symbol in self.symbols: for symbol in self.symbols:
self.cached_data[symbol] = {} self.cached_data[symbol] = {}
@ -1839,14 +1850,14 @@ class DataProvider:
low_pivots = monthly_data[lows == rolling_min]['low'].tolist() low_pivots = monthly_data[lows == rolling_min]['low'].tolist()
pivot_lows.extend(low_pivots) pivot_lows.extend(low_pivots)
# Create mock level structure # Create proper pivot level structure
mock_level = type('MockLevel', (), { pivot_level = SimplePivotLevel(
'swing_points': [], swing_points=[],
'support_levels': list(set(pivot_lows)), support_levels=list(set(pivot_lows)),
'resistance_levels': list(set(pivot_highs)) resistance_levels=list(set(pivot_highs))
})() )
return {'level_0': mock_level} return {'level_0': pivot_level}
except Exception as e: except Exception as e:
logger.error(f"Error in simple pivot detection: {e}") logger.error(f"Error in simple pivot detection: {e}")

View File

@ -1062,10 +1062,11 @@ class MultiExchangeCOBProvider:
consolidated_bids[price].exchange_breakdown[exchange_name] = level consolidated_bids[price].exchange_breakdown[exchange_name] = level
# Update dominant exchange based on volume # Update dominant exchange based on volume
if level.volume_usd > consolidated_bids[price].exchange_breakdown.get( current_dominant = consolidated_bids[price].exchange_breakdown.get(
consolidated_bids[price].dominant_exchange, consolidated_bids[price].dominant_exchange
type('obj', (object,), {'volume_usd': 0})() )
).volume_usd: current_volume = current_dominant.volume_usd if current_dominant else 0
if level.volume_usd > current_volume:
consolidated_bids[price].dominant_exchange = exchange_name consolidated_bids[price].dominant_exchange = exchange_name
# Process merged asks (similar logic) # Process merged asks (similar logic)
@ -1088,10 +1089,11 @@ class MultiExchangeCOBProvider:
consolidated_asks[price].total_orders += level.orders_count consolidated_asks[price].total_orders += level.orders_count
consolidated_asks[price].exchange_breakdown[exchange_name] = level consolidated_asks[price].exchange_breakdown[exchange_name] = level
if level.volume_usd > consolidated_asks[price].exchange_breakdown.get( current_dominant = consolidated_asks[price].exchange_breakdown.get(
consolidated_asks[price].dominant_exchange, consolidated_asks[price].dominant_exchange
type('obj', (object,), {'volume_usd': 0})() )
).volume_usd: current_volume = current_dominant.volume_usd if current_dominant else 0
if level.volume_usd > current_volume:
consolidated_asks[price].dominant_exchange = exchange_name consolidated_asks[price].dominant_exchange = exchange_name
logger.debug(f"Consolidated {len(consolidated_bids)} bids and {len(consolidated_asks)} asks for {symbol}") logger.debug(f"Consolidated {len(consolidated_bids)} bids and {len(consolidated_asks)} asks for {symbol}")

View File

@ -1494,6 +1494,17 @@ class TradingOrchestrator:
logger.warning(f"Cannot build BaseDataInput for predictions: {symbol}") logger.warning(f"Cannot build BaseDataInput for predictions: {symbol}")
return predictions return predictions
# Validate base_data has proper feature vector
if hasattr(base_data, 'get_feature_vector'):
try:
feature_vector = base_data.get_feature_vector()
if feature_vector is None or (isinstance(feature_vector, np.ndarray) and feature_vector.size == 0):
logger.warning(f"BaseDataInput has empty feature vector for {symbol}")
return predictions
except Exception as e:
logger.warning(f"Error getting feature vector from BaseDataInput for {symbol}: {e}")
return predictions
# log all registered models # log all registered models
logger.debug(f"inferencing registered models: {self.model_registry.models}") logger.debug(f"inferencing registered models: {self.model_registry.models}")
@ -1691,6 +1702,15 @@ class TradingOrchestrator:
try: try:
logger.debug(f"Storing inference for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})") logger.debug(f"Storing inference for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
# Validate model_input before storing
if model_input is None:
logger.warning(f"Skipping inference storage for {model_name}: model_input is None")
return
if isinstance(model_input, dict) and not model_input:
logger.warning(f"Skipping inference storage for {model_name}: model_input is empty dict")
return
# Extract symbol from prediction if not provided # Extract symbol from prediction if not provided
if symbol is None: if symbol is None:
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
@ -2569,6 +2589,25 @@ class TradingOrchestrator:
# Method 3: Dictionary with feature data # Method 3: Dictionary with feature data
if isinstance(model_input, dict): if isinstance(model_input, dict):
# Check if dictionary is empty
if not model_input:
logger.warning(f"Empty dictionary passed as model_input for {model_name}, using fallback")
# Try to use data provider to build state as fallback
if hasattr(self, 'data_provider'):
try:
base_data = self.data_provider.build_base_data_input('ETH/USDT')
if base_data and hasattr(base_data, 'get_feature_vector'):
state = base_data.get_feature_vector()
if isinstance(state, np.ndarray):
logger.debug(f"Used data provider fallback for empty dict in {model_name}")
return state
except Exception as e:
logger.debug(f"Data provider fallback failed for empty dict in {model_name}: {e}")
# Final fallback: return default state
logger.warning(f"Using default state for empty dict in {model_name}")
return np.zeros(403, dtype=np.float32) # Default state size
# Try to extract features from dictionary # Try to extract features from dictionary
if 'features' in model_input: if 'features' in model_input:
features = model_input['features'] features = model_input['features']
@ -2589,6 +2628,8 @@ class TradingOrchestrator:
if feature_list: if feature_list:
return np.array(feature_list, dtype=np.float32) return np.array(feature_list, dtype=np.float32)
else:
logger.warning(f"No numerical features found in dictionary for {model_name}, using fallback")
# Method 4: List or tuple # Method 4: List or tuple
if isinstance(model_input, (list, tuple)): if isinstance(model_input, (list, tuple)):

24
main.py
View File

@ -65,16 +65,27 @@ async def run_web_dashboard():
except Exception as e: except Exception as e:
logger.warning(f"[WARNING] Real-time streaming failed: {e}") logger.warning(f"[WARNING] Real-time streaming failed: {e}")
# Verify data connection # Verify data connection with retry mechanism
logger.info("[DATA] Verifying live data connection...") logger.info("[DATA] Verifying live data connection...")
symbol = config.get('symbols', ['ETH/USDT'])[0] symbol = config.get('symbols', ['ETH/USDT'])[0]
# Wait for data provider to initialize and fetch initial data
max_retries = 10
retry_delay = 2
for attempt in range(max_retries):
test_df = data_provider.get_historical_data(symbol, '1m', limit=10) test_df = data_provider.get_historical_data(symbol, '1m', limit=10)
if test_df is not None and len(test_df) > 0: if test_df is not None and len(test_df) > 0:
logger.info("[SUCCESS] Data connection verified") logger.info("[SUCCESS] Data connection verified")
logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation") logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation")
break
else: else:
logger.error("[ERROR] Data connection failed - no live data available") if attempt < max_retries - 1:
return logger.info(f"[DATA] Waiting for data provider to initialize... (attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(retry_delay)
else:
logger.warning("[WARNING] Data connection verification failed, but continuing with system startup")
logger.warning("The system will attempt to fetch data as needed during operation")
# Load model registry for integrated pipeline # Load model registry for integrated pipeline
try: try:
@ -122,6 +133,7 @@ async def run_web_dashboard():
logger.info("Starting training loop...") logger.info("Starting training loop...")
# Start the training loop # Start the training loop
logger.info("About to start training loop...")
await start_training_loop(orchestrator, trading_executor) await start_training_loop(orchestrator, trading_executor)
except Exception as e: except Exception as e:
@ -207,6 +219,8 @@ async def start_training_loop(orchestrator, trading_executor):
logger.info("STARTING ENHANCED TRAINING LOOP WITH COB INTEGRATION") logger.info("STARTING ENHANCED TRAINING LOOP WITH COB INTEGRATION")
logger.info("=" * 70) logger.info("=" * 70)
logger.info("Training loop function entered successfully")
# Initialize checkpoint management for training loop # Initialize checkpoint management for training loop
checkpoint_manager = get_checkpoint_manager() checkpoint_manager = get_checkpoint_manager()
training_integration = get_training_integration() training_integration = get_training_integration()
@ -222,8 +236,10 @@ async def start_training_loop(orchestrator, trading_executor):
try: try:
# Start real-time processing (Basic orchestrator doesn't have this method) # Start real-time processing (Basic orchestrator doesn't have this method)
logger.info("Checking for real-time processing capabilities...")
try: try:
if hasattr(orchestrator, 'start_realtime_processing'): if hasattr(orchestrator, 'start_realtime_processing'):
logger.info("Starting real-time processing...")
await orchestrator.start_realtime_processing() await orchestrator.start_realtime_processing()
logger.info("Real-time processing started") logger.info("Real-time processing started")
else: else:
@ -231,6 +247,8 @@ async def start_training_loop(orchestrator, trading_executor):
except Exception as e: except Exception as e:
logger.warning(f"Real-time processing not available: {e}") logger.warning(f"Real-time processing not available: {e}")
logger.info("About to enter main training loop...")
# Main training loop # Main training loop
iteration = 0 iteration = 0
while True: while True:

View File

@ -492,3 +492,56 @@ class CheckpointManager:
except Exception as e: except Exception as e:
logger.error(f"Error getting all checkpoints: {e}") logger.error(f"Error getting all checkpoints: {e}")
return [] return []
def get_checkpoint_stats(self) -> Dict[str, Any]:
"""
Get statistics about all checkpoints
Returns:
Dict[str, Any]: Statistics about checkpoints
"""
try:
stats = {
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {}
}
# Iterate through all model directories
for model_dir in os.listdir(self.checkpoint_dir):
model_path = os.path.join(self.checkpoint_dir, model_dir)
if not os.path.isdir(model_path):
continue
# Count checkpoints for this model
checkpoint_files = glob.glob(os.path.join(model_path, f"{model_dir}_*.pt"))
model_checkpoints = len(checkpoint_files)
# Calculate total size for this model
model_size_mb = 0.0
for checkpoint_file in checkpoint_files:
try:
size_bytes = os.path.getsize(checkpoint_file)
model_size_mb += size_bytes / (1024 * 1024) # Convert to MB
except OSError:
pass
stats['models'][model_dir] = {
'checkpoints': model_checkpoints,
'size_mb': round(model_size_mb, 2)
}
stats['total_checkpoints'] += model_checkpoints
stats['total_size_mb'] += model_size_mb
stats['total_size_mb'] = round(stats['total_size_mb'], 2)
return stats
except Exception as e:
logger.error(f"Error getting checkpoint stats: {e}")
return {
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {}
}

View File

@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
class TrainingIntegration: class TrainingIntegration:
def __init__(self, enable_wandb: bool = True): def __init__(self, enable_wandb: bool = True):
self.enable_wandb = enable_wandb
self.checkpoint_manager = get_checkpoint_manager() self.checkpoint_manager = get_checkpoint_manager()
@ -55,9 +56,13 @@ class TrainingIntegration:
except Exception as e: except Exception as e:
logger.warning(f"Error logging to W&B: {e}") logger.warning(f"Error logging to W&B: {e}")
# Save the model first to get the path
model_path = f"models/{model_name}_temp.pt"
torch.save(cnn_model.state_dict(), model_path)
metadata = self.checkpoint_manager.save_checkpoint( metadata = self.checkpoint_manager.save_checkpoint(
model=cnn_model,
model_name=model_name, model_name=model_name,
model_path=model_path,
model_type='cnn', model_type='cnn',
performance_metrics=performance_metrics, performance_metrics=performance_metrics,
training_metadata=training_metadata training_metadata=training_metadata
@ -114,9 +119,13 @@ class TrainingIntegration:
except Exception as e: except Exception as e:
logger.warning(f"Error logging to W&B: {e}") logger.warning(f"Error logging to W&B: {e}")
# Save the model first to get the path
model_path = f"models/{model_name}_temp.pt"
torch.save(rl_agent.state_dict() if hasattr(rl_agent, 'state_dict') else rl_agent, model_path)
metadata = self.checkpoint_manager.save_checkpoint( metadata = self.checkpoint_manager.save_checkpoint(
model=rl_agent,
model_name=model_name, model_name=model_name,
model_path=model_path,
model_type='rl', model_type='rl',
performance_metrics=performance_metrics, performance_metrics=performance_metrics,
training_metadata=training_metadata training_metadata=training_metadata

View File

@ -6056,6 +6056,7 @@ class CleanTradingDashboard:
# Fallback: create BaseDataInput from available data # Fallback: create BaseDataInput from available data
from core.data_models import BaseDataInput, OHLCVBar, COBData from core.data_models import BaseDataInput, OHLCVBar, COBData
import random
# Get OHLCV data for different timeframes - ensure we have enough data # Get OHLCV data for different timeframes - ensure we have enough data
ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300) ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300)
@ -6073,7 +6074,6 @@ class CleanTradingDashboard:
if len(bars) > 0: if len(bars) > 0:
last_bar = bars[-1] last_bar = bars[-1]
# Add small random variation to prevent identical data # Add small random variation to prevent identical data
import random
for i in range(target_count - len(bars)): for i in range(target_count - len(bars)):
# Create slight variations of the last bar # Create slight variations of the last bar
variation = random.uniform(-0.001, 0.001) # 0.1% variation variation = random.uniform(-0.001, 0.001) # 0.1% variation
@ -6090,7 +6090,6 @@ class CleanTradingDashboard:
bars.append(new_bar) bars.append(new_bar)
else: else:
# Create realistic dummy bars with variation # Create realistic dummy bars with variation
from core.data_models import OHLCVBar
base_price = 3500.0 base_price = 3500.0
for i in range(target_count): for i in range(target_count):
# Add realistic price movement # Add realistic price movement
@ -8725,6 +8724,14 @@ def signal_handler(sig, frame):
self.shutdown() # Assuming a shutdown method exists or add one self.shutdown() # Assuming a shutdown method exists or add one
sys.exit(0) sys.exit(0)
# Only set signal handlers if we're in the main thread
try:
import threading
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
else:
print("Warning: Signal handlers can only be set in main thread, skipping...")
except Exception as e:
print(f"Warning: Could not set signal handlers: {e}")