stability
This commit is contained in:
@ -1164,35 +1164,23 @@ class DQNAgent:
|
||||
# Check if state is a dict or complex object
|
||||
if isinstance(state, dict):
|
||||
logger.error(f"State is a dict: {state}")
|
||||
|
||||
# Handle empty dictionary case
|
||||
if not state:
|
||||
logger.error("No numerical values found in state dict, using default state")
|
||||
expected_size = getattr(self, 'state_size', 403)
|
||||
if isinstance(expected_size, tuple):
|
||||
expected_size = np.prod(expected_size)
|
||||
return np.zeros(int(expected_size), dtype=np.float32)
|
||||
|
||||
# Extract numerical values from dict if possible
|
||||
if 'features' in state:
|
||||
state = state['features']
|
||||
elif 'state' in state:
|
||||
state = state['state']
|
||||
else:
|
||||
# Try to extract all numerical values
|
||||
numerical_values = []
|
||||
for key, value in state.items():
|
||||
if isinstance(value, (int, float)):
|
||||
numerical_values.append(float(value))
|
||||
elif isinstance(value, (list, np.ndarray)):
|
||||
try:
|
||||
# Handle nested structures safely
|
||||
flattened = np.array(value).flatten()
|
||||
for x in flattened:
|
||||
if isinstance(x, (int, float)):
|
||||
numerical_values.append(float(x))
|
||||
elif hasattr(x, 'item'): # numpy scalar
|
||||
numerical_values.append(float(x.item()))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif isinstance(value, dict):
|
||||
# Recursively extract from nested dicts
|
||||
try:
|
||||
nested_values = self._extract_numeric_from_dict(value)
|
||||
numerical_values.extend(nested_values)
|
||||
except Exception:
|
||||
continue
|
||||
# Try to extract all numerical values using the helper method
|
||||
numerical_values = self._extract_numeric_from_dict(state)
|
||||
if numerical_values:
|
||||
state = np.array(numerical_values, dtype=np.float32)
|
||||
else:
|
||||
@ -1254,6 +1242,31 @@ class DQNAgent:
|
||||
expected_size = np.prod(expected_size)
|
||||
return np.zeros(int(expected_size), dtype=np.float32)
|
||||
|
||||
def _extract_numeric_from_dict(self, data_dict):
|
||||
"""Recursively extract numerical values from nested dictionaries"""
|
||||
numerical_values = []
|
||||
try:
|
||||
for key, value in data_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
numerical_values.append(float(value))
|
||||
elif isinstance(value, (list, np.ndarray)):
|
||||
try:
|
||||
flattened = np.array(value).flatten()
|
||||
for x in flattened:
|
||||
if isinstance(x, (int, float)):
|
||||
numerical_values.append(float(x))
|
||||
elif hasattr(x, 'item'): # numpy scalar
|
||||
numerical_values.append(float(x.item()))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif isinstance(value, dict):
|
||||
# Recursively extract from nested dicts
|
||||
nested_values = self._extract_numeric_from_dict(value)
|
||||
numerical_values.extend(nested_values)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting numeric values from dict: {e}")
|
||||
return numerical_values
|
||||
|
||||
def _replay_standard(self, states, actions, rewards, next_states, dones):
|
||||
"""Standard training step without mixed precision"""
|
||||
try:
|
||||
|
@ -83,6 +83,13 @@ class PivotBounds:
|
||||
distances = [abs(current_price - r) for r in self.pivot_resistance_levels]
|
||||
return min(distances) / self.get_price_range()
|
||||
|
||||
@dataclass
|
||||
class SimplePivotLevel:
|
||||
"""Simple pivot level structure for fallback pivot detection"""
|
||||
swing_points: List[Any] = field(default_factory=list)
|
||||
support_levels: List[float] = field(default_factory=list)
|
||||
resistance_levels: List[float] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class MarketTick:
|
||||
"""Standardized market tick data structure"""
|
||||
@ -127,6 +134,10 @@ class DataProvider:
|
||||
self.real_time_data = {} # {symbol: {timeframe: deque}}
|
||||
self.current_prices = {} # {symbol: float}
|
||||
|
||||
# Live price cache for low-latency price updates
|
||||
self.live_price_cache: Dict[str, Tuple[float, datetime]] = {}
|
||||
self.live_price_cache_ttl = timedelta(milliseconds=500)
|
||||
|
||||
# Initialize cached data structure
|
||||
for symbol in self.symbols:
|
||||
self.cached_data[symbol] = {}
|
||||
@ -1839,14 +1850,14 @@ class DataProvider:
|
||||
low_pivots = monthly_data[lows == rolling_min]['low'].tolist()
|
||||
pivot_lows.extend(low_pivots)
|
||||
|
||||
# Create mock level structure
|
||||
mock_level = type('MockLevel', (), {
|
||||
'swing_points': [],
|
||||
'support_levels': list(set(pivot_lows)),
|
||||
'resistance_levels': list(set(pivot_highs))
|
||||
})()
|
||||
# Create proper pivot level structure
|
||||
pivot_level = SimplePivotLevel(
|
||||
swing_points=[],
|
||||
support_levels=list(set(pivot_lows)),
|
||||
resistance_levels=list(set(pivot_highs))
|
||||
)
|
||||
|
||||
return {'level_0': mock_level}
|
||||
return {'level_0': pivot_level}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in simple pivot detection: {e}")
|
||||
|
@ -1062,10 +1062,11 @@ class MultiExchangeCOBProvider:
|
||||
consolidated_bids[price].exchange_breakdown[exchange_name] = level
|
||||
|
||||
# Update dominant exchange based on volume
|
||||
if level.volume_usd > consolidated_bids[price].exchange_breakdown.get(
|
||||
consolidated_bids[price].dominant_exchange,
|
||||
type('obj', (object,), {'volume_usd': 0})()
|
||||
).volume_usd:
|
||||
current_dominant = consolidated_bids[price].exchange_breakdown.get(
|
||||
consolidated_bids[price].dominant_exchange
|
||||
)
|
||||
current_volume = current_dominant.volume_usd if current_dominant else 0
|
||||
if level.volume_usd > current_volume:
|
||||
consolidated_bids[price].dominant_exchange = exchange_name
|
||||
|
||||
# Process merged asks (similar logic)
|
||||
@ -1088,10 +1089,11 @@ class MultiExchangeCOBProvider:
|
||||
consolidated_asks[price].total_orders += level.orders_count
|
||||
consolidated_asks[price].exchange_breakdown[exchange_name] = level
|
||||
|
||||
if level.volume_usd > consolidated_asks[price].exchange_breakdown.get(
|
||||
consolidated_asks[price].dominant_exchange,
|
||||
type('obj', (object,), {'volume_usd': 0})()
|
||||
).volume_usd:
|
||||
current_dominant = consolidated_asks[price].exchange_breakdown.get(
|
||||
consolidated_asks[price].dominant_exchange
|
||||
)
|
||||
current_volume = current_dominant.volume_usd if current_dominant else 0
|
||||
if level.volume_usd > current_volume:
|
||||
consolidated_asks[price].dominant_exchange = exchange_name
|
||||
|
||||
logger.debug(f"Consolidated {len(consolidated_bids)} bids and {len(consolidated_asks)} asks for {symbol}")
|
||||
|
@ -1494,6 +1494,17 @@ class TradingOrchestrator:
|
||||
logger.warning(f"Cannot build BaseDataInput for predictions: {symbol}")
|
||||
return predictions
|
||||
|
||||
# Validate base_data has proper feature vector
|
||||
if hasattr(base_data, 'get_feature_vector'):
|
||||
try:
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
if feature_vector is None or (isinstance(feature_vector, np.ndarray) and feature_vector.size == 0):
|
||||
logger.warning(f"BaseDataInput has empty feature vector for {symbol}")
|
||||
return predictions
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting feature vector from BaseDataInput for {symbol}: {e}")
|
||||
return predictions
|
||||
|
||||
# log all registered models
|
||||
logger.debug(f"inferencing registered models: {self.model_registry.models}")
|
||||
|
||||
@ -1691,6 +1702,15 @@ class TradingOrchestrator:
|
||||
try:
|
||||
logger.debug(f"Storing inference for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
|
||||
|
||||
# Validate model_input before storing
|
||||
if model_input is None:
|
||||
logger.warning(f"Skipping inference storage for {model_name}: model_input is None")
|
||||
return
|
||||
|
||||
if isinstance(model_input, dict) and not model_input:
|
||||
logger.warning(f"Skipping inference storage for {model_name}: model_input is empty dict")
|
||||
return
|
||||
|
||||
# Extract symbol from prediction if not provided
|
||||
if symbol is None:
|
||||
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
|
||||
@ -2569,6 +2589,25 @@ class TradingOrchestrator:
|
||||
|
||||
# Method 3: Dictionary with feature data
|
||||
if isinstance(model_input, dict):
|
||||
# Check if dictionary is empty
|
||||
if not model_input:
|
||||
logger.warning(f"Empty dictionary passed as model_input for {model_name}, using fallback")
|
||||
# Try to use data provider to build state as fallback
|
||||
if hasattr(self, 'data_provider'):
|
||||
try:
|
||||
base_data = self.data_provider.build_base_data_input('ETH/USDT')
|
||||
if base_data and hasattr(base_data, 'get_feature_vector'):
|
||||
state = base_data.get_feature_vector()
|
||||
if isinstance(state, np.ndarray):
|
||||
logger.debug(f"Used data provider fallback for empty dict in {model_name}")
|
||||
return state
|
||||
except Exception as e:
|
||||
logger.debug(f"Data provider fallback failed for empty dict in {model_name}: {e}")
|
||||
|
||||
# Final fallback: return default state
|
||||
logger.warning(f"Using default state for empty dict in {model_name}")
|
||||
return np.zeros(403, dtype=np.float32) # Default state size
|
||||
|
||||
# Try to extract features from dictionary
|
||||
if 'features' in model_input:
|
||||
features = model_input['features']
|
||||
@ -2589,6 +2628,8 @@ class TradingOrchestrator:
|
||||
|
||||
if feature_list:
|
||||
return np.array(feature_list, dtype=np.float32)
|
||||
else:
|
||||
logger.warning(f"No numerical features found in dictionary for {model_name}, using fallback")
|
||||
|
||||
# Method 4: List or tuple
|
||||
if isinstance(model_input, (list, tuple)):
|
||||
|
24
main.py
24
main.py
@ -65,16 +65,27 @@ async def run_web_dashboard():
|
||||
except Exception as e:
|
||||
logger.warning(f"[WARNING] Real-time streaming failed: {e}")
|
||||
|
||||
# Verify data connection
|
||||
# Verify data connection with retry mechanism
|
||||
logger.info("[DATA] Verifying live data connection...")
|
||||
symbol = config.get('symbols', ['ETH/USDT'])[0]
|
||||
|
||||
# Wait for data provider to initialize and fetch initial data
|
||||
max_retries = 10
|
||||
retry_delay = 2
|
||||
|
||||
for attempt in range(max_retries):
|
||||
test_df = data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if test_df is not None and len(test_df) > 0:
|
||||
logger.info("[SUCCESS] Data connection verified")
|
||||
logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation")
|
||||
break
|
||||
else:
|
||||
logger.error("[ERROR] Data connection failed - no live data available")
|
||||
return
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"[DATA] Waiting for data provider to initialize... (attempt {attempt + 1}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.warning("[WARNING] Data connection verification failed, but continuing with system startup")
|
||||
logger.warning("The system will attempt to fetch data as needed during operation")
|
||||
|
||||
# Load model registry for integrated pipeline
|
||||
try:
|
||||
@ -122,6 +133,7 @@ async def run_web_dashboard():
|
||||
logger.info("Starting training loop...")
|
||||
|
||||
# Start the training loop
|
||||
logger.info("About to start training loop...")
|
||||
await start_training_loop(orchestrator, trading_executor)
|
||||
|
||||
except Exception as e:
|
||||
@ -207,6 +219,8 @@ async def start_training_loop(orchestrator, trading_executor):
|
||||
logger.info("STARTING ENHANCED TRAINING LOOP WITH COB INTEGRATION")
|
||||
logger.info("=" * 70)
|
||||
|
||||
logger.info("Training loop function entered successfully")
|
||||
|
||||
# Initialize checkpoint management for training loop
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
training_integration = get_training_integration()
|
||||
@ -222,8 +236,10 @@ async def start_training_loop(orchestrator, trading_executor):
|
||||
|
||||
try:
|
||||
# Start real-time processing (Basic orchestrator doesn't have this method)
|
||||
logger.info("Checking for real-time processing capabilities...")
|
||||
try:
|
||||
if hasattr(orchestrator, 'start_realtime_processing'):
|
||||
logger.info("Starting real-time processing...")
|
||||
await orchestrator.start_realtime_processing()
|
||||
logger.info("Real-time processing started")
|
||||
else:
|
||||
@ -231,6 +247,8 @@ async def start_training_loop(orchestrator, trading_executor):
|
||||
except Exception as e:
|
||||
logger.warning(f"Real-time processing not available: {e}")
|
||||
|
||||
logger.info("About to enter main training loop...")
|
||||
|
||||
# Main training loop
|
||||
iteration = 0
|
||||
while True:
|
||||
|
@ -492,3 +492,56 @@ class CheckpointManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all checkpoints: {e}")
|
||||
return []
|
||||
|
||||
def get_checkpoint_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about all checkpoints
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Statistics about checkpoints
|
||||
"""
|
||||
try:
|
||||
stats = {
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
||||
|
||||
# Iterate through all model directories
|
||||
for model_dir in os.listdir(self.checkpoint_dir):
|
||||
model_path = os.path.join(self.checkpoint_dir, model_dir)
|
||||
if not os.path.isdir(model_path):
|
||||
continue
|
||||
|
||||
# Count checkpoints for this model
|
||||
checkpoint_files = glob.glob(os.path.join(model_path, f"{model_dir}_*.pt"))
|
||||
model_checkpoints = len(checkpoint_files)
|
||||
|
||||
# Calculate total size for this model
|
||||
model_size_mb = 0.0
|
||||
for checkpoint_file in checkpoint_files:
|
||||
try:
|
||||
size_bytes = os.path.getsize(checkpoint_file)
|
||||
model_size_mb += size_bytes / (1024 * 1024) # Convert to MB
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
stats['models'][model_dir] = {
|
||||
'checkpoints': model_checkpoints,
|
||||
'size_mb': round(model_size_mb, 2)
|
||||
}
|
||||
|
||||
stats['total_checkpoints'] += model_checkpoints
|
||||
stats['total_size_mb'] += model_size_mb
|
||||
|
||||
stats['total_size_mb'] = round(stats['total_size_mb'], 2)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint stats: {e}")
|
||||
return {
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingIntegration:
|
||||
def __init__(self, enable_wandb: bool = True):
|
||||
self.enable_wandb = enable_wandb
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
|
||||
|
||||
@ -55,9 +56,13 @@ class TrainingIntegration:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
# Save the model first to get the path
|
||||
model_path = f"models/{model_name}_temp.pt"
|
||||
torch.save(cnn_model.state_dict(), model_path)
|
||||
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model=cnn_model,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
model_type='cnn',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
@ -114,9 +119,13 @@ class TrainingIntegration:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
# Save the model first to get the path
|
||||
model_path = f"models/{model_name}_temp.pt"
|
||||
torch.save(rl_agent.state_dict() if hasattr(rl_agent, 'state_dict') else rl_agent, model_path)
|
||||
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model=rl_agent,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
model_type='rl',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
|
@ -6056,6 +6056,7 @@ class CleanTradingDashboard:
|
||||
|
||||
# Fallback: create BaseDataInput from available data
|
||||
from core.data_models import BaseDataInput, OHLCVBar, COBData
|
||||
import random
|
||||
|
||||
# Get OHLCV data for different timeframes - ensure we have enough data
|
||||
ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300)
|
||||
@ -6073,7 +6074,6 @@ class CleanTradingDashboard:
|
||||
if len(bars) > 0:
|
||||
last_bar = bars[-1]
|
||||
# Add small random variation to prevent identical data
|
||||
import random
|
||||
for i in range(target_count - len(bars)):
|
||||
# Create slight variations of the last bar
|
||||
variation = random.uniform(-0.001, 0.001) # 0.1% variation
|
||||
@ -6090,7 +6090,6 @@ class CleanTradingDashboard:
|
||||
bars.append(new_bar)
|
||||
else:
|
||||
# Create realistic dummy bars with variation
|
||||
from core.data_models import OHLCVBar
|
||||
base_price = 3500.0
|
||||
for i in range(target_count):
|
||||
# Add realistic price movement
|
||||
@ -8725,6 +8724,14 @@ def signal_handler(sig, frame):
|
||||
self.shutdown() # Assuming a shutdown method exists or add one
|
||||
sys.exit(0)
|
||||
|
||||
# Only set signal handlers if we're in the main thread
|
||||
try:
|
||||
import threading
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
else:
|
||||
print("Warning: Signal handlers can only be set in main thread, skipping...")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not set signal handlers: {e}")
|
||||
|
||||
|
Reference in New Issue
Block a user