fix broken merge
This commit is contained in:
@@ -448,12 +448,12 @@ class COBIntegration:
|
|||||||
cob_snapshot.liquidity_imbalance,
|
cob_snapshot.liquidity_imbalance,
|
||||||
len(cob_snapshot.exchanges_active) / 5, # Exchange count ratio
|
len(cob_snapshot.exchanges_active) / 5, # Exchange count ratio
|
||||||
min(1.0, total_liquidity / 10000000), # Liquidity abundance
|
min(1.0, total_liquidity / 10000000), # Liquidity abundance
|
||||||
0.5, # Price efficiency placeholder
|
0.0, # Price efficiency (placeholder - not yet implemented)
|
||||||
min(1.0, total_liquidity / 5000000), # Market impact resistance
|
min(1.0, total_liquidity / 5000000), # Market impact resistance
|
||||||
0.0, # Arbitrage score placeholder
|
0.0, # Arbitrage score (placeholder - not yet implemented)
|
||||||
0.0, # Liquidity fragmentation placeholder
|
0.0, # Liquidity fragmentation (placeholder - not yet implemented)
|
||||||
(datetime.now().hour * 60 + datetime.now().minute) / 1440, # Time of day
|
(datetime.now().hour * 60 + datetime.now().minute) / 1440, # Time of day
|
||||||
0.5 # Market regime indicator placeholder
|
0.0 # Market regime indicator (placeholder - not yet implemented)
|
||||||
])
|
])
|
||||||
|
|
||||||
return np.array(state_features, dtype=np.float32)
|
return np.array(state_features, dtype=np.float32)
|
||||||
|
|||||||
@@ -521,7 +521,7 @@ class EnhancedRLTrainingAdapter:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def _train_cob_rl_model(self, batch: TrainingBatch) -> bool:
|
async def _train_cob_rl_model(self, batch: TrainingBatch) -> bool:
|
||||||
"""Train COB RL model with batch data"""
|
"""Train COB RL model with batch data (placeholder - not yet implemented)"""
|
||||||
try:
|
try:
|
||||||
if (self.orchestrator and
|
if (self.orchestrator and
|
||||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||||
@@ -529,7 +529,7 @@ class EnhancedRLTrainingAdapter:
|
|||||||
|
|
||||||
# Use COB RL trainer if available
|
# Use COB RL trainer if available
|
||||||
# This is a placeholder - implement based on actual COB RL training interface
|
# This is a placeholder - implement based on actual COB RL training interface
|
||||||
logger.debug(f"COB RL training batch: {len(batch.states)} samples")
|
logger.debug(f"COB RL training batch: {len(batch.states)} samples (not yet implemented)")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -539,12 +539,12 @@ class EnhancedRLTrainingAdapter:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def _train_cnn_model(self, batch: TrainingBatch) -> bool:
|
async def _train_cnn_model(self, batch: TrainingBatch) -> bool:
|
||||||
"""Train CNN model with batch data"""
|
"""Train CNN model with batch data (placeholder - not yet implemented)"""
|
||||||
try:
|
try:
|
||||||
if self.orchestrator and hasattr(self.orchestrator, 'enhanced_training_system'):
|
if self.orchestrator and hasattr(self.orchestrator, 'enhanced_training_system'):
|
||||||
# Use enhanced training system for CNN training
|
# Use enhanced training system for CNN training
|
||||||
# This is a placeholder - implement based on actual CNN training interface
|
# This is a placeholder - implement based on actual CNN training interface
|
||||||
logger.debug(f"CNN training batch: {len(batch.states)} samples")
|
logger.debug(f"CNN training batch: {len(batch.states)} samples (not yet implemented)")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -237,6 +237,13 @@ class MEXCInterface(ExchangeInterface):
|
|||||||
|
|
||||||
response = self._send_public_request('GET', endpoint, params)
|
response = self._send_public_request('GET', endpoint, params)
|
||||||
|
|
||||||
|
if response and isinstance(response, dict):
|
||||||
|
return response
|
||||||
|
elif response and isinstance(response, list) and len(response) > 0:
|
||||||
|
# Find the ticker for our symbol
|
||||||
|
for ticker in response:
|
||||||
|
if ticker.get('symbol') == formatted_symbol:
|
||||||
|
return ticker
|
||||||
else:
|
else:
|
||||||
logger.error(f"Ticker data for {formatted_symbol} not found in response list.")
|
logger.error(f"Ticker data for {formatted_symbol} not found in response list.")
|
||||||
return None
|
return None
|
||||||
@@ -296,6 +303,12 @@ class MEXCInterface(ExchangeInterface):
|
|||||||
|
|
||||||
def place_order(self, symbol: str, side: str, order_type: str, quantity: float, price: Optional[float] = None) -> Dict[str, Any]:
|
def place_order(self, symbol: str, side: str, order_type: str, quantity: float, price: Optional[float] = None) -> Dict[str, Any]:
|
||||||
"""Place a new order on MEXC."""
|
"""Place a new order on MEXC."""
|
||||||
|
try:
|
||||||
|
formatted_symbol = self._format_spot_symbol(symbol)
|
||||||
|
|
||||||
|
# Validate symbol is supported
|
||||||
|
if not self.is_symbol_supported(formatted_symbol):
|
||||||
|
logger.error(f"Symbol {formatted_symbol} is not supported by MEXC")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Round quantity to MEXC precision requirements and ensure minimum order value
|
# Round quantity to MEXC precision requirements and ensure minimum order value
|
||||||
|
|||||||
@@ -975,17 +975,17 @@ class MultiExchangeCOBProvider:
|
|||||||
logger.info(f"Disconnected from Kraken order book stream for {symbol}")
|
logger.info(f"Disconnected from Kraken order book stream for {symbol}")
|
||||||
|
|
||||||
async def _stream_huobi_orderbook(self, symbol: str, config: ExchangeConfig):
|
async def _stream_huobi_orderbook(self, symbol: str, config: ExchangeConfig):
|
||||||
"""Stream Huobi order book data (placeholder implementation)"""
|
"""Stream Huobi order book data (placeholder - not yet implemented)"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"Huobi streaming for {symbol} not yet implemented")
|
logger.debug(f"Huobi streaming for {symbol} not yet implemented")
|
||||||
await asyncio.sleep(60) # Sleep to prevent spam
|
await asyncio.sleep(60) # Sleep to prevent spam
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error streaming Huobi order book for {symbol}: {e}")
|
logger.error(f"Error streaming Huobi order book for {symbol}: {e}")
|
||||||
|
|
||||||
async def _stream_bitfinex_orderbook(self, symbol: str, config: ExchangeConfig):
|
async def _stream_bitfinex_orderbook(self, symbol: str, config: ExchangeConfig):
|
||||||
"""Stream Bitfinex order book data (placeholder implementation)"""
|
"""Stream Bitfinex order book data (placeholder - not yet implemented)"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"Bitfinex streaming for {symbol} not yet implemented")
|
logger.debug(f"Bitfinex streaming for {symbol} not yet implemented")
|
||||||
await asyncio.sleep(60) # Sleep to prevent spam
|
await asyncio.sleep(60) # Sleep to prevent spam
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error streaming Bitfinex order book for {symbol}: {e}")
|
logger.error(f"Error streaming Bitfinex order book for {symbol}: {e}")
|
||||||
|
|||||||
@@ -688,28 +688,26 @@ class MultiHorizonPredictionManager:
|
|||||||
|
|
||||||
# Placeholder methods for CNN and RL feature preparation - to be implemented
|
# Placeholder methods for CNN and RL feature preparation - to be implemented
|
||||||
def _prepare_cnn_features_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray:
|
def _prepare_cnn_features_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray:
|
||||||
"""Prepare CNN features for specific horizon - placeholder"""
|
"""Prepare CNN features for specific horizon (placeholder - not yet implemented)"""
|
||||||
# This would extract relevant features based on horizon
|
# This would extract relevant features based on horizon
|
||||||
return np.random.rand(50) # Placeholder
|
logger.debug(f"CNN feature preparation for horizon {horizon} not yet implemented")
|
||||||
|
return np.array([]) # Return empty array instead of synthetic data
|
||||||
|
|
||||||
def _prepare_rl_state_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray:
|
def _prepare_rl_state_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray:
|
||||||
"""Prepare RL state for specific horizon - placeholder"""
|
"""Prepare RL state for specific horizon (placeholder - not yet implemented)"""
|
||||||
# This would create state representation for the horizon
|
# This would create state representation for the horizon
|
||||||
return np.random.rand(100) # Placeholder
|
logger.debug(f"RL state preparation for horizon {horizon} not yet implemented")
|
||||||
|
return np.array([]) # Return empty array instead of synthetic data
|
||||||
|
|
||||||
def _interpret_cnn_output(self, cnn_output, current_price: float, horizon: int) -> Tuple[float, float, float]:
|
def _interpret_cnn_output(self, cnn_output, current_price: float, horizon: int) -> Tuple[float, float, float]:
|
||||||
"""Interpret CNN output for min/max prediction - placeholder"""
|
"""Interpret CNN output for min/max prediction (placeholder - not yet implemented)"""
|
||||||
# This would convert CNN output to price predictions
|
# This would convert CNN output to price predictions
|
||||||
range_percent = 0.05 # 5% range
|
logger.debug(f"CNN output interpretation for horizon {horizon} not yet implemented")
|
||||||
return (current_price * 0.95, current_price * 1.05, 0.6) # Placeholder
|
return (0.0, 0.0, 0.0) # Return zeros instead of synthetic predictions
|
||||||
|
|
||||||
def _convert_rl_action_to_price_prediction(self, action: int, current_price: float,
|
def _convert_rl_action_to_price_prediction(self, action: int, current_price: float,
|
||||||
horizon: int, rl_agent) -> Tuple[float, float, float]:
|
horizon: int, rl_agent) -> Tuple[float, float, float]:
|
||||||
"""Convert RL action to price prediction - placeholder"""
|
"""Convert RL action to price prediction (placeholder - not yet implemented)"""
|
||||||
# This would interpret RL action as price movement expectation
|
# This would interpret RL action as price movement expectation
|
||||||
if action == 0: # BUY
|
logger.debug(f"RL action conversion for horizon {horizon} not yet implemented")
|
||||||
return (current_price * 0.98, current_price * 1.03, 0.7)
|
return (0.0, 0.0, 0.0) # Return zeros instead of synthetic predictions
|
||||||
elif action == 1: # SELL
|
|
||||||
return (current_price * 0.97, current_price * 1.02, 0.7)
|
|
||||||
else: # HOLD
|
|
||||||
return (current_price * 0.99, current_price * 1.01, 0.5)
|
|
||||||
|
|||||||
@@ -299,14 +299,24 @@ class TradingOrchestrator:
|
|||||||
Includes EnhancedRealtimeTrainingSystem for continuous learning
|
Includes EnhancedRealtimeTrainingSystem for continuous learning
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[ModelRegistry] = None):
|
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[Any] = None):
|
||||||
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
||||||
self.config = get_config()
|
self.config = get_config()
|
||||||
self.data_provider = data_provider or DataProvider()
|
self.data_provider = data_provider or DataProvider()
|
||||||
self.universal_adapter = UniversalDataAdapter(self.data_provider)
|
self.universal_adapter = UniversalDataAdapter(self.data_provider)
|
||||||
self.model_manager = model_manager or create_model_manager()
|
self.model_manager = None # Will be initialized later if needed
|
||||||
self.enhanced_rl_training = enhanced_rl_training
|
self.enhanced_rl_training = enhanced_rl_training
|
||||||
|
|
||||||
|
# Set primary trading symbol
|
||||||
|
self.symbol = self.config.get('primary_symbol', 'ETH/USDT')
|
||||||
|
self.ref_symbols = self.config.get('reference_symbols', ['BTC/USDT'])
|
||||||
|
|
||||||
|
# Initialize signal accumulator
|
||||||
|
self.signal_accumulator = {}
|
||||||
|
|
||||||
|
# Initialize confidence threshold
|
||||||
|
self.confidence_threshold = self.config.get('confidence_threshold', 0.6)
|
||||||
|
|
||||||
# Determine the device to use (GPU if available, else CPU)
|
# Determine the device to use (GPU if available, else CPU)
|
||||||
# Initialize device - force CPU mode to avoid CUDA errors
|
# Initialize device - force CPU mode to avoid CUDA errors
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -449,8 +459,8 @@ class TradingOrchestrator:
|
|||||||
self.last_inference: Dict[str, Dict] = {} # {model_name: last_inference_record}
|
self.last_inference: Dict[str, Dict] = {} # {model_name: last_inference_record}
|
||||||
|
|
||||||
# Initialize inference logger
|
# Initialize inference logger
|
||||||
self.inference_logger = get_inference_logger()
|
self.inference_logger = None # Will be initialized later if needed
|
||||||
self.db_manager = get_database_manager()
|
self.db_manager = None # Will be initialized later if needed
|
||||||
|
|
||||||
# ENHANCED: Real-time Training System Integration
|
# ENHANCED: Real-time Training System Integration
|
||||||
self.enhanced_training_system = (
|
self.enhanced_training_system = (
|
||||||
@@ -510,6 +520,247 @@ class TradingOrchestrator:
|
|||||||
self._initialize_decision_fusion() # Initialize fusion system
|
self._initialize_decision_fusion() # Initialize fusion system
|
||||||
self._initialize_transformer_model() # Initialize transformer model
|
self._initialize_transformer_model() # Initialize transformer model
|
||||||
self._initialize_enhanced_training_system() # Initialize real-time training
|
self._initialize_enhanced_training_system() # Initialize real-time training
|
||||||
|
|
||||||
|
def _normalize_model_name(self, model_name: str) -> str:
|
||||||
|
"""Normalize model name for consistent storage"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Convert to lowercase
|
||||||
|
normalized = model_name.lower()
|
||||||
|
|
||||||
|
# Replace spaces, hyphens, and other non-alphanumeric separators with underscores
|
||||||
|
normalized = re.sub(r'[^a-z0-9]+', '_', normalized)
|
||||||
|
|
||||||
|
# Collapse multiple consecutive underscores into a single underscore
|
||||||
|
normalized = re.sub(r'_+', '_', normalized)
|
||||||
|
|
||||||
|
# Strip leading and trailing underscores
|
||||||
|
normalized = normalized.strip('_')
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def _log_data_status(self):
|
||||||
|
"""Log data provider status"""
|
||||||
|
logger.info(f"Data provider initialized for symbols: {self.data_provider.symbols}")
|
||||||
|
logger.info(f"Available timeframes: {self.data_provider.timeframes}")
|
||||||
|
|
||||||
|
def _schedule_database_cleanup(self):
|
||||||
|
"""
|
||||||
|
Schedule periodic database cleanup tasks.
|
||||||
|
|
||||||
|
This method sets up a background task that periodically cleans up old
|
||||||
|
inference records from the database to prevent it from growing indefinitely.
|
||||||
|
|
||||||
|
Side effects:
|
||||||
|
- Creates a background asyncio task that runs every 24 hours
|
||||||
|
- Cleans up records older than 30 days by default
|
||||||
|
- Logs cleanup operations and any errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from utils.database_manager import get_database_manager
|
||||||
|
|
||||||
|
# Get database manager instance
|
||||||
|
db_manager = get_database_manager()
|
||||||
|
|
||||||
|
async def cleanup_task():
|
||||||
|
"""Background task for periodic database cleanup"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
logger.info("Running scheduled database cleanup...")
|
||||||
|
success = db_manager.cleanup_old_records(days_to_keep=30)
|
||||||
|
if success:
|
||||||
|
logger.info("Database cleanup completed successfully")
|
||||||
|
else:
|
||||||
|
logger.warning("Database cleanup failed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during database cleanup: {e}")
|
||||||
|
|
||||||
|
# Wait 24 hours before next cleanup
|
||||||
|
await asyncio.sleep(24 * 60 * 60) # 24 hours in seconds
|
||||||
|
|
||||||
|
# Create and start the cleanup task
|
||||||
|
self._db_cleanup_task = asyncio.create_task(cleanup_task())
|
||||||
|
logger.info("Database cleanup scheduler started - will run every 24 hours")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to schedule database cleanup: {e}")
|
||||||
|
logger.warning("Database cleanup will not be performed automatically")
|
||||||
|
|
||||||
|
def _initialize_checkpoint_manager(self):
|
||||||
|
"""
|
||||||
|
Initialize the global checkpoint manager for model checkpoint management.
|
||||||
|
|
||||||
|
This method initializes the checkpoint manager that handles:
|
||||||
|
- Saving model checkpoints with metadata
|
||||||
|
- Loading the best performing checkpoints
|
||||||
|
- Managing checkpoint storage and cleanup
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckpointManager: The initialized checkpoint manager instance, or None if initialization fails
|
||||||
|
|
||||||
|
Side effects:
|
||||||
|
- Sets self.checkpoint_manager to the global checkpoint manager instance
|
||||||
|
- Creates checkpoint directory if it doesn't exist
|
||||||
|
- Logs initialization status
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from utils.checkpoint_manager import get_checkpoint_manager
|
||||||
|
|
||||||
|
# Initialize the global checkpoint manager
|
||||||
|
self.checkpoint_manager = get_checkpoint_manager(
|
||||||
|
checkpoint_dir="models/checkpoints",
|
||||||
|
max_checkpoints=10,
|
||||||
|
metric_name="accuracy"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Checkpoint manager initialized successfully with directory: models/checkpoints")
|
||||||
|
logger.info(f"Maximum checkpoints per model: 10, Primary metric: accuracy")
|
||||||
|
|
||||||
|
return self.checkpoint_manager
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize checkpoint manager: {e}")
|
||||||
|
self.checkpoint_manager = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _start_cob_integration_sync(self):
|
||||||
|
"""
|
||||||
|
Start COB (Consolidated Order Book) integration synchronization.
|
||||||
|
|
||||||
|
This method initiates the COB integration system that provides real-time
|
||||||
|
market microstructure data to the trading models. The COB integration
|
||||||
|
streams order book data and generates features for CNN and DQN models.
|
||||||
|
|
||||||
|
Side effects:
|
||||||
|
- Creates an async task to start COB integration if available
|
||||||
|
- Registers COB data callbacks for model feeding
|
||||||
|
- Begins streaming COB features to registered models
|
||||||
|
- Logs integration status and any errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.cob_integration is None:
|
||||||
|
logger.warning("COB integration not initialized - cannot start sync")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create async task to start COB integration
|
||||||
|
# Since this is called from __init__ (sync context), we need to create a task
|
||||||
|
async def start_cob_task():
|
||||||
|
try:
|
||||||
|
await self.start_cob_integration()
|
||||||
|
logger.info("COB integration synchronization started successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start COB integration sync: {e}")
|
||||||
|
|
||||||
|
# Create the task (will be executed when event loop is running)
|
||||||
|
self._cob_sync_task = asyncio.create_task(start_cob_task())
|
||||||
|
logger.info("COB integration sync task created - will start when event loop is available")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize COB integration sync: {e}")
|
||||||
|
logger.warning("COB integration will not be available")
|
||||||
|
|
||||||
|
def _initialize_transformer_model(self):
|
||||||
|
"""
|
||||||
|
Initialize the transformer model for advanced trading pattern recognition.
|
||||||
|
|
||||||
|
This method loads or creates an AdvancedTradingTransformer model that uses
|
||||||
|
attention mechanisms to analyze complex market patterns and generate trading signals.
|
||||||
|
The model is optimized for COB (Consolidated Order Book) data and technical indicators.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if initialization successful, False otherwise
|
||||||
|
|
||||||
|
Side effects:
|
||||||
|
- Sets self.primary_transformer to the loaded/created transformer model
|
||||||
|
- Sets self.primary_transformer_trainer to the associated trainer
|
||||||
|
- Updates self.transformer_checkpoint_info with checkpoint metadata
|
||||||
|
- Loads best available checkpoint if exists
|
||||||
|
- Moves model to appropriate device (CPU/GPU)
|
||||||
|
- Logs initialization status and any errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from NN.models.advanced_transformer_trading import (
|
||||||
|
AdvancedTradingTransformer,
|
||||||
|
TradingTransformerTrainer,
|
||||||
|
TradingTransformerConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Initializing transformer model for trading...")
|
||||||
|
|
||||||
|
# Create transformer configuration
|
||||||
|
config = TradingTransformerConfig()
|
||||||
|
|
||||||
|
# Initialize the transformer model
|
||||||
|
self.primary_transformer = AdvancedTradingTransformer(config)
|
||||||
|
logger.info(f"AdvancedTradingTransformer created with config: d_model={config.d_model}, "
|
||||||
|
f"n_heads={config.n_heads}, n_layers={config.n_layers}")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
self.primary_transformer_trainer = TradingTransformerTrainer(
|
||||||
|
model=self.primary_transformer,
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
logger.info("TradingTransformerTrainer initialized")
|
||||||
|
|
||||||
|
# Move model to device
|
||||||
|
if hasattr(self, 'device') and self.device:
|
||||||
|
self.primary_transformer.to(self.device)
|
||||||
|
logger.info(f"Transformer model moved to device: {self.device}")
|
||||||
|
else:
|
||||||
|
logger.info("Transformer model using default device")
|
||||||
|
|
||||||
|
# Try to load best checkpoint
|
||||||
|
checkpoint_loaded = False
|
||||||
|
try:
|
||||||
|
if self.checkpoint_manager:
|
||||||
|
checkpoint_path, checkpoint_metadata = self.checkpoint_manager.load_best_checkpoint("transformer")
|
||||||
|
if checkpoint_path and checkpoint_metadata:
|
||||||
|
# Load the checkpoint
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
self.primary_transformer.load_state_dict(checkpoint.get('model_state_dict', checkpoint))
|
||||||
|
|
||||||
|
# Update checkpoint info
|
||||||
|
self.transformer_checkpoint_info = {
|
||||||
|
'path': checkpoint_path,
|
||||||
|
'metadata': checkpoint_metadata,
|
||||||
|
'loaded_at': datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Transformer checkpoint loaded from: {checkpoint_path}")
|
||||||
|
logger.info(f"Checkpoint metrics: {checkpoint_metadata.get('performance_metrics', {})}")
|
||||||
|
checkpoint_loaded = True
|
||||||
|
else:
|
||||||
|
logger.info("No transformer checkpoint found - using fresh model")
|
||||||
|
else:
|
||||||
|
logger.warning("Checkpoint manager not available - cannot load transformer checkpoint")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading transformer checkpoint: {e}")
|
||||||
|
logger.info("Continuing with fresh transformer model")
|
||||||
|
|
||||||
|
if not checkpoint_loaded:
|
||||||
|
# Initialize checkpoint info for new model
|
||||||
|
self.transformer_checkpoint_info = {
|
||||||
|
'status': 'fresh_model',
|
||||||
|
'created_at': datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Transformer model initialization completed successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Advanced transformer trading module not available: {e}")
|
||||||
|
self.primary_transformer = None
|
||||||
|
self.primary_transformer_trainer = None
|
||||||
|
logger.info("Transformer model will not be available")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize transformer model: {e}")
|
||||||
|
self.primary_transformer = None
|
||||||
|
self.primary_transformer_trainer = None
|
||||||
|
return False
|
||||||
|
|
||||||
def _initialize_ml_models(self):
|
def _initialize_ml_models(self):
|
||||||
"""Initialize ML models for enhanced trading"""
|
"""Initialize ML models for enhanced trading"""
|
||||||
try:
|
try:
|
||||||
@@ -549,7 +800,7 @@ class TradingOrchestrator:
|
|||||||
if hasattr(self.rl_agent, "load_best_checkpoint"):
|
if hasattr(self.rl_agent, "load_best_checkpoint"):
|
||||||
try:
|
try:
|
||||||
self.rl_agent.load_best_checkpoint()
|
self.rl_agent.load_best_checkpoint()
|
||||||
checkpoint_loaded = True
|
checkpoint_loaded = True
|
||||||
logger.info("DQN checkpoint loaded successfully")
|
logger.info("DQN checkpoint loaded successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
|
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
|
||||||
@@ -593,7 +844,7 @@ class TradingOrchestrator:
|
|||||||
try:
|
try:
|
||||||
# CNN checkpoint loading would go here
|
# CNN checkpoint loading would go here
|
||||||
logger.info("CNN checkpoint loaded successfully")
|
logger.info("CNN checkpoint loaded successfully")
|
||||||
checkpoint_loaded = True
|
checkpoint_loaded = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||||
checkpoint_loaded = False
|
checkpoint_loaded = False
|
||||||
@@ -687,7 +938,7 @@ class TradingOrchestrator:
|
|||||||
logger.warning("Extrema trainer not available")
|
logger.warning("Extrema trainer not available")
|
||||||
self.extrema_trainer = None
|
self.extrema_trainer = None
|
||||||
|
|
||||||
self.cob_rl_agent = None
|
self.cob_rl_agent = None
|
||||||
|
|
||||||
|
|
||||||
# CRITICAL: Register models with the model registry
|
# CRITICAL: Register models with the model registry
|
||||||
@@ -704,7 +955,7 @@ class TradingOrchestrator:
|
|||||||
try:
|
try:
|
||||||
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
|
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
|
||||||
if self.model_registry.register_model(rl_interface):
|
if self.model_registry.register_model(rl_interface):
|
||||||
logger.info("RL Agent registered successfully")
|
logger.info("RL Agent registered successfully")
|
||||||
else:
|
else:
|
||||||
logger.error("Failed to register RL Agent with registry")
|
logger.error("Failed to register RL Agent with registry")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -715,7 +966,7 @@ class TradingOrchestrator:
|
|||||||
try:
|
try:
|
||||||
cnn_interface = CNNModelInterface(self.cnn_model, name="cnn_model")
|
cnn_interface = CNNModelInterface(self.cnn_model, name="cnn_model")
|
||||||
if self.model_registry.register_model(cnn_interface):
|
if self.model_registry.register_model(cnn_interface):
|
||||||
logger.info("CNN Model registered successfully")
|
logger.info("CNN Model registered successfully")
|
||||||
else:
|
else:
|
||||||
logger.error("Failed to register CNN Model with registry")
|
logger.error("Failed to register CNN Model with registry")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -726,14 +977,13 @@ class TradingOrchestrator:
|
|||||||
try:
|
try:
|
||||||
extrema_interface = ExtremaTrainerInterface(self.extrema_trainer, name="extrema_trainer")
|
extrema_interface = ExtremaTrainerInterface(self.extrema_trainer, name="extrema_trainer")
|
||||||
if self.model_registry.register_model(extrema_interface):
|
if self.model_registry.register_model(extrema_interface):
|
||||||
logger.info("Extrema Trainer registered successfully")
|
logger.info("Extrema Trainer registered successfully")
|
||||||
else:
|
else:
|
||||||
logger.error("Failed to register Extrema Trainer with registry")
|
logger.error("Failed to register Extrema Trainer with registry")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to register Extrema Trainer: {e}")
|
logger.error(f"Failed to register Extrema Trainer: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error initializing ML models: {e}")
|
logger.error(f"Error initializing ML models: {e}")
|
||||||
|
|
||||||
def get_model_training_stats(self) -> Dict[str, Dict[str, Any]]:
|
def get_model_training_stats(self) -> Dict[str, Dict[str, Any]]:
|
||||||
@@ -841,7 +1091,7 @@ class TradingOrchestrator:
|
|||||||
}
|
}
|
||||||
|
|
||||||
for model_name, stats in dashboard_stats.items():
|
for model_name, stats in dashboard_stats.items():
|
||||||
if model_name in self.model_states:
|
if model_name in self.model_states:
|
||||||
self.model_states[model_name]["current_loss"] = stats["current_loss"]
|
self.model_states[model_name]["current_loss"] = stats["current_loss"]
|
||||||
self.model_states[model_name]["initial_loss"] = stats["initial_loss"]
|
self.model_states[model_name]["initial_loss"] = stats["initial_loss"]
|
||||||
if (
|
if (
|
||||||
@@ -1066,7 +1316,7 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
with open(session_file, "w", encoding="utf-8") as f:
|
with open(session_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(existing, f, indent=2)
|
json.dump(existing, f, indent=2)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error appending session snapshot: {e}")
|
logger.error(f"Error appending session snapshot: {e}")
|
||||||
|
|
||||||
def get_model_toggle_state(self, model_name: str) -> Dict[str, bool]:
|
def get_model_toggle_state(self, model_name: str) -> Dict[str, bool]:
|
||||||
@@ -1124,8 +1374,7 @@ class TradingOrchestrator:
|
|||||||
self._save_ui_state()
|
self._save_ui_state()
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
except Exception as e:
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error registering model {model_name} dynamically: {e}")
|
logger.error(f"Error registering model {model_name} dynamically: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -1240,7 +1489,7 @@ class TradingOrchestrator:
|
|||||||
self._on_cob_dashboard_data
|
self._on_cob_dashboard_data
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to initialize COB Integration: {e}")
|
logger.warning(f"Failed to initialize COB Integration: {e}")
|
||||||
self.cob_integration = None
|
self.cob_integration = None
|
||||||
else:
|
else:
|
||||||
@@ -1460,10 +1709,10 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Use programmatic decision for actual actions
|
# Use programmatic decision for actual actions
|
||||||
decision = self._combine_predictions(
|
decision = self._combine_predictions(
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
price=current_price,
|
price=current_price,
|
||||||
predictions=predictions,
|
predictions=predictions,
|
||||||
timestamp=current_time,
|
timestamp=current_time,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1545,7 +1794,7 @@ class TradingOrchestrator:
|
|||||||
price_change_pct = (
|
price_change_pct = (
|
||||||
(current_price - recent_prices[-2]) / recent_prices[-2] * 100
|
(current_price - recent_prices[-2]) / recent_prices[-2] * 100
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Could not get recent prices for {symbol}: {e}")
|
logger.debug(f"Could not get recent prices for {symbol}: {e}")
|
||||||
# Fallback: use current price and a small assumed change
|
# Fallback: use current price and a small assumed change
|
||||||
price_change_pct = 0.1 # Assume small positive change
|
price_change_pct = 0.1 # Assume small positive change
|
||||||
@@ -1658,7 +1907,7 @@ class TradingOrchestrator:
|
|||||||
# Validate base_data has the required method
|
# Validate base_data has the required method
|
||||||
if not hasattr(base_data, 'get_feature_vector'):
|
if not hasattr(base_data, 'get_feature_vector'):
|
||||||
logger.debug(f"BaseDataInput for {symbol} missing get_feature_vector method")
|
logger.debug(f"BaseDataInput for {symbol} missing get_feature_vector method")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get unified feature vector (7850 features including all timeframes and COB data)
|
# Get unified feature vector (7850 features including all timeframes and COB data)
|
||||||
feature_vector = base_data.get_feature_vector()
|
feature_vector = base_data.get_feature_vector()
|
||||||
@@ -1724,7 +1973,7 @@ class TradingOrchestrator:
|
|||||||
action_scores[action] += random.uniform(-0.001, 0.001)
|
action_scores[action] += random.uniform(-0.001, 0.001)
|
||||||
|
|
||||||
best_action = max(action_scores.keys(), key=lambda k: action_scores[k])
|
best_action = max(action_scores.keys(), key=lambda k: action_scores[k])
|
||||||
best_confidence = action_scores[best_action]
|
best_confidence = action_scores[best_action]
|
||||||
|
|
||||||
# DEBUG: Log action scores to understand bias
|
# DEBUG: Log action scores to understand bias
|
||||||
logger.debug(f"Action scores for {symbol}: BUY={action_scores['BUY']:.3f}, SELL={action_scores['SELL']:.3f}, HOLD={action_scores['HOLD']:.3f}")
|
logger.debug(f"Action scores for {symbol}: BUY={action_scores['BUY']:.3f}, SELL={action_scores['SELL']:.3f}, HOLD={action_scores['HOLD']:.3f}")
|
||||||
@@ -1958,8 +2207,8 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
# Try to load decision fusion checkpoint
|
# Try to load decision fusion checkpoint
|
||||||
result = load_best_checkpoint("decision_fusion")
|
result = load_best_checkpoint("decision_fusion")
|
||||||
if result:
|
if result:
|
||||||
file_path, metadata = result
|
file_path, metadata = result
|
||||||
# Load the checkpoint into the network
|
# Load the checkpoint into the network
|
||||||
checkpoint = torch.load(file_path, map_location=self.device)
|
checkpoint = torch.load(file_path, map_location=self.device)
|
||||||
|
|
||||||
@@ -1989,11 +2238,11 @@ class TradingOrchestrator:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Decision fusion network loaded from checkpoint: {metadata.checkpoint_id} (loss={loss_str})"
|
f"Decision fusion network loaded from checkpoint: {metadata.checkpoint_id} (loss={loss_str})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"No existing decision fusion checkpoint found, starting fresh"
|
"No existing decision fusion checkpoint found, starting fresh"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading decision fusion checkpoint: {e}")
|
logger.warning(f"Error loading decision fusion checkpoint: {e}")
|
||||||
logger.info("Decision fusion network starting fresh")
|
logger.info("Decision fusion network starting fresh")
|
||||||
|
|
||||||
@@ -2053,6 +2302,7 @@ class TradingOrchestrator:
|
|||||||
def stop_enhanced_training(self):
|
def stop_enhanced_training(self):
|
||||||
"""Stop the enhanced real-time training system"""
|
"""Stop the enhanced real-time training system"""
|
||||||
try:
|
try:
|
||||||
|
if self.enhanced_training_system:
|
||||||
self.enhanced_training_system.stop_training()
|
self.enhanced_training_system.stop_training()
|
||||||
logger.info("Enhanced real-time training stopped")
|
logger.info("Enhanced real-time training stopped")
|
||||||
return True
|
return True
|
||||||
@@ -2075,12 +2325,18 @@ class TradingOrchestrator:
|
|||||||
# Get base stats from enhanced training system
|
# Get base stats from enhanced training system
|
||||||
stats = {}
|
stats = {}
|
||||||
if hasattr(self.enhanced_training_system, "get_training_statistics"):
|
if hasattr(self.enhanced_training_system, "get_training_statistics"):
|
||||||
stats = self.enhanced_training_system.get_training_statistics()
|
stats = self.enhanced_training_system.get_training_statistics()
|
||||||
|
else:
|
||||||
|
stats = {}
|
||||||
|
|
||||||
stats["training_enabled"] = self.training_enabled
|
stats["training_enabled"] = self.training_enabled
|
||||||
stats["system_available"] = ENHANCED_TRAINING_AVAILABLE
|
stats["system_available"] = ENHANCED_TRAINING_AVAILABLE
|
||||||
|
|
||||||
# Add orchestrator-specific training integration data
|
# Add orchestrator-specific training integration data
|
||||||
|
stats["orchestrator_integration"] = {
|
||||||
|
"enhanced_training_enabled": self.enhanced_training_enabled,
|
||||||
|
"model_registry_count": len(self.model_registry.models) if hasattr(self, 'model_registry') else 0,
|
||||||
|
"decision_fusion_enabled": self.decision_fusion_enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add model-specific training status from orchestrator
|
# Add model-specific training status from orchestrator
|
||||||
@@ -2220,7 +2476,7 @@ class TradingOrchestrator:
|
|||||||
current_time
|
current_time
|
||||||
)
|
)
|
||||||
elif self.universal_adapter:
|
elif self.universal_adapter:
|
||||||
return self.universal_adapter.get_universal_data_stream(current_time)
|
return self.universal_adapter.get_universal_data_stream(current_time)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting universal data stream: {e}")
|
logger.error(f"Error getting universal data stream: {e}")
|
||||||
@@ -2235,10 +2491,10 @@ class TradingOrchestrator:
|
|||||||
stream, model_type
|
stream, model_type
|
||||||
)
|
)
|
||||||
elif self.universal_adapter:
|
elif self.universal_adapter:
|
||||||
stream = self.universal_adapter.get_universal_data_stream()
|
stream = self.universal_adapter.get_universal_data_stream()
|
||||||
if stream:
|
if stream:
|
||||||
return self.universal_adapter.format_for_model(stream, model_type)
|
return self.universal_adapter.format_for_model(stream, model_type)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting universal data for {model_type}: {e}")
|
logger.error(f"Error getting universal data for {model_type}: {e}")
|
||||||
return None
|
return None
|
||||||
@@ -2278,7 +2534,7 @@ class TradingOrchestrator:
|
|||||||
side = position.get("side", "LONG")
|
side = position.get("side", "LONG")
|
||||||
|
|
||||||
if entry_price and size > 0:
|
if entry_price and size > 0:
|
||||||
if side.upper() == "LONG":
|
if side.upper() == "LONG":
|
||||||
pnl = (current_price - entry_price) * size
|
pnl = (current_price - entry_price) * size
|
||||||
else: # SHORT
|
else: # SHORT
|
||||||
pnl = (entry_price - current_price) * size
|
pnl = (entry_price - current_price) * size
|
||||||
|
|||||||
@@ -428,7 +428,7 @@ class TimeframeInferenceCoordinator:
|
|||||||
async def _call_model_training(self, model_name: str, symbol: str,
|
async def _call_model_training(self, model_name: str, symbol: str,
|
||||||
timeframe: TimeFrame, training_data: List[Any]):
|
timeframe: TimeFrame, training_data: List[Any]):
|
||||||
"""
|
"""
|
||||||
Call model-specific training function
|
Call model-specific training function (placeholder - not yet implemented)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the model to train
|
model_name: Name of the model to train
|
||||||
@@ -438,7 +438,7 @@ class TimeframeInferenceCoordinator:
|
|||||||
"""
|
"""
|
||||||
# This is a placeholder for model-specific training calls
|
# This is a placeholder for model-specific training calls
|
||||||
# You'll need to implement this based on your specific model interfaces
|
# You'll need to implement this based on your specific model interfaces
|
||||||
logger.debug(f"Training call for {model_name}: {len(training_data)} samples")
|
logger.debug(f"Training call for {model_name}: {len(training_data)} samples (not yet implemented)")
|
||||||
|
|
||||||
def get_inference_statistics(self) -> Dict[str, Any]:
|
def get_inference_statistics(self) -> Dict[str, Any]:
|
||||||
"""Get inference coordination statistics"""
|
"""Get inference coordination statistics"""
|
||||||
|
|||||||
@@ -221,6 +221,7 @@ class TradingExecutor:
|
|||||||
# Connect to exchange - skip connection check in simulation mode
|
# Connect to exchange - skip connection check in simulation mode
|
||||||
if self.trading_enabled:
|
if self.trading_enabled:
|
||||||
if self.simulation_mode:
|
if self.simulation_mode:
|
||||||
|
logger.info("TRADING EXECUTOR: Running in simulation mode - no exchange connection needed")
|
||||||
else:
|
else:
|
||||||
logger.info("TRADING EXECUTOR: Attempting to connect to exchange...")
|
logger.info("TRADING EXECUTOR: Attempting to connect to exchange...")
|
||||||
if not self._connect_exchange():
|
if not self._connect_exchange():
|
||||||
@@ -533,8 +534,8 @@ class TradingExecutor:
|
|||||||
# For simplicity, assume required capital is the full position value in USD
|
# For simplicity, assume required capital is the full position value in USD
|
||||||
required_capital = self._calculate_position_size(confidence, current_price)
|
required_capital = self._calculate_position_size(confidence, current_price)
|
||||||
|
|
||||||
else:
|
# Get available balance
|
||||||
available_balance = self.exchange.get_balance(quote_asset)
|
available_balance = self.exchange.get_balance(quote_asset)
|
||||||
|
|
||||||
logger.info(f"BALANCE CHECK: Symbol: {symbol}, Action: {action}, Required: ${required_capital:.2f} {quote_asset}, Available: ${available_balance:.2f} {quote_asset}")
|
logger.info(f"BALANCE CHECK: Symbol: {symbol}, Action: {action}, Required: ${required_capital:.2f} {quote_asset}, Available: ${available_balance:.2f} {quote_asset}")
|
||||||
|
|
||||||
@@ -1401,6 +1402,23 @@ class TradingExecutor:
|
|||||||
if self.simulation_mode:
|
if self.simulation_mode:
|
||||||
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Short close logged but not executed")
|
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Short close logged but not executed")
|
||||||
# Calculate simulated fees in simulation mode
|
# Calculate simulated fees in simulation mode
|
||||||
|
trading_fees = self.exchange_config.get('trading_fees', {})
|
||||||
|
taker_fee_rate = trading_fees.get('taker_fee', trading_fees.get('default_fee', 0.0006))
|
||||||
|
simulated_fees = position.quantity * current_price * taker_fee_rate
|
||||||
|
|
||||||
|
# Get current leverage setting
|
||||||
|
leverage = self.get_leverage()
|
||||||
|
|
||||||
|
# Calculate position size in USD
|
||||||
|
position_size_usd = position.quantity * position.entry_price
|
||||||
|
|
||||||
|
# Calculate gross PnL (before fees) with leverage - SHORT profits when price falls
|
||||||
|
gross_pnl = (position.entry_price - current_price) * position.quantity * leverage
|
||||||
|
|
||||||
|
# Calculate net PnL (after fees)
|
||||||
|
net_pnl = gross_pnl - simulated_fees
|
||||||
|
|
||||||
|
# Calculate hold time
|
||||||
exit_time = datetime.now()
|
exit_time = datetime.now()
|
||||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||||
|
|
||||||
@@ -1413,12 +1431,37 @@ class TradingExecutor:
|
|||||||
exit_price=current_price,
|
exit_price=current_price,
|
||||||
entry_time=position.entry_time,
|
entry_time=position.entry_time,
|
||||||
exit_time=exit_time,
|
exit_time=exit_time,
|
||||||
|
pnl=net_pnl, # Store net PnL as the main PnL value
|
||||||
|
fees=simulated_fees,
|
||||||
|
confidence=confidence,
|
||||||
|
hold_time_seconds=hold_time_seconds,
|
||||||
|
leverage=leverage,
|
||||||
|
position_size_usd=position_size_usd,
|
||||||
|
gross_pnl=gross_pnl,
|
||||||
|
net_pnl=net_pnl
|
||||||
|
)
|
||||||
|
|
||||||
|
self.trade_history.append(trade_record)
|
||||||
|
self.trade_records.append(trade_record)
|
||||||
|
self.daily_loss += max(0, -net_pnl) # Use net_pnl instead of pnl
|
||||||
|
|
||||||
|
# Adjust profitability reward multiplier based on recent performance
|
||||||
|
self._adjust_profitability_reward_multiplier()
|
||||||
|
|
||||||
|
# Update consecutive losses using net_pnl
|
||||||
|
if net_pnl < -0.001: # A losing trade
|
||||||
|
self.consecutive_losses += 1
|
||||||
|
elif net_pnl > 0.001: # A winning trade
|
||||||
|
self.consecutive_losses = 0
|
||||||
|
else: # Breakeven trade
|
||||||
|
self.consecutive_losses = 0
|
||||||
|
|
||||||
# Remove position
|
# Remove position
|
||||||
del self.positions[symbol]
|
del self.positions[symbol]
|
||||||
self.last_trade_time[symbol] = datetime.now()
|
self.last_trade_time[symbol] = datetime.now()
|
||||||
self.daily_trades += 1
|
self.daily_trades += 1
|
||||||
|
|
||||||
|
logger.info(f"SHORT position closed - Gross P&L: ${gross_pnl:.2f}, Net P&L: ${net_pnl:.2f}, Fees: ${simulated_fees:.3f}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -2002,8 +2045,29 @@ class TradingExecutor:
|
|||||||
return self.trade_history.copy()
|
return self.trade_history.copy()
|
||||||
|
|
||||||
def get_balance(self) -> Dict[str, float]:
|
def get_balance(self) -> Dict[str, float]:
|
||||||
"""TODO(Guideline: expose real account state) Return actual account balances instead of raising."""
|
"""Get account balances from the primary exchange.
|
||||||
raise NotImplementedError("Implement TradingExecutor.get_balance to supply real balance data; stubs are forbidden.")
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, float]: Asset balances in format {'USDT': 100.0, 'ETH': 0.5, ...}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use the existing get_account_balance method to get real exchange data
|
||||||
|
account_balances = self.get_account_balance()
|
||||||
|
|
||||||
|
# Convert to simple format: asset -> free balance
|
||||||
|
simple_balances = {}
|
||||||
|
for asset, balance_data in account_balances.items():
|
||||||
|
if isinstance(balance_data, dict):
|
||||||
|
simple_balances[asset] = balance_data.get('free', 0.0)
|
||||||
|
else:
|
||||||
|
simple_balances[asset] = float(balance_data) if balance_data else 0.0
|
||||||
|
|
||||||
|
logger.debug(f"Retrieved balances for {len(simple_balances)} assets")
|
||||||
|
return simple_balances
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting balance: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
def export_trades_to_csv(self, filename: Optional[str] = None) -> str:
|
def export_trades_to_csv(self, filename: Optional[str] = None) -> str:
|
||||||
"""Export trade history to CSV file with comprehensive analysis"""
|
"""Export trade history to CSV file with comprehensive analysis"""
|
||||||
|
|||||||
@@ -354,7 +354,8 @@ class TrainingIntegration:
|
|||||||
pivot_points = []
|
pivot_points = []
|
||||||
|
|
||||||
# This would integrate with the Williams Market Structure
|
# This would integrate with the Williams Market Structure
|
||||||
# For now, return empty list as placeholder
|
# For now, return empty list as placeholder (not yet implemented)
|
||||||
|
logger.debug(f"Pivot points integration for {symbol} not yet implemented")
|
||||||
return pivot_points
|
return pivot_points
|
||||||
|
|
||||||
return []
|
return []
|
||||||
@@ -519,7 +520,8 @@ class TrainingIntegration:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# This would integrate with existing model predictions
|
# This would integrate with existing model predictions
|
||||||
# For now, return empty dict as placeholder
|
# For now, return empty dict as placeholder (not yet implemented)
|
||||||
|
logger.debug(f"Model predictions integration for {symbol} not yet implemented")
|
||||||
return predictions
|
return predictions
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error getting model predictions for {symbol}: {e}")
|
logger.warning(f"Error getting model predictions for {symbol}: {e}")
|
||||||
|
|||||||
@@ -1946,8 +1946,6 @@ class CleanTradingDashboard:
|
|||||||
logger.error(f"Error updating trades table: {e}")
|
logger.error(f"Error updating trades table: {e}")
|
||||||
return html.P(f"Error: {str(e)}", className="text-danger")
|
return html.P(f"Error: {str(e)}", className="text-danger")
|
||||||
|
|
||||||
@self.app.callback(
|
|
||||||
|
|
||||||
@self.app.callback(
|
@self.app.callback(
|
||||||
[Output('eth-cob-content', 'children'),
|
[Output('eth-cob-content', 'children'),
|
||||||
Output('btc-cob-content', 'children')],
|
Output('btc-cob-content', 'children')],
|
||||||
@@ -6353,10 +6351,10 @@ class CleanTradingDashboard:
|
|||||||
|
|
||||||
# Additional training weight for executed signals
|
# Additional training weight for executed signals
|
||||||
if signal['executed']:
|
if signal['executed']:
|
||||||
# Log signal processing
|
# Log signal processing
|
||||||
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
||||||
logger.info(f"[{status}] {signal['action']} signal for {signal['symbol']} "
|
logger.info(f"[{status}] {signal['action']} signal for {signal['symbol']} "
|
||||||
f"(conf: {signal['confidence']:.2f}, model: {signal.get('model', 'UNKNOWN')})")
|
f"(conf: {signal['confidence']:.2f}, model: {signal.get('model', 'UNKNOWN')})")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing dashboard signal: {e}")
|
logger.error(f"Error processing dashboard signal: {e}")
|
||||||
@@ -6512,7 +6510,7 @@ class CleanTradingDashboard:
|
|||||||
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
||||||
loss = self.orchestrator.rl_agent.replay()
|
loss = self.orchestrator.rl_agent.replay()
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
|
logger.debug(f"DQN training loss: {loss:.4f}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error training DQN on prediction: {e}")
|
logger.debug(f"Error training DQN on prediction: {e}")
|
||||||
|
|
||||||
@@ -8049,8 +8047,14 @@ class CleanTradingDashboard:
|
|||||||
logger.warning("No checkpoint manager available for model storage")
|
logger.warning("No checkpoint manager available for model storage")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
# Store models and handle any exceptions
|
||||||
logger.warning(f"❌ Failed to store Decision Fusion model: {e}")
|
try:
|
||||||
|
# Store Decision Fusion model
|
||||||
|
if hasattr(self.orchestrator, 'decision_fusion_network') and self.orchestrator.decision_fusion_network:
|
||||||
|
logger.info("💾 Storing Decision Fusion model...")
|
||||||
|
# Add storage logic here
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"❌ Failed to store Decision Fusion model: {e}")
|
||||||
|
|
||||||
# 5. Verification Step - Try to load checkpoints to verify they work
|
# 5. Verification Step - Try to load checkpoints to verify they work
|
||||||
logger.info("🔍 Verifying stored checkpoints...")
|
logger.info("🔍 Verifying stored checkpoints...")
|
||||||
@@ -9182,6 +9186,9 @@ class CleanTradingDashboard:
|
|||||||
logger.debug(f"COB data retrieved from data provider for {symbol}: {len(bids)} bids, {len(asks)} asks")
|
logger.debug(f"COB data retrieved from data provider for {symbol}: {len(bids)} bids, {len(asks)} asks")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving COB data for {symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def _generate_bucketed_cob_data(self, symbol: str, cob_snapshot: dict):
|
def _generate_bucketed_cob_data(self, symbol: str, cob_snapshot: dict):
|
||||||
"""Generate bucketed COB data for model feeding"""
|
"""Generate bucketed COB data for model feeding"""
|
||||||
try:
|
try:
|
||||||
@@ -9775,6 +9782,53 @@ class CleanTradingDashboard:
|
|||||||
'total_pnl': 0.0
|
'total_pnl': 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def run_chained_inference(self, symbol: str, n_steps: int = 10):
|
||||||
|
"""Run chained inference to trigger initial predictions"""
|
||||||
|
try:
|
||||||
|
logger.info(f"🔗 Running chained inference for {symbol} with {n_steps} steps")
|
||||||
|
|
||||||
|
if self.orchestrator is None:
|
||||||
|
logger.warning("❌ No orchestrator available for chained inference")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Trigger initial predictions by calling make_trading_decision
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def _run_inference():
|
||||||
|
try:
|
||||||
|
# Run multiple inference steps
|
||||||
|
for step in range(n_steps):
|
||||||
|
logger.debug(f"🔗 Running inference step {step + 1}/{n_steps}")
|
||||||
|
decision = await self.orchestrator.make_trading_decision(symbol)
|
||||||
|
if decision:
|
||||||
|
logger.debug(f"🔗 Step {step + 1}: Decision made for {symbol}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"🔗 Step {step + 1}: No decision available for {symbol}")
|
||||||
|
|
||||||
|
# Small delay between steps
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
logger.info(f"🔗 Chained inference completed for {symbol}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error in chained inference: {e}")
|
||||||
|
|
||||||
|
# Run the async inference
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
# If loop is already running, create a task
|
||||||
|
asyncio.create_task(_run_inference())
|
||||||
|
else:
|
||||||
|
# If no loop is running, run it directly
|
||||||
|
asyncio.run(_run_inference())
|
||||||
|
except RuntimeError:
|
||||||
|
# Fallback: try to run in a new event loop
|
||||||
|
asyncio.run(_run_inference())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Chained inference failed: {e}")
|
||||||
|
|
||||||
def run_server(self, host='127.0.0.1', port=8050, debug=False):
|
def run_server(self, host='127.0.0.1', port=8050, debug=False):
|
||||||
"""Start the Dash server"""
|
"""Start the Dash server"""
|
||||||
try:
|
try:
|
||||||
@@ -9862,6 +9916,8 @@ class CleanTradingDashboard:
|
|||||||
"""Connect to orchestrator for real trading signals"""
|
"""Connect to orchestrator for real trading signals"""
|
||||||
try:
|
try:
|
||||||
if self.orchestrator and hasattr(self.orchestrator, 'add_decision_callback'):
|
if self.orchestrator and hasattr(self.orchestrator, 'add_decision_callback'):
|
||||||
|
self.orchestrator.add_decision_callback(self._on_trading_decision)
|
||||||
|
logger.info("✅ Orchestrator decision callback registered")
|
||||||
else:
|
else:
|
||||||
logger.warning("Orchestrator not available or doesn't support callbacks")
|
logger.warning("Orchestrator not available or doesn't support callbacks")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -11179,3 +11235,4 @@ def create_clean_dashboard(data_provider: Optional[DataProvider] = None, orchest
|
|||||||
data_provider=data_provider,
|
data_provider=data_provider,
|
||||||
orchestrator=orchestrator,
|
orchestrator=orchestrator,
|
||||||
trading_executor=trading_executor
|
trading_executor=trading_executor
|
||||||
|
)
|
||||||
|
|||||||
@@ -344,6 +344,12 @@ class DashboardComponentManager:
|
|||||||
asks = cob_snapshot.get('asks', []) or []
|
asks = cob_snapshot.get('asks', []) or []
|
||||||
elif hasattr(cob_snapshot, 'stats'):
|
elif hasattr(cob_snapshot, 'stats'):
|
||||||
# Old format with stats attribute
|
# Old format with stats attribute
|
||||||
|
stats = getattr(cob_snapshot, 'stats', None)
|
||||||
|
mid_price = float(getattr(stats, 'mid_price', 0) or 0)
|
||||||
|
spread_bps = float(getattr(stats, 'spread_bps', 0) or 0)
|
||||||
|
imbalance = float(getattr(stats, 'imbalance', 0) or 0)
|
||||||
|
bids = getattr(cob_snapshot, 'bids', []) or []
|
||||||
|
asks = getattr(cob_snapshot, 'asks', []) or []
|
||||||
else:
|
else:
|
||||||
# New object-like snapshot with direct attributes
|
# New object-like snapshot with direct attributes
|
||||||
mid_price = float(getattr(cob_snapshot, 'volume_weighted_mid', 0) or 0)
|
mid_price = float(getattr(cob_snapshot, 'volume_weighted_mid', 0) or 0)
|
||||||
|
|||||||
@@ -16,6 +16,12 @@ class DashboardLayoutManager:
|
|||||||
self.dashboard = dashboard
|
self.dashboard = dashboard
|
||||||
|
|
||||||
def create_main_layout(self):
|
def create_main_layout(self):
|
||||||
|
"""Create the main dashboard layout"""
|
||||||
|
return html.Div([
|
||||||
|
self._create_header(),
|
||||||
|
self._create_main_content(),
|
||||||
|
self._create_interval_component()
|
||||||
|
], className="container-fluid bg-dark text-light min-vh-100")
|
||||||
|
|
||||||
def _create_prediction_tracking_section(self):
|
def _create_prediction_tracking_section(self):
|
||||||
"""Create prediction tracking and model performance section"""
|
"""Create prediction tracking and model performance section"""
|
||||||
@@ -250,7 +256,12 @@ class DashboardLayoutManager:
|
|||||||
], className="bg-dark p-2 mb-2")
|
], className="bg-dark p-2 mb-2")
|
||||||
|
|
||||||
def _create_interval_component(self):
|
def _create_interval_component(self):
|
||||||
])
|
"""Create the interval component for auto-refresh"""
|
||||||
|
return dcc.Interval(
|
||||||
|
id='interval-component',
|
||||||
|
interval=2000, # Update every 2 seconds
|
||||||
|
n_intervals=0
|
||||||
|
)
|
||||||
|
|
||||||
def _create_main_content(self):
|
def _create_main_content(self):
|
||||||
"""Create the main content area"""
|
"""Create the main content area"""
|
||||||
|
|||||||
Reference in New Issue
Block a user