fix broken merge
This commit is contained in:
@@ -448,12 +448,12 @@ class COBIntegration:
|
||||
cob_snapshot.liquidity_imbalance,
|
||||
len(cob_snapshot.exchanges_active) / 5, # Exchange count ratio
|
||||
min(1.0, total_liquidity / 10000000), # Liquidity abundance
|
||||
0.5, # Price efficiency placeholder
|
||||
0.0, # Price efficiency (placeholder - not yet implemented)
|
||||
min(1.0, total_liquidity / 5000000), # Market impact resistance
|
||||
0.0, # Arbitrage score placeholder
|
||||
0.0, # Liquidity fragmentation placeholder
|
||||
0.0, # Arbitrage score (placeholder - not yet implemented)
|
||||
0.0, # Liquidity fragmentation (placeholder - not yet implemented)
|
||||
(datetime.now().hour * 60 + datetime.now().minute) / 1440, # Time of day
|
||||
0.5 # Market regime indicator placeholder
|
||||
0.0 # Market regime indicator (placeholder - not yet implemented)
|
||||
])
|
||||
|
||||
return np.array(state_features, dtype=np.float32)
|
||||
|
||||
@@ -521,7 +521,7 @@ class EnhancedRLTrainingAdapter:
|
||||
return False
|
||||
|
||||
async def _train_cob_rl_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train COB RL model with batch data"""
|
||||
"""Train COB RL model with batch data (placeholder - not yet implemented)"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
@@ -529,7 +529,7 @@ class EnhancedRLTrainingAdapter:
|
||||
|
||||
# Use COB RL trainer if available
|
||||
# This is a placeholder - implement based on actual COB RL training interface
|
||||
logger.debug(f"COB RL training batch: {len(batch.states)} samples")
|
||||
logger.debug(f"COB RL training batch: {len(batch.states)} samples (not yet implemented)")
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -539,12 +539,12 @@ class EnhancedRLTrainingAdapter:
|
||||
return False
|
||||
|
||||
async def _train_cnn_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train CNN model with batch data"""
|
||||
"""Train CNN model with batch data (placeholder - not yet implemented)"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'enhanced_training_system'):
|
||||
# Use enhanced training system for CNN training
|
||||
# This is a placeholder - implement based on actual CNN training interface
|
||||
logger.debug(f"CNN training batch: {len(batch.states)} samples")
|
||||
logger.debug(f"CNN training batch: {len(batch.states)} samples (not yet implemented)")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -236,7 +236,14 @@ class MEXCInterface(ExchangeInterface):
|
||||
params = {'symbol': formatted_symbol}
|
||||
|
||||
response = self._send_public_request('GET', endpoint, params)
|
||||
|
||||
|
||||
if response and isinstance(response, dict):
|
||||
return response
|
||||
elif response and isinstance(response, list) and len(response) > 0:
|
||||
# Find the ticker for our symbol
|
||||
for ticker in response:
|
||||
if ticker.get('symbol') == formatted_symbol:
|
||||
return ticker
|
||||
else:
|
||||
logger.error(f"Ticker data for {formatted_symbol} not found in response list.")
|
||||
return None
|
||||
@@ -296,6 +303,12 @@ class MEXCInterface(ExchangeInterface):
|
||||
|
||||
def place_order(self, symbol: str, side: str, order_type: str, quantity: float, price: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""Place a new order on MEXC."""
|
||||
try:
|
||||
formatted_symbol = self._format_spot_symbol(symbol)
|
||||
|
||||
# Validate symbol is supported
|
||||
if not self.is_symbol_supported(formatted_symbol):
|
||||
logger.error(f"Symbol {formatted_symbol} is not supported by MEXC")
|
||||
return {}
|
||||
|
||||
# Round quantity to MEXC precision requirements and ensure minimum order value
|
||||
|
||||
@@ -975,17 +975,17 @@ class MultiExchangeCOBProvider:
|
||||
logger.info(f"Disconnected from Kraken order book stream for {symbol}")
|
||||
|
||||
async def _stream_huobi_orderbook(self, symbol: str, config: ExchangeConfig):
|
||||
"""Stream Huobi order book data (placeholder implementation)"""
|
||||
"""Stream Huobi order book data (placeholder - not yet implemented)"""
|
||||
try:
|
||||
logger.info(f"Huobi streaming for {symbol} not yet implemented")
|
||||
logger.debug(f"Huobi streaming for {symbol} not yet implemented")
|
||||
await asyncio.sleep(60) # Sleep to prevent spam
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Huobi order book for {symbol}: {e}")
|
||||
|
||||
async def _stream_bitfinex_orderbook(self, symbol: str, config: ExchangeConfig):
|
||||
"""Stream Bitfinex order book data (placeholder implementation)"""
|
||||
"""Stream Bitfinex order book data (placeholder - not yet implemented)"""
|
||||
try:
|
||||
logger.info(f"Bitfinex streaming for {symbol} not yet implemented")
|
||||
logger.debug(f"Bitfinex streaming for {symbol} not yet implemented")
|
||||
await asyncio.sleep(60) # Sleep to prevent spam
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Bitfinex order book for {symbol}: {e}")
|
||||
|
||||
@@ -688,28 +688,26 @@ class MultiHorizonPredictionManager:
|
||||
|
||||
# Placeholder methods for CNN and RL feature preparation - to be implemented
|
||||
def _prepare_cnn_features_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray:
|
||||
"""Prepare CNN features for specific horizon - placeholder"""
|
||||
"""Prepare CNN features for specific horizon (placeholder - not yet implemented)"""
|
||||
# This would extract relevant features based on horizon
|
||||
return np.random.rand(50) # Placeholder
|
||||
logger.debug(f"CNN feature preparation for horizon {horizon} not yet implemented")
|
||||
return np.array([]) # Return empty array instead of synthetic data
|
||||
|
||||
def _prepare_rl_state_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray:
|
||||
"""Prepare RL state for specific horizon - placeholder"""
|
||||
"""Prepare RL state for specific horizon (placeholder - not yet implemented)"""
|
||||
# This would create state representation for the horizon
|
||||
return np.random.rand(100) # Placeholder
|
||||
logger.debug(f"RL state preparation for horizon {horizon} not yet implemented")
|
||||
return np.array([]) # Return empty array instead of synthetic data
|
||||
|
||||
def _interpret_cnn_output(self, cnn_output, current_price: float, horizon: int) -> Tuple[float, float, float]:
|
||||
"""Interpret CNN output for min/max prediction - placeholder"""
|
||||
"""Interpret CNN output for min/max prediction (placeholder - not yet implemented)"""
|
||||
# This would convert CNN output to price predictions
|
||||
range_percent = 0.05 # 5% range
|
||||
return (current_price * 0.95, current_price * 1.05, 0.6) # Placeholder
|
||||
logger.debug(f"CNN output interpretation for horizon {horizon} not yet implemented")
|
||||
return (0.0, 0.0, 0.0) # Return zeros instead of synthetic predictions
|
||||
|
||||
def _convert_rl_action_to_price_prediction(self, action: int, current_price: float,
|
||||
horizon: int, rl_agent) -> Tuple[float, float, float]:
|
||||
"""Convert RL action to price prediction - placeholder"""
|
||||
"""Convert RL action to price prediction (placeholder - not yet implemented)"""
|
||||
# This would interpret RL action as price movement expectation
|
||||
if action == 0: # BUY
|
||||
return (current_price * 0.98, current_price * 1.03, 0.7)
|
||||
elif action == 1: # SELL
|
||||
return (current_price * 0.97, current_price * 1.02, 0.7)
|
||||
else: # HOLD
|
||||
return (current_price * 0.99, current_price * 1.01, 0.5)
|
||||
logger.debug(f"RL action conversion for horizon {horizon} not yet implemented")
|
||||
return (0.0, 0.0, 0.0) # Return zeros instead of synthetic predictions
|
||||
|
||||
@@ -299,14 +299,24 @@ class TradingOrchestrator:
|
||||
Includes EnhancedRealtimeTrainingSystem for continuous learning
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[ModelRegistry] = None):
|
||||
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[Any] = None):
|
||||
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
||||
self.config = get_config()
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.universal_adapter = UniversalDataAdapter(self.data_provider)
|
||||
self.model_manager = model_manager or create_model_manager()
|
||||
self.model_manager = None # Will be initialized later if needed
|
||||
self.enhanced_rl_training = enhanced_rl_training
|
||||
|
||||
# Set primary trading symbol
|
||||
self.symbol = self.config.get('primary_symbol', 'ETH/USDT')
|
||||
self.ref_symbols = self.config.get('reference_symbols', ['BTC/USDT'])
|
||||
|
||||
# Initialize signal accumulator
|
||||
self.signal_accumulator = {}
|
||||
|
||||
# Initialize confidence threshold
|
||||
self.confidence_threshold = self.config.get('confidence_threshold', 0.6)
|
||||
|
||||
# Determine the device to use (GPU if available, else CPU)
|
||||
# Initialize device - force CPU mode to avoid CUDA errors
|
||||
if torch.cuda.is_available():
|
||||
@@ -449,8 +459,8 @@ class TradingOrchestrator:
|
||||
self.last_inference: Dict[str, Dict] = {} # {model_name: last_inference_record}
|
||||
|
||||
# Initialize inference logger
|
||||
self.inference_logger = get_inference_logger()
|
||||
self.db_manager = get_database_manager()
|
||||
self.inference_logger = None # Will be initialized later if needed
|
||||
self.db_manager = None # Will be initialized later if needed
|
||||
|
||||
# ENHANCED: Real-time Training System Integration
|
||||
self.enhanced_training_system = (
|
||||
@@ -510,6 +520,247 @@ class TradingOrchestrator:
|
||||
self._initialize_decision_fusion() # Initialize fusion system
|
||||
self._initialize_transformer_model() # Initialize transformer model
|
||||
self._initialize_enhanced_training_system() # Initialize real-time training
|
||||
|
||||
def _normalize_model_name(self, model_name: str) -> str:
|
||||
"""Normalize model name for consistent storage"""
|
||||
import re
|
||||
|
||||
# Convert to lowercase
|
||||
normalized = model_name.lower()
|
||||
|
||||
# Replace spaces, hyphens, and other non-alphanumeric separators with underscores
|
||||
normalized = re.sub(r'[^a-z0-9]+', '_', normalized)
|
||||
|
||||
# Collapse multiple consecutive underscores into a single underscore
|
||||
normalized = re.sub(r'_+', '_', normalized)
|
||||
|
||||
# Strip leading and trailing underscores
|
||||
normalized = normalized.strip('_')
|
||||
|
||||
return normalized
|
||||
|
||||
def _log_data_status(self):
|
||||
"""Log data provider status"""
|
||||
logger.info(f"Data provider initialized for symbols: {self.data_provider.symbols}")
|
||||
logger.info(f"Available timeframes: {self.data_provider.timeframes}")
|
||||
|
||||
def _schedule_database_cleanup(self):
|
||||
"""
|
||||
Schedule periodic database cleanup tasks.
|
||||
|
||||
This method sets up a background task that periodically cleans up old
|
||||
inference records from the database to prevent it from growing indefinitely.
|
||||
|
||||
Side effects:
|
||||
- Creates a background asyncio task that runs every 24 hours
|
||||
- Cleans up records older than 30 days by default
|
||||
- Logs cleanup operations and any errors
|
||||
"""
|
||||
try:
|
||||
from utils.database_manager import get_database_manager
|
||||
|
||||
# Get database manager instance
|
||||
db_manager = get_database_manager()
|
||||
|
||||
async def cleanup_task():
|
||||
"""Background task for periodic database cleanup"""
|
||||
while True:
|
||||
try:
|
||||
logger.info("Running scheduled database cleanup...")
|
||||
success = db_manager.cleanup_old_records(days_to_keep=30)
|
||||
if success:
|
||||
logger.info("Database cleanup completed successfully")
|
||||
else:
|
||||
logger.warning("Database cleanup failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during database cleanup: {e}")
|
||||
|
||||
# Wait 24 hours before next cleanup
|
||||
await asyncio.sleep(24 * 60 * 60) # 24 hours in seconds
|
||||
|
||||
# Create and start the cleanup task
|
||||
self._db_cleanup_task = asyncio.create_task(cleanup_task())
|
||||
logger.info("Database cleanup scheduler started - will run every 24 hours")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to schedule database cleanup: {e}")
|
||||
logger.warning("Database cleanup will not be performed automatically")
|
||||
|
||||
def _initialize_checkpoint_manager(self):
|
||||
"""
|
||||
Initialize the global checkpoint manager for model checkpoint management.
|
||||
|
||||
This method initializes the checkpoint manager that handles:
|
||||
- Saving model checkpoints with metadata
|
||||
- Loading the best performing checkpoints
|
||||
- Managing checkpoint storage and cleanup
|
||||
|
||||
Returns:
|
||||
CheckpointManager: The initialized checkpoint manager instance, or None if initialization fails
|
||||
|
||||
Side effects:
|
||||
- Sets self.checkpoint_manager to the global checkpoint manager instance
|
||||
- Creates checkpoint directory if it doesn't exist
|
||||
- Logs initialization status
|
||||
"""
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
|
||||
# Initialize the global checkpoint manager
|
||||
self.checkpoint_manager = get_checkpoint_manager(
|
||||
checkpoint_dir="models/checkpoints",
|
||||
max_checkpoints=10,
|
||||
metric_name="accuracy"
|
||||
)
|
||||
|
||||
logger.info(f"Checkpoint manager initialized successfully with directory: models/checkpoints")
|
||||
logger.info(f"Maximum checkpoints per model: 10, Primary metric: accuracy")
|
||||
|
||||
return self.checkpoint_manager
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize checkpoint manager: {e}")
|
||||
self.checkpoint_manager = None
|
||||
return None
|
||||
|
||||
def _start_cob_integration_sync(self):
|
||||
"""
|
||||
Start COB (Consolidated Order Book) integration synchronization.
|
||||
|
||||
This method initiates the COB integration system that provides real-time
|
||||
market microstructure data to the trading models. The COB integration
|
||||
streams order book data and generates features for CNN and DQN models.
|
||||
|
||||
Side effects:
|
||||
- Creates an async task to start COB integration if available
|
||||
- Registers COB data callbacks for model feeding
|
||||
- Begins streaming COB features to registered models
|
||||
- Logs integration status and any errors
|
||||
"""
|
||||
try:
|
||||
if self.cob_integration is None:
|
||||
logger.warning("COB integration not initialized - cannot start sync")
|
||||
return
|
||||
|
||||
# Create async task to start COB integration
|
||||
# Since this is called from __init__ (sync context), we need to create a task
|
||||
async def start_cob_task():
|
||||
try:
|
||||
await self.start_cob_integration()
|
||||
logger.info("COB integration synchronization started successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start COB integration sync: {e}")
|
||||
|
||||
# Create the task (will be executed when event loop is running)
|
||||
self._cob_sync_task = asyncio.create_task(start_cob_task())
|
||||
logger.info("COB integration sync task created - will start when event loop is available")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize COB integration sync: {e}")
|
||||
logger.warning("COB integration will not be available")
|
||||
|
||||
def _initialize_transformer_model(self):
|
||||
"""
|
||||
Initialize the transformer model for advanced trading pattern recognition.
|
||||
|
||||
This method loads or creates an AdvancedTradingTransformer model that uses
|
||||
attention mechanisms to analyze complex market patterns and generate trading signals.
|
||||
The model is optimized for COB (Consolidated Order Book) data and technical indicators.
|
||||
|
||||
Returns:
|
||||
bool: True if initialization successful, False otherwise
|
||||
|
||||
Side effects:
|
||||
- Sets self.primary_transformer to the loaded/created transformer model
|
||||
- Sets self.primary_transformer_trainer to the associated trainer
|
||||
- Updates self.transformer_checkpoint_info with checkpoint metadata
|
||||
- Loads best available checkpoint if exists
|
||||
- Moves model to appropriate device (CPU/GPU)
|
||||
- Logs initialization status and any errors
|
||||
"""
|
||||
try:
|
||||
from NN.models.advanced_transformer_trading import (
|
||||
AdvancedTradingTransformer,
|
||||
TradingTransformerTrainer,
|
||||
TradingTransformerConfig
|
||||
)
|
||||
|
||||
logger.info("Initializing transformer model for trading...")
|
||||
|
||||
# Create transformer configuration
|
||||
config = TradingTransformerConfig()
|
||||
|
||||
# Initialize the transformer model
|
||||
self.primary_transformer = AdvancedTradingTransformer(config)
|
||||
logger.info(f"AdvancedTradingTransformer created with config: d_model={config.d_model}, "
|
||||
f"n_heads={config.n_heads}, n_layers={config.n_layers}")
|
||||
|
||||
# Initialize the trainer
|
||||
self.primary_transformer_trainer = TradingTransformerTrainer(
|
||||
model=self.primary_transformer,
|
||||
config=config
|
||||
)
|
||||
logger.info("TradingTransformerTrainer initialized")
|
||||
|
||||
# Move model to device
|
||||
if hasattr(self, 'device') and self.device:
|
||||
self.primary_transformer.to(self.device)
|
||||
logger.info(f"Transformer model moved to device: {self.device}")
|
||||
else:
|
||||
logger.info("Transformer model using default device")
|
||||
|
||||
# Try to load best checkpoint
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
if self.checkpoint_manager:
|
||||
checkpoint_path, checkpoint_metadata = self.checkpoint_manager.load_best_checkpoint("transformer")
|
||||
if checkpoint_path and checkpoint_metadata:
|
||||
# Load the checkpoint
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
self.primary_transformer.load_state_dict(checkpoint.get('model_state_dict', checkpoint))
|
||||
|
||||
# Update checkpoint info
|
||||
self.transformer_checkpoint_info = {
|
||||
'path': checkpoint_path,
|
||||
'metadata': checkpoint_metadata,
|
||||
'loaded_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Transformer checkpoint loaded from: {checkpoint_path}")
|
||||
logger.info(f"Checkpoint metrics: {checkpoint_metadata.get('performance_metrics', {})}")
|
||||
checkpoint_loaded = True
|
||||
else:
|
||||
logger.info("No transformer checkpoint found - using fresh model")
|
||||
else:
|
||||
logger.warning("Checkpoint manager not available - cannot load transformer checkpoint")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading transformer checkpoint: {e}")
|
||||
logger.info("Continuing with fresh transformer model")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# Initialize checkpoint info for new model
|
||||
self.transformer_checkpoint_info = {
|
||||
'status': 'fresh_model',
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info("Transformer model initialization completed successfully")
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Advanced transformer trading module not available: {e}")
|
||||
self.primary_transformer = None
|
||||
self.primary_transformer_trainer = None
|
||||
logger.info("Transformer model will not be available")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize transformer model: {e}")
|
||||
self.primary_transformer = None
|
||||
self.primary_transformer_trainer = None
|
||||
return False
|
||||
|
||||
def _initialize_ml_models(self):
|
||||
"""Initialize ML models for enhanced trading"""
|
||||
try:
|
||||
@@ -549,7 +800,7 @@ class TradingOrchestrator:
|
||||
if hasattr(self.rl_agent, "load_best_checkpoint"):
|
||||
try:
|
||||
self.rl_agent.load_best_checkpoint()
|
||||
checkpoint_loaded = True
|
||||
checkpoint_loaded = True
|
||||
logger.info("DQN checkpoint loaded successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
|
||||
@@ -593,7 +844,7 @@ class TradingOrchestrator:
|
||||
try:
|
||||
# CNN checkpoint loading would go here
|
||||
logger.info("CNN checkpoint loaded successfully")
|
||||
checkpoint_loaded = True
|
||||
checkpoint_loaded = True
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||
checkpoint_loaded = False
|
||||
@@ -686,8 +937,8 @@ class TradingOrchestrator:
|
||||
except ImportError:
|
||||
logger.warning("Extrema trainer not available")
|
||||
self.extrema_trainer = None
|
||||
|
||||
self.cob_rl_agent = None
|
||||
|
||||
self.cob_rl_agent = None
|
||||
|
||||
|
||||
# CRITICAL: Register models with the model registry
|
||||
@@ -704,7 +955,7 @@ class TradingOrchestrator:
|
||||
try:
|
||||
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
|
||||
if self.model_registry.register_model(rl_interface):
|
||||
logger.info("RL Agent registered successfully")
|
||||
logger.info("RL Agent registered successfully")
|
||||
else:
|
||||
logger.error("Failed to register RL Agent with registry")
|
||||
except Exception as e:
|
||||
@@ -715,7 +966,7 @@ class TradingOrchestrator:
|
||||
try:
|
||||
cnn_interface = CNNModelInterface(self.cnn_model, name="cnn_model")
|
||||
if self.model_registry.register_model(cnn_interface):
|
||||
logger.info("CNN Model registered successfully")
|
||||
logger.info("CNN Model registered successfully")
|
||||
else:
|
||||
logger.error("Failed to register CNN Model with registry")
|
||||
except Exception as e:
|
||||
@@ -726,14 +977,13 @@ class TradingOrchestrator:
|
||||
try:
|
||||
extrema_interface = ExtremaTrainerInterface(self.extrema_trainer, name="extrema_trainer")
|
||||
if self.model_registry.register_model(extrema_interface):
|
||||
logger.info("Extrema Trainer registered successfully")
|
||||
logger.info("Extrema Trainer registered successfully")
|
||||
else:
|
||||
logger.error("Failed to register Extrema Trainer with registry")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register Extrema Trainer: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing ML models: {e}")
|
||||
|
||||
def get_model_training_stats(self) -> Dict[str, Dict[str, Any]]:
|
||||
@@ -841,7 +1091,7 @@ class TradingOrchestrator:
|
||||
}
|
||||
|
||||
for model_name, stats in dashboard_stats.items():
|
||||
if model_name in self.model_states:
|
||||
if model_name in self.model_states:
|
||||
self.model_states[model_name]["current_loss"] = stats["current_loss"]
|
||||
self.model_states[model_name]["initial_loss"] = stats["initial_loss"]
|
||||
if (
|
||||
@@ -1066,7 +1316,7 @@ class TradingOrchestrator:
|
||||
|
||||
with open(session_file, "w", encoding="utf-8") as f:
|
||||
json.dump(existing, f, indent=2)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.error(f"Error appending session snapshot: {e}")
|
||||
|
||||
def get_model_toggle_state(self, model_name: str) -> Dict[str, bool]:
|
||||
@@ -1124,8 +1374,7 @@ class TradingOrchestrator:
|
||||
self._save_ui_state()
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering model {model_name} dynamically: {e}")
|
||||
return False
|
||||
|
||||
@@ -1239,8 +1488,8 @@ class TradingOrchestrator:
|
||||
self.cob_integration.add_dashboard_callback(
|
||||
self._on_cob_dashboard_data
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize COB Integration: {e}")
|
||||
self.cob_integration = None
|
||||
else:
|
||||
@@ -1460,10 +1709,10 @@ class TradingOrchestrator:
|
||||
)
|
||||
|
||||
# Use programmatic decision for actual actions
|
||||
decision = self._combine_predictions(
|
||||
symbol=symbol,
|
||||
price=current_price,
|
||||
predictions=predictions,
|
||||
decision = self._combine_predictions(
|
||||
symbol=symbol,
|
||||
price=current_price,
|
||||
predictions=predictions,
|
||||
timestamp=current_time,
|
||||
)
|
||||
else:
|
||||
@@ -1545,7 +1794,7 @@ class TradingOrchestrator:
|
||||
price_change_pct = (
|
||||
(current_price - recent_prices[-2]) / recent_prices[-2] * 100
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get recent prices for {symbol}: {e}")
|
||||
# Fallback: use current price and a small assumed change
|
||||
price_change_pct = 0.1 # Assume small positive change
|
||||
@@ -1658,7 +1907,7 @@ class TradingOrchestrator:
|
||||
# Validate base_data has the required method
|
||||
if not hasattr(base_data, 'get_feature_vector'):
|
||||
logger.debug(f"BaseDataInput for {symbol} missing get_feature_vector method")
|
||||
return None
|
||||
return None
|
||||
|
||||
# Get unified feature vector (7850 features including all timeframes and COB data)
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
@@ -1724,7 +1973,7 @@ class TradingOrchestrator:
|
||||
action_scores[action] += random.uniform(-0.001, 0.001)
|
||||
|
||||
best_action = max(action_scores.keys(), key=lambda k: action_scores[k])
|
||||
best_confidence = action_scores[best_action]
|
||||
best_confidence = action_scores[best_action]
|
||||
|
||||
# DEBUG: Log action scores to understand bias
|
||||
logger.debug(f"Action scores for {symbol}: BUY={action_scores['BUY']:.3f}, SELL={action_scores['SELL']:.3f}, HOLD={action_scores['HOLD']:.3f}")
|
||||
@@ -1958,8 +2207,8 @@ class TradingOrchestrator:
|
||||
|
||||
# Try to load decision fusion checkpoint
|
||||
result = load_best_checkpoint("decision_fusion")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
# Load the checkpoint into the network
|
||||
checkpoint = torch.load(file_path, map_location=self.device)
|
||||
|
||||
@@ -1989,11 +2238,11 @@ class TradingOrchestrator:
|
||||
logger.info(
|
||||
f"Decision fusion network loaded from checkpoint: {metadata.checkpoint_id} (loss={loss_str})"
|
||||
)
|
||||
else:
|
||||
else:
|
||||
logger.info(
|
||||
"No existing decision fusion checkpoint found, starting fresh"
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading decision fusion checkpoint: {e}")
|
||||
logger.info("Decision fusion network starting fresh")
|
||||
|
||||
@@ -2053,6 +2302,7 @@ class TradingOrchestrator:
|
||||
def stop_enhanced_training(self):
|
||||
"""Stop the enhanced real-time training system"""
|
||||
try:
|
||||
if self.enhanced_training_system:
|
||||
self.enhanced_training_system.stop_training()
|
||||
logger.info("Enhanced real-time training stopped")
|
||||
return True
|
||||
@@ -2075,12 +2325,18 @@ class TradingOrchestrator:
|
||||
# Get base stats from enhanced training system
|
||||
stats = {}
|
||||
if hasattr(self.enhanced_training_system, "get_training_statistics"):
|
||||
stats = self.enhanced_training_system.get_training_statistics()
|
||||
stats = self.enhanced_training_system.get_training_statistics()
|
||||
else:
|
||||
stats = {}
|
||||
|
||||
stats["training_enabled"] = self.training_enabled
|
||||
stats["system_available"] = ENHANCED_TRAINING_AVAILABLE
|
||||
|
||||
# Add orchestrator-specific training integration data
|
||||
stats["orchestrator_integration"] = {
|
||||
"enhanced_training_enabled": self.enhanced_training_enabled,
|
||||
"model_registry_count": len(self.model_registry.models) if hasattr(self, 'model_registry') else 0,
|
||||
"decision_fusion_enabled": self.decision_fusion_enabled
|
||||
}
|
||||
|
||||
# Add model-specific training status from orchestrator
|
||||
@@ -2220,7 +2476,7 @@ class TradingOrchestrator:
|
||||
current_time
|
||||
)
|
||||
elif self.universal_adapter:
|
||||
return self.universal_adapter.get_universal_data_stream(current_time)
|
||||
return self.universal_adapter.get_universal_data_stream(current_time)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting universal data stream: {e}")
|
||||
@@ -2235,10 +2491,10 @@ class TradingOrchestrator:
|
||||
stream, model_type
|
||||
)
|
||||
elif self.universal_adapter:
|
||||
stream = self.universal_adapter.get_universal_data_stream()
|
||||
if stream:
|
||||
return self.universal_adapter.format_for_model(stream, model_type)
|
||||
return None
|
||||
stream = self.universal_adapter.get_universal_data_stream()
|
||||
if stream:
|
||||
return self.universal_adapter.format_for_model(stream, model_type)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting universal data for {model_type}: {e}")
|
||||
return None
|
||||
@@ -2278,7 +2534,7 @@ class TradingOrchestrator:
|
||||
side = position.get("side", "LONG")
|
||||
|
||||
if entry_price and size > 0:
|
||||
if side.upper() == "LONG":
|
||||
if side.upper() == "LONG":
|
||||
pnl = (current_price - entry_price) * size
|
||||
else: # SHORT
|
||||
pnl = (entry_price - current_price) * size
|
||||
|
||||
@@ -428,7 +428,7 @@ class TimeframeInferenceCoordinator:
|
||||
async def _call_model_training(self, model_name: str, symbol: str,
|
||||
timeframe: TimeFrame, training_data: List[Any]):
|
||||
"""
|
||||
Call model-specific training function
|
||||
Call model-specific training function (placeholder - not yet implemented)
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to train
|
||||
@@ -438,7 +438,7 @@ class TimeframeInferenceCoordinator:
|
||||
"""
|
||||
# This is a placeholder for model-specific training calls
|
||||
# You'll need to implement this based on your specific model interfaces
|
||||
logger.debug(f"Training call for {model_name}: {len(training_data)} samples")
|
||||
logger.debug(f"Training call for {model_name}: {len(training_data)} samples (not yet implemented)")
|
||||
|
||||
def get_inference_statistics(self) -> Dict[str, Any]:
|
||||
"""Get inference coordination statistics"""
|
||||
|
||||
@@ -221,6 +221,7 @@ class TradingExecutor:
|
||||
# Connect to exchange - skip connection check in simulation mode
|
||||
if self.trading_enabled:
|
||||
if self.simulation_mode:
|
||||
logger.info("TRADING EXECUTOR: Running in simulation mode - no exchange connection needed")
|
||||
else:
|
||||
logger.info("TRADING EXECUTOR: Attempting to connect to exchange...")
|
||||
if not self._connect_exchange():
|
||||
@@ -533,8 +534,8 @@ class TradingExecutor:
|
||||
# For simplicity, assume required capital is the full position value in USD
|
||||
required_capital = self._calculate_position_size(confidence, current_price)
|
||||
|
||||
else:
|
||||
available_balance = self.exchange.get_balance(quote_asset)
|
||||
# Get available balance
|
||||
available_balance = self.exchange.get_balance(quote_asset)
|
||||
|
||||
logger.info(f"BALANCE CHECK: Symbol: {symbol}, Action: {action}, Required: ${required_capital:.2f} {quote_asset}, Available: ${available_balance:.2f} {quote_asset}")
|
||||
|
||||
@@ -1401,9 +1402,26 @@ class TradingExecutor:
|
||||
if self.simulation_mode:
|
||||
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Short close logged but not executed")
|
||||
# Calculate simulated fees in simulation mode
|
||||
trading_fees = self.exchange_config.get('trading_fees', {})
|
||||
taker_fee_rate = trading_fees.get('taker_fee', trading_fees.get('default_fee', 0.0006))
|
||||
simulated_fees = position.quantity * current_price * taker_fee_rate
|
||||
|
||||
# Get current leverage setting
|
||||
leverage = self.get_leverage()
|
||||
|
||||
# Calculate position size in USD
|
||||
position_size_usd = position.quantity * position.entry_price
|
||||
|
||||
# Calculate gross PnL (before fees) with leverage - SHORT profits when price falls
|
||||
gross_pnl = (position.entry_price - current_price) * position.quantity * leverage
|
||||
|
||||
# Calculate net PnL (after fees)
|
||||
net_pnl = gross_pnl - simulated_fees
|
||||
|
||||
# Calculate hold time
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
|
||||
# Create trade record with corrected PnL calculations
|
||||
trade_record = TradeRecord(
|
||||
symbol=symbol,
|
||||
@@ -1413,12 +1431,37 @@ class TradingExecutor:
|
||||
exit_price=current_price,
|
||||
entry_time=position.entry_time,
|
||||
exit_time=exit_time,
|
||||
|
||||
pnl=net_pnl, # Store net PnL as the main PnL value
|
||||
fees=simulated_fees,
|
||||
confidence=confidence,
|
||||
hold_time_seconds=hold_time_seconds,
|
||||
leverage=leverage,
|
||||
position_size_usd=position_size_usd,
|
||||
gross_pnl=gross_pnl,
|
||||
net_pnl=net_pnl
|
||||
)
|
||||
|
||||
self.trade_history.append(trade_record)
|
||||
self.trade_records.append(trade_record)
|
||||
self.daily_loss += max(0, -net_pnl) # Use net_pnl instead of pnl
|
||||
|
||||
# Adjust profitability reward multiplier based on recent performance
|
||||
self._adjust_profitability_reward_multiplier()
|
||||
|
||||
# Update consecutive losses using net_pnl
|
||||
if net_pnl < -0.001: # A losing trade
|
||||
self.consecutive_losses += 1
|
||||
elif net_pnl > 0.001: # A winning trade
|
||||
self.consecutive_losses = 0
|
||||
else: # Breakeven trade
|
||||
self.consecutive_losses = 0
|
||||
|
||||
# Remove position
|
||||
del self.positions[symbol]
|
||||
self.last_trade_time[symbol] = datetime.now()
|
||||
self.daily_trades += 1
|
||||
|
||||
|
||||
logger.info(f"SHORT position closed - Gross P&L: ${gross_pnl:.2f}, Net P&L: ${net_pnl:.2f}, Fees: ${simulated_fees:.3f}")
|
||||
return True
|
||||
|
||||
try:
|
||||
@@ -2002,8 +2045,29 @@ class TradingExecutor:
|
||||
return self.trade_history.copy()
|
||||
|
||||
def get_balance(self) -> Dict[str, float]:
|
||||
"""TODO(Guideline: expose real account state) Return actual account balances instead of raising."""
|
||||
raise NotImplementedError("Implement TradingExecutor.get_balance to supply real balance data; stubs are forbidden.")
|
||||
"""Get account balances from the primary exchange.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: Asset balances in format {'USDT': 100.0, 'ETH': 0.5, ...}
|
||||
"""
|
||||
try:
|
||||
# Use the existing get_account_balance method to get real exchange data
|
||||
account_balances = self.get_account_balance()
|
||||
|
||||
# Convert to simple format: asset -> free balance
|
||||
simple_balances = {}
|
||||
for asset, balance_data in account_balances.items():
|
||||
if isinstance(balance_data, dict):
|
||||
simple_balances[asset] = balance_data.get('free', 0.0)
|
||||
else:
|
||||
simple_balances[asset] = float(balance_data) if balance_data else 0.0
|
||||
|
||||
logger.debug(f"Retrieved balances for {len(simple_balances)} assets")
|
||||
return simple_balances
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting balance: {e}")
|
||||
return {}
|
||||
|
||||
def export_trades_to_csv(self, filename: Optional[str] = None) -> str:
|
||||
"""Export trade history to CSV file with comprehensive analysis"""
|
||||
|
||||
@@ -354,7 +354,8 @@ class TrainingIntegration:
|
||||
pivot_points = []
|
||||
|
||||
# This would integrate with the Williams Market Structure
|
||||
# For now, return empty list as placeholder
|
||||
# For now, return empty list as placeholder (not yet implemented)
|
||||
logger.debug(f"Pivot points integration for {symbol} not yet implemented")
|
||||
return pivot_points
|
||||
|
||||
return []
|
||||
@@ -519,7 +520,8 @@ class TrainingIntegration:
|
||||
|
||||
try:
|
||||
# This would integrate with existing model predictions
|
||||
# For now, return empty dict as placeholder
|
||||
# For now, return empty dict as placeholder (not yet implemented)
|
||||
logger.debug(f"Model predictions integration for {symbol} not yet implemented")
|
||||
return predictions
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting model predictions for {symbol}: {e}")
|
||||
|
||||
@@ -1946,8 +1946,6 @@ class CleanTradingDashboard:
|
||||
logger.error(f"Error updating trades table: {e}")
|
||||
return html.P(f"Error: {str(e)}", className="text-danger")
|
||||
|
||||
@self.app.callback(
|
||||
|
||||
@self.app.callback(
|
||||
[Output('eth-cob-content', 'children'),
|
||||
Output('btc-cob-content', 'children')],
|
||||
@@ -6353,10 +6351,10 @@ class CleanTradingDashboard:
|
||||
|
||||
# Additional training weight for executed signals
|
||||
if signal['executed']:
|
||||
# Log signal processing
|
||||
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
||||
logger.info(f"[{status}] {signal['action']} signal for {signal['symbol']} "
|
||||
f"(conf: {signal['confidence']:.2f}, model: {signal.get('model', 'UNKNOWN')})")
|
||||
# Log signal processing
|
||||
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
||||
logger.info(f"[{status}] {signal['action']} signal for {signal['symbol']} "
|
||||
f"(conf: {signal['confidence']:.2f}, model: {signal.get('model', 'UNKNOWN')})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing dashboard signal: {e}")
|
||||
@@ -6512,7 +6510,7 @@ class CleanTradingDashboard:
|
||||
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
||||
loss = self.orchestrator.rl_agent.replay()
|
||||
if loss is not None:
|
||||
|
||||
logger.debug(f"DQN training loss: {loss:.4f}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training DQN on prediction: {e}")
|
||||
|
||||
@@ -8049,8 +8047,14 @@ class CleanTradingDashboard:
|
||||
logger.warning("No checkpoint manager available for model storage")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"❌ Failed to store Decision Fusion model: {e}")
|
||||
# Store models and handle any exceptions
|
||||
try:
|
||||
# Store Decision Fusion model
|
||||
if hasattr(self.orchestrator, 'decision_fusion_network') and self.orchestrator.decision_fusion_network:
|
||||
logger.info("💾 Storing Decision Fusion model...")
|
||||
# Add storage logic here
|
||||
except Exception as e:
|
||||
logger.warning(f"❌ Failed to store Decision Fusion model: {e}")
|
||||
|
||||
# 5. Verification Step - Try to load checkpoints to verify they work
|
||||
logger.info("🔍 Verifying stored checkpoints...")
|
||||
@@ -9182,6 +9186,9 @@ class CleanTradingDashboard:
|
||||
logger.debug(f"COB data retrieved from data provider for {symbol}: {len(bids)} bids, {len(asks)} asks")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving COB data for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _generate_bucketed_cob_data(self, symbol: str, cob_snapshot: dict):
|
||||
"""Generate bucketed COB data for model feeding"""
|
||||
try:
|
||||
@@ -9775,6 +9782,53 @@ class CleanTradingDashboard:
|
||||
'total_pnl': 0.0
|
||||
}
|
||||
|
||||
def run_chained_inference(self, symbol: str, n_steps: int = 10):
|
||||
"""Run chained inference to trigger initial predictions"""
|
||||
try:
|
||||
logger.info(f"🔗 Running chained inference for {symbol} with {n_steps} steps")
|
||||
|
||||
if self.orchestrator is None:
|
||||
logger.warning("❌ No orchestrator available for chained inference")
|
||||
return
|
||||
|
||||
# Trigger initial predictions by calling make_trading_decision
|
||||
import asyncio
|
||||
|
||||
async def _run_inference():
|
||||
try:
|
||||
# Run multiple inference steps
|
||||
for step in range(n_steps):
|
||||
logger.debug(f"🔗 Running inference step {step + 1}/{n_steps}")
|
||||
decision = await self.orchestrator.make_trading_decision(symbol)
|
||||
if decision:
|
||||
logger.debug(f"🔗 Step {step + 1}: Decision made for {symbol}")
|
||||
else:
|
||||
logger.debug(f"🔗 Step {step + 1}: No decision available for {symbol}")
|
||||
|
||||
# Small delay between steps
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.info(f"🔗 Chained inference completed for {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in chained inference: {e}")
|
||||
|
||||
# Run the async inference
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If loop is already running, create a task
|
||||
asyncio.create_task(_run_inference())
|
||||
else:
|
||||
# If no loop is running, run it directly
|
||||
asyncio.run(_run_inference())
|
||||
except RuntimeError:
|
||||
# Fallback: try to run in a new event loop
|
||||
asyncio.run(_run_inference())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Chained inference failed: {e}")
|
||||
|
||||
def run_server(self, host='127.0.0.1', port=8050, debug=False):
|
||||
"""Start the Dash server"""
|
||||
try:
|
||||
@@ -9862,6 +9916,8 @@ class CleanTradingDashboard:
|
||||
"""Connect to orchestrator for real trading signals"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'add_decision_callback'):
|
||||
self.orchestrator.add_decision_callback(self._on_trading_decision)
|
||||
logger.info("✅ Orchestrator decision callback registered")
|
||||
else:
|
||||
logger.warning("Orchestrator not available or doesn't support callbacks")
|
||||
except Exception as e:
|
||||
@@ -11179,3 +11235,4 @@ def create_clean_dashboard(data_provider: Optional[DataProvider] = None, orchest
|
||||
data_provider=data_provider,
|
||||
orchestrator=orchestrator,
|
||||
trading_executor=trading_executor
|
||||
)
|
||||
|
||||
@@ -344,6 +344,12 @@ class DashboardComponentManager:
|
||||
asks = cob_snapshot.get('asks', []) or []
|
||||
elif hasattr(cob_snapshot, 'stats'):
|
||||
# Old format with stats attribute
|
||||
stats = getattr(cob_snapshot, 'stats', None)
|
||||
mid_price = float(getattr(stats, 'mid_price', 0) or 0)
|
||||
spread_bps = float(getattr(stats, 'spread_bps', 0) or 0)
|
||||
imbalance = float(getattr(stats, 'imbalance', 0) or 0)
|
||||
bids = getattr(cob_snapshot, 'bids', []) or []
|
||||
asks = getattr(cob_snapshot, 'asks', []) or []
|
||||
else:
|
||||
# New object-like snapshot with direct attributes
|
||||
mid_price = float(getattr(cob_snapshot, 'volume_weighted_mid', 0) or 0)
|
||||
|
||||
@@ -16,7 +16,13 @@ class DashboardLayoutManager:
|
||||
self.dashboard = dashboard
|
||||
|
||||
def create_main_layout(self):
|
||||
|
||||
"""Create the main dashboard layout"""
|
||||
return html.Div([
|
||||
self._create_header(),
|
||||
self._create_main_content(),
|
||||
self._create_interval_component()
|
||||
], className="container-fluid bg-dark text-light min-vh-100")
|
||||
|
||||
def _create_prediction_tracking_section(self):
|
||||
"""Create prediction tracking and model performance section"""
|
||||
return html.Div([
|
||||
@@ -250,7 +256,12 @@ class DashboardLayoutManager:
|
||||
], className="bg-dark p-2 mb-2")
|
||||
|
||||
def _create_interval_component(self):
|
||||
])
|
||||
"""Create the interval component for auto-refresh"""
|
||||
return dcc.Interval(
|
||||
id='interval-component',
|
||||
interval=2000, # Update every 2 seconds
|
||||
n_intervals=0
|
||||
)
|
||||
|
||||
def _create_main_content(self):
|
||||
"""Create the main content area"""
|
||||
|
||||
Reference in New Issue
Block a user