diff --git a/CLEANUP_TODO.md b/CLEANUP_TODO.md
new file mode 100644
index 0000000..00ef5e6
--- /dev/null
+++ b/CLEANUP_TODO.md
@@ -0,0 +1,56 @@
+Cleanup run summary:
+- Deleted files: 183
+ - NN\__init__.py
+ - NN\models\__init__.py
+ - NN\models\cnn_model.py
+ - NN\models\transformer_model.py
+ - NN\start_tensorboard.py
+ - NN\training\enhanced_rl_training_integration.py
+ - NN\training\example_checkpoint_usage.py
+ - NN\training\integrate_checkpoint_management.py
+ - NN\utils\__init__.py
+ - NN\utils\data_interface.py
+ - NN\utils\multi_data_interface.py
+ - NN\utils\realtime_analyzer.py
+ - NN\utils\signal_interpreter.py
+ - NN\utils\trading_env.py
+ - _dev\cleanup_models_now.py
+ - _tools\build_keep_set.py
+ - apply_trading_fixes.py
+ - apply_trading_fixes_to_main.py
+ - audit_training_system.py
+ - balance_trading_signals.py
+ - check_live_trading.py
+ - check_mexc_symbols.py
+ - cleanup_checkpoint_db.py
+ - cleanup_checkpoints.py
+ - core\__init__.py
+ - core\api_rate_limiter.py
+ - core\async_handler.py
+ - core\bookmap_data_provider.py
+ - core\bookmap_integration.py
+ - core\cnn_monitor.py
+ - core\cnn_training_pipeline.py
+ - core\config_sync.py
+ - core\enhanced_cnn_adapter.py
+ - core\enhanced_cob_websocket.py
+ - core\enhanced_orchestrator.py
+ - core\enhanced_training_integration.py
+ - core\exchanges\__init__.py
+ - core\exchanges\binance_interface.py
+ - core\exchanges\bybit\debug\test_bybit_balance.py
+ - core\exchanges\bybit_interface.py
+ - core\exchanges\bybit_rest_client.py
+ - core\exchanges\deribit_interface.py
+ - core\exchanges\mexc\debug\final_mexc_order_test.py
+ - core\exchanges\mexc\debug\fix_mexc_orders.py
+ - core\exchanges\mexc\debug\fix_mexc_orders_v2.py
+ - core\exchanges\mexc\debug\fix_mexc_orders_v3.py
+ - core\exchanges\mexc\debug\test_mexc_interface_debug.py
+ - core\exchanges\mexc\debug\test_mexc_order_signature.py
+ - core\exchanges\mexc\debug\test_mexc_order_signature_v2.py
+ - core\exchanges\mexc\debug\test_mexc_signature_debug.py
+ ... and 133 more
+- Removed test directories: 1
+ - tests
+- Kept (excluded): 1
\ No newline at end of file
diff --git a/DELETE_CANDIDATES.txt b/DELETE_CANDIDATES.txt
new file mode 100644
index 0000000..3e2673b
--- /dev/null
+++ b/DELETE_CANDIDATES.txt
@@ -0,0 +1,184 @@
+NN\__init__.py
+NN\models\__init__.py
+NN\models\cnn_model.py
+NN\models\transformer_model.py
+NN\start_tensorboard.py
+NN\training\enhanced_realtime_training.py
+NN\training\enhanced_rl_training_integration.py
+NN\training\example_checkpoint_usage.py
+NN\training\integrate_checkpoint_management.py
+NN\utils\__init__.py
+NN\utils\data_interface.py
+NN\utils\multi_data_interface.py
+NN\utils\realtime_analyzer.py
+NN\utils\signal_interpreter.py
+NN\utils\trading_env.py
+_dev\cleanup_models_now.py
+_tools\build_keep_set.py
+apply_trading_fixes.py
+apply_trading_fixes_to_main.py
+audit_training_system.py
+balance_trading_signals.py
+check_live_trading.py
+check_mexc_symbols.py
+cleanup_checkpoint_db.py
+cleanup_checkpoints.py
+core\__init__.py
+core\api_rate_limiter.py
+core\async_handler.py
+core\bookmap_data_provider.py
+core\bookmap_integration.py
+core\cnn_monitor.py
+core\cnn_training_pipeline.py
+core\config_sync.py
+core\enhanced_cnn_adapter.py
+core\enhanced_cob_websocket.py
+core\enhanced_orchestrator.py
+core\enhanced_training_integration.py
+core\exchanges\__init__.py
+core\exchanges\binance_interface.py
+core\exchanges\bybit\debug\test_bybit_balance.py
+core\exchanges\bybit_interface.py
+core\exchanges\bybit_rest_client.py
+core\exchanges\deribit_interface.py
+core\exchanges\mexc\debug\final_mexc_order_test.py
+core\exchanges\mexc\debug\fix_mexc_orders.py
+core\exchanges\mexc\debug\fix_mexc_orders_v2.py
+core\exchanges\mexc\debug\fix_mexc_orders_v3.py
+core\exchanges\mexc\debug\test_mexc_interface_debug.py
+core\exchanges\mexc\debug\test_mexc_order_signature.py
+core\exchanges\mexc\debug\test_mexc_order_signature_v2.py
+core\exchanges\mexc\debug\test_mexc_signature_debug.py
+core\exchanges\mexc\debug\test_small_mexc_order.py
+core\exchanges\mexc\test_live_trading.py
+core\exchanges\mexc_interface.py
+core\exchanges\trading_agent_test.py
+core\mexc_webclient\__init__.py
+core\mexc_webclient\auto_browser.py
+core\mexc_webclient\browser_automation.py
+core\mexc_webclient\mexc_futures_client.py
+core\mexc_webclient\session_manager.py
+core\mexc_webclient\test_mexc_futures_webclient.py
+core\model_output_manager.py
+core\negative_case_trainer.py
+core\nn_decision_fusion.py
+core\prediction_tracker.py
+core\realtime_tick_processor.py
+core\retrospective_trainer.py
+core\rl_training_pipeline.py
+core\robust_cob_provider.py
+core\shared_cob_service.py
+core\shared_data_manager.py
+core\tick_aggregator.py
+core\trading_action.py
+core\trading_executor_fix.py
+core\training_data_collector.py
+core\williams_market_structure.py
+dataprovider_realtime.py
+debug\test_fixed_issues.py
+debug\test_trading_fixes.py
+debug\trade_audit.py
+debug_training_methods.py
+docs\exchanges\bybit\examples.py
+example_usage_simplified_data_provider.py
+kill_stale_processes.py
+launch_training.py
+main.py
+main_clean.py
+migrate_existing_models.py
+model_manager.py
+position_sync_enhancement.py
+read_logs.py
+reset_db_manager.py
+reset_models_and_fix_mapping.py
+run_clean_dashboard.py
+run_continuous_training.py
+run_crash_safe_dashboard.py
+run_enhanced_rl_training.py
+run_enhanced_training_dashboard.py
+run_integrated_rl_cob_dashboard.py
+run_mexc_browser.py
+run_optimized_cob_system.py
+run_realtime_rl_cob_trader.py
+run_simple_dashboard.py
+run_stable_dashboard.py
+run_templated_dashboard.py
+run_tensorboard.py
+run_tests.py
+scripts\kill_stale_processes.py
+scripts\restart_dashboard_with_learning.py
+scripts\restart_main_overnight.py
+setup_mexc_browser.py
+start_monitoring.py
+start_overnight_training.py
+system_stability_audit.py
+test_build_base_data_performance.py
+test_bybit_eth_futures.py
+test_bybit_eth_futures_fixed.py
+test_bybit_eth_live.py
+test_bybit_public_api.py
+test_cache_fix.py
+test_cnn_integration.py
+test_cob_dashboard.py
+test_cob_data_quality.py
+test_cob_websocket_only.py
+test_continuous_cnn_training.py
+test_dashboard_data_flow.py
+test_dashboard_performance.py
+test_data_integration.py
+test_data_provider_integration.py
+test_db_migration.py
+test_deribit_integration.py
+test_device_fix.py
+test_device_training_fix.py
+test_enhanced_cnn_adapter.py
+test_enhanced_cob_websocket.py
+test_enhanced_data_provider_websocket.py
+test_enhanced_inference_logging.py
+test_enhanced_training_integration.py
+test_enhanced_training_simple.py
+test_fifo_queues.py
+test_hold_position_fix.py
+test_imbalance_calculation.py
+test_improved_data_integration.py
+test_integrated_standardized_provider.py
+test_leverage_fix.py
+test_massive_dqn.py
+test_mexc_order_fix.py
+test_model_output_manager.py
+test_model_registry.py
+test_model_statistics.py
+test_model_stats.py
+test_model_training.py
+test_orchestrator_fix.py
+test_order_sync_and_fees.py
+test_position_based_rewards.py
+test_profitability_reward_system.py
+test_training_data_collection.py
+test_training_fixes.py
+test_websocket_cob_data.py
+tests\cob\test_cob_comparison.py
+tests\cob\test_cob_data_stability.py
+tests\test_training.py
+tests\test_training_integration.py
+tests\test_training_status.py
+tests\test_universal_data_format.py
+tests\test_universal_stream_integration.py
+trading_main.py
+utils\__init__.py
+utils\async_task_manager.py
+utils\launch_tensorboard.py
+utils\model_utils.py
+utils\port_manager.py
+utils\process_supervisor.py
+utils\reward_calculator.py
+utils\system_monitor.py
+utils\tensorboard_logger.py
+utils\text_logger.py
+verify_checkpoint_system.py
+web\__init__.py
+web\dashboard_fix.py
+web\dashboard_model.py
+web\layout_manager_with_tensorboard.py
+web\tensorboard_component.py
+web\tensorboard_integration.py
\ No newline at end of file
diff --git a/DEPENDENCY_TREE.md b/DEPENDENCY_TREE.md
new file mode 100644
index 0000000..6872583
--- /dev/null
+++ b/DEPENDENCY_TREE.md
@@ -0,0 +1,84 @@
+Dependency tree from dashboards (module -> deps):
+- NN\models\advanced_transformer_trading.py
+- NN\models\cob_rl_model.py
+ - models\__init__.py
+- NN\models\dqn_agent.py
+ - utils\checkpoint_manager.py
+ - utils\training_integration.py
+- NN\models\enhanced_cnn.py
+- NN\models\model_interfaces.py
+- NN\models\standardized_cnn.py
+ - core\data_models.py
+- core\cob_integration.py
+- core\config.py
+ - safe_logging.py
+- core\data_models.py
+- core\data_provider.py
+ - utils\cache_manager.py
+ - utils\timezone_utils.py
+- core\exchanges\exchange_factory.py
+- core\exchanges\exchange_interface.py
+- core\extrema_trainer.py
+ - utils\checkpoint_manager.py
+ - utils\training_integration.py
+- core\multi_exchange_cob_provider.py
+- core\orchestrator.py
+ - NN\models\advanced_transformer_trading.py
+ - NN\models\cob_rl_model.py
+ - NN\models\dqn_agent.py
+ - NN\models\enhanced_cnn.py
+ - NN\models\model_interfaces.py
+ - NN\models\standardized_cnn.py
+ - core\data_models.py
+ - core\extrema_trainer.py
+ - enhanced_realtime_training.py
+ - models\__init__.py
+ - utils\checkpoint_manager.py
+ - utils\database_manager.py
+ - utils\inference_logger.py
+- core\overnight_training_coordinator.py
+- core\realtime_rl_cob_trader.py
+ - NN\models\cob_rl_model.py
+ - core\trading_executor.py
+ - utils\checkpoint_manager.py
+- core\standardized_data_provider.py
+- core\trade_data_manager.py
+- core\trading_executor.py
+ - core\data_provider.py
+ - core\exchanges\exchange_factory.py
+ - core\exchanges\exchange_interface.py
+- core\training_integration.py
+- core\universal_data_adapter.py
+- enhanced_realtime_training.py
+- models\__init__.py
+- safe_logging.py
+- utils\cache_manager.py
+- utils\checkpoint_manager.py
+- utils\database_manager.py
+- utils\inference_logger.py
+- utils\timezone_utils.py
+- utils\training_integration.py
+- web\clean_dashboard.py
+ - NN\models\advanced_transformer_trading.py
+ - NN\models\standardized_cnn.py
+ - core\cob_integration.py
+ - core\config.py
+ - core\data_models.py
+ - core\data_provider.py
+ - core\multi_exchange_cob_provider.py
+ - core\orchestrator.py
+ - core\overnight_training_coordinator.py
+ - core\realtime_rl_cob_trader.py
+ - core\standardized_data_provider.py
+ - core\trade_data_manager.py
+ - core\trading_executor.py
+ - core\training_integration.py
+ - core\universal_data_adapter.py
+ - utils\checkpoint_manager.py
+ - utils\timezone_utils.py
+ - web\component_manager.py
+ - web\layout_manager.py
+- web\cob_realtime_dashboard.py
+ - core\cob_integration.py
+- web\component_manager.py
+- web\layout_manager.py
\ No newline at end of file
diff --git a/KEEP_SET.txt b/KEEP_SET.txt
new file mode 100644
index 0000000..2486020
--- /dev/null
+++ b/KEEP_SET.txt
@@ -0,0 +1,35 @@
+NN\models\advanced_transformer_trading.py
+NN\models\cob_rl_model.py
+NN\models\dqn_agent.py
+NN\models\enhanced_cnn.py
+NN\models\model_interfaces.py
+NN\models\standardized_cnn.py
+core\cob_integration.py
+core\config.py
+core\data_models.py
+core\data_provider.py
+core\exchanges\exchange_factory.py
+core\exchanges\exchange_interface.py
+core\extrema_trainer.py
+core\multi_exchange_cob_provider.py
+core\orchestrator.py
+core\overnight_training_coordinator.py
+core\realtime_rl_cob_trader.py
+core\standardized_data_provider.py
+core\trade_data_manager.py
+core\trading_executor.py
+core\training_integration.py
+core\universal_data_adapter.py
+enhanced_realtime_training.py
+models\__init__.py
+safe_logging.py
+utils\cache_manager.py
+utils\checkpoint_manager.py
+utils\database_manager.py
+utils\inference_logger.py
+utils\timezone_utils.py
+utils\training_integration.py
+web\clean_dashboard.py
+web\cob_realtime_dashboard.py
+web\component_manager.py
+web\layout_manager.py
\ No newline at end of file
diff --git a/NN/__init__.py b/NN/__init__.py
deleted file mode 100644
index 2622416..0000000
--- a/NN/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-"""
-Neural Network Trading System
-============================
-
-A comprehensive neural network trading system that uses deep learning models
-to analyze cryptocurrency price data and generate trading signals.
-
-The system consists of:
-1. Data Interface: Connects to realtime trading data
-2. CNN Model: Deep convolutional neural network for feature extraction
-3. Transformer Model: Processes high-level features for improved pattern recognition
-4. MoE: Mixture of Experts model that combines multiple neural networks
-"""
-
-__version__ = '0.1.0'
-__author__ = 'Gogo2 Project'
\ No newline at end of file
diff --git a/NN/models/__init__.py b/NN/models/__init__.py
deleted file mode 100644
index df25097..0000000
--- a/NN/models/__init__.py
+++ /dev/null
@@ -1,27 +0,0 @@
-"""
-Neural Network Models
-====================
-
-This package contains the neural network models used in the trading system:
-- CNN Model: Deep convolutional neural network for feature extraction
-- DQN Agent: Deep Q-Network for reinforcement learning
-- COB RL Model: Specialized RL model for order book data
-- Advanced Transformer: High-performance transformer for trading
-
-PyTorch implementation only.
-"""
-
-# Import core models
-from NN.models.dqn_agent import DQNAgent
-from NN.models.cob_rl_model import COBRLModelInterface
-from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig
-from NN.models.standardized_cnn import StandardizedCNN # Use the unified CNN model
-
-# Import model interfaces
-from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
-
-# Export the unified StandardizedCNN as CNNModel for compatibility
-CNNModel = StandardizedCNN
-
-__all__ = ['CNNModel', 'StandardizedCNN', 'DQNAgent', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
-'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']
diff --git a/NN/models/cnn_model.py b/NN/models/cnn_model.py
deleted file mode 100644
index 4b34bb0..0000000
--- a/NN/models/cnn_model.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# """
-# Legacy CNN Model Compatibility Layer
-
-# This module provides compatibility redirects to the unified StandardizedCNN model.
-# All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired
-# in favor of the StandardizedCNN architecture.
-# """
-
-# import logging
-# import warnings
-# from typing import Tuple, Dict, Any, Optional
-# import torch
-# import numpy as np
-
-# # Import the standardized CNN model
-# from .standardized_cnn import StandardizedCNN
-
-# logger = logging.getLogger(__name__)
-
-# # Compatibility aliases and wrappers
-# class EnhancedCNNModel:
-# """Legacy compatibility wrapper - redirects to StandardizedCNN"""
-
-# def __init__(self, *args, **kwargs):
-# warnings.warn(
-# "EnhancedCNNModel is deprecated. Use StandardizedCNN instead.",
-# DeprecationWarning,
-# stacklevel=2
-# )
-# # Create StandardizedCNN with default parameters
-# self.standardized_cnn = StandardizedCNN()
-# logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN")
-
-# def __getattr__(self, name):
-# """Delegate all method calls to StandardizedCNN"""
-# return getattr(self.standardized_cnn, name)
-
-
-# class CNNModelTrainer:
-# """Legacy compatibility wrapper for CNN training"""
-
-# def __init__(self, model=None, *args, **kwargs):
-# warnings.warn(
-# "CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.",
-# DeprecationWarning,
-# stacklevel=2
-# )
-# if isinstance(model, EnhancedCNNModel):
-# self.model = model.standardized_cnn
-# else:
-# self.model = StandardizedCNN()
-# logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()")
-
-# def train_step(self, x, y, *args, **kwargs):
-# """Legacy train step wrapper"""
-# try:
-# # Convert to BaseDataInput format if needed
-# if hasattr(x, 'get_feature_vector'):
-# # Already BaseDataInput
-# base_input = x
-# else:
-# # Create mock BaseDataInput for legacy compatibility
-# from core.data_models import BaseDataInput
-# base_input = BaseDataInput()
-# # Set mock feature vector
-# if isinstance(x, torch.Tensor):
-# feature_vector = x.flatten().cpu().numpy()
-# else:
-# feature_vector = np.array(x).flatten()
-
-# # Pad or truncate to expected size
-# expected_size = self.model.expected_feature_dim
-# if len(feature_vector) < expected_size:
-# padding = np.zeros(expected_size - len(feature_vector))
-# feature_vector = np.concatenate([feature_vector, padding])
-# else:
-# feature_vector = feature_vector[:expected_size]
-
-# base_input._feature_vector = feature_vector
-
-# # Convert target to string format
-# if isinstance(y, torch.Tensor):
-# y_val = y.item() if y.numel() == 1 else y.argmax().item()
-# else:
-# y_val = int(y) if np.isscalar(y) else int(np.argmax(y))
-
-# target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'}
-# target = target_map.get(y_val, 'HOLD')
-
-# # Use StandardizedCNN training
-# optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
-# loss = self.model.train_step([base_input], [target], optimizer)
-
-# return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5}
-
-# except Exception as e:
-# logger.error(f"Legacy train_step error: {e}")
-# return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5}
-
-
-# # class CNNModel:
-# # """Legacy compatibility wrapper for CNN model interface"""
-
-# # def __init__(self, input_shape=(900, 50), output_size=3, model_path=None):
-# # warnings.warn(
-# # "CNNModel is deprecated. Use StandardizedCNN directly.",
-# # DeprecationWarning,
-# # stacklevel=2
-# # )
-# # self.input_shape = input_shape
-# # self.output_size = output_size
-# # self.standardized_cnn = StandardizedCNN()
-# # self.trainer = CNNModelTrainer(self.standardized_cnn)
-# # logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN")
-
-# # def build_model(self, **kwargs):
-# # """Legacy build method - no-op for StandardizedCNN"""
-# # return self
-
-# # def predict(self, X):
-# # """Legacy predict method"""
-# # try:
-# # # Convert input to BaseDataInput
-# # from core.data_models import BaseDataInput
-# # base_input = BaseDataInput()
-
-# # if isinstance(X, np.ndarray):
-# # feature_vector = X.flatten()
-# # else:
-# # feature_vector = np.array(X).flatten()
-
-# # # Pad or truncate to expected size
-# # expected_size = self.standardized_cnn.expected_feature_dim
-# # if len(feature_vector) < expected_size:
-# # padding = np.zeros(expected_size - len(feature_vector))
-# # feature_vector = np.concatenate([feature_vector, padding])
-# # else:
-# # feature_vector = feature_vector[:expected_size]
-
-# # base_input._feature_vector = feature_vector
-
-# # # Get prediction from StandardizedCNN
-# # result = self.standardized_cnn.predict_from_base_input(base_input)
-
-# # # Convert to legacy format
-# # action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
-# # pred_class = np.array([action_map.get(result.predictions['action'], 2)])
-# # pred_proba = np.array([result.predictions['action_probabilities']])
-
-# # return pred_class, pred_proba
-
-# # except Exception as e:
-# # logger.error(f"Legacy predict error: {e}")
-# # # Return safe defaults
-# # pred_class = np.array([2]) # HOLD
-# # pred_proba = np.array([[0.33, 0.33, 0.34]])
-# # return pred_class, pred_proba
-
-# # def fit(self, X, y, **kwargs):
-# # """Legacy fit method"""
-# # try:
-# # return self.trainer.train_step(X, y)
-# # except Exception as e:
-# # logger.error(f"Legacy fit error: {e}")
-# # return self
-
-# # def save(self, filepath: str):
-# # """Legacy save method"""
-# # try:
-# # torch.save(self.standardized_cnn.state_dict(), filepath)
-# # logger.info(f"StandardizedCNN saved to {filepath}")
-# # except Exception as e:
-# # logger.error(f"Error saving model: {e}")
-
-
-# def create_enhanced_cnn_model(input_size: int = 60,
-# feature_dim: int = 50,
-# output_size: int = 3,
-# base_channels: int = 256,
-# device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]:
-# """Legacy compatibility function - returns StandardizedCNN"""
-# warnings.warn(
-# "create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.",
-# DeprecationWarning,
-# stacklevel=2
-# )
-
-# model = StandardizedCNN()
-# trainer = CNNModelTrainer(model)
-
-# logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly")
-# return model, trainer
-
-
-# # Export compatibility symbols
-# __all__ = [
-# 'EnhancedCNNModel',
-# 'CNNModelTrainer',
-# # 'CNNModel',
-# 'create_enhanced_cnn_model'
-# ]
diff --git a/NN/models/transformer_model.py b/NN/models/transformer_model.py
deleted file mode 100644
index 16700b3..0000000
--- a/NN/models/transformer_model.py
+++ /dev/null
@@ -1,821 +0,0 @@
-"""
-Transformer Neural Network for timeseries analysis
-
-This module implements a Transformer model with attention mechanisms for cryptocurrency price analysis.
-It also includes a Mixture of Experts model that combines predictions from multiple models.
-"""
-
-import os
-import logging
-import numpy as np
-import matplotlib.pyplot as plt
-import tensorflow as tf
-from tensorflow.keras.models import Model, load_model
-from tensorflow.keras.layers import (
- Input, Dense, Dropout, BatchNormalization,
- Concatenate, Layer, LayerNormalization, MultiHeadAttention,
- Add, GlobalAveragePooling1D, Conv1D, Reshape
-)
-from tensorflow.keras.optimizers import Adam
-from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
-import datetime
-import json
-
-logger = logging.getLogger(__name__)
-
-class TransformerBlock(Layer):
- """
- Transformer block implementation with multi-head attention and feed-forward networks.
- """
- def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
- super(TransformerBlock, self).__init__()
- self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
- self.ffn = tf.keras.Sequential([
- Dense(ff_dim, activation="relu"),
- Dense(embed_dim),
- ])
- self.layernorm1 = LayerNormalization(epsilon=1e-6)
- self.layernorm2 = LayerNormalization(epsilon=1e-6)
- self.dropout1 = Dropout(rate)
- self.dropout2 = Dropout(rate)
-
- def call(self, inputs, training=False):
- attn_output = self.att(inputs, inputs)
- attn_output = self.dropout1(attn_output, training=training)
- out1 = self.layernorm1(inputs + attn_output)
- ffn_output = self.ffn(out1)
- ffn_output = self.dropout2(ffn_output, training=training)
- return self.layernorm2(out1 + ffn_output)
-
- def get_config(self):
- config = super().get_config()
- config.update({
- 'att': self.att,
- 'ffn': self.ffn,
- 'layernorm1': self.layernorm1,
- 'layernorm2': self.layernorm2,
- 'dropout1': self.dropout1,
- 'dropout2': self.dropout2
- })
- return config
-
-class PositionalEncoding(Layer):
- """
- Positional encoding layer to add position information to input embeddings.
- """
- def __init__(self, position, d_model):
- super(PositionalEncoding, self).__init__()
- self.position = position
- self.d_model = d_model
- self.pos_encoding = self.positional_encoding(position, d_model)
-
- def get_angles(self, position, i, d_model):
- angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
- return position * angles
-
- def positional_encoding(self, position, d_model):
- angle_rads = self.get_angles(
- position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
- i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
- d_model=d_model
- )
-
- # Apply sin to even indices in the array
- sines = tf.math.sin(angle_rads[:, 0::2])
-
- # Apply cos to odd indices in the array
- cosines = tf.math.cos(angle_rads[:, 1::2])
-
- pos_encoding = tf.concat([sines, cosines], axis=-1)
- pos_encoding = pos_encoding[tf.newaxis, ...]
-
- return tf.cast(pos_encoding, tf.float32)
-
- def call(self, inputs):
- return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
-
- def get_config(self):
- config = super().get_config()
- config.update({
- 'position': self.position,
- 'd_model': self.d_model,
- 'pos_encoding': self.pos_encoding
- })
- return config
-
-class TransformerModel:
- """
- Transformer Neural Network for time series analysis.
-
- This model uses self-attention mechanisms to capture relationships between
- different time points in the input data.
- """
-
- def __init__(self, ts_input_shape=(20, 5), feature_input_shape=64, output_size=1, model_dir="NN/models/saved"):
- """
- Initialize the Transformer model.
-
- Args:
- ts_input_shape (tuple): Shape of time series input data (sequence_length, features)
- feature_input_shape (int): Shape of additional feature input (e.g., from CNN)
- output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
- model_dir (str): Directory to save trained models
- """
- self.ts_input_shape = ts_input_shape
- self.feature_input_shape = feature_input_shape
- self.output_size = output_size
- self.model_dir = model_dir
- self.model = None
- self.history = None
-
- # Create model directory if it doesn't exist
- os.makedirs(self.model_dir, exist_ok=True)
-
- logger.info(f"Initialized Transformer model with TS input shape {ts_input_shape}, "
- f"feature input shape {feature_input_shape}, and output size {output_size}")
-
- def build_model(self, embed_dim=32, num_heads=4, ff_dim=64, num_transformer_blocks=2, dropout_rate=0.1, learning_rate=0.001):
- """
- Build the Transformer model architecture.
-
- Args:
- embed_dim (int): Embedding dimension for transformer
- num_heads (int): Number of attention heads
- ff_dim (int): Hidden dimension of the feed forward network
- num_transformer_blocks (int): Number of transformer blocks
- dropout_rate (float): Dropout rate for regularization
- learning_rate (float): Learning rate for Adam optimizer
-
- Returns:
- The compiled model
- """
- # Time series input
- ts_inputs = Input(shape=self.ts_input_shape, name="ts_input")
-
- # Additional feature input (e.g., from CNN)
- feature_inputs = Input(shape=(self.feature_input_shape,), name="feature_input")
-
- # Process time series with transformer
- # First, project the input to the embedding dimension
- x = Conv1D(embed_dim, 1, activation="relu")(ts_inputs)
-
- # Add positional encoding
- x = PositionalEncoding(self.ts_input_shape[0], embed_dim)(x)
-
- # Add transformer blocks
- for _ in range(num_transformer_blocks):
- x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate)(x)
-
- # Global pooling to get a single vector representation
- x = GlobalAveragePooling1D()(x)
- x = Dropout(dropout_rate)(x)
-
- # Combine with additional features
- combined = Concatenate()([x, feature_inputs])
-
- # Dense layers for final classification/regression
- x = Dense(64, activation="relu")(combined)
- x = BatchNormalization()(x)
- x = Dropout(dropout_rate)(x)
-
- # Output layer
- if self.output_size == 1:
- # Binary classification (up/down)
- outputs = Dense(1, activation='sigmoid', name='output')(x)
- loss = 'binary_crossentropy'
- metrics = ['accuracy']
- elif self.output_size == 3:
- # Multi-class classification (buy/hold/sell)
- outputs = Dense(3, activation='softmax', name='output')(x)
- loss = 'categorical_crossentropy'
- metrics = ['accuracy']
- else:
- # Regression
- outputs = Dense(self.output_size, activation='linear', name='output')(x)
- loss = 'mse'
- metrics = ['mae']
-
- # Create and compile model
- self.model = Model(inputs=[ts_inputs, feature_inputs], outputs=outputs)
-
- # Compile with Adam optimizer
- self.model.compile(
- optimizer=Adam(learning_rate=learning_rate),
- loss=loss,
- metrics=metrics
- )
-
- # Log model summary
- self.model.summary(print_fn=lambda x: logger.info(x))
-
- return self.model
-
- def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
- callbacks=None, class_weights=None):
- """
- Train the Transformer model on the provided data.
-
- Args:
- X_ts (numpy.ndarray): Time series input features
- X_features (numpy.ndarray): Additional input features
- y (numpy.ndarray): Target labels
- batch_size (int): Batch size
- epochs (int): Number of epochs
- validation_split (float): Fraction of data to use for validation
- callbacks (list): List of Keras callbacks
- class_weights (dict): Class weights for imbalanced datasets
-
- Returns:
- History object containing training metrics
- """
- if self.model is None:
- self.build_model()
-
- # Default callbacks if none provided
- if callbacks is None:
- # Create a timestamp for model checkpoints
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
-
- callbacks = [
- EarlyStopping(
- monitor='val_loss',
- patience=10,
- restore_best_weights=True
- ),
- ReduceLROnPlateau(
- monitor='val_loss',
- factor=0.5,
- patience=5,
- min_lr=1e-6
- ),
- ModelCheckpoint(
- filepath=os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5"),
- monitor='val_loss',
- save_best_only=True
- )
- ]
-
- # Check if y needs to be one-hot encoded for multi-class
- if self.output_size == 3 and len(y.shape) == 1:
- y = tf.keras.utils.to_categorical(y, num_classes=3)
-
- # Train the model
- logger.info(f"Training Transformer model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
- self.history = self.model.fit(
- [X_ts, X_features], y,
- batch_size=batch_size,
- epochs=epochs,
- validation_split=validation_split,
- callbacks=callbacks,
- class_weight=class_weights,
- verbose=2
- )
-
- # Save the trained model
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
- model_path = os.path.join(self.model_dir, f"transformer_model_final_{timestamp}.h5")
- self.model.save(model_path)
- logger.info(f"Model saved to {model_path}")
-
- # Save training history
- history_path = os.path.join(self.model_dir, f"transformer_model_history_{timestamp}.json")
- with open(history_path, 'w') as f:
- # Convert numpy values to Python native types for JSON serialization
- history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
- json.dump(history_dict, f, indent=2)
-
- return self.history
-
- def evaluate(self, X_ts, X_features, y):
- """
- Evaluate the model on test data.
-
- Args:
- X_ts (numpy.ndarray): Time series input features
- X_features (numpy.ndarray): Additional input features
- y (numpy.ndarray): Target labels
-
- Returns:
- dict: Evaluation metrics
- """
- if self.model is None:
- raise ValueError("Model has not been built or trained yet")
-
- # Convert y to one-hot encoding for multi-class
- if self.output_size == 3 and len(y.shape) == 1:
- y = tf.keras.utils.to_categorical(y, num_classes=3)
-
- # Evaluate model
- logger.info(f"Evaluating Transformer model on {len(X_ts)} samples")
- eval_results = self.model.evaluate([X_ts, X_features], y, verbose=0)
-
- metrics = {}
- for metric, value in zip(self.model.metrics_names, eval_results):
- metrics[metric] = value
- logger.info(f"{metric}: {value:.4f}")
-
- return metrics
-
- def predict(self, X_ts, X_features=None):
- """
- Make predictions on new data.
-
- Args:
- X_ts (numpy.ndarray): Time series input features
- X_features (numpy.ndarray): Additional input features
-
- Returns:
- tuple: (y_pred, y_proba) where:
- y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
- y_proba is the class probability
- """
- if self.model is None:
- raise ValueError("Model has not been built or trained yet")
-
- # Ensure X_ts has the right shape
- if len(X_ts.shape) == 2:
- # Single sample, add batch dimension
- X_ts = np.expand_dims(X_ts, axis=0)
-
- # Ensure X_features has the right shape
- if X_features is None:
- # Extract features from time series data if no external features provided
- X_features = self._extract_features_from_timeseries(X_ts)
- elif len(X_features.shape) == 1:
- # Single sample, add batch dimension
- X_features = np.expand_dims(X_features, axis=0)
-
- def _extract_features_from_timeseries(self, X_ts: np.ndarray) -> np.ndarray:
- """Extract meaningful features from time series data instead of using dummy zeros"""
- try:
- batch_size = X_ts.shape[0]
- features = []
-
- for i in range(batch_size):
- sample = X_ts[i] # Shape: (timesteps, features)
-
- # Extract statistical features from each feature dimension
- sample_features = []
-
- for feature_idx in range(sample.shape[1]):
- feature_data = sample[:, feature_idx]
-
- # Basic statistical features
- sample_features.extend([
- np.mean(feature_data), # Mean
- np.std(feature_data), # Standard deviation
- np.min(feature_data), # Minimum
- np.max(feature_data), # Maximum
- np.percentile(feature_data, 25), # 25th percentile
- np.percentile(feature_data, 75), # 75th percentile
- ])
-
- # Trend features
- if len(feature_data) > 1:
- # Linear trend (slope)
- x = np.arange(len(feature_data))
- slope = np.polyfit(x, feature_data, 1)[0]
- sample_features.append(slope)
-
- # Rate of change
- rate_of_change = (feature_data[-1] - feature_data[0]) / feature_data[0] if feature_data[0] != 0 else 0
- sample_features.append(rate_of_change)
- else:
- sample_features.extend([0.0, 0.0])
-
- # Pad or truncate to expected feature size
- while len(sample_features) < self.feature_input_shape:
- sample_features.append(0.0)
- sample_features = sample_features[:self.feature_input_shape]
-
- features.append(sample_features)
-
- return np.array(features, dtype=np.float32)
-
- except Exception as e:
- logger.error(f"Error extracting features from time series: {e}")
- # Fallback to zeros if extraction fails
- return np.zeros((X_ts.shape[0], self.feature_input_shape), dtype=np.float32)
-
- # Get predictions
- y_proba = self.model.predict([X_ts, X_features])
-
- # Process based on output type
- if self.output_size == 1:
- # Binary classification
- y_pred = (y_proba > 0.5).astype(int).flatten()
- return y_pred, y_proba.flatten()
- elif self.output_size == 3:
- # Multi-class classification
- y_pred = np.argmax(y_proba, axis=1)
- return y_pred, y_proba
- else:
- # Regression
- return y_proba, y_proba
-
- def save(self, filepath=None):
- """
- Save the model to disk.
-
- Args:
- filepath (str): Path to save the model
-
- Returns:
- str: Path where the model was saved
- """
- if self.model is None:
- raise ValueError("Model has not been built yet")
-
- if filepath is None:
- # Create a default filepath with timestamp
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
- filepath = os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5")
-
- self.model.save(filepath)
- logger.info(f"Model saved to {filepath}")
- return filepath
-
- def load(self, filepath):
- """
- Load a saved model from disk.
-
- Args:
- filepath (str): Path to the saved model
-
- Returns:
- The loaded model
- """
- # Register custom layers
- custom_objects = {
- 'TransformerBlock': TransformerBlock,
- 'PositionalEncoding': PositionalEncoding
- }
-
- self.model = load_model(filepath, custom_objects=custom_objects)
- logger.info(f"Model loaded from {filepath}")
- return self.model
-
- def plot_training_history(self):
- """
- Plot training history (loss and metrics).
-
- Returns:
- str: Path to the saved plot
- """
- if self.history is None:
- raise ValueError("Model has not been trained yet")
-
- plt.figure(figsize=(12, 5))
-
- # Plot loss
- plt.subplot(1, 2, 1)
- plt.plot(self.history.history['loss'], label='Training Loss')
- if 'val_loss' in self.history.history:
- plt.plot(self.history.history['val_loss'], label='Validation Loss')
- plt.title('Model Loss')
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.legend()
-
- # Plot accuracy
- plt.subplot(1, 2, 2)
-
- if 'accuracy' in self.history.history:
- plt.plot(self.history.history['accuracy'], label='Training Accuracy')
- if 'val_accuracy' in self.history.history:
- plt.plot(self.history.history['val_accuracy'], label='Validation Accuracy')
- plt.title('Model Accuracy')
- plt.ylabel('Accuracy')
- elif 'mae' in self.history.history:
- plt.plot(self.history.history['mae'], label='Training MAE')
- if 'val_mae' in self.history.history:
- plt.plot(self.history.history['val_mae'], label='Validation MAE')
- plt.title('Model MAE')
- plt.ylabel('MAE')
-
- plt.xlabel('Epoch')
- plt.legend()
-
- plt.tight_layout()
-
- # Save figure
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
- fig_path = os.path.join(self.model_dir, f"transformer_training_history_{timestamp}.png")
- plt.savefig(fig_path)
- plt.close()
-
- logger.info(f"Training history plot saved to {fig_path}")
- return fig_path
-
-
-class MixtureOfExpertsModel:
- """
- Mixture of Experts (MoE) model.
-
- This model combines predictions from multiple expert models (such as CNN and Transformer)
- using a weighted ensemble approach.
- """
-
- def __init__(self, output_size=1, model_dir="NN/models/saved"):
- """
- Initialize the MoE model.
-
- Args:
- output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
- model_dir (str): Directory to save trained models
- """
- self.output_size = output_size
- self.model_dir = model_dir
- self.model = None
- self.history = None
- self.experts = {}
-
- # Create model directory if it doesn't exist
- os.makedirs(self.model_dir, exist_ok=True)
-
- logger.info(f"Initialized Mixture of Experts model with output size {output_size}")
-
- def add_expert(self, name, model):
- """
- Add an expert model to the MoE.
-
- Args:
- name (str): Name of the expert model
- model: The expert model instance
-
- Returns:
- None
- """
- self.experts[name] = model
- logger.info(f"Added expert model '{name}' to MoE")
-
- def build_model(self, ts_input_shape=(20, 5), expert_weights=None, learning_rate=0.001):
- """
- Build the MoE model by combining expert models.
-
- Args:
- ts_input_shape (tuple): Shape of time series input data
- expert_weights (dict): Weights for each expert model
- learning_rate (float): Learning rate for Adam optimizer
-
- Returns:
- The compiled model
- """
- # Time series input
- ts_inputs = Input(shape=ts_input_shape, name="ts_input")
-
- # Additional feature input (from CNN)
- feature_inputs = Input(shape=(64,), name="feature_input") # Default size for features
-
- # Process with each expert model
- expert_outputs = []
- expert_names = []
-
- for name, expert in self.experts.items():
- # Skip if expert model is not valid or doesn't have a call/predict method
- if expert is None:
- logger.warning(f"Expert model '{name}' is None, skipping")
- continue
-
- try:
- # Different handling based on model type
- if name == 'cnn':
- # CNN model takes only time series input
- expert_output = expert(ts_inputs)
- expert_outputs.append(expert_output)
- expert_names.append(name)
- elif name == 'transformer':
- # Transformer model takes both time series and feature inputs
- expert_output = expert([ts_inputs, feature_inputs])
- expert_outputs.append(expert_output)
- expert_names.append(name)
- else:
- logger.warning(f"Unknown expert model type: {name}")
- except Exception as e:
- logger.error(f"Error adding expert '{name}': {str(e)}")
-
- if not expert_outputs:
- logger.error("No valid expert models found")
- return None
-
- # Use expert weighting
- if expert_weights is None:
- # Equal weighting
- weights = [1.0 / len(expert_outputs)] * len(expert_outputs)
- else:
- # User-provided weights
- weights = [expert_weights.get(name, 1.0 / len(expert_outputs)) for name in expert_names]
- # Normalize weights
- weights = [w / sum(weights) for w in weights]
-
- # Combine expert outputs using weighted average
- if len(expert_outputs) == 1:
- # Only one expert, use its output directly
- combined_output = expert_outputs[0]
- else:
- # Multiple experts, compute weighted average
- weighted_outputs = [output * weight for output, weight in zip(expert_outputs, weights)]
- combined_output = Add()(weighted_outputs)
-
- # Create the MoE model
- moe_model = Model(inputs=[ts_inputs, feature_inputs], outputs=combined_output)
-
- # Compile the model
- if self.output_size == 1:
- # Binary classification
- moe_model.compile(
- optimizer=Adam(learning_rate=learning_rate),
- loss='binary_crossentropy',
- metrics=['accuracy']
- )
- elif self.output_size == 3:
- # Multi-class classification for BUY/HOLD/SELL
- moe_model.compile(
- optimizer=Adam(learning_rate=learning_rate),
- loss='categorical_crossentropy',
- metrics=['accuracy']
- )
- else:
- # Regression
- moe_model.compile(
- optimizer=Adam(learning_rate=learning_rate),
- loss='mse',
- metrics=['mae']
- )
-
- self.model = moe_model
-
- # Log model summary
- self.model.summary(print_fn=lambda x: logger.info(x))
-
- logger.info(f"Built MoE model with weights: {weights}")
- return self.model
-
- def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
- callbacks=None, class_weights=None):
- """
- Train the MoE model on the provided data.
-
- Args:
- X_ts (numpy.ndarray): Time series input features
- X_features (numpy.ndarray): Additional input features
- y (numpy.ndarray): Target labels
- batch_size (int): Batch size
- epochs (int): Number of epochs
- validation_split (float): Fraction of data to use for validation
- callbacks (list): List of Keras callbacks
- class_weights (dict): Class weights for imbalanced datasets
-
- Returns:
- History object containing training metrics
- """
- if self.model is None:
- logger.error("MoE model has not been built yet")
- return None
-
- # Default callbacks if none provided
- if callbacks is None:
- # Create a timestamp for model checkpoints
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
-
- callbacks = [
- EarlyStopping(
- monitor='val_loss',
- patience=10,
- restore_best_weights=True
- ),
- ReduceLROnPlateau(
- monitor='val_loss',
- factor=0.5,
- patience=5,
- min_lr=1e-6
- ),
- ModelCheckpoint(
- filepath=os.path.join(self.model_dir, f"moe_model_{timestamp}.h5"),
- monitor='val_loss',
- save_best_only=True
- )
- ]
-
- # Check if y needs to be one-hot encoded for multi-class
- if self.output_size == 3 and len(y.shape) == 1:
- y = tf.keras.utils.to_categorical(y, num_classes=3)
-
- # Train the model
- logger.info(f"Training MoE model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
- self.history = self.model.fit(
- [X_ts, X_features], y,
- batch_size=batch_size,
- epochs=epochs,
- validation_split=validation_split,
- callbacks=callbacks,
- class_weight=class_weights,
- verbose=2
- )
-
- # Save the trained model
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
- model_path = os.path.join(self.model_dir, f"moe_model_final_{timestamp}.h5")
- self.model.save(model_path)
- logger.info(f"Model saved to {model_path}")
-
- # Save training history
- history_path = os.path.join(self.model_dir, f"moe_model_history_{timestamp}.json")
- with open(history_path, 'w') as f:
- # Convert numpy values to Python native types for JSON serialization
- history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
- json.dump(history_dict, f, indent=2)
-
- return self.history
-
- def predict(self, X_ts, X_features=None):
- """
- Make predictions on new data.
-
- Args:
- X_ts (numpy.ndarray): Time series input features
- X_features (numpy.ndarray): Additional input features
-
- Returns:
- tuple: (y_pred, y_proba) where:
- y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
- y_proba is the class probability
- """
- if self.model is None:
- raise ValueError("Model has not been built or trained yet")
-
- # Ensure X_ts has the right shape
- if len(X_ts.shape) == 2:
- # Single sample, add batch dimension
- X_ts = np.expand_dims(X_ts, axis=0)
-
- # Ensure X_features has the right shape
- if X_features is None:
- # Create dummy features with zeros
- X_features = np.zeros((X_ts.shape[0], 64)) # Default size
- elif len(X_features.shape) == 1:
- # Single sample, add batch dimension
- X_features = np.expand_dims(X_features, axis=0)
-
- # Get predictions
- y_proba = self.model.predict([X_ts, X_features])
-
- # Process based on output type
- if self.output_size == 1:
- # Binary classification
- y_pred = (y_proba > 0.5).astype(int).flatten()
- return y_pred, y_proba.flatten()
- elif self.output_size == 3:
- # Multi-class classification
- y_pred = np.argmax(y_proba, axis=1)
- return y_pred, y_proba
- else:
- # Regression
- return y_proba, y_proba
-
- def save(self, filepath=None):
- """
- Save the model to disk.
-
- Args:
- filepath (str): Path to save the model
-
- Returns:
- str: Path where the model was saved
- """
- if self.model is None:
- raise ValueError("Model has not been built yet")
-
- if filepath is None:
- # Create a default filepath with timestamp
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
- filepath = os.path.join(self.model_dir, f"moe_model_{timestamp}.h5")
-
- self.model.save(filepath)
- logger.info(f"Model saved to {filepath}")
- return filepath
-
- def load(self, filepath):
- """
- Load a saved model from disk.
-
- Args:
- filepath (str): Path to the saved model
-
- Returns:
- The loaded model
- """
- # Register custom layers
- custom_objects = {
- 'TransformerBlock': TransformerBlock,
- 'PositionalEncoding': PositionalEncoding
- }
-
- self.model = load_model(filepath, custom_objects=custom_objects)
- logger.info(f"Model loaded from {filepath}")
- return self.model
-
-# Example usage:
-if __name__ == "__main__":
- # This would be a complete implementation in a real system
- print("Transformer and MoE models defined, but not implemented here.")
\ No newline at end of file
diff --git a/NN/start_tensorboard.py b/NN/start_tensorboard.py
deleted file mode 100644
index ed27a1b..0000000
--- a/NN/start_tensorboard.py
+++ /dev/null
@@ -1,88 +0,0 @@
-#!/usr/bin/env python
-"""
-Start TensorBoard for monitoring neural network training
-"""
-
-import os
-import sys
-import subprocess
-import webbrowser
-from time import sleep
-
-def start_tensorboard(logdir="NN/models/saved/logs", port=6006, open_browser=True):
- """
- Start TensorBoard in a subprocess
-
- Args:
- logdir: Directory containing TensorBoard logs
- port: Port to run TensorBoard on
- open_browser: Whether to open a browser automatically
- """
- # Make sure the log directory exists
- os.makedirs(logdir, exist_ok=True)
-
- # Create command
- cmd = [
- sys.executable,
- "-m",
- "tensorboard.main",
- f"--logdir={logdir}",
- f"--port={port}",
- "--bind_all"
- ]
-
- print(f"Starting TensorBoard with logs from {logdir} on port {port}")
- print(f"Command: {' '.join(cmd)}")
-
- # Start TensorBoard in a subprocess
- process = subprocess.Popen(
- cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- universal_newlines=True
- )
-
- # Wait for TensorBoard to start up
- for line in process.stdout:
- print(line.strip())
- if "TensorBoard" in line and "http://" in line:
- # TensorBoard is running, extract the URL
- url = None
- for part in line.split():
- if part.startswith(("http://", "https://")):
- url = part
- break
-
- # Open browser if requested and URL found
- if open_browser and url:
- print(f"Opening TensorBoard in browser: {url}")
- webbrowser.open(url)
-
- break
-
- # Return the process for the caller to manage
- return process
-
-if __name__ == "__main__":
- import argparse
-
- # Parse command line arguments
- parser = argparse.ArgumentParser(description="Start TensorBoard for NN training visualization")
- parser.add_argument("--logdir", default="NN/models/saved/logs", help="Directory containing TensorBoard logs")
- parser.add_argument("--port", type=int, default=6006, help="Port to run TensorBoard on")
- parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically")
-
- args = parser.parse_args()
-
- # Start TensorBoard
- process = start_tensorboard(args.logdir, args.port, not args.no_browser)
-
- try:
- # Keep the script running until Ctrl+C
- print("TensorBoard is running. Press Ctrl+C to stop.")
- while True:
- sleep(1)
- except KeyboardInterrupt:
- print("Stopping TensorBoard...")
- process.terminate()
- process.wait()
\ No newline at end of file
diff --git a/NN/training/enhanced_rl_training_integration.py b/NN/training/enhanced_rl_training_integration.py
deleted file mode 100644
index 3d240e5..0000000
--- a/NN/training/enhanced_rl_training_integration.py
+++ /dev/null
@@ -1,490 +0,0 @@
-#!/usr/bin/env python3
-"""
-Enhanced RL Training Integration - Comprehensive Fix
-
-This script addresses the critical RL training audit issues:
-1. MASSIVE INPUT DATA GAP (99.25% Missing) - Implements full 13,400 feature state
-2. Disconnected Training Pipeline - Provides proper data flow integration
-3. Missing Enhanced State Builder - Connects orchestrator to dashboard
-4. Reward Calculation Issues - Ensures enhanced pivot-based rewards
-5. Williams Market Structure Integration - Proper feature extraction
-6. Real-time Data Integration - Live market data to RL
-
-Usage:
- python enhanced_rl_training_integration.py
-"""
-
-import os
-import sys
-import asyncio
-import logging
-import numpy as np
-from datetime import datetime, timedelta
-from pathlib import Path
-from typing import Dict, List, Optional, Any
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import setup_logging, get_config
-from core.data_provider import DataProvider
-from core.enhanced_orchestrator import EnhancedTradingOrchestrator
-from core.trading_executor import TradingExecutor
-from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
-from utils.tensorboard_logger import TensorBoardLogger
-
-logger = logging.getLogger(__name__)
-
-class EnhancedRLTrainingIntegrator:
- """
- Comprehensive RL Training Integrator
-
- Fixes all audit issues by ensuring proper data flow and feature completeness.
- """
-
- def __init__(self):
- """Initialize the enhanced RL training integrator"""
- # Setup logging
- setup_logging()
- logger.info("=" * 70)
- logger.info("ENHANCED RL TRAINING INTEGRATION - COMPREHENSIVE FIX")
- logger.info("=" * 70)
-
- # Get configuration
- self.config = get_config()
-
- # Initialize core components
- self.data_provider = DataProvider()
- self.enhanced_orchestrator = None
- self.trading_executor = TradingExecutor()
- self.dashboard = None
-
- # Training metrics
- self.training_stats = {
- 'total_episodes': 0,
- 'successful_state_builds': 0,
- 'enhanced_reward_calculations': 0,
- 'comprehensive_features_used': 0,
- 'pivot_features_extracted': 0,
- 'cob_features_available': 0
- }
-
- # Initialize TensorBoard logger
- experiment_name = f"enhanced_rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
- self.tb_logger = TensorBoardLogger(
- log_dir="runs",
- experiment_name=experiment_name,
- enabled=True
- )
- logger.info(f"TensorBoard logging enabled for experiment: {experiment_name}")
-
- logger.info("Enhanced RL Training Integrator initialized")
-
- async def start_integration(self):
- """Start the comprehensive RL training integration"""
- try:
- logger.info("Starting comprehensive RL training integration...")
-
- # 1. Initialize Enhanced Orchestrator with comprehensive features
- await self._initialize_enhanced_orchestrator()
-
- # 2. Create enhanced dashboard with proper connections
- await self._create_enhanced_dashboard()
-
- # 3. Verify comprehensive state building
- await self._verify_comprehensive_state_building()
-
- # 4. Test enhanced reward calculation
- await self._test_enhanced_reward_calculation()
-
- # 5. Validate Williams market structure integration
- await self._validate_williams_integration()
-
- # 6. Start live training with comprehensive features
- await self._start_live_comprehensive_training()
-
- logger.info("=" * 70)
- logger.info("COMPREHENSIVE RL TRAINING INTEGRATION COMPLETE")
- logger.info("=" * 70)
- self._log_integration_stats()
-
- except Exception as e:
- logger.error(f"Error in RL training integration: {e}")
- import traceback
- logger.error(traceback.format_exc())
-
- async def _initialize_enhanced_orchestrator(self):
- """Initialize enhanced orchestrator with comprehensive RL capabilities"""
- try:
- logger.info("[STEP 1] Initializing Enhanced Orchestrator...")
-
- # Create enhanced orchestrator with RL training enabled
- self.enhanced_orchestrator = EnhancedTradingOrchestrator(
- data_provider=self.data_provider,
- symbols=['ETH/USDT', 'BTC/USDT'],
- enhanced_rl_training=True,
- model_registry={} # Will be populated as needed
- )
-
- # Start COB integration for real-time market microstructure
- await self.enhanced_orchestrator.start_cob_integration()
-
- # Start real-time processing
- await self.enhanced_orchestrator.start_realtime_processing()
-
- logger.info("[SUCCESS] Enhanced Orchestrator initialized with:")
- logger.info(" - Comprehensive RL state building: ENABLED")
- logger.info(" - Enhanced pivot-based rewards: ENABLED")
- logger.info(" - COB integration: ENABLED")
- logger.info(" - Williams market structure: ENABLED")
- logger.info(" - Real-time tick processing: ENABLED")
-
- except Exception as e:
- logger.error(f"Error initializing enhanced orchestrator: {e}")
- raise
-
- async def _create_enhanced_dashboard(self):
- """Create dashboard with enhanced orchestrator connections"""
- try:
- logger.info("[STEP 2] Creating Enhanced Dashboard...")
-
- # Create trading dashboard with enhanced orchestrator
- self.dashboard = TradingDashboard(
- data_provider=self.data_provider,
- orchestrator=self.enhanced_orchestrator, # Use enhanced orchestrator
- trading_executor=self.trading_executor
- )
-
- # Verify enhanced connections
- has_comprehensive_state_builder = hasattr(self.dashboard.orchestrator, 'build_comprehensive_rl_state')
- has_enhanced_reward_calc = hasattr(self.dashboard.orchestrator, 'calculate_enhanced_pivot_reward')
- has_symbol_correlation = hasattr(self.dashboard.orchestrator, '_get_symbol_correlation')
-
- logger.info("[SUCCESS] Enhanced Dashboard created with:")
- logger.info(f" - Comprehensive state builder: {'AVAILABLE' if has_comprehensive_state_builder else 'MISSING'}")
- logger.info(f" - Enhanced reward calculation: {'AVAILABLE' if has_enhanced_reward_calc else 'MISSING'}")
- logger.info(f" - Symbol correlation analysis: {'AVAILABLE' if has_symbol_correlation else 'MISSING'}")
-
- if not all([has_comprehensive_state_builder, has_enhanced_reward_calc, has_symbol_correlation]):
- logger.warning("Some enhanced features are missing - this will cause fallbacks to basic training")
- else:
- logger.info(" - ALL ENHANCED FEATURES AVAILABLE!")
-
- except Exception as e:
- logger.error(f"Error creating enhanced dashboard: {e}")
- raise
-
- async def _verify_comprehensive_state_building(self):
- """Verify that comprehensive RL state building works correctly"""
- try:
- logger.info("[STEP 3] Verifying Comprehensive State Building...")
-
- # Test comprehensive state building for ETH
- eth_state = self.enhanced_orchestrator.build_comprehensive_rl_state('ETH/USDT')
-
- if eth_state is not None:
- logger.info(f"[SUCCESS] ETH comprehensive state built: {len(eth_state)} features")
-
- # Verify feature count
- if len(eth_state) == 13400:
- logger.info(" - PERFECT: Exactly 13,400 features as required!")
- self.training_stats['comprehensive_features_used'] += 1
- else:
- logger.warning(f" - MISMATCH: Expected 13,400 features, got {len(eth_state)}")
-
- # Analyze feature distribution
- self._analyze_state_features(eth_state)
- self.training_stats['successful_state_builds'] += 1
-
- else:
- logger.error(" - FAILED: Comprehensive state building returned None")
-
- # Test for BTC reference
- btc_state = self.enhanced_orchestrator.build_comprehensive_rl_state('BTC/USDT')
- if btc_state is not None:
- logger.info(f"[SUCCESS] BTC reference state built: {len(btc_state)} features")
- self.training_stats['successful_state_builds'] += 1
-
- except Exception as e:
- logger.error(f"Error verifying comprehensive state building: {e}")
-
- def _analyze_state_features(self, state_vector: np.ndarray):
- """Analyze the comprehensive state feature distribution"""
- try:
- # Calculate feature statistics
- non_zero_features = np.count_nonzero(state_vector)
- zero_features = len(state_vector) - non_zero_features
- feature_mean = np.mean(state_vector)
- feature_std = np.std(state_vector)
- feature_min = np.min(state_vector)
- feature_max = np.max(state_vector)
-
- logger.info(" - Feature Analysis:")
- logger.info(f" * Non-zero features: {non_zero_features:,} ({non_zero_features/len(state_vector)*100:.1f}%)")
- logger.info(f" * Zero features: {zero_features:,} ({zero_features/len(state_vector)*100:.1f}%)")
- logger.info(f" * Mean: {feature_mean:.6f}")
- logger.info(f" * Std: {feature_std:.6f}")
- logger.info(f" * Range: [{feature_min:.6f}, {feature_max:.6f}]")
-
- # Log feature statistics to TensorBoard
- step = self.training_stats['total_episodes']
- self.tb_logger.log_scalars('Features/Distribution', {
- 'non_zero_percentage': non_zero_features/len(state_vector)*100,
- 'mean': feature_mean,
- 'std': feature_std,
- 'min': feature_min,
- 'max': feature_max
- }, step)
-
- # Log feature histogram to TensorBoard
- self.tb_logger.log_histogram('Features/Values', state_vector, step)
-
- # Check if features are properly distributed
- if non_zero_features > len(state_vector) * 0.1: # At least 10% non-zero
- logger.info(" * GOOD: Features are well distributed")
- else:
- logger.warning(" * WARNING: Too many zero features - data may be incomplete")
-
- except Exception as e:
- logger.warning(f"Error analyzing state features: {e}")
-
- async def _test_enhanced_reward_calculation(self):
- """Test enhanced pivot-based reward calculation"""
- try:
- logger.info("[STEP 4] Testing Enhanced Reward Calculation...")
-
- # Create mock trade data for testing
- trade_decision = {
- 'action': 'BUY',
- 'confidence': 0.75,
- 'price': 2500.0,
- 'timestamp': datetime.now()
- }
-
- trade_outcome = {
- 'net_pnl': 50.0,
- 'exit_price': 2550.0,
- 'duration': timedelta(minutes=15)
- }
-
- # Get market data for reward calculation
- market_data = {
- 'volatility': 0.03,
- 'order_flow_direction': 'bullish',
- 'order_flow_strength': 0.8
- }
-
- # Test enhanced reward calculation
- if hasattr(self.enhanced_orchestrator, 'calculate_enhanced_pivot_reward'):
- enhanced_reward = self.enhanced_orchestrator.calculate_enhanced_pivot_reward(
- trade_decision, market_data, trade_outcome
- )
-
- logger.info(f"[SUCCESS] Enhanced reward calculated: {enhanced_reward:.3f}")
- logger.info(" - Enhanced pivot-based reward system: WORKING")
- self.training_stats['enhanced_reward_calculations'] += 1
-
- # Log reward metrics to TensorBoard
- step = self.training_stats['enhanced_reward_calculations']
- self.tb_logger.log_scalar('Rewards/Enhanced', enhanced_reward, step)
-
- # Log reward components to TensorBoard
- self.tb_logger.log_scalars('Rewards/Components', {
- 'pnl_component': trade_outcome['net_pnl'],
- 'confidence': trade_decision['confidence'],
- 'volatility': market_data['volatility'],
- 'order_flow_strength': market_data['order_flow_strength']
- }, step)
-
- else:
- logger.error(" - FAILED: Enhanced reward calculation method not available")
-
- except Exception as e:
- logger.error(f"Error testing enhanced reward calculation: {e}")
-
- async def _validate_williams_integration(self):
- """Validate Williams market structure integration"""
- try:
- logger.info("[STEP 5] Validating Williams Market Structure Integration...")
-
- # Test Williams pivot feature extraction
- try:
- from training.williams_market_structure import extract_pivot_features, analyze_pivot_context
-
- # Get test market data
- df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=100)
-
- if df is not None and not df.empty:
- # Test pivot feature extraction
- pivot_features = extract_pivot_features(df)
-
- if pivot_features is not None:
- logger.info(f"[SUCCESS] Williams pivot features extracted: {len(pivot_features)} features")
- self.training_stats['pivot_features_extracted'] += 1
-
- # Test pivot context analysis
- market_data = {'ohlcv_data': df}
- pivot_context = analyze_pivot_context(
- market_data, datetime.now(), 'BUY'
- )
-
- if pivot_context is not None:
- logger.info("[SUCCESS] Williams pivot context analysis: WORKING")
- logger.info(f" - Near pivot: {pivot_context.get('near_pivot', False)}")
- logger.info(f" - Pivot strength: {pivot_context.get('pivot_strength', 0):.3f}")
- else:
- logger.warning(" - Williams pivot context analysis returned None")
- else:
- logger.warning(" - Williams pivot feature extraction returned None")
- else:
- logger.warning(" - No market data available for Williams testing")
-
- except ImportError:
- logger.error(" - Williams market structure module not available")
- except Exception as e:
- logger.error(f" - Error in Williams integration: {e}")
-
- except Exception as e:
- logger.error(f"Error validating Williams integration: {e}")
-
- async def _start_live_comprehensive_training(self):
- """Start live training with comprehensive feature integration"""
- try:
- logger.info("[STEP 6] Starting Live Comprehensive Training...")
-
- # Run a few training iterations to verify integration
- for iteration in range(5):
- logger.info(f"Training iteration {iteration + 1}/5")
-
- # Make coordinated decisions using enhanced orchestrator
- decisions = await self.enhanced_orchestrator.make_coordinated_decisions()
-
- # Track iteration metrics for TensorBoard
- iteration_metrics = {
- 'decisions_count': len(decisions),
- 'confidence_avg': 0.0,
- 'state_size_avg': 0.0,
- 'successful_states': 0
- }
-
- # Process each decision
- for symbol, decision in decisions.items():
- if decision:
- logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
-
- # Track confidence for TensorBoard
- iteration_metrics['confidence_avg'] += decision.confidence
-
- # Build comprehensive state for this decision
- comprehensive_state = self.enhanced_orchestrator.build_comprehensive_rl_state(symbol)
-
- if comprehensive_state is not None:
- state_size = len(comprehensive_state)
- logger.info(f" - Comprehensive state: {state_size} features")
- self.training_stats['total_episodes'] += 1
-
- # Track state size for TensorBoard
- iteration_metrics['state_size_avg'] += state_size
- iteration_metrics['successful_states'] += 1
-
- # Log individual state metrics to TensorBoard
- self.tb_logger.log_state_metrics(
- symbol=symbol,
- state_info={
- 'size': state_size,
- 'quality': 1.0 if state_size == 13400 else 0.8,
- 'feature_counts': {
- 'total': state_size,
- 'non_zero': np.count_nonzero(comprehensive_state)
- }
- },
- step=self.training_stats['total_episodes']
- )
- else:
- logger.warning(f" - Failed to build comprehensive state for {symbol}")
-
- # Calculate averages for TensorBoard
- if decisions:
- iteration_metrics['confidence_avg'] /= len(decisions)
-
- if iteration_metrics['successful_states'] > 0:
- iteration_metrics['state_size_avg'] /= iteration_metrics['successful_states']
-
- # Log iteration metrics to TensorBoard
- self.tb_logger.log_scalars('Training/Iteration', {
- 'iteration': iteration + 1,
- 'decisions_count': iteration_metrics['decisions_count'],
- 'confidence_avg': iteration_metrics['confidence_avg'],
- 'state_size_avg': iteration_metrics['state_size_avg'],
- 'successful_states': iteration_metrics['successful_states']
- }, iteration + 1)
-
- # Wait between iterations
- await asyncio.sleep(2)
-
- logger.info("[SUCCESS] Live comprehensive training demonstration complete")
-
- except Exception as e:
- logger.error(f"Error in live comprehensive training: {e}")
-
- def _log_integration_stats(self):
- """Log comprehensive integration statistics"""
- logger.info("INTEGRATION STATISTICS:")
- logger.info(f" - Total training episodes: {self.training_stats['total_episodes']}")
- logger.info(f" - Successful state builds: {self.training_stats['successful_state_builds']}")
- logger.info(f" - Enhanced reward calculations: {self.training_stats['enhanced_reward_calculations']}")
- logger.info(f" - Comprehensive features used: {self.training_stats['comprehensive_features_used']}")
- logger.info(f" - Pivot features extracted: {self.training_stats['pivot_features_extracted']}")
-
- # Calculate success rates
- state_success_rate = 0
- if self.training_stats['total_episodes'] > 0:
- state_success_rate = self.training_stats['successful_state_builds'] / self.training_stats['total_episodes'] * 100
- logger.info(f" - State building success rate: {state_success_rate:.1f}%")
-
- # Log final statistics to TensorBoard
- self.tb_logger.log_scalars('Integration/Statistics', {
- 'total_episodes': self.training_stats['total_episodes'],
- 'successful_state_builds': self.training_stats['successful_state_builds'],
- 'enhanced_reward_calculations': self.training_stats['enhanced_reward_calculations'],
- 'comprehensive_features_used': self.training_stats['comprehensive_features_used'],
- 'pivot_features_extracted': self.training_stats['pivot_features_extracted'],
- 'state_success_rate': state_success_rate
- }, 0) # Use step 0 for final summary stats
-
- # Integration status
- if self.training_stats['comprehensive_features_used'] > 0:
- logger.info("STATUS: COMPREHENSIVE RL TRAINING INTEGRATION SUCCESSFUL! ✅")
- logger.info("The system is now using the full 13,400 feature comprehensive state.")
-
- # Log success status to TensorBoard
- self.tb_logger.log_scalar('Integration/Success', 1.0, 0)
- else:
- logger.warning("STATUS: Integration partially successful - some fallbacks may occur")
-
- # Log partial success status to TensorBoard
- self.tb_logger.log_scalar('Integration/Success', 0.5, 0)
-
-async def main():
- """Main entry point"""
- try:
- # Create and run the enhanced RL training integrator
- integrator = EnhancedRLTrainingIntegrator()
- await integrator.start_integration()
-
- logger.info("Enhanced RL training integration completed successfully!")
- return 0
-
- except KeyboardInterrupt:
- logger.info("Integration interrupted by user")
- return 0
- except Exception as e:
- logger.error(f"Fatal error in integration: {e}")
- import traceback
- logger.error(traceback.format_exc())
- return 1
-
-if __name__ == "__main__":
- exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
diff --git a/NN/training/example_checkpoint_usage.py b/NN/training/example_checkpoint_usage.py
deleted file mode 100644
index b54fe38..0000000
--- a/NN/training/example_checkpoint_usage.py
+++ /dev/null
@@ -1,148 +0,0 @@
-#!/usr/bin/env python3
-"""
-Example: Using the Checkpoint Management System
-"""
-
-import logging
-import torch
-import torch.nn as nn
-import numpy as np
-from datetime import datetime
-
-from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint, get_checkpoint_manager
-from utils.training_integration import get_training_integration
-
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-class ExampleCNN(nn.Module):
- def __init__(self, input_channels=5, num_classes=3):
- super().__init__()
- self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1)
- self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
- self.pool = nn.AdaptiveAvgPool2d((1, 1))
- self.fc = nn.Linear(64, num_classes)
-
- def forward(self, x):
- x = torch.relu(self.conv1(x))
- x = torch.relu(self.conv2(x))
- x = self.pool(x)
- x = x.view(x.size(0), -1)
- return self.fc(x)
-
-def example_cnn_training():
- logger.info("=== CNN Training Example ===")
-
- model = ExampleCNN()
- training_integration = get_training_integration()
-
- for epoch in range(5): # Simulate 5 epochs
- # Simulate training metrics
- train_loss = 2.0 - (epoch * 0.15) + np.random.normal(0, 0.1)
- train_acc = 0.3 + (epoch * 0.06) + np.random.normal(0, 0.02)
- val_loss = train_loss + np.random.normal(0, 0.05)
- val_acc = train_acc - 0.05 + np.random.normal(0, 0.02)
-
- # Clamp values to realistic ranges
- train_acc = max(0.0, min(1.0, train_acc))
- val_acc = max(0.0, min(1.0, val_acc))
- train_loss = max(0.1, train_loss)
- val_loss = max(0.1, val_loss)
-
- logger.info(f"Epoch {epoch+1}: train_acc={train_acc:.3f}, val_acc={val_acc:.3f}")
-
- # Save checkpoint
- saved = training_integration.save_cnn_checkpoint(
- cnn_model=model,
- model_name="example_cnn",
- epoch=epoch + 1,
- train_accuracy=train_acc,
- val_accuracy=val_acc,
- train_loss=train_loss,
- val_loss=val_loss,
- training_time_hours=0.1 * (epoch + 1)
- )
-
- if saved:
- logger.info(f" Checkpoint saved for epoch {epoch+1}")
- else:
- logger.info(f" Checkpoint not saved (performance not improved)")
-
- # Load the best checkpoint
- logger.info("\\nLoading best checkpoint...")
- best_result = load_best_checkpoint("example_cnn")
- if best_result:
- file_path, metadata = best_result
- logger.info(f"Best checkpoint: {metadata.checkpoint_id}")
- logger.info(f"Performance score: {metadata.performance_score:.4f}")
-
-def example_manual_checkpoint():
- logger.info("\\n=== Manual Checkpoint Example ===")
-
- model = nn.Linear(10, 3)
-
- performance_metrics = {
- 'accuracy': 0.85,
- 'val_accuracy': 0.82,
- 'loss': 0.45,
- 'val_loss': 0.48
- }
-
- training_metadata = {
- 'epoch': 25,
- 'training_time_hours': 2.5,
- 'total_parameters': sum(p.numel() for p in model.parameters())
- }
-
- logger.info("Saving checkpoint manually...")
- metadata = save_checkpoint(
- model=model,
- model_name="example_manual",
- model_type="cnn",
- performance_metrics=performance_metrics,
- training_metadata=training_metadata,
- force_save=True
- )
-
- if metadata:
- logger.info(f" Manual checkpoint saved: {metadata.checkpoint_id}")
- logger.info(f" Performance score: {metadata.performance_score:.4f}")
-
-def show_checkpoint_stats():
- logger.info("\\n=== Checkpoint Statistics ===")
-
- checkpoint_manager = get_checkpoint_manager()
- stats = checkpoint_manager.get_checkpoint_stats()
-
- logger.info(f"Total models: {stats['total_models']}")
- logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
- logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
-
- for model_name, model_stats in stats['models'].items():
- logger.info(f"\\n{model_name}:")
- logger.info(f" Checkpoints: {model_stats['checkpoint_count']}")
- logger.info(f" Size: {model_stats['total_size_mb']:.2f} MB")
- logger.info(f" Best performance: {model_stats['best_performance']:.4f}")
-
-def main():
- logger.info(" Checkpoint Management System Examples")
- logger.info("=" * 50)
-
- try:
- example_cnn_training()
- example_manual_checkpoint()
- show_checkpoint_stats()
-
- logger.info("\\n All examples completed successfully!")
- logger.info("\\nTo use in your training:")
- logger.info("1. Import: from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint")
- logger.info("2. Or use: from utils.training_integration import get_training_integration")
- logger.info("3. Save checkpoints during training with performance metrics")
- logger.info("4. Load best checkpoints for inference or continued training")
-
- except Exception as e:
- logger.error(f"Error in examples: {e}")
- raise
-
-if __name__ == "__main__":
- main()
diff --git a/NN/training/integrate_checkpoint_management.py b/NN/training/integrate_checkpoint_management.py
deleted file mode 100644
index 6c04c57..0000000
--- a/NN/training/integrate_checkpoint_management.py
+++ /dev/null
@@ -1,517 +0,0 @@
-#!/usr/bin/env python3
-"""
-Comprehensive Checkpoint Management Integration
-
-This script demonstrates how to integrate the checkpoint management system
-across all training pipelines in the gogo2 project.
-
-Features:
-- DQN Agent training with automatic checkpointing
-- CNN Model training with checkpoint management
-- ExtremaTrainer with checkpoint persistence
-- NegativeCaseTrainer with checkpoint integration
-- Unified training orchestration with checkpoint coordination
-"""
-
-import asyncio
-import logging
-import time
-import signal
-import sys
-import numpy as np
-from datetime import datetime
-from pathlib import Path
-from typing import Dict, Any, List
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler('logs/checkpoint_integration.log'),
- logging.StreamHandler()
- ]
-)
-logger = logging.getLogger(__name__)
-
-# Import checkpoint management
-from utils.checkpoint_manager import get_checkpoint_manager, get_checkpoint_stats
-from utils.training_integration import get_training_integration
-
-# Import training components
-from NN.models.dqn_agent import DQNAgent
-from NN.models.standardized_cnn import StandardizedCNN
-from core.extrema_trainer import ExtremaTrainer
-from core.negative_case_trainer import NegativeCaseTrainer
-from core.data_provider import DataProvider
-from core.config import get_config
-
-class CheckpointIntegratedTrainingSystem:
- """Unified training system with comprehensive checkpoint management"""
-
- def __init__(self):
- """Initialize the checkpoint-integrated training system"""
- self.config = get_config()
- self.running = False
-
- # Checkpoint management
- self.checkpoint_manager = get_checkpoint_manager()
- self.training_integration = get_training_integration()
-
- # Data provider
- self.data_provider = DataProvider(
- symbols=['ETH/USDT', 'BTC/USDT'],
- timeframes=['1s', '1m', '1h', '1d']
- )
-
- # Training components with checkpoint management
- self.dqn_agent = None
- self.cnn_trainer = None
- self.extrema_trainer = None
- self.negative_case_trainer = None
-
- # Training statistics
- self.training_stats = {
- 'start_time': None,
- 'total_training_sessions': 0,
- 'checkpoints_saved': 0,
- 'models_loaded': 0,
- 'best_performances': {}
- }
-
- logger.info("Checkpoint-Integrated Training System initialized")
-
- async def initialize_components(self):
- """Initialize all training components with checkpoint management"""
- try:
- logger.info("Initializing training components with checkpoint management...")
-
- # Initialize data provider
- await self.data_provider.start_real_time_streaming()
- logger.info("Data provider streaming started")
-
- # Initialize DQN Agent with checkpoint management
- logger.info("Initializing DQN Agent with checkpoints...")
- self.dqn_agent = DQNAgent(
- state_shape=(100,), # Example state shape
- n_actions=3,
- model_name="integrated_dqn_agent",
- enable_checkpoints=True
- )
- logger.info("✅ DQN Agent initialized with checkpoint management")
-
- # Initialize StandardizedCNN Model with checkpoint management
- logger.info("Initializing StandardizedCNN Model with checkpoints...")
- self.cnn_model = StandardizedCNN(model_name="integrated_cnn_model")
- logger.info("✅ StandardizedCNN Model initialized with checkpoint management")
-
- # Initialize ExtremaTrainer with checkpoint management
- logger.info("Initializing ExtremaTrainer with checkpoints...")
- self.extrema_trainer = ExtremaTrainer(
- data_provider=self.data_provider,
- symbols=['ETH/USDT', 'BTC/USDT'],
- model_name="integrated_extrema_trainer",
- enable_checkpoints=True
- )
- await self.extrema_trainer.initialize_context_data()
- logger.info("✅ ExtremaTrainer initialized with checkpoint management")
-
- # Initialize NegativeCaseTrainer with checkpoint management
- logger.info("Initializing NegativeCaseTrainer with checkpoints...")
- self.negative_case_trainer = NegativeCaseTrainer(
- model_name="integrated_negative_case_trainer",
- enable_checkpoints=True
- )
- logger.info("✅ NegativeCaseTrainer initialized with checkpoint management")
-
- # Load existing checkpoints for all components
- self.training_stats['models_loaded'] = await self._load_all_checkpoints()
-
- logger.info("All training components initialized successfully")
-
- except Exception as e:
- logger.error(f"Error initializing components: {e}")
- raise
-
- async def _load_all_checkpoints(self) -> int:
- """Load checkpoints for all training components"""
- loaded_count = 0
-
- try:
- # DQN Agent checkpoint loading is handled in __init__
- if hasattr(self.dqn_agent, 'episode_count') and self.dqn_agent.episode_count > 0:
- loaded_count += 1
- logger.info(f"DQN Agent resumed from episode {self.dqn_agent.episode_count}")
-
- # CNN Trainer checkpoint loading is handled in __init__
- if hasattr(self.cnn_trainer, 'epoch_count') and self.cnn_trainer.epoch_count > 0:
- loaded_count += 1
- logger.info(f"CNN Trainer resumed from epoch {self.cnn_trainer.epoch_count}")
-
- # ExtremaTrainer checkpoint loading is handled in __init__
- if hasattr(self.extrema_trainer, 'training_session_count') and self.extrema_trainer.training_session_count > 0:
- loaded_count += 1
- logger.info(f"ExtremaTrainer resumed from session {self.extrema_trainer.training_session_count}")
-
- # NegativeCaseTrainer checkpoint loading is handled in __init__
- if hasattr(self.negative_case_trainer, 'training_session_count') and self.negative_case_trainer.training_session_count > 0:
- loaded_count += 1
- logger.info(f"NegativeCaseTrainer resumed from session {self.negative_case_trainer.training_session_count}")
-
- return loaded_count
-
- except Exception as e:
- logger.error(f"Error loading checkpoints: {e}")
- return 0
-
- async def run_integrated_training_loop(self):
- """Run the integrated training loop with checkpoint coordination"""
- logger.info("Starting integrated training loop with checkpoint management...")
-
- self.running = True
- self.training_stats['start_time'] = datetime.now()
-
- training_cycle = 0
-
- try:
- while self.running:
- training_cycle += 1
- cycle_start = time.time()
-
- logger.info(f"=== Training Cycle {training_cycle} ===")
-
- # DQN Training
- dqn_results = await self._train_dqn_agent()
-
- # CNN Training
- cnn_results = await self._train_cnn_model()
-
- # Extrema Detection Training
- extrema_results = await self._train_extrema_detector()
-
- # Negative Case Training (runs in background)
- negative_results = await self._process_negative_cases()
-
- # Coordinate checkpoint saving
- await self._coordinate_checkpoint_saving(
- dqn_results, cnn_results, extrema_results, negative_results
- )
-
- # Update statistics
- self.training_stats['total_training_sessions'] += 1
-
- # Log cycle summary
- cycle_duration = time.time() - cycle_start
- logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
-
- # Wait before next cycle
- await asyncio.sleep(60) # 1-minute cycles
-
- except KeyboardInterrupt:
- logger.info("Training interrupted by user")
- except Exception as e:
- logger.error(f"Error in training loop: {e}")
- finally:
- await self.shutdown()
-
- async def _train_dqn_agent(self) -> Dict[str, Any]:
- """Train DQN agent with automatic checkpointing"""
- try:
- if not self.dqn_agent:
- return {'status': 'skipped', 'reason': 'no_agent'}
-
- # Simulate DQN training episode
- episode_reward = 0.0
-
- # Add some training experiences (simulate real training)
- for _ in range(10): # Simulate 10 training steps
- state = np.random.randn(100).astype(np.float32)
- action = np.random.randint(0, 3)
- reward = np.random.randn() * 0.1
- next_state = np.random.randn(100).astype(np.float32)
- done = np.random.random() < 0.1
-
- self.dqn_agent.remember(state, action, reward, next_state, done)
- episode_reward += reward
-
- # Train if enough experiences
- loss = 0.0
- if len(self.dqn_agent.memory) >= self.dqn_agent.batch_size:
- loss = self.dqn_agent.replay()
-
- # Save checkpoint (automatic based on performance)
- checkpoint_saved = self.dqn_agent.save_checkpoint(episode_reward)
-
- if checkpoint_saved:
- self.training_stats['checkpoints_saved'] += 1
-
- return {
- 'status': 'completed',
- 'episode_reward': episode_reward,
- 'loss': loss,
- 'checkpoint_saved': checkpoint_saved,
- 'episode': self.dqn_agent.episode_count
- }
-
- except Exception as e:
- logger.error(f"Error training DQN agent: {e}")
- return {'status': 'error', 'error': str(e)}
-
- async def _train_cnn_model(self) -> Dict[str, Any]:
- """Train CNN model with automatic checkpointing"""
- try:
- if not self.cnn_trainer:
- return {'status': 'skipped', 'reason': 'no_trainer'}
-
- # Simulate CNN training step
- import torch
- import numpy as np
-
- batch_size = 32
- input_size = 60
- feature_dim = 50
-
- # Generate synthetic training data
- x = torch.randn(batch_size, input_size, feature_dim)
- y = torch.randint(0, 3, (batch_size,))
-
- # Training step
- results = self.cnn_trainer.train_step(x, y)
-
- # Simulate validation
- val_x = torch.randn(16, input_size, feature_dim)
- val_y = torch.randint(0, 3, (16,))
- val_results = self.cnn_trainer.train_step(val_x, val_y)
-
- # Save checkpoint (automatic based on performance)
- checkpoint_saved = self.cnn_trainer.save_checkpoint(
- train_accuracy=results.get('accuracy', 0.5),
- val_accuracy=val_results.get('accuracy', 0.5),
- train_loss=results.get('total_loss', 1.0),
- val_loss=val_results.get('total_loss', 1.0)
- )
-
- if checkpoint_saved:
- self.training_stats['checkpoints_saved'] += 1
-
- return {
- 'status': 'completed',
- 'train_accuracy': results.get('accuracy', 0.5),
- 'val_accuracy': val_results.get('accuracy', 0.5),
- 'train_loss': results.get('total_loss', 1.0),
- 'val_loss': val_results.get('total_loss', 1.0),
- 'checkpoint_saved': checkpoint_saved,
- 'epoch': self.cnn_trainer.epoch_count
- }
-
- except Exception as e:
- logger.error(f"Error training CNN model: {e}")
- return {'status': 'error', 'error': str(e)}
-
- async def _train_extrema_detector(self) -> Dict[str, Any]:
- """Train extrema detector with automatic checkpointing"""
- try:
- if not self.extrema_trainer:
- return {'status': 'skipped', 'reason': 'no_trainer'}
-
- # Update context data and detect extrema
- update_results = self.extrema_trainer.update_context_data()
-
- # Get training data
- extrema_data = self.extrema_trainer.get_extrema_training_data(count=10)
-
- # Simulate training accuracy improvement
- if extrema_data:
- self.extrema_trainer.training_stats['total_extrema_detected'] += len(extrema_data)
- self.extrema_trainer.training_stats['successful_predictions'] += len(extrema_data) // 2
- self.extrema_trainer.training_stats['failed_predictions'] += len(extrema_data) // 2
-
- # Save checkpoint (automatic based on performance)
- checkpoint_saved = self.extrema_trainer.save_checkpoint()
-
- if checkpoint_saved:
- self.training_stats['checkpoints_saved'] += 1
-
- return {
- 'status': 'completed',
- 'extrema_detected': len(extrema_data),
- 'context_updates': sum(1 for success in update_results.values() if success),
- 'checkpoint_saved': checkpoint_saved,
- 'session': self.extrema_trainer.training_session_count
- }
-
- except Exception as e:
- logger.error(f"Error training extrema detector: {e}")
- return {'status': 'error', 'error': str(e)}
-
- async def _process_negative_cases(self) -> Dict[str, Any]:
- """Process negative cases with automatic checkpointing"""
- try:
- if not self.negative_case_trainer:
- return {'status': 'skipped', 'reason': 'no_trainer'}
-
- # Simulate adding a negative case
- if np.random.random() < 0.1: # 10% chance of negative case
- trade_info = {
- 'symbol': 'ETH/USDT',
- 'action': 'BUY',
- 'price': 2000.0,
- 'pnl': -50.0, # Loss
- 'value': 1000.0,
- 'confidence': 0.7,
- 'timestamp': datetime.now()
- }
-
- market_data = {
- 'exit_price': 1950.0,
- 'state_before': {},
- 'state_after': {},
- 'tick_data': [],
- 'technical_indicators': {}
- }
-
- case_id = self.negative_case_trainer.add_losing_trade(trade_info, market_data)
-
- # Simulate loss improvement
- loss_improvement = np.random.random() * 0.1
-
- # Save checkpoint (automatic based on performance)
- checkpoint_saved = self.negative_case_trainer.save_checkpoint(loss_improvement)
-
- if checkpoint_saved:
- self.training_stats['checkpoints_saved'] += 1
-
- return {
- 'status': 'completed',
- 'case_added': case_id,
- 'loss_improvement': loss_improvement,
- 'checkpoint_saved': checkpoint_saved,
- 'session': self.negative_case_trainer.training_session_count
- }
- else:
- return {'status': 'no_cases'}
-
- except Exception as e:
- logger.error(f"Error processing negative cases: {e}")
- return {'status': 'error', 'error': str(e)}
-
- async def _coordinate_checkpoint_saving(self, dqn_results: Dict, cnn_results: Dict,
- extrema_results: Dict, negative_results: Dict):
- """Coordinate checkpoint saving across all components"""
- try:
- # Count successful checkpoints
- checkpoints_saved = sum([
- dqn_results.get('checkpoint_saved', False),
- cnn_results.get('checkpoint_saved', False),
- extrema_results.get('checkpoint_saved', False),
- negative_results.get('checkpoint_saved', False)
- ])
-
- if checkpoints_saved > 0:
- logger.info(f"Saved {checkpoints_saved} checkpoints this cycle")
-
- # Update best performances
- if 'episode_reward' in dqn_results:
- current_best = self.training_stats['best_performances'].get('dqn_reward', float('-inf'))
- if dqn_results['episode_reward'] > current_best:
- self.training_stats['best_performances']['dqn_reward'] = dqn_results['episode_reward']
-
- if 'val_accuracy' in cnn_results:
- current_best = self.training_stats['best_performances'].get('cnn_accuracy', 0.0)
- if cnn_results['val_accuracy'] > current_best:
- self.training_stats['best_performances']['cnn_accuracy'] = cnn_results['val_accuracy']
-
- # Log checkpoint statistics every 10 cycles
- if self.training_stats['total_training_sessions'] % 10 == 0:
- await self._log_checkpoint_statistics()
-
- except Exception as e:
- logger.error(f"Error coordinating checkpoint saving: {e}")
-
- async def _log_checkpoint_statistics(self):
- """Log comprehensive checkpoint statistics"""
- try:
- stats = get_checkpoint_stats()
-
- logger.info("=== Checkpoint Statistics ===")
- logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
- logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
- logger.info(f"Models managed: {len(stats['models'])}")
-
- for model_name, model_stats in stats['models'].items():
- logger.info(f" {model_name}: {model_stats['checkpoint_count']} checkpoints, "
- f"{model_stats['total_size_mb']:.2f} MB, "
- f"best: {model_stats['best_performance']:.4f}")
-
- logger.info(f"Training sessions: {self.training_stats['total_training_sessions']}")
- logger.info(f"Checkpoints saved: {self.training_stats['checkpoints_saved']}")
- logger.info(f"Best performances: {self.training_stats['best_performances']}")
-
- except Exception as e:
- logger.error(f"Error logging checkpoint statistics: {e}")
-
- async def shutdown(self):
- """Shutdown the training system and save final checkpoints"""
- logger.info("Shutting down checkpoint-integrated training system...")
-
- self.running = False
-
- try:
- # Force save checkpoints for all components
- if self.dqn_agent:
- self.dqn_agent.save_checkpoint(0.0, force_save=True)
-
- if self.cnn_trainer:
- self.cnn_trainer.save_checkpoint(0.0, 0.0, 0.0, 0.0, force_save=True)
-
- if self.extrema_trainer:
- self.extrema_trainer.save_checkpoint(force_save=True)
-
- if self.negative_case_trainer:
- self.negative_case_trainer.save_checkpoint(force_save=True)
-
- # Final statistics
- await self._log_checkpoint_statistics()
-
- logger.info("Checkpoint-integrated training system shutdown complete")
-
- except Exception as e:
- logger.error(f"Error during shutdown: {e}")
-
-async def main():
- """Main function to run the checkpoint-integrated training system"""
- logger.info("🚀 Starting Checkpoint-Integrated Training System")
-
- # Create and initialize the training system
- training_system = CheckpointIntegratedTrainingSystem()
-
- # Setup signal handlers for graceful shutdown
- def signal_handler(signum, frame):
- logger.info("Received shutdown signal")
- asyncio.create_task(training_system.shutdown())
-
- signal.signal(signal.SIGINT, signal_handler)
- signal.signal(signal.SIGTERM, signal_handler)
-
- try:
- # Initialize components
- await training_system.initialize_components()
-
- # Run the integrated training loop
- await training_system.run_integrated_training_loop()
-
- except Exception as e:
- logger.error(f"Error in main: {e}")
- raise
- finally:
- await training_system.shutdown()
-
- logger.info("✅ Checkpoint management integration complete!")
- logger.info("All training pipelines now support automatic checkpointing")
-
-if __name__ == "__main__":
- # Ensure logs directory exists
- Path("logs").mkdir(exist_ok=True)
-
- # Run the checkpoint-integrated training system
- asyncio.run(main())
\ No newline at end of file
diff --git a/NN/utils/__init__.py b/NN/utils/__init__.py
deleted file mode 100644
index cad5ee3..0000000
--- a/NN/utils/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-"""
-Neural Network Utilities
-======================
-
-This package contains utility functions and classes used in the neural network trading system:
-- Data Interface: Connects to realtime trading data and processes it for the neural network models
-"""
-
-from .data_interface import DataInterface
-from .trading_env import TradingEnvironment
-from .signal_interpreter import SignalInterpreter
-
-__all__ = ['DataInterface', 'TradingEnvironment', 'SignalInterpreter']
\ No newline at end of file
diff --git a/NN/utils/multi_data_interface.py b/NN/utils/multi_data_interface.py
deleted file mode 100644
index c747c6c..0000000
--- a/NN/utils/multi_data_interface.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""
-Enhanced Data Interface with additional NN trading parameters
-"""
-from typing import List, Optional, Tuple
-import numpy as np
-import pandas as pd
-from datetime import datetime
-from .data_interface import DataInterface
-
-class MultiDataInterface(DataInterface):
- """
- Enhanced data interface that supports window_size and output_size parameters
- for neural network trading models.
- """
-
- def __init__(self, symbol: str,
- timeframes: List[str],
- window_size: int = 20,
- output_size: int = 3,
- data_dir: str = "NN/data"):
- """
- Initialize with window_size and output_size for NN predictions.
- """
- super().__init__(symbol, timeframes, data_dir)
- self.window_size = window_size
- self.output_size = output_size
- self.scalers = {} # Store scalers for each timeframe
- self.min_window_threshold = 100 # Minimum candles needed for training
-
- def get_feature_count(self) -> int:
- """
- Get number of features (OHLCV) for NN input.
- """
- return 5 # open, high, low, close, volume
-
- def prepare_training_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
- """Prepare training data with windowed sequences"""
- # Get historical data for primary timeframe
- primary_tf = self.timeframes[0]
- df = self.get_historical_data(timeframe=primary_tf,
- n_candles=self.min_window_threshold + 1000)
-
- if df is None or len(df) < self.min_window_threshold:
- raise ValueError(f"Insufficient data for training. Need at least {self.min_window_threshold} candles")
-
- # Prepare OHLCV sequences
- ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
-
- # Create sequences and labels
- X = []
- y = []
-
- for i in range(len(ohlcv) - self.window_size - self.output_size):
- # Input sequence
- seq = ohlcv[i:i+self.window_size]
- X.append(seq)
-
- # Output target (price movement direction)
- close_prices = ohlcv[i+self.window_size:i+self.window_size+self.output_size, 3] # Close prices
- price_changes = np.diff(close_prices)
-
- if self.output_size == 1:
- # Binary classification (up/down)
- label = 1 if price_changes[0] > 0 else 0
- elif self.output_size == 3:
- # 3-class classification (buy/hold/sell)
- if price_changes[0] > 0.002: # Significant rise
- label = 0 # Buy
- elif price_changes[0] < -0.002: # Significant drop
- label = 2 # Sell
- else:
- label = 1 # Hold
- else:
- raise ValueError(f"Unsupported output_size: {self.output_size}")
-
- y.append(label)
-
- # Convert to numpy arrays
- X = np.array(X)
- y = np.array(y)
-
- # Split into train/validation (80/20)
- split_idx = int(0.8 * len(X))
- X_train, y_train = X[:split_idx], y[:split_idx]
- X_val, y_val = X[split_idx:], y[split_idx:]
-
- return X_train, y_train, X_val, y_val
-
- def prepare_prediction_data(self) -> np.ndarray:
- """Prepare most recent window for predictions"""
- primary_tf = self.timeframes[0]
- df = self.get_historical_data(timeframe=primary_tf,
- n_candles=self.window_size,
- use_cache=False)
-
- if df is None or len(df) < self.window_size:
- raise ValueError(f"Need at least {self.window_size} candles for prediction")
-
- ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values[-self.window_size:]
- return np.array([ohlcv]) # Add batch dimension
-
- def process_predictions(self, predictions: np.ndarray):
- """Convert prediction probabilities to trading signals"""
- signals = []
- for pred in predictions:
- if self.output_size == 1:
- signal = "BUY" if pred[0] > 0.5 else "SELL"
- confidence = np.abs(pred[0] - 0.5) * 2 # Convert to 0-1 scale
- elif self.output_size == 3:
- action_idx = np.argmax(pred)
- signal = ["BUY", "HOLD", "SELL"][action_idx]
- confidence = pred[action_idx]
- else:
- signal = "HOLD"
- confidence = 0.0
-
- signals.append({
- 'action': signal,
- 'confidence': confidence,
- 'timestamp': datetime.now().isoformat()
- })
-
- return signals
diff --git a/NN/utils/realtime_analyzer.py b/NN/utils/realtime_analyzer.py
deleted file mode 100644
index 5ed9c07..0000000
--- a/NN/utils/realtime_analyzer.py
+++ /dev/null
@@ -1,364 +0,0 @@
-"""
-Realtime Analyzer for Neural Network Trading System
-
-This module implements real-time analysis of market data using trained neural network models.
-"""
-
-import logging
-import time
-import numpy as np
-from threading import Thread
-from queue import Queue
-from datetime import datetime
-import asyncio
-import websockets
-import json
-import os
-import pandas as pd
-from collections import deque
-
-logger = logging.getLogger(__name__)
-
-class RealtimeAnalyzer:
- """
- Handles real-time analysis of market data using trained neural network models.
-
- Features:
- - Connects to real-time data sources (websockets)
- - Processes tick data into multiple timeframes (1s, 1m, 1h, 1d)
- - Uses trained models to analyze all timeframes
- - Generates trading signals
- - Manages risk and position sizing
- - Logs all trading decisions
- """
-
- def __init__(self, data_interface, model, symbol="BTC/USDT", timeframes=None):
- """
- Initialize the realtime analyzer.
-
- Args:
- data_interface (DataInterface): Preconfigured data interface
- model: Trained neural network model
- symbol (str): Trading pair symbol
- timeframes (list): List of timeframes to monitor (default: ['1s', '1m', '1h', '1d'])
- """
- self.data_interface = data_interface
- self.model = model
- self.symbol = symbol
- self.timeframes = timeframes or ['1s', '1m', '1h', '1d']
- self.running = False
- self.data_queue = Queue()
- self.prediction_interval = 10 # Seconds between predictions
- self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
- self.ws = None
- self.tick_storage = deque(maxlen=10000) # Store up to 10,000 ticks
- self.candle_cache = {
- '1s': deque(maxlen=5000),
- '1m': deque(maxlen=5000),
- '1h': deque(maxlen=5000),
- '1d': deque(maxlen=5000)
- }
-
- logger.info(f"RealtimeAnalyzer initialized for {symbol} with timeframes: {self.timeframes}")
-
- def start(self):
- """Start the realtime analysis process."""
- if self.running:
- logger.warning("Realtime analyzer already running")
- return
-
- self.running = True
-
- # Start WebSocket connection thread
- self.ws_thread = Thread(target=self._run_websocket, daemon=True)
- self.ws_thread.start()
-
- # Start data processing thread
- self.processing_thread = Thread(target=self._process_data, daemon=True)
- self.processing_thread.start()
-
- # Start analysis thread
- self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
- self.analysis_thread.start()
-
- logger.info("Realtime analysis started")
-
- def stop(self):
- """Stop the realtime analysis process."""
- self.running = False
- if self.ws:
- asyncio.run(self.ws.close())
- if hasattr(self, 'ws_thread'):
- self.ws_thread.join(timeout=1)
- if hasattr(self, 'processing_thread'):
- self.processing_thread.join(timeout=1)
- if hasattr(self, 'analysis_thread'):
- self.analysis_thread.join(timeout=1)
- logger.info("Realtime analysis stopped")
-
- def _run_websocket(self):
- """Thread function for running WebSocket connection."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- loop.run_until_complete(self._connect_websocket())
-
- async def _connect_websocket(self):
- """Connect to WebSocket and receive data."""
- while self.running:
- try:
- logger.info(f"Connecting to WebSocket: {self.ws_url}")
- async with websockets.connect(self.ws_url) as ws:
- self.ws = ws
- logger.info("WebSocket connected")
-
- while self.running:
- try:
- message = await ws.recv()
- data = json.loads(message)
-
- if 'e' in data and data['e'] == 'trade':
- tick = {
- 'timestamp': data['T'],
- 'price': float(data['p']),
- 'volume': float(data['q']),
- 'symbol': self.symbol
- }
- self.tick_storage.append(tick)
- self.data_queue.put(tick)
-
- except websockets.exceptions.ConnectionClosed:
- logger.warning("WebSocket connection closed")
- break
- except Exception as e:
- logger.error(f"Error receiving WebSocket message: {str(e)}")
- time.sleep(1)
-
- except Exception as e:
- logger.error(f"WebSocket connection error: {str(e)}")
- time.sleep(5) # Wait before reconnecting
-
- def _process_data(self):
- """Process incoming tick data into candles for all timeframes."""
- logger.info("Starting data processing thread")
-
- while self.running:
- try:
- # Process any new ticks
- while not self.data_queue.empty():
- tick = self.data_queue.get()
-
- # Convert timestamp to datetime
- timestamp = datetime.fromtimestamp(tick['timestamp'] / 1000)
-
- # Process for each timeframe
- for timeframe in self.timeframes:
- interval = self._get_interval_seconds(timeframe)
- if interval is None:
- continue
-
- # Round timestamp to nearest candle interval
- candle_ts = int(tick['timestamp'] // (interval * 1000)) * (interval * 1000)
-
- # Get or create candle for this timeframe
- if not self.candle_cache[timeframe]:
- # First candle for this timeframe
- candle = {
- 'timestamp': candle_ts,
- 'open': tick['price'],
- 'high': tick['price'],
- 'low': tick['price'],
- 'close': tick['price'],
- 'volume': tick['volume']
- }
- self.candle_cache[timeframe].append(candle)
- else:
- # Update existing candle
- last_candle = self.candle_cache[timeframe][-1]
-
- if last_candle['timestamp'] == candle_ts:
- # Update current candle
- last_candle['high'] = max(last_candle['high'], tick['price'])
- last_candle['low'] = min(last_candle['low'], tick['price'])
- last_candle['close'] = tick['price']
- last_candle['volume'] += tick['volume']
- else:
- # New candle
- candle = {
- 'timestamp': candle_ts,
- 'open': tick['price'],
- 'high': tick['price'],
- 'low': tick['price'],
- 'close': tick['price'],
- 'volume': tick['volume']
- }
- self.candle_cache[timeframe].append(candle)
-
- time.sleep(0.1)
-
- except Exception as e:
- logger.error(f"Error in data processing: {str(e)}")
- time.sleep(1)
-
- def _get_interval_seconds(self, timeframe):
- """Convert timeframe string to seconds."""
- intervals = {
- '1s': 1,
- '1m': 60,
- '1h': 3600,
- '1d': 86400
- }
- return intervals.get(timeframe)
-
- def _analyze_data(self):
- """Thread function for analyzing data and generating signals."""
- logger.info("Starting analysis thread")
-
- last_prediction_time = 0
-
- while self.running:
- try:
- current_time = time.time()
-
- # Only make predictions at the specified interval
- if current_time - last_prediction_time < self.prediction_interval:
- time.sleep(0.1)
- continue
-
- # Prepare input data from all timeframes
- input_data = {}
- valid = True
-
- for timeframe in self.timeframes:
- if not self.candle_cache[timeframe]:
- logger.warning(f"No data available for timeframe {timeframe}")
- valid = False
- break
-
- # Get last N candles for this timeframe
- candles = list(self.candle_cache[timeframe])[-self.data_interface.window_size:]
-
- # Convert to numpy array
- ohlcv = np.array([
- [c['open'], c['high'], c['low'], c['close'], c['volume']]
- for c in candles
- ])
-
- # Normalize data
- ohlcv_normalized = (ohlcv - ohlcv.mean(axis=0)) / (ohlcv.std(axis=0) + 1e-8)
- input_data[timeframe] = ohlcv_normalized
-
- if not valid:
- time.sleep(0.1)
- continue
-
- # Make prediction using the model
- try:
- prediction = self.model.predict(input_data)
-
- # Get latest timestamp from 1s timeframe
- latest_ts = self.candle_cache['1s'][-1]['timestamp'] if self.candle_cache['1s'] else int(time.time() * 1000)
-
- # Process prediction
- self._process_prediction(
- prediction=prediction,
- timeframe='multi',
- timestamp=latest_ts
- )
-
- last_prediction_time = current_time
- except Exception as e:
- logger.error(f"Error making prediction: {str(e)}")
-
- time.sleep(0.1)
-
- except Exception as e:
- logger.error(f"Error in analysis: {str(e)}")
- time.sleep(1)
-
- def _process_prediction(self, prediction, timeframe, timestamp):
- """
- Process model prediction and generate trading signals.
-
- Args:
- prediction: Model prediction output
- timeframe (str): Timeframe the prediction is for ('multi' for combined)
- timestamp: Timestamp of the prediction (ms)
- """
- # Convert prediction to trading signal
- signal, confidence = self._prediction_to_signal(prediction)
-
- # Convert timestamp to datetime
- try:
- dt = datetime.fromtimestamp(timestamp / 1000)
- except:
- dt = datetime.now()
-
- # Log the signal with all timeframes
- logger.info(
- f"Signal generated - Timeframes: {', '.join(self.timeframes)}, "
- f"Timestamp: {dt}, "
- f"Signal: {signal} (Confidence: {confidence:.2f})"
- )
-
- # In a real implementation, we would execute trades here
- # For now, we'll just log the signals
-
- def _prediction_to_signal(self, prediction):
- """
- Convert model prediction to trading signal and confidence.
-
- Args:
- prediction: Model prediction output (can be dict for multi-timeframe)
-
- Returns:
- tuple: (signal, confidence) where signal is BUY/SELL/HOLD,
- confidence is probability (0-1)
- """
- if isinstance(prediction, dict):
- # Multi-timeframe prediction - combine signals
- signals = []
- confidences = []
-
- for tf, pred in prediction.items():
- if len(pred.shape) == 1:
- # Binary classification
- signal = "BUY" if pred[0] > 0.5 else "SELL"
- confidence = pred[0] if signal == "BUY" else 1 - pred[0]
- else:
- # Multi-class
- class_idx = np.argmax(pred)
- signal = ["SELL", "HOLD", "BUY"][class_idx]
- confidence = pred[class_idx]
-
- signals.append(signal)
- confidences.append(confidence)
-
- # Simple voting system - count BUY/SELL signals
- buy_count = signals.count("BUY")
- sell_count = signals.count("SELL")
-
- if buy_count > sell_count:
- final_signal = "BUY"
- final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "BUY"])
- elif sell_count > buy_count:
- final_signal = "SELL"
- final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "SELL"])
- else:
- final_signal = "HOLD"
- final_confidence = np.mean(confidences)
-
- return final_signal, final_confidence
-
- else:
- # Single prediction
- if len(prediction.shape) == 1:
- # Binary classification
- signal = "BUY" if prediction[0] > 0.5 else "SELL"
- confidence = prediction[0] if signal == "BUY" else 1 - prediction[0]
- else:
- # Multi-class
- class_idx = np.argmax(prediction)
- signal = ["SELL", "HOLD", "BUY"][class_idx]
- confidence = prediction[class_idx]
-
- return signal, confidence
diff --git a/_dev/cleanup_models_now.py b/_dev/cleanup_models_now.py
deleted file mode 100644
index 2fc94c0..0000000
--- a/_dev/cleanup_models_now.py
+++ /dev/null
@@ -1,98 +0,0 @@
-#!/usr/bin/env python3
-"""
-Immediate Model Cleanup Script
-
-This script will clean up all existing model files and prepare the system
-for fresh training with the new model management system.
-"""
-
-import logging
-import sys
-from model_manager import ModelManager
-
-def main():
- """Run the model cleanup"""
-
- # Configure logging for better output
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
- )
-
- print("=" * 60)
- print("GOGO2 MODEL CLEANUP SYSTEM")
- print("=" * 60)
- print()
- print("This script will:")
- print("1. Delete ALL existing model files (.pt, .pth)")
- print("2. Remove ALL checkpoint directories")
- print("3. Clear model backup directories")
- print("4. Reset the model registry")
- print("5. Create clean directory structure")
- print()
- print("WARNING: This action cannot be undone!")
- print()
-
- # Calculate current space usage first
- try:
- manager = ModelManager()
- storage_stats = manager.get_storage_stats()
- print(f"Current storage usage:")
- print(f"- Models: {storage_stats['total_models']}")
- print(f"- Size: {storage_stats['actual_size_mb']:.1f}MB")
- print()
- except Exception as e:
- print(f"Error checking current storage: {e}")
- print()
-
- # Ask for confirmation
- print("Type 'CLEANUP' to proceed with the cleanup:")
- user_input = input("> ").strip()
-
- if user_input != "CLEANUP":
- print("Cleanup cancelled. No changes made.")
- return
-
- print()
- print("Starting cleanup...")
- print("-" * 40)
-
- try:
- # Create manager and run cleanup
- manager = ModelManager()
- cleanup_result = manager.cleanup_all_existing_models(confirm=True)
-
- print()
- print("=" * 60)
- print("CLEANUP COMPLETE")
- print("=" * 60)
- print(f"Files deleted: {cleanup_result['deleted_files']}")
- print(f"Space freed: {cleanup_result['freed_space_mb']:.1f} MB")
- print(f"Directories cleaned: {len(cleanup_result['deleted_directories'])}")
-
- if cleanup_result['errors']:
- print(f"Errors encountered: {len(cleanup_result['errors'])}")
- print("Errors:")
- for error in cleanup_result['errors'][:5]: # Show first 5 errors
- print(f" - {error}")
- if len(cleanup_result['errors']) > 5:
- print(f" ... and {len(cleanup_result['errors']) - 5} more")
-
- print()
- print("System is now ready for fresh model training!")
- print("The following directories have been created:")
- print("- models/best_models/")
- print("- models/cnn/")
- print("- models/rl/")
- print("- models/checkpoints/")
- print("- NN/models/saved/")
- print()
- print("New models will be automatically managed by the ModelManager.")
-
- except Exception as e:
- print(f"Error during cleanup: {e}")
- logging.exception("Cleanup failed")
- sys.exit(1)
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/_tools/delete_candidates.py b/_tools/delete_candidates.py
new file mode 100644
index 0000000..8677b7c
--- /dev/null
+++ b/_tools/delete_candidates.py
@@ -0,0 +1,74 @@
+import os
+import shutil
+from pathlib import Path
+
+ROOT = Path(__file__).resolve().parents[1]
+
+EXCLUDE_PREFIXES = (
+ 'COBY' + os.sep, # Do not touch COBY subsystem
+)
+EXCLUDE_FILES = {
+ 'NN' + os.sep + 'training' + os.sep + 'enhanced_realtime_training.py',
+}
+
+delete_list_path = ROOT / 'DELETE_CANDIDATES.txt'
+
+deleted_files: list[str] = []
+kept_files: list[str] = []
+
+if delete_list_path.exists():
+ for line in delete_list_path.read_text(encoding='utf-8').splitlines():
+ rel = line.strip()
+ if not rel:
+ continue
+ # Skip excluded prefixes
+ if any(rel.startswith(p) for p in EXCLUDE_PREFIXES):
+ kept_files.append(rel)
+ continue
+ # Skip explicitly excluded files
+ if rel in EXCLUDE_FILES:
+ kept_files.append(rel)
+ continue
+ fp = ROOT / rel
+ if fp.exists() and fp.is_file():
+ try:
+ fp.unlink()
+ deleted_files.append(rel)
+ except Exception:
+ kept_files.append(rel)
+
+# Remove tests directories outside COBY
+removed_dirs: list[str] = []
+for d in ROOT.rglob('tests'):
+ try:
+ rel = str(d.relative_to(ROOT))
+ except Exception:
+ continue
+ if any(rel.startswith(p) for p in EXCLUDE_PREFIXES):
+ continue
+ if d.is_dir():
+ try:
+ shutil.rmtree(d)
+ removed_dirs.append(rel)
+ except Exception:
+ pass
+
+# Write cleanup log / todo
+log_lines = []
+log_lines.append('Cleanup run summary:')
+log_lines.append(f'- Deleted files: {len(deleted_files)}')
+for x in deleted_files[:50]:
+ log_lines.append(f' - {x}')
+if len(deleted_files) > 50:
+ log_lines.append(f' ... and {len(deleted_files)-50} more')
+log_lines.append(f'- Removed test directories: {len(removed_dirs)}')
+for x in removed_dirs[:50]:
+ log_lines.append(f' - {x}')
+log_lines.append(f'- Kept (excluded): {len(kept_files)}')
+
+(ROOT / 'CLEANUP_TODO.md').write_text('\n'.join(log_lines), encoding='utf-8')
+
+print(f'Deleted files: {len(deleted_files)}')
+print(f'Removed test dirs: {len(removed_dirs)}')
+print(f'Kept (excluded): {len(kept_files)}')
+
diff --git a/apply_trading_fixes.py b/apply_trading_fixes.py
deleted file mode 100644
index 05b2dcf..0000000
--- a/apply_trading_fixes.py
+++ /dev/null
@@ -1,193 +0,0 @@
-#!/usr/bin/env python3
-"""
-Apply Trading System Fixes
-
-This script applies fixes to the trading system to address:
-1. Duplicate entry prices
-2. P&L calculation issues
-3. Position tracking problems
-4. Trade display issues
-
-Usage:
- python apply_trading_fixes.py
-"""
-
-import os
-import sys
-import logging
-from pathlib import Path
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.StreamHandler(),
- logging.FileHandler('logs/trading_fixes.log')
- ]
-)
-
-logger = logging.getLogger(__name__)
-
-def apply_fixes():
- """Apply all fixes to the trading system"""
- logger.info("=" * 70)
- logger.info("APPLYING TRADING SYSTEM FIXES")
- logger.info("=" * 70)
-
- # Import fixes
- try:
- from core.trading_executor_fix import TradingExecutorFix
- from web.dashboard_fix import DashboardFix
-
- logger.info("Fix modules imported successfully")
- except ImportError as e:
- logger.error(f"Error importing fix modules: {e}")
- return False
-
- # Apply fixes to trading executor
- try:
- # Import trading executor
- from core.trading_executor import TradingExecutor
-
- # Create a test instance to apply fixes
- test_executor = TradingExecutor()
-
- # Apply fixes
- TradingExecutorFix.apply_fixes(test_executor)
-
- logger.info("Trading executor fixes applied successfully to test instance")
-
- # Verify fixes
- if hasattr(test_executor, 'price_cache_timestamp'):
- logger.info("✅ Price caching fix verified")
- else:
- logger.warning("❌ Price caching fix not verified")
-
- if hasattr(test_executor, 'trade_cooldown_seconds'):
- logger.info("✅ Trade cooldown fix verified")
- else:
- logger.warning("❌ Trade cooldown fix not verified")
-
- if hasattr(test_executor, '_check_trade_cooldown'):
- logger.info("✅ Trade cooldown check method verified")
- else:
- logger.warning("❌ Trade cooldown check method not verified")
-
- except Exception as e:
- logger.error(f"Error applying trading executor fixes: {e}")
- import traceback
- logger.error(traceback.format_exc())
-
- # Create patch for main.py
- try:
- main_patch = """
-# Apply trading system fixes
-try:
- from core.trading_executor_fix import TradingExecutorFix
- from web.dashboard_fix import DashboardFix
-
- # Apply fixes to trading executor
- if trading_executor:
- TradingExecutorFix.apply_fixes(trading_executor)
- logger.info("✅ Trading executor fixes applied")
-
- # Apply fixes to dashboard
- if 'dashboard' in locals() and dashboard:
- DashboardFix.apply_fixes(dashboard)
- logger.info("✅ Dashboard fixes applied")
-
- logger.info("Trading system fixes applied successfully")
-except Exception as e:
- logger.warning(f"Error applying trading system fixes: {e}")
-"""
-
- # Write patch instructions
- with open('patch_instructions.txt', 'w') as f:
- f.write("""
-TRADING SYSTEM FIX INSTRUCTIONS
-==============================
-
-To apply the fixes to your trading system, follow these steps:
-
-1. Add the following code to main.py just before the dashboard.run_server() call:
-
-```python
-# Apply trading system fixes
-try:
- from core.trading_executor_fix import TradingExecutorFix
- from web.dashboard_fix import DashboardFix
-
- # Apply fixes to trading executor
- if trading_executor:
- TradingExecutorFix.apply_fixes(trading_executor)
- logger.info("✅ Trading executor fixes applied")
-
- # Apply fixes to dashboard
- if 'dashboard' in locals() and dashboard:
- DashboardFix.apply_fixes(dashboard)
- logger.info("✅ Dashboard fixes applied")
-
- logger.info("Trading system fixes applied successfully")
-except Exception as e:
- logger.warning(f"Error applying trading system fixes: {e}")
-```
-
-2. Add the following code to web/clean_dashboard.py in the __init__ method, just before the run_server method:
-
-```python
-# Apply dashboard fixes if available
-try:
- from web.dashboard_fix import DashboardFix
- DashboardFix.apply_fixes(self)
- logger.info("✅ Dashboard fixes applied during initialization")
-except ImportError:
- logger.warning("Dashboard fixes not available")
-```
-
-3. Run the system with the fixes applied:
-
-```
-python main.py
-```
-
-4. Monitor the logs for any issues with the fixes.
-
-These fixes address:
-- Duplicate entry prices
-- P&L calculation issues
-- Position tracking problems
-- Trade display issues
-- Rapid consecutive trades
-""")
-
- logger.info("Patch instructions written to patch_instructions.txt")
-
- except Exception as e:
- logger.error(f"Error creating patch: {e}")
-
- logger.info("=" * 70)
- logger.info("TRADING SYSTEM FIXES READY TO APPLY")
- logger.info("See patch_instructions.txt for instructions")
- logger.info("=" * 70)
-
- return True
-
-if __name__ == "__main__":
- # Create logs directory if it doesn't exist
- os.makedirs('logs', exist_ok=True)
-
- # Apply fixes
- success = apply_fixes()
-
- if success:
- print("\nTrading system fixes ready to apply!")
- print("See patch_instructions.txt for instructions")
- sys.exit(0)
- else:
- print("\nError preparing trading system fixes")
- sys.exit(1)
\ No newline at end of file
diff --git a/apply_trading_fixes_to_main.py b/apply_trading_fixes_to_main.py
deleted file mode 100644
index 7ffd89e..0000000
--- a/apply_trading_fixes_to_main.py
+++ /dev/null
@@ -1,218 +0,0 @@
-#!/usr/bin/env python3
-"""
-Apply Trading System Fixes to Main.py
-
-This script applies the trading system fixes directly to main.py
-to address the issues with duplicate entry prices and P&L calculation.
-
-Usage:
- python apply_trading_fixes_to_main.py
-"""
-
-import os
-import sys
-import logging
-import re
-from pathlib import Path
-import shutil
-from datetime import datetime
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.StreamHandler(),
- logging.FileHandler('logs/apply_fixes.log')
- ]
-)
-
-logger = logging.getLogger(__name__)
-
-def backup_file(file_path):
- """Create a backup of a file"""
- try:
- backup_path = f"{file_path}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
- shutil.copy2(file_path, backup_path)
- logger.info(f"Created backup: {backup_path}")
- return True
- except Exception as e:
- logger.error(f"Error creating backup of {file_path}: {e}")
- return False
-
-def apply_fixes_to_main():
- """Apply fixes to main.py"""
- main_py_path = "main.py"
-
- if not os.path.exists(main_py_path):
- logger.error(f"File {main_py_path} not found")
- return False
-
- # Create backup
- if not backup_file(main_py_path):
- logger.error("Failed to create backup, aborting")
- return False
-
- try:
- # Read main.py
- with open(main_py_path, 'r') as f:
- content = f.read()
-
- # Find the position to insert the fixes
- # Look for the line before dashboard.run_server()
- run_server_pattern = r"dashboard\.run_server\("
- match = re.search(run_server_pattern, content)
-
- if not match:
- logger.error("Could not find dashboard.run_server() call in main.py")
- return False
-
- # Find the position to insert the fixes (before the run_server call)
- insert_pos = content.rfind("\n", 0, match.start())
-
- if insert_pos == -1:
- logger.error("Could not find insertion point in main.py")
- return False
-
- # Prepare the fixes to insert
- fixes_code = """
-# Apply trading system fixes
-try:
- from core.trading_executor_fix import TradingExecutorFix
- from web.dashboard_fix import DashboardFix
-
- # Apply fixes to trading executor
- if trading_executor:
- TradingExecutorFix.apply_fixes(trading_executor)
- logger.info("✅ Trading executor fixes applied")
-
- # Apply fixes to dashboard
- if 'dashboard' in locals() and dashboard:
- DashboardFix.apply_fixes(dashboard)
- logger.info("✅ Dashboard fixes applied")
-
- logger.info("Trading system fixes applied successfully")
-except Exception as e:
- logger.warning(f"Error applying trading system fixes: {e}")
-
-"""
-
- # Insert the fixes
- new_content = content[:insert_pos] + fixes_code + content[insert_pos:]
-
- # Write the modified content back to main.py
- with open(main_py_path, 'w') as f:
- f.write(new_content)
-
- logger.info(f"Successfully applied fixes to {main_py_path}")
- return True
-
- except Exception as e:
- logger.error(f"Error applying fixes to {main_py_path}: {e}")
- return False
-
-def apply_fixes_to_dashboard():
- """Apply fixes to web/clean_dashboard.py"""
- dashboard_py_path = "web/clean_dashboard.py"
-
- if not os.path.exists(dashboard_py_path):
- logger.error(f"File {dashboard_py_path} not found")
- return False
-
- # Create backup
- if not backup_file(dashboard_py_path):
- logger.error("Failed to create backup, aborting")
- return False
-
- try:
- # Read dashboard.py
- with open(dashboard_py_path, 'r') as f:
- content = f.read()
-
- # Find the position to insert the fixes
- # Look for the __init__ method
- init_pattern = r"def __init__\(self,"
- match = re.search(init_pattern, content)
-
- if not match:
- logger.error("Could not find __init__ method in dashboard.py")
- return False
-
- # Find the end of the __init__ method
- init_end_pattern = r"logger\.debug\(.*\)"
- init_end_matches = list(re.finditer(init_end_pattern, content[match.end():]))
-
- if not init_end_matches:
- logger.error("Could not find end of __init__ method in dashboard.py")
- return False
-
- # Get the last logger.debug line in the __init__ method
- last_debug_match = init_end_matches[-1]
- insert_pos = match.end() + last_debug_match.end()
-
- # Prepare the fixes to insert
- fixes_code = """
-
- # Apply dashboard fixes if available
- try:
- from web.dashboard_fix import DashboardFix
- DashboardFix.apply_fixes(self)
- logger.info("✅ Dashboard fixes applied during initialization")
- except ImportError:
- logger.warning("Dashboard fixes not available")
-"""
-
- # Insert the fixes
- new_content = content[:insert_pos] + fixes_code + content[insert_pos:]
-
- # Write the modified content back to dashboard.py
- with open(dashboard_py_path, 'w') as f:
- f.write(new_content)
-
- logger.info(f"Successfully applied fixes to {dashboard_py_path}")
- return True
-
- except Exception as e:
- logger.error(f"Error applying fixes to {dashboard_py_path}: {e}")
- return False
-
-def main():
- """Main entry point"""
- logger.info("=" * 70)
- logger.info("APPLYING TRADING SYSTEM FIXES TO MAIN.PY")
- logger.info("=" * 70)
-
- # Create logs directory if it doesn't exist
- os.makedirs('logs', exist_ok=True)
-
- # Apply fixes to main.py
- main_success = apply_fixes_to_main()
-
- # Apply fixes to dashboard.py
- dashboard_success = apply_fixes_to_dashboard()
-
- if main_success and dashboard_success:
- logger.info("=" * 70)
- logger.info("TRADING SYSTEM FIXES APPLIED SUCCESSFULLY")
- logger.info("=" * 70)
- logger.info("The following issues have been fixed:")
- logger.info("1. Duplicate entry prices")
- logger.info("2. P&L calculation issues")
- logger.info("3. Position tracking problems")
- logger.info("4. Trade display issues")
- logger.info("5. Rapid consecutive trades")
- logger.info("=" * 70)
- logger.info("You can now run the trading system with the fixes applied:")
- logger.info("python main.py")
- logger.info("=" * 70)
- return 0
- else:
- logger.error("=" * 70)
- logger.error("FAILED TO APPLY SOME FIXES")
- logger.error("=" * 70)
- logger.error("Please check the logs for details")
- logger.error("=" * 70)
- return 1
-
-if __name__ == "__main__":
- sys.exit(main())
\ No newline at end of file
diff --git a/audit_training_system.py b/audit_training_system.py
deleted file mode 100644
index e69de29..0000000
diff --git a/balance_trading_signals.py b/balance_trading_signals.py
deleted file mode 100644
index 092dfa5..0000000
--- a/balance_trading_signals.py
+++ /dev/null
@@ -1,189 +0,0 @@
-#!/usr/bin/env python3
-"""
-Balance Trading Signals - Analyze and fix SHORT signal bias
-
-This script analyzes the trading signals from the orchestrator and adjusts
-the model weights to balance BUY and SELL signals.
-"""
-
-import os
-import sys
-import logging
-import json
-from pathlib import Path
-from datetime import datetime
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import get_config, setup_logging
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-
-# Setup logging
-setup_logging()
-logger = logging.getLogger(__name__)
-
-def analyze_trading_signals():
- """Analyze trading signals from the orchestrator"""
- logger.info("Analyzing trading signals...")
-
- # Initialize components
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
-
- # Get recent decisions
- symbols = orchestrator.symbols
- all_decisions = {}
-
- for symbol in symbols:
- decisions = orchestrator.get_recent_decisions(symbol)
- all_decisions[symbol] = decisions
-
- # Count actions
- action_counts = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
- for decision in decisions:
- action_counts[decision.action] += 1
-
- total_decisions = sum(action_counts.values())
- if total_decisions > 0:
- buy_percent = action_counts['BUY'] / total_decisions * 100
- sell_percent = action_counts['SELL'] / total_decisions * 100
- hold_percent = action_counts['HOLD'] / total_decisions * 100
-
- logger.info(f"Symbol: {symbol}")
- logger.info(f" Total decisions: {total_decisions}")
- logger.info(f" BUY: {action_counts['BUY']} ({buy_percent:.1f}%)")
- logger.info(f" SELL: {action_counts['SELL']} ({sell_percent:.1f}%)")
- logger.info(f" HOLD: {action_counts['HOLD']} ({hold_percent:.1f}%)")
-
- # Check for bias
- if sell_percent > buy_percent * 2: # If SELL signals are more than twice BUY signals
- logger.warning(f" SELL bias detected: {sell_percent:.1f}% vs {buy_percent:.1f}%")
-
- # Adjust model weights to balance signals
- logger.info(" Adjusting model weights to balance signals...")
-
- # Get current model weights
- model_weights = orchestrator.model_weights
- logger.info(f" Current model weights: {model_weights}")
-
- # Identify models with SELL bias
- model_predictions = {}
- for model_name in model_weights:
- model_predictions[model_name] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
-
- # Analyze recent decisions to identify biased models
- for decision in decisions:
- reasoning = decision.reasoning
- if 'models_used' in reasoning:
- for model_name in reasoning['models_used']:
- if model_name in model_predictions:
- model_predictions[model_name][decision.action] += 1
-
- # Calculate bias for each model
- model_bias = {}
- for model_name, actions in model_predictions.items():
- total = sum(actions.values())
- if total > 0:
- buy_pct = actions['BUY'] / total * 100
- sell_pct = actions['SELL'] / total * 100
-
- # Calculate bias score (-100 to 100, negative = SELL bias, positive = BUY bias)
- bias_score = buy_pct - sell_pct
- model_bias[model_name] = bias_score
-
- logger.info(f" Model {model_name}: Bias score = {bias_score:.1f} (BUY: {buy_pct:.1f}%, SELL: {sell_pct:.1f}%)")
-
- # Adjust weights based on bias
- adjusted_weights = {}
- for model_name, weight in model_weights.items():
- if model_name in model_bias:
- bias = model_bias[model_name]
-
- # If model has strong SELL bias, reduce its weight
- if bias < -30: # Strong SELL bias
- adjusted_weights[model_name] = max(0.05, weight * 0.7) # Reduce weight by 30%
- logger.info(f" Reducing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} due to SELL bias")
- # If model has BUY bias, increase its weight to balance
- elif bias > 10: # BUY bias
- adjusted_weights[model_name] = min(0.5, weight * 1.3) # Increase weight by 30%
- logger.info(f" Increasing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} to balance SELL bias")
- else:
- adjusted_weights[model_name] = weight
- else:
- adjusted_weights[model_name] = weight
-
- # Save adjusted weights
- save_adjusted_weights(adjusted_weights)
-
- logger.info(f" Adjusted weights: {adjusted_weights}")
- logger.info(" Weights saved to 'adjusted_model_weights.json'")
-
- # Recommend next steps
- logger.info("\nRecommended actions:")
- logger.info("1. Update the model weights in the orchestrator")
- logger.info("2. Monitor trading signals for balance")
- logger.info("3. Consider retraining models with balanced data")
-
-def save_adjusted_weights(weights):
- """Save adjusted weights to a file"""
- output = {
- 'timestamp': datetime.now().isoformat(),
- 'weights': weights,
- 'notes': 'Adjusted to balance BUY/SELL signals'
- }
-
- with open('adjusted_model_weights.json', 'w') as f:
- json.dump(output, f, indent=2)
-
-def apply_balanced_weights():
- """Apply balanced weights to the orchestrator"""
- try:
- # Check if weights file exists
- if not os.path.exists('adjusted_model_weights.json'):
- logger.error("Adjusted weights file not found. Run analyze_trading_signals() first.")
- return False
-
- # Load adjusted weights
- with open('adjusted_model_weights.json', 'r') as f:
- data = json.load(f)
-
- weights = data.get('weights', {})
- if not weights:
- logger.error("No weights found in the file.")
- return False
-
- logger.info(f"Loaded adjusted weights: {weights}")
-
- # Initialize components
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
-
- # Apply weights
- for model_name, weight in weights.items():
- if model_name in orchestrator.model_weights:
- orchestrator.model_weights[model_name] = weight
-
- # Save updated weights
- orchestrator._save_orchestrator_state()
-
- logger.info("Applied balanced weights to orchestrator.")
- logger.info("Restart the trading system for changes to take effect.")
-
- return True
-
- except Exception as e:
- logger.error(f"Error applying balanced weights: {e}")
- return False
-
-if __name__ == "__main__":
- logger.info("=" * 70)
- logger.info("TRADING SIGNAL BALANCE ANALYZER")
- logger.info("=" * 70)
-
- if len(sys.argv) > 1 and sys.argv[1] == 'apply':
- apply_balanced_weights()
- else:
- analyze_trading_signals()
\ No newline at end of file
diff --git a/check_live_trading.py b/check_live_trading.py
deleted file mode 100644
index dc17e9f..0000000
--- a/check_live_trading.py
+++ /dev/null
@@ -1,163 +0,0 @@
-import os
-import sys
-import logging
-import importlib
-import asyncio
-from dotenv import load_dotenv
-from safe_logging import setup_safe_logging
-
-# Configure logging
-setup_safe_logging()
-logger = logging.getLogger("check_live_trading")
-
-def check_dependencies():
- """Check if all required dependencies are installed"""
- required_packages = [
- "numpy", "pandas", "matplotlib", "mplfinance", "torch",
- "dotenv", "ccxt", "websockets", "tensorboard",
- "sklearn", "PIL", "asyncio"
- ]
-
- missing_packages = []
-
- for package in required_packages:
- try:
- if package == "dotenv":
- importlib.import_module("dotenv")
- elif package == "PIL":
- importlib.import_module("PIL")
- else:
- importlib.import_module(package)
- logger.info(f"✅ {package} is installed")
- except ImportError:
- missing_packages.append(package)
- logger.error(f"❌ {package} is NOT installed")
-
- if missing_packages:
- logger.error(f"Missing packages: {', '.join(missing_packages)}")
- logger.info("Install missing packages with: pip install -r requirements.txt")
- return False
-
- return True
-
-def check_api_keys():
- """Check if API keys are configured"""
- load_dotenv()
-
- api_key = os.getenv('MEXC_API_KEY')
- secret_key = os.getenv('MEXC_SECRET_KEY')
-
- if not api_key or api_key == "your_api_key_here" or not secret_key or secret_key == "your_secret_key_here":
- logger.error("❌ API keys are not properly configured in .env file")
- logger.info("Please update your .env file with valid MEXC API keys")
- return False
-
- logger.info("✅ API keys are configured")
- return True
-
-def check_model_files():
- """Check if trained model files exist"""
- model_files = [
- "models/trading_agent_best_pnl.pt",
- "models/trading_agent_best_reward.pt",
- "models/trading_agent_final.pt"
- ]
-
- missing_models = []
-
- for model_file in model_files:
- if os.path.exists(model_file):
- logger.info(f"✅ Model file exists: {model_file}")
- else:
- missing_models.append(model_file)
- logger.error(f"❌ Model file missing: {model_file}")
-
- if missing_models:
- logger.warning("Some model files are missing. You need to train the model first.")
- return False
-
- return True
-
-async def check_exchange_connection():
- """Test connection to MEXC exchange"""
- try:
- import ccxt
-
- # Load API keys
- load_dotenv()
- api_key = os.getenv('MEXC_API_KEY')
- secret_key = os.getenv('MEXC_SECRET_KEY')
-
- if api_key == "your_api_key_here" or secret_key == "your_secret_key_here":
- logger.warning("⚠️ Using placeholder API keys, skipping exchange connection test")
- return False
-
- # Initialize exchange
- exchange = ccxt.mexc({
- 'apiKey': api_key,
- 'secret': secret_key,
- 'enableRateLimit': True
- })
-
- # Test connection by fetching markets
- markets = exchange.fetch_markets()
- logger.info(f"✅ Successfully connected to MEXC exchange")
- logger.info(f"✅ Found {len(markets)} markets")
-
- return True
- except Exception as e:
- logger.error(f"❌ Failed to connect to MEXC exchange: {str(e)}")
- return False
-
-def check_directories():
- """Check if required directories exist"""
- required_dirs = ["models", "runs", "trade_logs"]
-
- for directory in required_dirs:
- if not os.path.exists(directory):
- logger.info(f"Creating directory: {directory}")
- os.makedirs(directory, exist_ok=True)
-
- logger.info("✅ All required directories exist")
- return True
-
-async def main():
- """Run all checks"""
- logger.info("Running pre-flight checks for live trading...")
-
- checks = [
- ("Dependencies", check_dependencies()),
- ("API Keys", check_api_keys()),
- ("Model Files", check_model_files()),
- ("Directories", check_directories()),
- ("Exchange Connection", await check_exchange_connection())
- ]
-
- # Count failed checks
- failed_checks = sum(1 for _, result in checks if not result)
-
- # Print summary
- logger.info("\n" + "="*50)
- logger.info("LIVE TRADING PRE-FLIGHT CHECK SUMMARY")
- logger.info("="*50)
-
- for check_name, result in checks:
- status = "✅ PASS" if result else "❌ FAIL"
- logger.info(f"{check_name}: {status}")
-
- logger.info("="*50)
-
- if failed_checks == 0:
- logger.info("🚀 All checks passed! You're ready for live trading.")
- logger.info("\nRun live trading with:")
- logger.info("python main.py --mode live --demo true --symbol ETH/USDT --timeframe 1m")
- logger.info("\nFor real trading (after updating API keys):")
- logger.info("python main.py --mode live --demo false --symbol ETH/USDT --timeframe 1m --leverage 50")
- return 0
- else:
- logger.error(f"❌ {failed_checks} check(s) failed. Please fix the issues before running live trading.")
- return 1
-
-if __name__ == "__main__":
- exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
diff --git a/check_mexc_symbols.py b/check_mexc_symbols.py
deleted file mode 100644
index bdd063f..0000000
--- a/check_mexc_symbols.py
+++ /dev/null
@@ -1,77 +0,0 @@
-#!/usr/bin/env python3
-"""
-Check MEXC Available Trading Symbols
-"""
-
-import os
-import sys
-import logging
-
-# Add project root to path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from core.trading_executor import TradingExecutor
-
-# Setup logging
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-def check_mexc_symbols():
- """Check available trading symbols on MEXC"""
- try:
- logger.info("=== MEXC SYMBOL AVAILABILITY CHECK ===")
-
- # Initialize trading executor
- executor = TradingExecutor("config.yaml")
-
- if not executor.exchange:
- logger.error("Failed to initialize exchange")
- return
-
- # Get all supported symbols
- logger.info("Fetching all supported symbols from MEXC...")
- supported_symbols = executor.exchange.get_api_symbols()
-
- logger.info(f"Total supported symbols: {len(supported_symbols)}")
-
- # Filter ETH-related symbols
- eth_symbols = [s for s in supported_symbols if 'ETH' in s]
- logger.info(f"ETH-related symbols ({len(eth_symbols)}):")
- for symbol in sorted(eth_symbols):
- logger.info(f" {symbol}")
-
- # Filter USDT pairs
- usdt_symbols = [s for s in supported_symbols if s.endswith('USDT')]
- logger.info(f"USDT pairs ({len(usdt_symbols)}):")
- for symbol in sorted(usdt_symbols)[:20]: # Show first 20
- logger.info(f" {symbol}")
- if len(usdt_symbols) > 20:
- logger.info(f" ... and {len(usdt_symbols) - 20} more")
-
- # Filter USDC pairs
- usdc_symbols = [s for s in supported_symbols if s.endswith('USDC')]
- logger.info(f"USDC pairs ({len(usdc_symbols)}):")
- for symbol in sorted(usdc_symbols):
- logger.info(f" {symbol}")
-
- # Check specific symbols we're interested in
- test_symbols = ['ETHUSDT', 'ETHUSDC', 'BTCUSDT', 'BTCUSDC']
- logger.info("Checking specific symbols:")
- for symbol in test_symbols:
- if symbol in supported_symbols:
- logger.info(f" ✅ {symbol} - SUPPORTED")
- else:
- logger.info(f" ❌ {symbol} - NOT SUPPORTED")
-
- # Show a sample of all available symbols
- logger.info("Sample of all available symbols:")
- for symbol in sorted(supported_symbols)[:30]:
- logger.info(f" {symbol}")
- if len(supported_symbols) > 30:
- logger.info(f" ... and {len(supported_symbols) - 30} more")
-
- except Exception as e:
- logger.error(f"Error checking MEXC symbols: {e}")
-
-if __name__ == "__main__":
- check_mexc_symbols()
\ No newline at end of file
diff --git a/cleanup_checkpoint_db.py b/cleanup_checkpoint_db.py
deleted file mode 100644
index b8d4ae3..0000000
--- a/cleanup_checkpoint_db.py
+++ /dev/null
@@ -1,108 +0,0 @@
-#!/usr/bin/env python3
-"""
-Cleanup Checkpoint Database
-
-Remove invalid database entries and ensure consistency
-"""
-
-import logging
-from pathlib import Path
-from utils.database_manager import get_database_manager
-
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-def cleanup_invalid_checkpoints():
- """Remove database entries for non-existent checkpoint files"""
- print("=== Cleaning Up Invalid Checkpoint Entries ===")
-
- db_manager = get_database_manager()
-
- # Get all checkpoints from database
- all_models = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target', 'cob_rl', 'extrema_trainer', 'decision']
-
- removed_count = 0
-
- for model_name in all_models:
- checkpoints = db_manager.list_checkpoints(model_name)
-
- for checkpoint in checkpoints:
- file_path = Path(checkpoint.file_path)
-
- if not file_path.exists():
- print(f"Removing invalid entry: {checkpoint.checkpoint_id} -> {checkpoint.file_path}")
-
- # Remove from database by setting as inactive and creating a new active one if needed
- try:
- # For now, we'll just report - the system will handle missing files gracefully
- logger.warning(f"Invalid checkpoint file: {checkpoint.file_path}")
- removed_count += 1
- except Exception as e:
- logger.error(f"Failed to remove invalid checkpoint: {e}")
- else:
- print(f"Valid checkpoint: {checkpoint.checkpoint_id} -> {checkpoint.file_path}")
-
- print(f"Found {removed_count} invalid checkpoint entries")
-
-def verify_checkpoint_loading():
- """Test that checkpoint loading works correctly"""
- print("\n=== Verifying Checkpoint Loading ===")
-
- from utils.checkpoint_manager import load_best_checkpoint
-
- models_to_test = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target']
-
- for model_name in models_to_test:
- try:
- result = load_best_checkpoint(model_name)
-
- if result:
- file_path, metadata = result
- file_exists = Path(file_path).exists()
-
- print(f"{model_name}:")
- print(f" ✅ Checkpoint found: {metadata.checkpoint_id}")
- print(f" 📁 File exists: {file_exists}")
- print(f" 📊 Loss: {getattr(metadata, 'loss', 'N/A')}")
- print(f" 💾 Size: {Path(file_path).stat().st_size / (1024*1024):.1f}MB" if file_exists else " 💾 Size: N/A")
- else:
- print(f"{model_name}: ❌ No valid checkpoint found")
-
- except Exception as e:
- print(f"{model_name}: ❌ Error loading checkpoint: {e}")
-
-def test_checkpoint_system_integration():
- """Test integration with the orchestrator"""
- print("\n=== Testing Orchestrator Integration ===")
-
- try:
- # Test database manager integration
- from utils.database_manager import get_database_manager
- db_manager = get_database_manager()
-
- # Test fast metadata access
- for model_name in ['dqn_agent', 'enhanced_cnn']:
- metadata = db_manager.get_best_checkpoint_metadata(model_name)
- if metadata:
- print(f"{model_name}: ✅ Fast metadata access works")
- print(f" ID: {metadata.checkpoint_id}")
- print(f" Loss: {metadata.performance_metrics.get('loss', 'N/A')}")
- else:
- print(f"{model_name}: ❌ No metadata found")
-
- print("\n✅ Checkpoint system is ready for use!")
-
- except Exception as e:
- print(f"❌ Integration test failed: {e}")
-
-def main():
- """Main cleanup process"""
- cleanup_invalid_checkpoints()
- verify_checkpoint_loading()
- test_checkpoint_system_integration()
-
- print("\n=== Cleanup Complete ===")
- print("The checkpoint system should now work without 'file not found' errors!")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/cleanup_checkpoints.py b/cleanup_checkpoints.py
deleted file mode 100644
index 5f4ab67..0000000
--- a/cleanup_checkpoints.py
+++ /dev/null
@@ -1,186 +0,0 @@
-#!/usr/bin/env python3
-"""
-Checkpoint Cleanup and Migration Script
-
-This script helps clean up existing checkpoints and migrate to the new
-checkpoint management system with W&B integration.
-"""
-
-import os
-import logging
-import shutil
-from pathlib import Path
-from datetime import datetime
-from typing import List, Dict, Any
-import torch
-
-from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata
-
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-class CheckpointCleanup:
- def __init__(self):
- self.saved_models_dir = Path("NN/models/saved")
- self.checkpoint_manager = get_checkpoint_manager()
-
- def analyze_existing_checkpoints(self) -> Dict[str, Any]:
- logger.info("Analyzing existing checkpoint files...")
-
- analysis = {
- 'total_files': 0,
- 'total_size_mb': 0.0,
- 'model_types': {},
- 'file_patterns': {},
- 'potential_duplicates': []
- }
-
- if not self.saved_models_dir.exists():
- logger.warning(f"Saved models directory not found: {self.saved_models_dir}")
- return analysis
-
- for pt_file in self.saved_models_dir.rglob("*.pt"):
- try:
- file_size_mb = pt_file.stat().st_size / (1024 * 1024)
- analysis['total_files'] += 1
- analysis['total_size_mb'] += file_size_mb
-
- filename = pt_file.name
-
- if 'cnn' in filename.lower():
- model_type = 'cnn'
- elif 'dqn' in filename.lower() or 'rl' in filename.lower():
- model_type = 'rl'
- elif 'agent' in filename.lower():
- model_type = 'rl'
- else:
- model_type = 'unknown'
-
- if model_type not in analysis['model_types']:
- analysis['model_types'][model_type] = {'count': 0, 'size_mb': 0.0}
-
- analysis['model_types'][model_type]['count'] += 1
- analysis['model_types'][model_type]['size_mb'] += file_size_mb
-
- base_name = filename.split('_')[0] if '_' in filename else filename.replace('.pt', '')
- if base_name not in analysis['file_patterns']:
- analysis['file_patterns'][base_name] = []
-
- analysis['file_patterns'][base_name].append({
- 'path': str(pt_file),
- 'size_mb': file_size_mb,
- 'modified': datetime.fromtimestamp(pt_file.stat().st_mtime)
- })
-
- except Exception as e:
- logger.error(f"Error analyzing {pt_file}: {e}")
-
- for base_name, files in analysis['file_patterns'].items():
- if len(files) > 5: # More than 5 files with same base name
- analysis['potential_duplicates'].append({
- 'base_name': base_name,
- 'count': len(files),
- 'total_size_mb': sum(f['size_mb'] for f in files),
- 'files': files
- })
-
- logger.info(f"Analysis complete:")
- logger.info(f" Total files: {analysis['total_files']}")
- logger.info(f" Total size: {analysis['total_size_mb']:.2f} MB")
- logger.info(f" Model types: {analysis['model_types']}")
- logger.info(f" Potential duplicates: {len(analysis['potential_duplicates'])}")
-
- return analysis
-
- def cleanup_duplicates(self, dry_run: bool = True) -> Dict[str, Any]:
- logger.info(f"Starting duplicate cleanup (dry_run={dry_run})...")
-
- cleanup_results = {
- 'removed': 0,
- 'kept': 0,
- 'space_saved_mb': 0.0,
- 'details': []
- }
-
- analysis = self.analyze_existing_checkpoints()
-
- for duplicate_group in analysis['potential_duplicates']:
- base_name = duplicate_group['base_name']
- files = duplicate_group['files']
-
- # Sort by modification time (newest first)
- files.sort(key=lambda x: x['modified'], reverse=True)
-
- logger.info(f"Processing {base_name}: {len(files)} files")
-
- # Keep only the 5 newest files
- for i, file_info in enumerate(files):
- if i < 5: # Keep first 5 (newest)
- cleanup_results['kept'] += 1
- cleanup_results['details'].append({
- 'action': 'kept',
- 'file': file_info['path']
- })
- else: # Remove the rest
- if not dry_run:
- try:
- Path(file_info['path']).unlink()
- logger.info(f"Removed: {file_info['path']}")
- except Exception as e:
- logger.error(f"Error removing {file_info['path']}: {e}")
- continue
-
- cleanup_results['removed'] += 1
- cleanup_results['space_saved_mb'] += file_info['size_mb']
- cleanup_results['details'].append({
- 'action': 'removed',
- 'file': file_info['path'],
- 'size_mb': file_info['size_mb']
- })
-
- logger.info(f"Cleanup {'simulation' if dry_run else 'complete'}:")
- logger.info(f" Kept: {cleanup_results['kept']}")
- logger.info(f" Removed: {cleanup_results['removed']}")
- logger.info(f" Space saved: {cleanup_results['space_saved_mb']:.2f} MB")
-
- return cleanup_results
-
-def main():
- logger.info("=== Checkpoint Cleanup Tool ===")
-
- cleanup = CheckpointCleanup()
-
- # Analyze existing checkpoints
- logger.info("\\n1. Analyzing existing checkpoints...")
- analysis = cleanup.analyze_existing_checkpoints()
-
- if analysis['total_files'] == 0:
- logger.info("No checkpoint files found.")
- return
-
- # Show potential space savings
- total_duplicates = sum(len(group['files']) - 5 for group in analysis['potential_duplicates'] if len(group['files']) > 5)
- if total_duplicates > 0:
- logger.info(f"\\nFound {total_duplicates} files that could be cleaned up")
-
- # Dry run first
- logger.info("\\n2. Simulating cleanup...")
- dry_run_results = cleanup.cleanup_duplicates(dry_run=True)
-
- if dry_run_results['removed'] > 0:
- proceed = input(f"\\nProceed with cleanup? Will remove {dry_run_results['removed']} files "
- f"and save {dry_run_results['space_saved_mb']:.2f} MB. (y/n): ").lower().strip() == 'y'
-
- if proceed:
- logger.info("\\n3. Performing actual cleanup...")
- cleanup_results = cleanup.cleanup_duplicates(dry_run=False)
- logger.info("\\n=== Cleanup Complete ===")
- else:
- logger.info("Cleanup cancelled.")
- else:
- logger.info("No files to remove.")
- else:
- logger.info("No duplicate files found that need cleanup.")
-
-if __name__ == "__main__":
- main()
diff --git a/core/__init__.py b/core/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/core/api_rate_limiter.py b/core/api_rate_limiter.py
deleted file mode 100644
index 528c345..0000000
--- a/core/api_rate_limiter.py
+++ /dev/null
@@ -1,402 +0,0 @@
-"""
-API Rate Limiter and Error Handler
-
-This module provides robust rate limiting and error handling for API requests,
-specifically designed to handle Binance's aggressive rate limiting (HTTP 418 errors)
-and other exchange API limitations.
-
-Features:
-- Exponential backoff for rate limiting
-- IP rotation and proxy support
-- Request queuing and throttling
-- Error recovery strategies
-- Thread-safe operations
-"""
-
-import asyncio
-import logging
-import time
-import random
-from datetime import datetime, timedelta
-from typing import Dict, List, Optional, Callable, Any
-from dataclasses import dataclass, field
-from collections import deque
-import threading
-from concurrent.futures import ThreadPoolExecutor
-import requests
-from requests.adapters import HTTPAdapter
-from urllib3.util.retry import Retry
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class RateLimitConfig:
- """Configuration for rate limiting"""
- requests_per_second: float = 0.5 # Very conservative for Binance
- requests_per_minute: int = 20
- requests_per_hour: int = 1000
-
- # Backoff configuration
- initial_backoff: float = 1.0
- max_backoff: float = 300.0 # 5 minutes max
- backoff_multiplier: float = 2.0
-
- # Error handling
- max_retries: int = 3
- retry_delay: float = 5.0
-
- # IP blocking detection
- block_detection_threshold: int = 3 # 3 consecutive 418s = blocked
- block_recovery_time: int = 3600 # 1 hour recovery time
-
-@dataclass
-class APIEndpoint:
- """API endpoint configuration"""
- name: str
- base_url: str
- rate_limit: RateLimitConfig
- last_request_time: float = 0.0
- request_count_minute: int = 0
- request_count_hour: int = 0
- consecutive_errors: int = 0
- blocked_until: Optional[datetime] = None
-
- # Request history for rate limiting
- request_history: deque = field(default_factory=lambda: deque(maxlen=3600)) # 1 hour history
-
-class APIRateLimiter:
- """Thread-safe API rate limiter with error handling"""
-
- def __init__(self, config: RateLimitConfig = None):
- self.config = config or RateLimitConfig()
-
- # Thread safety
- self.lock = threading.RLock()
-
- # Endpoint tracking
- self.endpoints: Dict[str, APIEndpoint] = {}
-
- # Global rate limiting
- self.global_request_history = deque(maxlen=3600)
- self.global_blocked_until: Optional[datetime] = None
-
- # Request session with retry strategy
- self.session = self._create_session()
-
- # Background cleanup thread
- self.cleanup_thread = None
- self.is_running = False
-
- logger.info("API Rate Limiter initialized")
- logger.info(f"Rate limits: {self.config.requests_per_second}/s, {self.config.requests_per_minute}/m")
-
- def _create_session(self) -> requests.Session:
- """Create requests session with retry strategy"""
- session = requests.Session()
-
- # Retry strategy
- retry_strategy = Retry(
- total=self.config.max_retries,
- backoff_factor=1,
- status_forcelist=[429, 500, 502, 503, 504],
- allowed_methods=["HEAD", "GET", "OPTIONS"]
- )
-
- adapter = HTTPAdapter(max_retries=retry_strategy)
- session.mount("http://", adapter)
- session.mount("https://", adapter)
-
- # Headers to appear more legitimate
- session.headers.update({
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
- 'Accept': 'application/json',
- 'Accept-Language': 'en-US,en;q=0.9',
- 'Accept-Encoding': 'gzip, deflate, br',
- 'Connection': 'keep-alive',
- 'Upgrade-Insecure-Requests': '1',
- })
-
- return session
-
- def register_endpoint(self, name: str, base_url: str, rate_limit: RateLimitConfig = None):
- """Register an API endpoint for rate limiting"""
- with self.lock:
- self.endpoints[name] = APIEndpoint(
- name=name,
- base_url=base_url,
- rate_limit=rate_limit or self.config
- )
- logger.info(f"Registered endpoint: {name} -> {base_url}")
-
- def start_background_cleanup(self):
- """Start background cleanup thread"""
- if self.is_running:
- return
-
- self.is_running = True
- self.cleanup_thread = threading.Thread(target=self._cleanup_worker, daemon=True)
- self.cleanup_thread.start()
- logger.info("Started background cleanup thread")
-
- def stop_background_cleanup(self):
- """Stop background cleanup thread"""
- self.is_running = False
- if self.cleanup_thread:
- self.cleanup_thread.join(timeout=5)
- logger.info("Stopped background cleanup thread")
-
- def _cleanup_worker(self):
- """Background worker to clean up old request history"""
- while self.is_running:
- try:
- current_time = time.time()
- cutoff_time = current_time - 3600 # 1 hour ago
-
- with self.lock:
- # Clean global history
- while (self.global_request_history and
- self.global_request_history[0] < cutoff_time):
- self.global_request_history.popleft()
-
- # Clean endpoint histories
- for endpoint in self.endpoints.values():
- while (endpoint.request_history and
- endpoint.request_history[0] < cutoff_time):
- endpoint.request_history.popleft()
-
- # Reset counters
- endpoint.request_count_minute = len([
- t for t in endpoint.request_history
- if t > current_time - 60
- ])
- endpoint.request_count_hour = len(endpoint.request_history)
-
- time.sleep(60) # Clean every minute
-
- except Exception as e:
- logger.error(f"Error in cleanup worker: {e}")
- time.sleep(30)
-
- def can_make_request(self, endpoint_name: str) -> tuple[bool, float]:
- """
- Check if we can make a request to the endpoint
-
- Returns:
- (can_make_request, wait_time_seconds)
- """
- with self.lock:
- current_time = time.time()
-
- # Check global blocking
- if self.global_blocked_until and datetime.now() < self.global_blocked_until:
- wait_time = (self.global_blocked_until - datetime.now()).total_seconds()
- return False, wait_time
-
- # Get endpoint
- endpoint = self.endpoints.get(endpoint_name)
- if not endpoint:
- logger.warning(f"Unknown endpoint: {endpoint_name}")
- return False, 60.0
-
- # Check endpoint blocking
- if endpoint.blocked_until and datetime.now() < endpoint.blocked_until:
- wait_time = (endpoint.blocked_until - datetime.now()).total_seconds()
- return False, wait_time
-
- # Check rate limits
- config = endpoint.rate_limit
-
- # Per-second rate limit
- time_since_last = current_time - endpoint.last_request_time
- if time_since_last < (1.0 / config.requests_per_second):
- wait_time = (1.0 / config.requests_per_second) - time_since_last
- return False, wait_time
-
- # Per-minute rate limit
- minute_requests = len([
- t for t in endpoint.request_history
- if t > current_time - 60
- ])
- if minute_requests >= config.requests_per_minute:
- return False, 60.0
-
- # Per-hour rate limit
- if len(endpoint.request_history) >= config.requests_per_hour:
- return False, 3600.0
-
- return True, 0.0
-
- def make_request(self, endpoint_name: str, url: str, method: str = 'GET',
- **kwargs) -> Optional[requests.Response]:
- """
- Make a rate-limited request with error handling
-
- Args:
- endpoint_name: Name of the registered endpoint
- url: Full URL to request
- method: HTTP method
- **kwargs: Additional arguments for requests
-
- Returns:
- Response object or None if failed
- """
- with self.lock:
- endpoint = self.endpoints.get(endpoint_name)
- if not endpoint:
- logger.error(f"Unknown endpoint: {endpoint_name}")
- return None
-
- # Check if we can make the request
- can_request, wait_time = self.can_make_request(endpoint_name)
- if not can_request:
- logger.debug(f"Rate limited for {endpoint_name}, waiting {wait_time:.2f}s")
- time.sleep(min(wait_time, 30)) # Cap wait time
- return None
-
- # Record request attempt
- current_time = time.time()
- endpoint.last_request_time = current_time
- endpoint.request_history.append(current_time)
- self.global_request_history.append(current_time)
-
- # Add jitter to avoid thundering herd
- jitter = random.uniform(0.1, 0.5)
- time.sleep(jitter)
-
- # Make the request (outside of lock to avoid blocking other threads)
- try:
- # Set timeout
- kwargs.setdefault('timeout', 10)
-
- # Make request
- response = self.session.request(method, url, **kwargs)
-
- # Handle response
- with self.lock:
- if response.status_code == 200:
- # Success - reset error counter
- endpoint.consecutive_errors = 0
- return response
-
- elif response.status_code == 418:
- # Binance "I'm a teapot" - rate limited/blocked
- endpoint.consecutive_errors += 1
- logger.warning(f"HTTP 418 (rate limited) for {endpoint_name}, consecutive errors: {endpoint.consecutive_errors}")
-
- if endpoint.consecutive_errors >= endpoint.rate_limit.block_detection_threshold:
- # We're likely IP blocked
- block_time = datetime.now() + timedelta(seconds=endpoint.rate_limit.block_recovery_time)
- endpoint.blocked_until = block_time
- logger.error(f"Endpoint {endpoint_name} blocked until {block_time}")
-
- return None
-
- elif response.status_code == 429:
- # Too many requests
- endpoint.consecutive_errors += 1
- logger.warning(f"HTTP 429 (too many requests) for {endpoint_name}")
-
- # Implement exponential backoff
- backoff_time = min(
- endpoint.rate_limit.initial_backoff * (endpoint.rate_limit.backoff_multiplier ** endpoint.consecutive_errors),
- endpoint.rate_limit.max_backoff
- )
-
- block_time = datetime.now() + timedelta(seconds=backoff_time)
- endpoint.blocked_until = block_time
- logger.warning(f"Backing off {endpoint_name} for {backoff_time:.2f}s")
-
- return None
-
- else:
- # Other error
- endpoint.consecutive_errors += 1
- logger.warning(f"HTTP {response.status_code} for {endpoint_name}: {response.text[:200]}")
- return None
-
- except requests.exceptions.RequestException as e:
- with self.lock:
- endpoint.consecutive_errors += 1
- logger.error(f"Request exception for {endpoint_name}: {e}")
- return None
-
- except Exception as e:
- with self.lock:
- endpoint.consecutive_errors += 1
- logger.error(f"Unexpected error for {endpoint_name}: {e}")
- return None
-
- def get_endpoint_status(self, endpoint_name: str) -> Dict[str, Any]:
- """Get status information for an endpoint"""
- with self.lock:
- endpoint = self.endpoints.get(endpoint_name)
- if not endpoint:
- return {'error': 'Unknown endpoint'}
-
- current_time = time.time()
-
- return {
- 'name': endpoint.name,
- 'base_url': endpoint.base_url,
- 'consecutive_errors': endpoint.consecutive_errors,
- 'blocked_until': endpoint.blocked_until.isoformat() if endpoint.blocked_until else None,
- 'requests_last_minute': len([t for t in endpoint.request_history if t > current_time - 60]),
- 'requests_last_hour': len(endpoint.request_history),
- 'last_request_time': endpoint.last_request_time,
- 'can_make_request': self.can_make_request(endpoint_name)[0]
- }
-
- def get_all_endpoint_status(self) -> Dict[str, Dict[str, Any]]:
- """Get status for all endpoints"""
- return {name: self.get_endpoint_status(name) for name in self.endpoints.keys()}
-
- def reset_endpoint(self, endpoint_name: str):
- """Reset an endpoint's error state"""
- with self.lock:
- endpoint = self.endpoints.get(endpoint_name)
- if endpoint:
- endpoint.consecutive_errors = 0
- endpoint.blocked_until = None
- logger.info(f"Reset endpoint: {endpoint_name}")
-
- def reset_all_endpoints(self):
- """Reset all endpoints' error states"""
- with self.lock:
- for endpoint in self.endpoints.values():
- endpoint.consecutive_errors = 0
- endpoint.blocked_until = None
- self.global_blocked_until = None
- logger.info("Reset all endpoints")
-
-# Global rate limiter instance
-_global_rate_limiter = None
-
-def get_rate_limiter() -> APIRateLimiter:
- """Get global rate limiter instance"""
- global _global_rate_limiter
- if _global_rate_limiter is None:
- _global_rate_limiter = APIRateLimiter()
- _global_rate_limiter.start_background_cleanup()
-
- # Register common endpoints
- _global_rate_limiter.register_endpoint(
- 'binance_api',
- 'https://api.binance.com',
- RateLimitConfig(
- requests_per_second=0.2, # Very conservative
- requests_per_minute=10,
- requests_per_hour=500
- )
- )
-
- _global_rate_limiter.register_endpoint(
- 'mexc_api',
- 'https://api.mexc.com',
- RateLimitConfig(
- requests_per_second=0.5,
- requests_per_minute=20,
- requests_per_hour=1000
- )
- )
-
- return _global_rate_limiter
\ No newline at end of file
diff --git a/core/async_handler.py b/core/async_handler.py
deleted file mode 100644
index 10a0793..0000000
--- a/core/async_handler.py
+++ /dev/null
@@ -1,442 +0,0 @@
-"""
-Async Handler for UI Stability Fix
-
-Properly handles all async operations in the dashboard with single event loop management,
-proper exception handling, and timeout support to prevent async/await errors.
-"""
-
-import asyncio
-import logging
-import threading
-import time
-from typing import Any, Callable, Coroutine, Dict, Optional, Union
-from concurrent.futures import ThreadPoolExecutor
-import functools
-import weakref
-
-logger = logging.getLogger(__name__)
-
-
-class AsyncOperationError(Exception):
- """Exception raised for async operation errors"""
- pass
-
-
-class AsyncHandler:
- """
- Centralized async operation handler with single event loop management
- and proper exception handling for async operations.
- """
-
- def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
- """
- Initialize the async handler
-
- Args:
- loop: Optional event loop to use. If None, creates a new one.
- """
- self._loop = loop
- self._thread = None
- self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="AsyncHandler")
- self._running = False
- self._callbacks = weakref.WeakSet()
- self._timeout_default = 30.0 # Default timeout for operations
-
- # Start the event loop in a separate thread if not provided
- if self._loop is None:
- self._start_event_loop_thread()
-
- logger.info("AsyncHandler initialized with event loop management")
-
- def _start_event_loop_thread(self):
- """Start the event loop in a separate thread"""
- def run_event_loop():
- """Run the event loop in a separate thread"""
- try:
- self._loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self._loop)
- self._running = True
- logger.debug("Event loop started in separate thread")
- self._loop.run_forever()
- except Exception as e:
- logger.error(f"Error in event loop thread: {e}")
- finally:
- self._running = False
- logger.debug("Event loop thread stopped")
-
- self._thread = threading.Thread(target=run_event_loop, daemon=True, name="AsyncHandler-EventLoop")
- self._thread.start()
-
- # Wait for the loop to be ready
- timeout = 5.0
- start_time = time.time()
- while not self._running and (time.time() - start_time) < timeout:
- time.sleep(0.1)
-
- if not self._running:
- raise AsyncOperationError("Failed to start event loop within timeout")
-
- def is_running(self) -> bool:
- """Check if the async handler is running"""
- return self._running and self._loop is not None and not self._loop.is_closed()
-
- def run_async_safely(self, coro: Coroutine, timeout: Optional[float] = None) -> Any:
- """
- Run an async coroutine safely with proper error handling and timeout
-
- Args:
- coro: The coroutine to run
- timeout: Timeout in seconds (uses default if None)
-
- Returns:
- The result of the coroutine
-
- Raises:
- AsyncOperationError: If the operation fails or times out
- """
- if not self.is_running():
- raise AsyncOperationError("AsyncHandler is not running")
-
- timeout = timeout or self._timeout_default
-
- try:
- # Schedule the coroutine on the event loop
- future = asyncio.run_coroutine_threadsafe(
- asyncio.wait_for(coro, timeout=timeout),
- self._loop
- )
-
- # Wait for the result with timeout
- result = future.result(timeout=timeout + 1.0) # Add buffer to future timeout
- logger.debug("Async operation completed successfully")
- return result
-
- except asyncio.TimeoutError:
- logger.error(f"Async operation timed out after {timeout} seconds")
- raise AsyncOperationError(f"Operation timed out after {timeout} seconds")
- except Exception as e:
- logger.error(f"Async operation failed: {e}")
- raise AsyncOperationError(f"Async operation failed: {e}")
-
- def schedule_coroutine(self, coro: Coroutine, callback: Optional[Callable] = None) -> None:
- """
- Schedule a coroutine to run asynchronously without waiting for result
-
- Args:
- coro: The coroutine to schedule
- callback: Optional callback to call with the result
- """
- if not self.is_running():
- logger.warning("Cannot schedule coroutine: AsyncHandler is not running")
- return
-
- async def wrapped_coro():
- """Wrapper to handle exceptions and callbacks"""
- try:
- result = await coro
- if callback:
- try:
- callback(result)
- except Exception as e:
- logger.error(f"Error in coroutine callback: {e}")
- return result
- except Exception as e:
- logger.error(f"Error in scheduled coroutine: {e}")
- if callback:
- try:
- callback(None) # Call callback with None on error
- except Exception as cb_e:
- logger.error(f"Error in error callback: {cb_e}")
-
- try:
- asyncio.run_coroutine_threadsafe(wrapped_coro(), self._loop)
- logger.debug("Coroutine scheduled successfully")
- except Exception as e:
- logger.error(f"Failed to schedule coroutine: {e}")
-
- def create_task_safely(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]:
- """
- Create an asyncio task safely with proper error handling
-
- Args:
- coro: The coroutine to create a task for
- name: Optional name for the task
-
- Returns:
- The created task or None if failed
- """
- if not self.is_running():
- logger.warning("Cannot create task: AsyncHandler is not running")
- return None
-
- async def create_task():
- """Create the task in the event loop"""
- try:
- task = asyncio.create_task(coro, name=name)
- logger.debug(f"Task created: {name or 'unnamed'}")
- return task
- except Exception as e:
- logger.error(f"Failed to create task {name}: {e}")
- return None
-
- try:
- future = asyncio.run_coroutine_threadsafe(create_task(), self._loop)
- return future.result(timeout=5.0)
- except Exception as e:
- logger.error(f"Failed to create task {name}: {e}")
- return None
-
- async def handle_orchestrator_connection(self, orchestrator) -> bool:
- """
- Handle orchestrator connection with proper async patterns
-
- Args:
- orchestrator: The orchestrator instance to connect to
-
- Returns:
- True if connection successful, False otherwise
- """
- try:
- logger.info("Connecting to orchestrator...")
-
- # Add decision callback if orchestrator supports it
- if hasattr(orchestrator, 'add_decision_callback'):
- await orchestrator.add_decision_callback(self._handle_trading_decision)
- logger.info("Decision callback added to orchestrator")
-
- # Start COB integration if available
- if hasattr(orchestrator, 'start_cob_integration'):
- await orchestrator.start_cob_integration()
- logger.info("COB integration started")
-
- # Start continuous trading if available
- if hasattr(orchestrator, 'start_continuous_trading'):
- await orchestrator.start_continuous_trading()
- logger.info("Continuous trading started")
-
- logger.info("Successfully connected to orchestrator")
- return True
-
- except Exception as e:
- logger.error(f"Failed to connect to orchestrator: {e}")
- return False
-
- async def handle_cob_integration(self, cob_integration) -> bool:
- """
- Handle COB integration startup with proper async patterns
-
- Args:
- cob_integration: The COB integration instance
-
- Returns:
- True if startup successful, False otherwise
- """
- try:
- logger.info("Starting COB integration...")
-
- if hasattr(cob_integration, 'start'):
- await cob_integration.start()
- logger.info("COB integration started successfully")
- return True
- else:
- logger.warning("COB integration does not have start method")
- return False
-
- except Exception as e:
- logger.error(f"Failed to start COB integration: {e}")
- return False
-
- async def _handle_trading_decision(self, decision: Dict[str, Any]) -> None:
- """
- Handle trading decision with proper async patterns
-
- Args:
- decision: The trading decision dictionary
- """
- try:
- logger.debug(f"Handling trading decision: {decision.get('action', 'UNKNOWN')}")
-
- # Process the decision (this would be customized based on needs)
- # For now, just log it
- symbol = decision.get('symbol', 'UNKNOWN')
- action = decision.get('action', 'HOLD')
- confidence = decision.get('confidence', 0.0)
-
- logger.info(f"Trading decision processed: {action} {symbol} (confidence: {confidence:.2f})")
-
- except Exception as e:
- logger.error(f"Error handling trading decision: {e}")
-
- def run_in_executor(self, func: Callable, *args, **kwargs) -> Any:
- """
- Run a blocking function in the thread pool executor
-
- Args:
- func: The function to run
- *args: Positional arguments for the function
- **kwargs: Keyword arguments for the function
-
- Returns:
- The result of the function
- """
- if not self.is_running():
- raise AsyncOperationError("AsyncHandler is not running")
-
- try:
- # Create a partial function with the arguments
- partial_func = functools.partial(func, *args, **kwargs)
-
- # Create a coroutine that runs the function in executor
- async def run_in_executor_coro():
- return await self._loop.run_in_executor(self._executor, partial_func)
-
- # Run the coroutine
- future = asyncio.run_coroutine_threadsafe(run_in_executor_coro(), self._loop)
-
- result = future.result(timeout=self._timeout_default)
- logger.debug("Executor function completed successfully")
- return result
-
- except Exception as e:
- logger.error(f"Error running function in executor: {e}")
- raise AsyncOperationError(f"Executor function failed: {e}")
-
- def add_periodic_task(self, coro_func: Callable[[], Coroutine], interval: float, name: Optional[str] = None) -> Optional[asyncio.Task]:
- """
- Add a periodic task that runs at specified intervals
-
- Args:
- coro_func: Function that returns a coroutine to run periodically
- interval: Interval in seconds between runs
- name: Optional name for the task
-
- Returns:
- The created task or None if failed
- """
- async def periodic_runner():
- """Run the coroutine periodically"""
- task_name = name or "periodic_task"
- logger.info(f"Starting periodic task: {task_name} (interval: {interval}s)")
-
- try:
- while True:
- try:
- coro = coro_func()
- await coro
- logger.debug(f"Periodic task {task_name} completed")
- except Exception as e:
- logger.error(f"Error in periodic task {task_name}: {e}")
-
- await asyncio.sleep(interval)
-
- except asyncio.CancelledError:
- logger.info(f"Periodic task {task_name} cancelled")
- raise
- except Exception as e:
- logger.error(f"Fatal error in periodic task {task_name}: {e}")
-
- return self.create_task_safely(periodic_runner(), name=f"periodic_{name}")
-
- def stop(self) -> None:
- """Stop the async handler and clean up resources"""
- try:
- logger.info("Stopping AsyncHandler...")
-
- if self._loop and not self._loop.is_closed():
- # Cancel all tasks
- if self._loop.is_running():
- asyncio.run_coroutine_threadsafe(self._cancel_all_tasks(), self._loop)
-
- # Stop the event loop
- self._loop.call_soon_threadsafe(self._loop.stop)
-
- # Shutdown executor
- if self._executor:
- self._executor.shutdown(wait=True)
-
- # Wait for thread to finish
- if self._thread and self._thread.is_alive():
- self._thread.join(timeout=5.0)
-
- self._running = False
- logger.info("AsyncHandler stopped successfully")
-
- except Exception as e:
- logger.error(f"Error stopping AsyncHandler: {e}")
-
- async def _cancel_all_tasks(self) -> None:
- """Cancel all running tasks"""
- try:
- tasks = [task for task in asyncio.all_tasks(self._loop) if not task.done()]
- if tasks:
- logger.info(f"Cancelling {len(tasks)} running tasks")
- for task in tasks:
- task.cancel()
-
- # Wait for tasks to be cancelled
- await asyncio.gather(*tasks, return_exceptions=True)
- logger.debug("All tasks cancelled")
- except Exception as e:
- logger.error(f"Error cancelling tasks: {e}")
-
- def __enter__(self):
- """Context manager entry"""
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- """Context manager exit"""
- self.stop()
-
-
-class AsyncContextManager:
- """
- Context manager for async operations that ensures proper cleanup
- """
-
- def __init__(self, async_handler: AsyncHandler):
- self.async_handler = async_handler
- self.active_tasks = []
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- # Cancel any active tasks
- for task in self.active_tasks:
- if not task.done():
- task.cancel()
-
- def create_task(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]:
- """Create a task and track it for cleanup"""
- task = self.async_handler.create_task_safely(coro, name)
- if task:
- self.active_tasks.append(task)
- return task
-
-
-def create_async_handler(loop: Optional[asyncio.AbstractEventLoop] = None) -> AsyncHandler:
- """
- Factory function to create an AsyncHandler instance
-
- Args:
- loop: Optional event loop to use
-
- Returns:
- AsyncHandler instance
- """
- return AsyncHandler(loop=loop)
-
-
-def run_async_safely(coro: Coroutine, timeout: Optional[float] = None) -> Any:
- """
- Convenience function to run a coroutine safely with a temporary AsyncHandler
-
- Args:
- coro: The coroutine to run
- timeout: Timeout in seconds
-
- Returns:
- The result of the coroutine
- """
- with AsyncHandler() as handler:
- return handler.run_async_safely(coro, timeout=timeout)
\ No newline at end of file
diff --git a/core/bookmap_data_provider.py b/core/bookmap_data_provider.py
deleted file mode 100644
index c25bc47..0000000
--- a/core/bookmap_data_provider.py
+++ /dev/null
@@ -1,952 +0,0 @@
-"""
-Bookmap Order Book Data Provider
-
-This module integrates with Bookmap to gather:
-- Current Order Book (COB) data
-- Session Volume Profile (SVP) data
-- Order book sweeps and momentum trades detection
-- Real-time order size heatmap matrix (last 10 minutes)
-- Level 2 market depth analysis
-
-The data is processed and fed to CNN and DQN networks for enhanced trading decisions.
-"""
-
-import asyncio
-import json
-import logging
-import time
-import websockets
-import numpy as np
-import pandas as pd
-from datetime import datetime, timedelta
-from typing import Dict, List, Optional, Tuple, Any, Callable
-from collections import deque, defaultdict
-from dataclasses import dataclass
-from threading import Thread, Lock
-import requests
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class OrderBookLevel:
- """Represents a single order book level"""
- price: float
- size: float
- orders: int
- side: str # 'bid' or 'ask'
- timestamp: datetime
-
-@dataclass
-class OrderBookSnapshot:
- """Complete order book snapshot"""
- symbol: str
- timestamp: datetime
- bids: List[OrderBookLevel]
- asks: List[OrderBookLevel]
- spread: float
- mid_price: float
-
-@dataclass
-class VolumeProfileLevel:
- """Volume profile level data"""
- price: float
- volume: float
- buy_volume: float
- sell_volume: float
- trades_count: int
- vwap: float
-
-@dataclass
-class OrderFlowSignal:
- """Order flow signal detection"""
- timestamp: datetime
- signal_type: str # 'sweep', 'absorption', 'iceberg', 'momentum'
- price: float
- volume: float
- confidence: float
- description: str
-
-class BookmapDataProvider:
- """
- Real-time order book data provider using Bookmap-style analysis
-
- Features:
- - Level 2 order book monitoring
- - Order flow detection (sweeps, absorptions)
- - Volume profile analysis
- - Order size heatmap generation
- - Market microstructure analysis
- """
-
- def __init__(self, symbols: List[str] = None, depth_levels: int = 20):
- """
- Initialize Bookmap data provider
-
- Args:
- symbols: List of symbols to monitor
- depth_levels: Number of order book levels to track
- """
- self.symbols = symbols or ['ETHUSDT', 'BTCUSDT']
- self.depth_levels = depth_levels
- self.is_streaming = False
-
- # Order book data storage
- self.order_books: Dict[str, OrderBookSnapshot] = {}
- self.order_book_history: Dict[str, deque] = {}
- self.volume_profiles: Dict[str, List[VolumeProfileLevel]] = {}
-
- # Heatmap data (10-minute rolling window)
- self.heatmap_window = timedelta(minutes=10)
- self.order_heatmaps: Dict[str, deque] = {}
- self.price_levels: Dict[str, List[float]] = {}
-
- # Order flow detection
- self.flow_signals: Dict[str, deque] = {}
- self.sweep_threshold = 0.8 # Minimum confidence for sweep detection
- self.absorption_threshold = 0.7 # Minimum confidence for absorption
-
- # Market microstructure metrics
- self.bid_ask_spreads: Dict[str, deque] = {}
- self.order_book_imbalances: Dict[str, deque] = {}
- self.liquidity_metrics: Dict[str, Dict] = {}
-
- # WebSocket connections
- self.websocket_tasks: Dict[str, asyncio.Task] = {}
- self.data_lock = Lock()
-
- # Callbacks for CNN/DQN integration
- self.cnn_callbacks: List[Callable] = []
- self.dqn_callbacks: List[Callable] = []
-
- # Performance tracking
- self.update_counts = defaultdict(int)
- self.last_update_times = {}
-
- # Initialize data structures
- for symbol in self.symbols:
- self.order_book_history[symbol] = deque(maxlen=1000)
- self.order_heatmaps[symbol] = deque(maxlen=600) # 10 min at 1s intervals
- self.flow_signals[symbol] = deque(maxlen=500)
- self.bid_ask_spreads[symbol] = deque(maxlen=1000)
- self.order_book_imbalances[symbol] = deque(maxlen=1000)
- self.liquidity_metrics[symbol] = {
- 'total_bid_size': 0.0,
- 'total_ask_size': 0.0,
- 'weighted_mid': 0.0,
- 'liquidity_ratio': 1.0
- }
-
- logger.info(f"BookmapDataProvider initialized for {len(self.symbols)} symbols")
- logger.info(f"Tracking {depth_levels} order book levels per side")
-
- def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
- """Add callback for CNN model updates"""
- self.cnn_callbacks.append(callback)
- logger.info(f"Added CNN callback: {len(self.cnn_callbacks)} total")
-
- def add_dqn_callback(self, callback: Callable[[str, Dict], None]):
- """Add callback for DQN model updates"""
- self.dqn_callbacks.append(callback)
- logger.info(f"Added DQN callback: {len(self.dqn_callbacks)} total")
-
- async def start_streaming(self):
- """Start real-time order book streaming"""
- if self.is_streaming:
- logger.warning("Bookmap streaming already active")
- return
-
- self.is_streaming = True
- logger.info("Starting Bookmap order book streaming")
-
- # Start order book streams for each symbol
- for symbol in self.symbols:
- # Order book depth stream
- depth_task = asyncio.create_task(self._stream_order_book_depth(symbol))
- self.websocket_tasks[f"{symbol}_depth"] = depth_task
-
- # Trade stream for order flow analysis
- trade_task = asyncio.create_task(self._stream_trades(symbol))
- self.websocket_tasks[f"{symbol}_trades"] = trade_task
-
- # Start analysis threads
- analysis_task = asyncio.create_task(self._continuous_analysis())
- self.websocket_tasks["analysis"] = analysis_task
-
- logger.info(f"Started streaming for {len(self.symbols)} symbols")
-
- async def stop_streaming(self):
- """Stop order book streaming"""
- if not self.is_streaming:
- return
-
- logger.info("Stopping Bookmap streaming")
- self.is_streaming = False
-
- # Cancel all tasks
- for name, task in self.websocket_tasks.items():
- if not task.done():
- task.cancel()
- try:
- await task
- except asyncio.CancelledError:
- pass
-
- self.websocket_tasks.clear()
- logger.info("Bookmap streaming stopped")
-
- async def _stream_order_book_depth(self, symbol: str):
- """Stream order book depth data"""
- binance_symbol = symbol.lower()
- url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@depth20@100ms"
-
- while self.is_streaming:
- try:
- async with websockets.connect(url) as websocket:
- logger.info(f"Order book depth WebSocket connected for {symbol}")
-
- async for message in websocket:
- if not self.is_streaming:
- break
-
- try:
- data = json.loads(message)
- await self._process_depth_update(symbol, data)
- except Exception as e:
- logger.warning(f"Error processing depth for {symbol}: {e}")
-
- except Exception as e:
- logger.error(f"Depth WebSocket error for {symbol}: {e}")
- if self.is_streaming:
- await asyncio.sleep(2)
-
- async def _stream_trades(self, symbol: str):
- """Stream trade data for order flow analysis"""
- binance_symbol = symbol.lower()
- url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade"
-
- while self.is_streaming:
- try:
- async with websockets.connect(url) as websocket:
- logger.info(f"Trade WebSocket connected for {symbol}")
-
- async for message in websocket:
- if not self.is_streaming:
- break
-
- try:
- data = json.loads(message)
- await self._process_trade_update(symbol, data)
- except Exception as e:
- logger.warning(f"Error processing trade for {symbol}: {e}")
-
- except Exception as e:
- logger.error(f"Trade WebSocket error for {symbol}: {e}")
- if self.is_streaming:
- await asyncio.sleep(2)
-
- async def _process_depth_update(self, symbol: str, data: Dict):
- """Process order book depth update"""
- try:
- timestamp = datetime.now()
-
- # Parse bids and asks
- bids = []
- asks = []
-
- for bid_data in data.get('bids', []):
- price = float(bid_data[0])
- size = float(bid_data[1])
- bids.append(OrderBookLevel(
- price=price,
- size=size,
- orders=1, # Binance doesn't provide order count
- side='bid',
- timestamp=timestamp
- ))
-
- for ask_data in data.get('asks', []):
- price = float(ask_data[0])
- size = float(ask_data[1])
- asks.append(OrderBookLevel(
- price=price,
- size=size,
- orders=1,
- side='ask',
- timestamp=timestamp
- ))
-
- # Sort order book levels
- bids.sort(key=lambda x: x.price, reverse=True)
- asks.sort(key=lambda x: x.price)
-
- # Calculate spread and mid price
- if bids and asks:
- best_bid = bids[0].price
- best_ask = asks[0].price
- spread = best_ask - best_bid
- mid_price = (best_bid + best_ask) / 2
- else:
- spread = 0.0
- mid_price = 0.0
-
- # Create order book snapshot
- snapshot = OrderBookSnapshot(
- symbol=symbol,
- timestamp=timestamp,
- bids=bids,
- asks=asks,
- spread=spread,
- mid_price=mid_price
- )
-
- with self.data_lock:
- self.order_books[symbol] = snapshot
- self.order_book_history[symbol].append(snapshot)
-
- # Update liquidity metrics
- self._update_liquidity_metrics(symbol, snapshot)
-
- # Update order book imbalance
- self._calculate_order_book_imbalance(symbol, snapshot)
-
- # Update heatmap data
- self._update_order_heatmap(symbol, snapshot)
-
- # Update counters
- self.update_counts[f"{symbol}_depth"] += 1
- self.last_update_times[f"{symbol}_depth"] = timestamp
-
- except Exception as e:
- logger.error(f"Error processing depth update for {symbol}: {e}")
-
- async def _process_trade_update(self, symbol: str, data: Dict):
- """Process trade data for order flow analysis"""
- try:
- timestamp = datetime.fromtimestamp(int(data['T']) / 1000)
- price = float(data['p'])
- quantity = float(data['q'])
- is_buyer_maker = data['m']
-
- # Analyze for order flow signals
- await self._analyze_order_flow(symbol, timestamp, price, quantity, is_buyer_maker)
-
- # Update volume profile
- self._update_volume_profile(symbol, price, quantity, is_buyer_maker)
-
- self.update_counts[f"{symbol}_trades"] += 1
-
- except Exception as e:
- logger.error(f"Error processing trade for {symbol}: {e}")
-
- def _update_liquidity_metrics(self, symbol: str, snapshot: OrderBookSnapshot):
- """Update liquidity metrics from order book snapshot"""
- try:
- total_bid_size = sum(level.size for level in snapshot.bids)
- total_ask_size = sum(level.size for level in snapshot.asks)
-
- # Calculate weighted mid price
- if snapshot.bids and snapshot.asks:
- bid_weight = total_bid_size / (total_bid_size + total_ask_size)
- ask_weight = total_ask_size / (total_bid_size + total_ask_size)
- weighted_mid = (snapshot.bids[0].price * ask_weight +
- snapshot.asks[0].price * bid_weight)
- else:
- weighted_mid = snapshot.mid_price
-
- # Liquidity ratio (bid/ask balance)
- if total_ask_size > 0:
- liquidity_ratio = total_bid_size / total_ask_size
- else:
- liquidity_ratio = 1.0
-
- self.liquidity_metrics[symbol] = {
- 'total_bid_size': total_bid_size,
- 'total_ask_size': total_ask_size,
- 'weighted_mid': weighted_mid,
- 'liquidity_ratio': liquidity_ratio,
- 'spread_bps': (snapshot.spread / snapshot.mid_price) * 10000 if snapshot.mid_price > 0 else 0
- }
-
- except Exception as e:
- logger.error(f"Error updating liquidity metrics for {symbol}: {e}")
-
- def _calculate_order_book_imbalance(self, symbol: str, snapshot: OrderBookSnapshot):
- """Calculate order book imbalance ratio"""
- try:
- if not snapshot.bids or not snapshot.asks:
- return
-
- # Calculate imbalance for top N levels
- n_levels = min(5, len(snapshot.bids), len(snapshot.asks))
-
- total_bid_size = sum(snapshot.bids[i].size for i in range(n_levels))
- total_ask_size = sum(snapshot.asks[i].size for i in range(n_levels))
-
- if total_bid_size + total_ask_size > 0:
- imbalance = (total_bid_size - total_ask_size) / (total_bid_size + total_ask_size)
- else:
- imbalance = 0.0
-
- self.order_book_imbalances[symbol].append({
- 'timestamp': snapshot.timestamp,
- 'imbalance': imbalance,
- 'bid_size': total_bid_size,
- 'ask_size': total_ask_size
- })
-
- except Exception as e:
- logger.error(f"Error calculating imbalance for {symbol}: {e}")
-
- def _update_order_heatmap(self, symbol: str, snapshot: OrderBookSnapshot):
- """Update order size heatmap matrix"""
- try:
- # Create heatmap entry
- heatmap_entry = {
- 'timestamp': snapshot.timestamp,
- 'mid_price': snapshot.mid_price,
- 'levels': {}
- }
-
- # Add bid levels
- for level in snapshot.bids:
- price_offset = level.price - snapshot.mid_price
- heatmap_entry['levels'][price_offset] = {
- 'side': 'bid',
- 'size': level.size,
- 'price': level.price
- }
-
- # Add ask levels
- for level in snapshot.asks:
- price_offset = level.price - snapshot.mid_price
- heatmap_entry['levels'][price_offset] = {
- 'side': 'ask',
- 'size': level.size,
- 'price': level.price
- }
-
- self.order_heatmaps[symbol].append(heatmap_entry)
-
- # Clean old entries (keep 10 minutes)
- cutoff_time = snapshot.timestamp - self.heatmap_window
- while (self.order_heatmaps[symbol] and
- self.order_heatmaps[symbol][0]['timestamp'] < cutoff_time):
- self.order_heatmaps[symbol].popleft()
-
- except Exception as e:
- logger.error(f"Error updating heatmap for {symbol}: {e}")
-
- def _update_volume_profile(self, symbol: str, price: float, quantity: float, is_buyer_maker: bool):
- """Update volume profile with new trade"""
- try:
- # Initialize if not exists
- if symbol not in self.volume_profiles:
- self.volume_profiles[symbol] = []
-
- # Find or create price level
- price_level = None
- for level in self.volume_profiles[symbol]:
- if abs(level.price - price) < 0.01: # Price tolerance
- price_level = level
- break
-
- if not price_level:
- price_level = VolumeProfileLevel(
- price=price,
- volume=0.0,
- buy_volume=0.0,
- sell_volume=0.0,
- trades_count=0,
- vwap=price
- )
- self.volume_profiles[symbol].append(price_level)
-
- # Update volume profile
- volume = price * quantity
- old_total = price_level.volume
-
- price_level.volume += volume
- price_level.trades_count += 1
-
- if is_buyer_maker:
- price_level.sell_volume += volume
- else:
- price_level.buy_volume += volume
-
- # Update VWAP
- if price_level.volume > 0:
- price_level.vwap = ((price_level.vwap * old_total) + (price * volume)) / price_level.volume
-
- except Exception as e:
- logger.error(f"Error updating volume profile for {symbol}: {e}")
-
- async def _analyze_order_flow(self, symbol: str, timestamp: datetime, price: float,
- quantity: float, is_buyer_maker: bool):
- """Analyze order flow for sweep and absorption patterns"""
- try:
- # Get recent order book data
- if symbol not in self.order_book_history or not self.order_book_history[symbol]:
- return
-
- recent_snapshots = list(self.order_book_history[symbol])[-10:] # Last 10 snapshots
-
- # Check for order book sweeps
- sweep_signal = self._detect_order_sweep(symbol, recent_snapshots, price, quantity, is_buyer_maker)
- if sweep_signal:
- self.flow_signals[symbol].append(sweep_signal)
- await self._notify_flow_signal(symbol, sweep_signal)
-
- # Check for absorption patterns
- absorption_signal = self._detect_absorption(symbol, recent_snapshots, price, quantity)
- if absorption_signal:
- self.flow_signals[symbol].append(absorption_signal)
- await self._notify_flow_signal(symbol, absorption_signal)
-
- # Check for momentum trades
- momentum_signal = self._detect_momentum_trade(symbol, price, quantity, is_buyer_maker)
- if momentum_signal:
- self.flow_signals[symbol].append(momentum_signal)
- await self._notify_flow_signal(symbol, momentum_signal)
-
- except Exception as e:
- logger.error(f"Error analyzing order flow for {symbol}: {e}")
-
- def _detect_order_sweep(self, symbol: str, snapshots: List[OrderBookSnapshot],
- price: float, quantity: float, is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
- """Detect order book sweep patterns"""
- try:
- if len(snapshots) < 2:
- return None
-
- before_snapshot = snapshots[-2]
- after_snapshot = snapshots[-1]
-
- # Check if multiple levels were consumed
- if is_buyer_maker: # Sell order, check ask side
- levels_consumed = 0
- total_consumed_size = 0
-
- for level in before_snapshot.asks[:5]: # Check top 5 levels
- if level.price <= price:
- levels_consumed += 1
- total_consumed_size += level.size
-
- if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
- confidence = min(0.9, levels_consumed / 5.0 + 0.3)
-
- return OrderFlowSignal(
- timestamp=datetime.now(),
- signal_type='sweep',
- price=price,
- volume=quantity * price,
- confidence=confidence,
- description=f"Sell sweep: {levels_consumed} levels, {total_consumed_size:.2f} size"
- )
- else: # Buy order, check bid side
- levels_consumed = 0
- total_consumed_size = 0
-
- for level in before_snapshot.bids[:5]:
- if level.price >= price:
- levels_consumed += 1
- total_consumed_size += level.size
-
- if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
- confidence = min(0.9, levels_consumed / 5.0 + 0.3)
-
- return OrderFlowSignal(
- timestamp=datetime.now(),
- signal_type='sweep',
- price=price,
- volume=quantity * price,
- confidence=confidence,
- description=f"Buy sweep: {levels_consumed} levels, {total_consumed_size:.2f} size"
- )
-
- return None
-
- except Exception as e:
- logger.error(f"Error detecting sweep for {symbol}: {e}")
- return None
-
- def _detect_absorption(self, symbol: str, snapshots: List[OrderBookSnapshot],
- price: float, quantity: float) -> Optional[OrderFlowSignal]:
- """Detect absorption patterns where large orders are absorbed without price movement"""
- try:
- if len(snapshots) < 3:
- return None
-
- # Check if large order was absorbed with minimal price impact
- volume_threshold = 10000 # $10K minimum for absorption
- price_impact_threshold = 0.001 # 0.1% max price impact
-
- trade_value = price * quantity
- if trade_value < volume_threshold:
- return None
-
- # Calculate price impact
- price_before = snapshots[-3].mid_price
- price_after = snapshots[-1].mid_price
- price_impact = abs(price_after - price_before) / price_before
-
- if price_impact < price_impact_threshold:
- confidence = min(0.8, (trade_value / 50000) * 0.5 + 0.3) # Scale with size
-
- return OrderFlowSignal(
- timestamp=datetime.now(),
- signal_type='absorption',
- price=price,
- volume=trade_value,
- confidence=confidence,
- description=f"Absorption: ${trade_value:.0f} with {price_impact*100:.3f}% impact"
- )
-
- return None
-
- except Exception as e:
- logger.error(f"Error detecting absorption for {symbol}: {e}")
- return None
-
- def _detect_momentum_trade(self, symbol: str, price: float, quantity: float,
- is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
- """Detect momentum trades based on size and direction"""
- try:
- trade_value = price * quantity
- momentum_threshold = 25000 # $25K minimum for momentum classification
-
- if trade_value < momentum_threshold:
- return None
-
- # Calculate confidence based on trade size
- confidence = min(0.9, trade_value / 100000 * 0.6 + 0.3)
-
- direction = "sell" if is_buyer_maker else "buy"
-
- return OrderFlowSignal(
- timestamp=datetime.now(),
- signal_type='momentum',
- price=price,
- volume=trade_value,
- confidence=confidence,
- description=f"Large {direction}: ${trade_value:.0f}"
- )
-
- except Exception as e:
- logger.error(f"Error detecting momentum for {symbol}: {e}")
- return None
-
- async def _notify_flow_signal(self, symbol: str, signal: OrderFlowSignal):
- """Notify CNN and DQN models of order flow signals"""
- try:
- signal_data = {
- 'signal_type': signal.signal_type,
- 'price': signal.price,
- 'volume': signal.volume,
- 'confidence': signal.confidence,
- 'timestamp': signal.timestamp,
- 'description': signal.description
- }
-
- # Notify CNN callbacks
- for callback in self.cnn_callbacks:
- try:
- callback(symbol, signal_data)
- except Exception as e:
- logger.warning(f"Error in CNN callback: {e}")
-
- # Notify DQN callbacks
- for callback in self.dqn_callbacks:
- try:
- callback(symbol, signal_data)
- except Exception as e:
- logger.warning(f"Error in DQN callback: {e}")
-
- except Exception as e:
- logger.error(f"Error notifying flow signal: {e}")
-
- async def _continuous_analysis(self):
- """Continuous analysis of market microstructure"""
- while self.is_streaming:
- try:
- await asyncio.sleep(1) # Analyze every second
-
- for symbol in self.symbols:
- # Generate CNN features
- cnn_features = self.get_cnn_features(symbol)
- if cnn_features is not None:
- for callback in self.cnn_callbacks:
- try:
- callback(symbol, {'features': cnn_features, 'type': 'orderbook'})
- except Exception as e:
- logger.warning(f"Error in CNN feature callback: {e}")
-
- # Generate DQN state features
- dqn_features = self.get_dqn_state_features(symbol)
- if dqn_features is not None:
- for callback in self.dqn_callbacks:
- try:
- callback(symbol, {'state': dqn_features, 'type': 'orderbook'})
- except Exception as e:
- logger.warning(f"Error in DQN state callback: {e}")
-
- except Exception as e:
- logger.error(f"Error in continuous analysis: {e}")
- await asyncio.sleep(5)
-
- def get_cnn_features(self, symbol: str) -> Optional[np.ndarray]:
- """Generate CNN input features from order book data"""
- try:
- if symbol not in self.order_books:
- return None
-
- snapshot = self.order_books[symbol]
- features = []
-
- # Order book features (40 features: 20 levels x 2 sides)
- for i in range(min(20, len(snapshot.bids))):
- bid = snapshot.bids[i]
- features.append(bid.size)
- features.append(bid.price - snapshot.mid_price) # Price offset
-
- # Pad if not enough bid levels
- while len(features) < 40:
- features.extend([0.0, 0.0])
-
- for i in range(min(20, len(snapshot.asks))):
- ask = snapshot.asks[i]
- features.append(ask.size)
- features.append(ask.price - snapshot.mid_price) # Price offset
-
- # Pad if not enough ask levels
- while len(features) < 80:
- features.extend([0.0, 0.0])
-
- # Liquidity metrics (10 features)
- metrics = self.liquidity_metrics.get(symbol, {})
- features.extend([
- metrics.get('total_bid_size', 0.0),
- metrics.get('total_ask_size', 0.0),
- metrics.get('liquidity_ratio', 1.0),
- metrics.get('spread_bps', 0.0),
- snapshot.spread,
- metrics.get('weighted_mid', snapshot.mid_price) - snapshot.mid_price,
- len(snapshot.bids),
- len(snapshot.asks),
- snapshot.mid_price,
- time.time() % 86400 # Time of day
- ])
-
- # Order book imbalance features (5 features)
- if self.order_book_imbalances[symbol]:
- latest_imbalance = self.order_book_imbalances[symbol][-1]
- features.extend([
- latest_imbalance['imbalance'],
- latest_imbalance['bid_size'],
- latest_imbalance['ask_size'],
- latest_imbalance['bid_size'] + latest_imbalance['ask_size'],
- abs(latest_imbalance['imbalance'])
- ])
- else:
- features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
-
- # Flow signal features (5 features)
- recent_signals = [s for s in self.flow_signals[symbol]
- if (datetime.now() - s.timestamp).seconds < 60]
-
- sweep_count = sum(1 for s in recent_signals if s.signal_type == 'sweep')
- absorption_count = sum(1 for s in recent_signals if s.signal_type == 'absorption')
- momentum_count = sum(1 for s in recent_signals if s.signal_type == 'momentum')
-
- max_confidence = max([s.confidence for s in recent_signals], default=0.0)
- total_flow_volume = sum(s.volume for s in recent_signals)
-
- features.extend([
- sweep_count,
- absorption_count,
- momentum_count,
- max_confidence,
- total_flow_volume
- ])
-
- return np.array(features, dtype=np.float32)
-
- except Exception as e:
- logger.error(f"Error generating CNN features for {symbol}: {e}")
- return None
-
- def get_dqn_state_features(self, symbol: str) -> Optional[np.ndarray]:
- """Generate DQN state features from order book data"""
- try:
- if symbol not in self.order_books:
- return None
-
- snapshot = self.order_books[symbol]
- state_features = []
-
- # Normalized order book state (20 features)
- total_bid_size = sum(level.size for level in snapshot.bids[:10])
- total_ask_size = sum(level.size for level in snapshot.asks[:10])
- total_size = total_bid_size + total_ask_size
-
- if total_size > 0:
- for i in range(min(10, len(snapshot.bids))):
- state_features.append(snapshot.bids[i].size / total_size)
-
- # Pad bids
- while len(state_features) < 10:
- state_features.append(0.0)
-
- for i in range(min(10, len(snapshot.asks))):
- state_features.append(snapshot.asks[i].size / total_size)
-
- # Pad asks
- while len(state_features) < 20:
- state_features.append(0.0)
- else:
- state_features.extend([0.0] * 20)
-
- # Market state indicators (10 features)
- metrics = self.liquidity_metrics.get(symbol, {})
-
- # Normalize spread as percentage
- spread_pct = (snapshot.spread / snapshot.mid_price) if snapshot.mid_price > 0 else 0
-
- # Liquidity imbalance
- liquidity_ratio = metrics.get('liquidity_ratio', 1.0)
- liquidity_imbalance = (liquidity_ratio - 1) / (liquidity_ratio + 1)
-
- # Recent flow signals strength
- recent_signals = [s for s in self.flow_signals[symbol]
- if (datetime.now() - s.timestamp).seconds < 30]
- flow_strength = sum(s.confidence for s in recent_signals) / max(len(recent_signals), 1)
-
- # Price volatility (from recent snapshots)
- if len(self.order_book_history[symbol]) >= 10:
- recent_prices = [s.mid_price for s in list(self.order_book_history[symbol])[-10:]]
- price_volatility = np.std(recent_prices) / np.mean(recent_prices) if recent_prices else 0
- else:
- price_volatility = 0
-
- state_features.extend([
- spread_pct * 10000, # Spread in basis points
- liquidity_imbalance,
- flow_strength,
- price_volatility * 100, # Volatility as percentage
- min(len(snapshot.bids), 20) / 20, # Book depth ratio
- min(len(snapshot.asks), 20) / 20,
- sweep_count / 10 if 'sweep_count' in locals() else 0, # From CNN features
- absorption_count / 5 if 'absorption_count' in locals() else 0,
- momentum_count / 5 if 'momentum_count' in locals() else 0,
- (datetime.now().hour * 60 + datetime.now().minute) / 1440 # Time of day normalized
- ])
-
- return np.array(state_features, dtype=np.float32)
-
- except Exception as e:
- logger.error(f"Error generating DQN features for {symbol}: {e}")
- return None
-
- def get_order_heatmap_matrix(self, symbol: str, levels: int = 40) -> Optional[np.ndarray]:
- """Generate order size heatmap matrix for dashboard visualization"""
- try:
- if symbol not in self.order_heatmaps or not self.order_heatmaps[symbol]:
- return None
-
- # Create price levels around current mid price
- current_snapshot = self.order_books.get(symbol)
- if not current_snapshot:
- return None
-
- mid_price = current_snapshot.mid_price
- price_step = mid_price * 0.0001 # 1 basis point steps
-
- # Create matrix: time x price levels
- time_window = min(600, len(self.order_heatmaps[symbol])) # 10 minutes max
- heatmap_matrix = np.zeros((time_window, levels))
-
- # Fill matrix with order sizes
- for t, entry in enumerate(list(self.order_heatmaps[symbol])[-time_window:]):
- for price_offset, level_data in entry['levels'].items():
- # Convert price offset to matrix index
- level_idx = int((price_offset + (levels/2) * price_step) / price_step)
-
- if 0 <= level_idx < levels:
- size_weight = 1.0 if level_data['side'] == 'bid' else -1.0
- heatmap_matrix[t, level_idx] = level_data['size'] * size_weight
-
- return heatmap_matrix
-
- except Exception as e:
- logger.error(f"Error generating heatmap matrix for {symbol}: {e}")
- return None
-
- def get_volume_profile_data(self, symbol: str) -> Optional[List[Dict]]:
- """Get session volume profile data"""
- try:
- if symbol not in self.volume_profiles:
- return None
-
- profile_data = []
- for level in sorted(self.volume_profiles[symbol], key=lambda x: x.price):
- profile_data.append({
- 'price': level.price,
- 'volume': level.volume,
- 'buy_volume': level.buy_volume,
- 'sell_volume': level.sell_volume,
- 'trades_count': level.trades_count,
- 'vwap': level.vwap,
- 'net_volume': level.buy_volume - level.sell_volume
- })
-
- return profile_data
-
- except Exception as e:
- logger.error(f"Error getting volume profile for {symbol}: {e}")
- return None
-
- def get_current_order_book(self, symbol: str) -> Optional[Dict]:
- """Get current order book snapshot"""
- try:
- if symbol not in self.order_books:
- return None
-
- snapshot = self.order_books[symbol]
-
- return {
- 'timestamp': snapshot.timestamp.isoformat(),
- 'symbol': symbol,
- 'mid_price': snapshot.mid_price,
- 'spread': snapshot.spread,
- 'bids': [{'price': l.price, 'size': l.size} for l in snapshot.bids[:20]],
- 'asks': [{'price': l.price, 'size': l.size} for l in snapshot.asks[:20]],
- 'liquidity_metrics': self.liquidity_metrics.get(symbol, {}),
- 'recent_signals': [
- {
- 'type': s.signal_type,
- 'price': s.price,
- 'volume': s.volume,
- 'confidence': s.confidence,
- 'timestamp': s.timestamp.isoformat()
- }
- for s in list(self.flow_signals[symbol])[-5:] # Last 5 signals
- ]
- }
-
- except Exception as e:
- logger.error(f"Error getting order book for {symbol}: {e}")
- return None
-
- def get_statistics(self) -> Dict[str, Any]:
- """Get provider statistics"""
- return {
- 'symbols': self.symbols,
- 'is_streaming': self.is_streaming,
- 'update_counts': dict(self.update_counts),
- 'last_update_times': {k: v.isoformat() if isinstance(v, datetime) else v
- for k, v in self.last_update_times.items()},
- 'order_books_active': len(self.order_books),
- 'flow_signals_total': sum(len(signals) for signals in self.flow_signals.values()),
- 'cnn_callbacks': len(self.cnn_callbacks),
- 'dqn_callbacks': len(self.dqn_callbacks),
- 'websocket_tasks': len(self.websocket_tasks)
- }
\ No newline at end of file
diff --git a/core/bookmap_integration.py b/core/bookmap_integration.py
deleted file mode 100644
index 904e253..0000000
--- a/core/bookmap_integration.py
+++ /dev/null
@@ -1,1839 +0,0 @@
-"""
-Order Book Analysis Integration (Free Data Sources)
-
-This module provides Bookmap-style functionality using free order book data:
-- Current Order Book (COB) analysis using Binance free depth streams
-- Session Volume Profile (SVP) calculated from trade and depth data
-- Order flow detection (sweeps, absorptions, momentum)
-- Real-time order book heatmap generation
-- Level 2 market depth streaming (20 levels via Binance free API)
-
-Data is fed to CNN and DQN networks for enhanced trading decisions.
-Uses only free data sources - no paid APIs required.
-"""
-
-import asyncio
-import json
-import logging
-import time
-import websockets
-import numpy as np
-import pandas as pd
-from datetime import datetime, timedelta
-from typing import Dict, List, Optional, Tuple, Any, Callable
-from collections import deque, defaultdict
-from dataclasses import dataclass
-from threading import Thread, Lock
-import requests
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class OrderBookLevel:
- """Single order book level"""
- price: float
- size: float
- orders: int
- side: str # 'bid' or 'ask'
- timestamp: datetime
-
-@dataclass
-class OrderBookSnapshot:
- """Complete order book snapshot"""
- symbol: str
- timestamp: datetime
- bids: List[OrderBookLevel]
- asks: List[OrderBookLevel]
- spread: float
- mid_price: float
-
-@dataclass
-class VolumeProfileLevel:
- """Volume profile level data"""
- price: float
- volume: float
- buy_volume: float
- sell_volume: float
- trades_count: int
- vwap: float
-
-@dataclass
-class OrderFlowSignal:
- """Order flow signal detection"""
- timestamp: datetime
- signal_type: str # 'sweep', 'absorption', 'iceberg', 'momentum'
- price: float
- volume: float
- confidence: float
- description: str
-
-class BookmapIntegration:
- """
- Order book analysis using free data sources
-
- Features:
- - Real-time order book monitoring (Binance free depth@20 levels)
- - Order flow pattern detection
- - Enhanced Session Volume Profile (SVP) analysis
- - Market microstructure metrics
- - CNN/DQN model integration
- - High-frequency order book snapshots for pattern detection
- """
-
- def __init__(self, symbols: List[str] = None):
- self.symbols = symbols or ['ETHUSDT', 'BTCUSDT']
- self.is_streaming = False
-
- # Data storage
- self.order_books: Dict[str, OrderBookSnapshot] = {}
- self.order_book_history: Dict[str, deque] = {}
- self.volume_profiles: Dict[str, List[VolumeProfileLevel]] = {}
- self.flow_signals: Dict[str, deque] = {}
-
- # Enhanced Session Volume Profile tracking
- self.session_start_time = {} # Track session start for each symbol
- self.session_volume_profiles: Dict[str, List[VolumeProfileLevel]] = {}
- self.price_level_cache: Dict[str, Dict[float, VolumeProfileLevel]] = {}
-
- # Heatmap data (10-minute rolling window)
- self.heatmap_window = timedelta(minutes=10)
- self.order_heatmaps: Dict[str, deque] = {}
-
- # Market metrics
- self.liquidity_metrics: Dict[str, Dict] = {}
- self.order_book_imbalances: Dict[str, deque] = {}
-
- # Enhanced Order Flow Analysis
- self.aggressive_passive_ratios: Dict[str, deque] = {}
- self.trade_size_distributions: Dict[str, deque] = {}
- self.market_maker_taker_flows: Dict[str, deque] = {}
- self.order_flow_intensity: Dict[str, deque] = {}
- self.liquidity_consumption_rates: Dict[str, deque] = {}
- self.price_impact_measurements: Dict[str, deque] = {}
-
- # Advanced metrics for institutional vs retail detection
- self.large_order_threshold = 50000 # $50K+ considered institutional
- self.block_trade_threshold = 100000 # $100K+ considered block trades
- self.iceberg_detection_window = 30 # seconds for iceberg detection
- self.trade_clustering_window = 5 # seconds for trade clustering analysis
-
- # Free data source optimization
- self.depth_snapshots_per_second = 10 # 100ms updates = 10 per second
- self.trade_aggregation_window = 1.0 # 1 second aggregation
- self.last_trade_aggregation = {}
-
- # WebSocket connections
- self.websocket_tasks: Dict[str, asyncio.Task] = {}
- self.data_lock = Lock()
-
- # Model callbacks
- self.cnn_callbacks: List[Callable] = []
- self.dqn_callbacks: List[Callable] = []
-
- # Initialize data structures
- for symbol in self.symbols:
- self.order_book_history[symbol] = deque(maxlen=1000)
- self.order_heatmaps[symbol] = deque(maxlen=600) # 10 min at 1s intervals
- self.flow_signals[symbol] = deque(maxlen=500)
- self.order_book_imbalances[symbol] = deque(maxlen=1000)
- self.session_volume_profiles[symbol] = []
- self.price_level_cache[symbol] = {}
- self.session_start_time[symbol] = datetime.now()
- self.last_trade_aggregation[symbol] = datetime.now()
-
- # Enhanced order flow analysis buffers
- self.aggressive_passive_ratios[symbol] = deque(maxlen=300) # 5 minutes at 1s intervals
- self.trade_size_distributions[symbol] = deque(maxlen=1000)
- self.market_maker_taker_flows[symbol] = deque(maxlen=600)
- self.order_flow_intensity[symbol] = deque(maxlen=300)
- self.liquidity_consumption_rates[symbol] = deque(maxlen=300)
- self.price_impact_measurements[symbol] = deque(maxlen=300)
-
- self.liquidity_metrics[symbol] = {
- 'total_bid_size': 0.0,
- 'total_ask_size': 0.0,
- 'weighted_mid': 0.0,
- 'liquidity_ratio': 1.0,
- 'avg_spread_bps': 0.0,
- 'volume_weighted_spread': 0.0
- }
-
- logger.info(f"Order Book Integration initialized for symbols: {self.symbols}")
- logger.info("Using FREE data sources: Binance WebSocket depth@20 + trades")
-
- def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
- """Add CNN model callback"""
- self.cnn_callbacks.append(callback)
- logger.info(f"Added CNN callback: {len(self.cnn_callbacks)} total")
-
- def add_dqn_callback(self, callback: Callable[[str, Dict], None]):
- """Add DQN model callback"""
- self.dqn_callbacks.append(callback)
- logger.info(f"Added DQN callback: {len(self.dqn_callbacks)} total")
-
- async def start_streaming(self):
- """Start order book data streaming"""
- if self.is_streaming:
- logger.warning("Bookmap streaming already active")
- return
-
- self.is_streaming = True
- logger.info("Starting Bookmap order book streaming")
-
- # Start streams for each symbol
- for symbol in self.symbols:
- # Order book depth stream (20 levels, 100ms updates)
- depth_task = asyncio.create_task(self._stream_order_book_depth(symbol))
- self.websocket_tasks[f"{symbol}_depth"] = depth_task
-
- # Trade stream for order flow analysis
- trade_task = asyncio.create_task(self._stream_trades(symbol))
- self.websocket_tasks[f"{symbol}_trades"] = trade_task
-
- # Aggregated trade stream (for larger trades and better order flow analysis)
- agg_trade_task = asyncio.create_task(self._stream_aggregate_trades(symbol))
- self.websocket_tasks[f"{symbol}_aggTrade"] = agg_trade_task
-
- # 24hr ticker stream (for volume and statistical analysis)
- ticker_task = asyncio.create_task(self._stream_ticker(symbol))
- self.websocket_tasks[f"{symbol}_ticker"] = ticker_task
-
- # Start continuous analysis
- analysis_task = asyncio.create_task(self._continuous_analysis())
- self.websocket_tasks["analysis"] = analysis_task
-
- logger.info(f"Started streaming for {len(self.symbols)} symbols")
-
- async def stop_streaming(self):
- """Stop streaming"""
- if not self.is_streaming:
- return
-
- logger.info("Stopping Bookmap streaming")
- self.is_streaming = False
-
- # Cancel all tasks
- for name, task in self.websocket_tasks.items():
- if not task.done():
- task.cancel()
- try:
- await task
- except asyncio.CancelledError:
- pass
-
- self.websocket_tasks.clear()
- logger.info("Bookmap streaming stopped")
-
- async def _stream_order_book_depth(self, symbol: str):
- """Stream order book depth data"""
- binance_symbol = symbol.lower()
- url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@depth20@100ms"
-
- while self.is_streaming:
- try:
- async with websockets.connect(url) as websocket:
- logger.info(f"Order book depth connected for {symbol}")
-
- async for message in websocket:
- if not self.is_streaming:
- break
-
- try:
- data = json.loads(message)
- await self._process_depth_update(symbol, data)
- except Exception as e:
- logger.warning(f"Error processing depth for {symbol}: {e}")
-
- except Exception as e:
- logger.error(f"Depth WebSocket error for {symbol}: {e}")
- if self.is_streaming:
- await asyncio.sleep(2)
-
- async def _stream_trades(self, symbol: str):
- """Stream individual trade data for order flow analysis"""
- binance_symbol = symbol.lower()
- url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade"
-
- while self.is_streaming:
- try:
- async with websockets.connect(url) as websocket:
- logger.info(f"Trade stream connected for {symbol}")
-
- async for message in websocket:
- if not self.is_streaming:
- break
-
- try:
- data = json.loads(message)
- await self._process_trade_update(symbol, data)
- except Exception as e:
- logger.warning(f"Error processing trade for {symbol}: {e}")
-
- except Exception as e:
- logger.error(f"Trade WebSocket error for {symbol}: {e}")
- if self.is_streaming:
- await asyncio.sleep(2)
-
- async def _stream_aggregate_trades(self, symbol: str):
- """Stream aggregated trade data for institutional order flow detection"""
- binance_symbol = symbol.lower()
- url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@aggTrade"
-
- while self.is_streaming:
- try:
- async with websockets.connect(url) as websocket:
- logger.info(f"Aggregate Trade stream connected for {symbol}")
-
- async for message in websocket:
- if not self.is_streaming:
- break
-
- try:
- data = json.loads(message)
- await self._process_aggregate_trade_update(symbol, data)
- except Exception as e:
- logger.warning(f"Error processing aggTrade for {symbol}: {e}")
-
- except Exception as e:
- logger.error(f"Aggregate Trade WebSocket error for {symbol}: {e}")
- if self.is_streaming:
- await asyncio.sleep(2)
-
- async def _stream_ticker(self, symbol: str):
- """Stream 24hr ticker data for volume analysis"""
- binance_symbol = symbol.lower()
- url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@ticker"
-
- while self.is_streaming:
- try:
- async with websockets.connect(url) as websocket:
- logger.info(f"Ticker stream connected for {symbol}")
-
- async for message in websocket:
- if not self.is_streaming:
- break
-
- try:
- data = json.loads(message)
- await self._process_ticker_update(symbol, data)
- except Exception as e:
- logger.warning(f"Error processing ticker for {symbol}: {e}")
-
- except Exception as e:
- logger.error(f"Ticker WebSocket error for {symbol}: {e}")
- if self.is_streaming:
- await asyncio.sleep(2)
-
- async def _process_depth_update(self, symbol: str, data: Dict):
- """Process order book depth update"""
- try:
- timestamp = datetime.now()
-
- # Parse bids and asks
- bids = []
- asks = []
-
- for bid_data in data.get('bids', []):
- price = float(bid_data[0])
- size = float(bid_data[1])
- bids.append(OrderBookLevel(
- price=price,
- size=size,
- orders=1,
- side='bid',
- timestamp=timestamp
- ))
-
- for ask_data in data.get('asks', []):
- price = float(ask_data[0])
- size = float(ask_data[1])
- asks.append(OrderBookLevel(
- price=price,
- size=size,
- orders=1,
- side='ask',
- timestamp=timestamp
- ))
-
- # Sort levels
- bids.sort(key=lambda x: x.price, reverse=True)
- asks.sort(key=lambda x: x.price)
-
- # Calculate spread and mid price
- if bids and asks:
- best_bid = bids[0].price
- best_ask = asks[0].price
- spread = best_ask - best_bid
- mid_price = (best_bid + best_ask) / 2
- else:
- spread = 0.0
- mid_price = 0.0
-
- # Create snapshot
- snapshot = OrderBookSnapshot(
- symbol=symbol,
- timestamp=timestamp,
- bids=bids,
- asks=asks,
- spread=spread,
- mid_price=mid_price
- )
-
- with self.data_lock:
- self.order_books[symbol] = snapshot
- self.order_book_history[symbol].append(snapshot)
-
- # Update metrics
- self._update_liquidity_metrics(symbol, snapshot)
- self._calculate_order_book_imbalance(symbol, snapshot)
- self._update_order_heatmap(symbol, snapshot)
-
- except Exception as e:
- logger.error(f"Error processing depth update for {symbol}: {e}")
-
- async def _process_trade_update(self, symbol: str, data: Dict):
- """Process individual trade data with enhanced order flow analysis"""
- try:
- timestamp = datetime.fromtimestamp(int(data['T']) / 1000)
- price = float(data['p'])
- quantity = float(data['q'])
- is_buyer_maker = data['m']
- trade_id = data.get('t', '')
-
- # Calculate trade value
- trade_value = price * quantity
-
- # Enhanced order flow analysis
- await self._analyze_enhanced_order_flow(symbol, timestamp, price, quantity, trade_value, is_buyer_maker, 'individual')
-
- # Traditional order flow analysis
- await self._analyze_order_flow(symbol, timestamp, price, quantity, is_buyer_maker)
-
- # Update volume profile
- self._update_volume_profile(symbol, price, quantity, is_buyer_maker)
-
- except Exception as e:
- logger.error(f"Error processing trade for {symbol}: {e}")
-
- async def _process_aggregate_trade_update(self, symbol: str, data: Dict):
- """Process aggregated trade data for institutional flow detection"""
- try:
- timestamp = datetime.fromtimestamp(int(data['T']) / 1000)
- price = float(data['p'])
- quantity = float(data['q'])
- is_buyer_maker = data['m']
- first_trade_id = data.get('f', '')
- last_trade_id = data.get('l', '')
-
- # Calculate trade value and aggregation size
- trade_value = price * quantity
- trade_aggregation_size = int(last_trade_id) - int(first_trade_id) + 1 if first_trade_id and last_trade_id else 1
-
- # Enhanced analysis for aggregated trades (institutional detection)
- await self._analyze_enhanced_order_flow(symbol, timestamp, price, quantity, trade_value, is_buyer_maker, 'aggregated', trade_aggregation_size)
-
- # Detect large block trades and iceberg orders
- await self._detect_institutional_activity(symbol, timestamp, price, quantity, trade_value, trade_aggregation_size, is_buyer_maker)
-
- except Exception as e:
- logger.error(f"Error processing aggregate trade for {symbol}: {e}")
-
- async def _process_ticker_update(self, symbol: str, data: Dict):
- """Process ticker data for volume and statistical analysis"""
- try:
- # Extract relevant ticker data
- volume_24h = float(data.get('v', 0)) # 24hr volume
- quote_volume_24h = float(data.get('q', 0)) # 24hr quote volume
- price_change_24h = float(data.get('P', 0)) # 24hr price change %
- high_24h = float(data.get('h', 0))
- low_24h = float(data.get('l', 0))
- weighted_avg_price = float(data.get('w', 0)) # Weighted average price
-
- # Update volume statistics for relative analysis
- self._update_volume_statistics(symbol, volume_24h, quote_volume_24h, weighted_avg_price)
-
- except Exception as e:
- logger.error(f"Error processing ticker for {symbol}: {e}")
-
- def _update_liquidity_metrics(self, symbol: str, snapshot: OrderBookSnapshot):
- """Update liquidity metrics"""
- try:
- total_bid_size = sum(level.size for level in snapshot.bids)
- total_ask_size = sum(level.size for level in snapshot.asks)
-
- # Weighted mid price
- if snapshot.bids and snapshot.asks:
- bid_weight = total_bid_size / (total_bid_size + total_ask_size)
- ask_weight = total_ask_size / (total_bid_size + total_ask_size)
- weighted_mid = (snapshot.bids[0].price * ask_weight +
- snapshot.asks[0].price * bid_weight)
- else:
- weighted_mid = snapshot.mid_price
-
- # Liquidity ratio
- liquidity_ratio = total_bid_size / total_ask_size if total_ask_size > 0 else 1.0
-
- self.liquidity_metrics[symbol] = {
- 'total_bid_size': total_bid_size,
- 'total_ask_size': total_ask_size,
- 'weighted_mid': weighted_mid,
- 'liquidity_ratio': liquidity_ratio,
- 'spread_bps': (snapshot.spread / snapshot.mid_price) * 10000 if snapshot.mid_price > 0 else 0
- }
-
- except Exception as e:
- logger.error(f"Error updating liquidity metrics for {symbol}: {e}")
-
- def _calculate_order_book_imbalance(self, symbol: str, snapshot: OrderBookSnapshot):
- """Calculate order book imbalance"""
- try:
- if not snapshot.bids or not snapshot.asks:
- return
-
- # Top 5 levels imbalance
- n_levels = min(5, len(snapshot.bids), len(snapshot.asks))
-
- total_bid_size = sum(snapshot.bids[i].size for i in range(n_levels))
- total_ask_size = sum(snapshot.asks[i].size for i in range(n_levels))
-
- if total_bid_size + total_ask_size > 0:
- imbalance = (total_bid_size - total_ask_size) / (total_bid_size + total_ask_size)
- else:
- imbalance = 0.0
-
- self.order_book_imbalances[symbol].append({
- 'timestamp': snapshot.timestamp,
- 'imbalance': imbalance,
- 'bid_size': total_bid_size,
- 'ask_size': total_ask_size
- })
-
- except Exception as e:
- logger.error(f"Error calculating imbalance for {symbol}: {e}")
-
- def _update_order_heatmap(self, symbol: str, snapshot: OrderBookSnapshot):
- """Update order heatmap matrix"""
- try:
- heatmap_entry = {
- 'timestamp': snapshot.timestamp,
- 'mid_price': snapshot.mid_price,
- 'levels': {}
- }
-
- # Add bid levels
- for level in snapshot.bids:
- price_offset = level.price - snapshot.mid_price
- heatmap_entry['levels'][price_offset] = {
- 'side': 'bid',
- 'size': level.size,
- 'price': level.price
- }
-
- # Add ask levels
- for level in snapshot.asks:
- price_offset = level.price - snapshot.mid_price
- heatmap_entry['levels'][price_offset] = {
- 'side': 'ask',
- 'size': level.size,
- 'price': level.price
- }
-
- self.order_heatmaps[symbol].append(heatmap_entry)
-
- # Clean old entries
- cutoff_time = snapshot.timestamp - self.heatmap_window
- while (self.order_heatmaps[symbol] and
- self.order_heatmaps[symbol][0]['timestamp'] < cutoff_time):
- self.order_heatmaps[symbol].popleft()
-
- except Exception as e:
- logger.error(f"Error updating heatmap for {symbol}: {e}")
-
- def _update_volume_profile(self, symbol: str, price: float, quantity: float, is_buyer_maker: bool):
- """Enhanced Session Volume Profile (SVP) update using free data"""
- try:
- # Calculate trade volume in USDT
- volume = price * quantity
-
- # Use price level caching for better performance
- price_key = round(price, 2) # Round to 2 decimal places for price level grouping
-
- # Update session volume profile
- if price_key not in self.price_level_cache[symbol]:
- self.price_level_cache[symbol][price_key] = VolumeProfileLevel(
- price=price_key,
- volume=0.0,
- buy_volume=0.0,
- sell_volume=0.0,
- trades_count=0,
- vwap=price
- )
-
- level = self.price_level_cache[symbol][price_key]
- old_total_volume = level.volume
- old_total_quantity = level.trades_count
-
- # Update volume metrics
- level.volume += volume
- level.trades_count += 1
-
- # Update buy/sell volume breakdown
- if is_buyer_maker:
- level.sell_volume += volume # Market maker is selling
- else:
- level.buy_volume += volume # Market maker is buying
-
- # Calculate Volume Weighted Average Price (VWAP) for this level
- if level.volume > 0:
- level.vwap = ((level.vwap * old_total_volume) + (price * volume)) / level.volume
-
- # Also update the rolling volume profile (last 10 minutes)
- self._update_rolling_volume_profile(symbol, price_key, volume, is_buyer_maker)
-
- # Session reset detection (every 24 hours or major price gaps)
- current_time = datetime.now()
- if self._should_reset_session(symbol, current_time, price):
- self._reset_session_volume_profile(symbol, current_time)
-
- except Exception as e:
- logger.error(f"Error updating Session Volume Profile for {symbol}: {e}")
-
- def _update_rolling_volume_profile(self, symbol: str, price_key: float, volume: float, is_buyer_maker: bool):
- """Update rolling 10-minute volume profile for real-time heatmap"""
- try:
- # Find or create level in regular volume profile
- price_level = None
- for level in self.volume_profiles.get(symbol, []):
- if abs(level.price - price_key) < 0.01:
- price_level = level
- break
-
- if not price_level:
- if symbol not in self.volume_profiles:
- self.volume_profiles[symbol] = []
-
- price_level = VolumeProfileLevel(
- price=price_key,
- volume=0.0,
- buy_volume=0.0,
- sell_volume=0.0,
- trades_count=0,
- vwap=price_key
- )
- self.volume_profiles[symbol].append(price_level)
-
- # Update rolling metrics
- old_volume = price_level.volume
- price_level.volume += volume
- price_level.trades_count += 1
-
- if is_buyer_maker:
- price_level.sell_volume += volume
- else:
- price_level.buy_volume += volume
-
- # Update VWAP
- if price_level.volume > 0:
- price_level.vwap = ((price_level.vwap * old_volume) + (price_key * volume)) / price_level.volume
-
- except Exception as e:
- logger.error(f"Error updating rolling volume profile for {symbol}: {e}")
-
- def _should_reset_session(self, symbol: str, current_time: datetime, current_price: float) -> bool:
- """Determine if session volume profile should be reset"""
- try:
- session_start = self.session_start_time.get(symbol)
- if not session_start:
- return False
-
- # Reset every 24 hours (daily session)
- if (current_time - session_start).total_seconds() > 86400: # 24 hours
- return True
-
- # Reset on major price gaps (> 5% from session VWAP)
- if self.price_level_cache.get(symbol):
- total_volume = sum(level.volume for level in self.price_level_cache[symbol].values())
- if total_volume > 0:
- weighted_price = sum(level.vwap * level.volume for level in self.price_level_cache[symbol].values()) / total_volume
- price_gap = abs(current_price - weighted_price) / weighted_price
- if price_gap > 0.05: # 5% gap
- return True
-
- return False
-
- except Exception as e:
- logger.error(f"Error checking session reset for {symbol}: {e}")
- return False
-
- def _reset_session_volume_profile(self, symbol: str, reset_time: datetime):
- """Reset session volume profile"""
- try:
- logger.info(f"Resetting session volume profile for {symbol}")
- self.session_start_time[symbol] = reset_time
- self.price_level_cache[symbol] = {}
- self.session_volume_profiles[symbol] = []
-
- except Exception as e:
- logger.error(f"Error resetting session volume profile for {symbol}: {e}")
-
- async def _analyze_order_flow(self, symbol: str, timestamp: datetime, price: float,
- quantity: float, is_buyer_maker: bool):
- """Analyze order flow patterns"""
- try:
- if symbol not in self.order_book_history or not self.order_book_history[symbol]:
- return
-
- recent_snapshots = list(self.order_book_history[symbol])[-10:]
-
- # Check for sweeps
- sweep_signal = self._detect_order_sweep(symbol, recent_snapshots, price, quantity, is_buyer_maker)
- if sweep_signal:
- self.flow_signals[symbol].append(sweep_signal)
- await self._notify_flow_signal(symbol, sweep_signal)
-
- # Check for absorption
- absorption_signal = self._detect_absorption(symbol, recent_snapshots, price, quantity)
- if absorption_signal:
- self.flow_signals[symbol].append(absorption_signal)
- await self._notify_flow_signal(symbol, absorption_signal)
-
- # Check for momentum
- momentum_signal = self._detect_momentum_trade(symbol, price, quantity, is_buyer_maker)
- if momentum_signal:
- self.flow_signals[symbol].append(momentum_signal)
- await self._notify_flow_signal(symbol, momentum_signal)
-
- except Exception as e:
- logger.error(f"Error analyzing order flow for {symbol}: {e}")
-
- async def _analyze_enhanced_order_flow(self, symbol: str, timestamp: datetime, price: float,
- quantity: float, trade_value: float, is_buyer_maker: bool,
- trade_type: str, aggregation_size: int = 1):
- """Enhanced order flow analysis with aggressive vs passive ratios"""
- try:
- # Determine if trade is aggressive (taker) or passive (maker)
- is_aggressive = not is_buyer_maker # In Binance data, m=false means buyer is taker (aggressive)
-
- # Calculate aggressive vs passive ratios
- self._update_aggressive_passive_ratio(symbol, timestamp, trade_value, is_aggressive)
-
- # Update trade size distribution
- self._update_trade_size_distribution(symbol, timestamp, trade_value, trade_type)
-
- # Update market maker vs taker flow
- self._update_market_maker_taker_flow(symbol, timestamp, trade_value, is_buyer_maker, is_aggressive)
-
- # Calculate order flow intensity
- self._update_order_flow_intensity(symbol, timestamp, trade_value, aggregation_size)
-
- # Measure liquidity consumption
- await self._measure_liquidity_consumption(symbol, timestamp, price, quantity, trade_value, is_aggressive)
-
- # Measure price impact
- await self._measure_price_impact(symbol, timestamp, price, trade_value, is_aggressive)
-
- except Exception as e:
- logger.error(f"Error in enhanced order flow analysis for {symbol}: {e}")
-
- def _update_aggressive_passive_ratio(self, symbol: str, timestamp: datetime, trade_value: float, is_aggressive: bool):
- """Update aggressive vs passive participant ratios"""
- try:
- current_window = []
- cutoff_time = timestamp - timedelta(seconds=60) # 1-minute window
-
- # Filter recent trades within window
- for entry in self.aggressive_passive_ratios[symbol]:
- if entry['timestamp'] > cutoff_time:
- current_window.append(entry)
-
- # Add current trade
- current_window.append({
- 'timestamp': timestamp,
- 'trade_value': trade_value,
- 'is_aggressive': is_aggressive
- })
-
- # Calculate ratios
- aggressive_volume = sum(t['trade_value'] for t in current_window if t['is_aggressive'])
- passive_volume = sum(t['trade_value'] for t in current_window if not t['is_aggressive'])
- total_volume = aggressive_volume + passive_volume
-
- if total_volume > 0:
- aggressive_ratio = aggressive_volume / total_volume
- passive_ratio = passive_volume / total_volume
-
- ratio_data = {
- 'timestamp': timestamp,
- 'aggressive_ratio': aggressive_ratio,
- 'passive_ratio': passive_ratio,
- 'aggressive_volume': aggressive_volume,
- 'passive_volume': passive_volume,
- 'total_volume': total_volume,
- 'trade_count': len(current_window),
- 'avg_aggressive_size': aggressive_volume / max(1, sum(1 for t in current_window if t['is_aggressive'])),
- 'avg_passive_size': passive_volume / max(1, sum(1 for t in current_window if not t['is_aggressive']))
- }
-
- # Update buffer
- self.aggressive_passive_ratios[symbol].clear()
- self.aggressive_passive_ratios[symbol].extend(current_window)
-
- # Store calculated ratios for use by models
- if not hasattr(self, 'current_flow_ratios'):
- self.current_flow_ratios = {}
- self.current_flow_ratios[symbol] = ratio_data
-
- except Exception as e:
- logger.error(f"Error updating aggressive/passive ratio for {symbol}: {e}")
-
- def _update_trade_size_distribution(self, symbol: str, timestamp: datetime, trade_value: float, trade_type: str):
- """Update trade size distribution for institutional vs retail detection"""
- try:
- # Classify trade size
- if trade_value < 1000:
- size_category = 'micro' # < $1K (retail)
- elif trade_value < 10000:
- size_category = 'small' # $1K-$10K (retail/small institutional)
- elif trade_value < 50000:
- size_category = 'medium' # $10K-$50K (institutional)
- elif trade_value < 100000:
- size_category = 'large' # $50K-$100K (large institutional)
- else:
- size_category = 'block' # > $100K (block trades)
-
- trade_data = {
- 'timestamp': timestamp,
- 'trade_value': trade_value,
- 'trade_type': trade_type,
- 'size_category': size_category,
- 'is_institutional': trade_value >= self.large_order_threshold,
- 'is_block_trade': trade_value >= self.block_trade_threshold
- }
-
- self.trade_size_distributions[symbol].append(trade_data)
-
- except Exception as e:
- logger.error(f"Error updating trade size distribution for {symbol}: {e}")
-
- def _update_market_maker_taker_flow(self, symbol: str, timestamp: datetime, trade_value: float,
- is_buyer_maker: bool, is_aggressive: bool):
- """Update market maker vs taker flow analysis"""
- try:
- flow_data = {
- 'timestamp': timestamp,
- 'trade_value': trade_value,
- 'is_buyer_maker': is_buyer_maker,
- 'is_aggressive': is_aggressive,
- 'flow_direction': 'buy_aggressive' if not is_buyer_maker else 'sell_aggressive',
- 'market_maker_side': 'sell' if is_buyer_maker else 'buy'
- }
-
- self.market_maker_taker_flows[symbol].append(flow_data)
-
- except Exception as e:
- logger.error(f"Error updating market maker/taker flow for {symbol}: {e}")
-
- def _update_order_flow_intensity(self, symbol: str, timestamp: datetime, trade_value: float, aggregation_size: int):
- """Calculate order flow intensity based on trade frequency and size"""
- try:
- # Calculate intensity based on trade value and aggregation
- base_intensity = trade_value / 10000 # Normalize by $10K
- aggregation_intensity = aggregation_size / 10 # Normalize aggregation factor
-
- # Time-based intensity (trades per second)
- recent_trades = [t for t in self.order_flow_intensity[symbol]
- if (timestamp - t['timestamp']).total_seconds() < 10]
- time_intensity = len(recent_trades) / 10 # Trades per second over 10s window
-
- intensity_score = base_intensity * (1 + aggregation_intensity) * (1 + time_intensity)
-
- intensity_data = {
- 'timestamp': timestamp,
- 'intensity_score': intensity_score,
- 'base_intensity': base_intensity,
- 'aggregation_intensity': aggregation_intensity,
- 'time_intensity': time_intensity,
- 'trade_value': trade_value,
- 'aggregation_size': aggregation_size
- }
-
- self.order_flow_intensity[symbol].append(intensity_data)
-
- except Exception as e:
- logger.error(f"Error updating order flow intensity for {symbol}: {e}")
-
- async def _measure_liquidity_consumption(self, symbol: str, timestamp: datetime, price: float,
- quantity: float, trade_value: float, is_aggressive: bool):
- """Measure liquidity consumption rates"""
- try:
- if not is_aggressive:
- return # Only measure for aggressive trades
-
- current_snapshot = self.order_books.get(symbol)
- if not current_snapshot:
- return
-
- # Calculate how much liquidity was consumed
- if price >= current_snapshot.mid_price: # Buy-side consumption
- consumed_liquidity = 0
- for ask_level in current_snapshot.asks:
- if ask_level.price <= price:
- consumed_liquidity += min(ask_level.size, quantity) * ask_level.price
- quantity -= ask_level.size
- if quantity <= 0:
- break
- else: # Sell-side consumption
- consumed_liquidity = 0
- for bid_level in current_snapshot.bids:
- if bid_level.price >= price:
- consumed_liquidity += min(bid_level.size, quantity) * bid_level.price
- quantity -= bid_level.size
- if quantity <= 0:
- break
-
- consumption_rate = consumed_liquidity / trade_value if trade_value > 0 else 0
-
- consumption_data = {
- 'timestamp': timestamp,
- 'price': price,
- 'trade_value': trade_value,
- 'consumed_liquidity': consumed_liquidity,
- 'consumption_rate': consumption_rate,
- 'side': 'buy' if price >= current_snapshot.mid_price else 'sell'
- }
-
- self.liquidity_consumption_rates[symbol].append(consumption_data)
-
- except Exception as e:
- logger.error(f"Error measuring liquidity consumption for {symbol}: {e}")
-
- async def _measure_price_impact(self, symbol: str, timestamp: datetime, price: float,
- trade_value: float, is_aggressive: bool):
- """Measure price impact of trades"""
- try:
- if not is_aggressive:
- return
-
- # Get price before and after (approximated by looking at recent snapshots)
- recent_snapshots = list(self.order_book_history[symbol])[-5:]
- if len(recent_snapshots) < 2:
- return
-
- price_before = recent_snapshots[-2].mid_price
- price_after = recent_snapshots[-1].mid_price
-
- price_impact = abs(price_after - price_before) / price_before if price_before > 0 else 0
- impact_per_dollar = price_impact / (trade_value / 1000000) if trade_value > 0 else 0 # Impact per $1M
-
- impact_data = {
- 'timestamp': timestamp,
- 'trade_price': price,
- 'trade_value': trade_value,
- 'price_before': price_before,
- 'price_after': price_after,
- 'price_impact': price_impact,
- 'impact_per_million': impact_per_dollar,
- 'impact_category': self._categorize_impact(price_impact)
- }
-
- self.price_impact_measurements[symbol].append(impact_data)
-
- except Exception as e:
- logger.error(f"Error measuring price impact for {symbol}: {e}")
-
- def _categorize_impact(self, price_impact: float) -> str:
- """Categorize price impact level"""
- if price_impact < 0.0001: # < 0.01%
- return 'minimal'
- elif price_impact < 0.001: # < 0.1%
- return 'low'
- elif price_impact < 0.005: # < 0.5%
- return 'medium'
- elif price_impact < 0.01: # < 1%
- return 'high'
- else:
- return 'extreme'
-
- async def _detect_institutional_activity(self, symbol: str, timestamp: datetime, price: float,
- quantity: float, trade_value: float, aggregation_size: int,
- is_buyer_maker: bool):
- """Detect institutional trading activity patterns"""
- try:
- # Block trade detection
- if trade_value >= self.block_trade_threshold:
- signal = OrderFlowSignal(
- timestamp=timestamp,
- signal_type='block_trade',
- price=price,
- volume=trade_value,
- confidence=min(0.95, trade_value / 500000), # Higher confidence for larger trades
- description=f"Block trade: ${trade_value:.0f} ({'Buy' if not is_buyer_maker else 'Sell'})"
- )
- self.flow_signals[symbol].append(signal)
- await self._notify_flow_signal(symbol, signal)
-
- # Iceberg order detection (multiple large aggregated trades in sequence)
- await self._detect_iceberg_orders(symbol, timestamp, price, trade_value, aggregation_size, is_buyer_maker)
-
- # High-frequency activity detection
- await self._detect_hft_activity(symbol, timestamp, trade_value, aggregation_size)
-
- except Exception as e:
- logger.error(f"Error detecting institutional activity for {symbol}: {e}")
-
- async def _detect_iceberg_orders(self, symbol: str, timestamp: datetime, price: float,
- trade_value: float, aggregation_size: int, is_buyer_maker: bool):
- """Detect iceberg order patterns"""
- try:
- if trade_value < self.large_order_threshold:
- return
-
- # Look for similar-sized trades in recent history
- cutoff_time = timestamp - timedelta(seconds=self.iceberg_detection_window)
- recent_large_trades = []
-
- for trade_data in self.trade_size_distributions[symbol]:
- if (trade_data['timestamp'] > cutoff_time and
- trade_data['trade_value'] >= self.large_order_threshold):
- recent_large_trades.append(trade_data)
-
- # Iceberg pattern: 3+ large trades with similar sizes
- if len(recent_large_trades) >= 3:
- avg_size = sum(t['trade_value'] for t in recent_large_trades) / len(recent_large_trades)
- size_consistency = all(abs(t['trade_value'] - avg_size) / avg_size < 0.2 for t in recent_large_trades)
-
- if size_consistency:
- total_iceberg_volume = sum(t['trade_value'] for t in recent_large_trades)
- confidence = min(0.9, len(recent_large_trades) / 10 + total_iceberg_volume / 1000000)
-
- signal = OrderFlowSignal(
- timestamp=timestamp,
- signal_type='iceberg',
- price=price,
- volume=total_iceberg_volume,
- confidence=confidence,
- description=f"Iceberg: {len(recent_large_trades)} trades, ${total_iceberg_volume:.0f} total"
- )
- self.flow_signals[symbol].append(signal)
- await self._notify_flow_signal(symbol, signal)
-
- except Exception as e:
- logger.error(f"Error detecting iceberg orders for {symbol}: {e}")
-
- async def _detect_hft_activity(self, symbol: str, timestamp: datetime, trade_value: float, aggregation_size: int):
- """Detect high-frequency trading activity"""
- try:
- # Look for high-frequency patterns (many small trades in rapid succession)
- cutoff_time = timestamp - timedelta(seconds=5)
- recent_trades = [t for t in self.order_flow_intensity[symbol] if t['timestamp'] > cutoff_time]
-
- if len(recent_trades) >= 20: # 20+ trades in 5 seconds
- avg_trade_size = sum(t['trade_value'] for t in recent_trades) / len(recent_trades)
-
- if avg_trade_size < 5000: # Small average trade size suggests HFT
- total_hft_volume = sum(t['trade_value'] for t in recent_trades)
- confidence = min(0.8, len(recent_trades) / 50)
-
- signal = OrderFlowSignal(
- timestamp=timestamp,
- signal_type='hft_activity',
- price=0, # Multiple prices
- volume=total_hft_volume,
- confidence=confidence,
- description=f"HFT: {len(recent_trades)} trades in 5s, avg ${avg_trade_size:.0f}"
- )
- self.flow_signals[symbol].append(signal)
- await self._notify_flow_signal(symbol, signal)
-
- except Exception as e:
- logger.error(f"Error detecting HFT activity for {symbol}: {e}")
-
- def _update_volume_statistics(self, symbol: str, volume_24h: float, quote_volume_24h: float, weighted_avg_price: float):
- """Update volume statistics for relative analysis"""
- try:
- # Store 24h volume data for relative comparisons
- if not hasattr(self, 'volume_stats'):
- self.volume_stats = {}
-
- self.volume_stats[symbol] = {
- 'volume_24h': volume_24h,
- 'quote_volume_24h': quote_volume_24h,
- 'weighted_avg_price': weighted_avg_price,
- 'timestamp': datetime.now()
- }
-
- except Exception as e:
- logger.error(f"Error updating volume statistics for {symbol}: {e}")
-
- def _detect_order_sweep(self, symbol: str, snapshots: List[OrderBookSnapshot],
- price: float, quantity: float, is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
- """Detect order book sweeps"""
- try:
- if len(snapshots) < 2:
- return None
-
- before_snapshot = snapshots[-2]
-
- if is_buyer_maker: # Sell order, check ask side
- levels_consumed = 0
- total_consumed_size = 0
-
- for level in before_snapshot.asks[:5]:
- if level.price <= price:
- levels_consumed += 1
- total_consumed_size += level.size
-
- if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
- confidence = min(0.9, levels_consumed / 5.0 + 0.3)
-
- return OrderFlowSignal(
- timestamp=datetime.now(),
- signal_type='sweep',
- price=price,
- volume=quantity * price,
- confidence=confidence,
- description=f"Sell sweep: {levels_consumed} levels"
- )
- else: # Buy order, check bid side
- levels_consumed = 0
- total_consumed_size = 0
-
- for level in before_snapshot.bids[:5]:
- if level.price >= price:
- levels_consumed += 1
- total_consumed_size += level.size
-
- if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
- confidence = min(0.9, levels_consumed / 5.0 + 0.3)
-
- return OrderFlowSignal(
- timestamp=datetime.now(),
- signal_type='sweep',
- price=price,
- volume=quantity * price,
- confidence=confidence,
- description=f"Buy sweep: {levels_consumed} levels"
- )
-
- return None
-
- except Exception as e:
- logger.error(f"Error detecting sweep for {symbol}: {e}")
- return None
-
- def _detect_absorption(self, symbol: str, snapshots: List[OrderBookSnapshot],
- price: float, quantity: float) -> Optional[OrderFlowSignal]:
- """Detect absorption patterns"""
- try:
- if len(snapshots) < 3:
- return None
-
- volume_threshold = 10000 # $10K minimum
- price_impact_threshold = 0.001 # 0.1% max impact
-
- trade_value = price * quantity
- if trade_value < volume_threshold:
- return None
-
- # Calculate price impact
- price_before = snapshots[-3].mid_price
- price_after = snapshots[-1].mid_price
- price_impact = abs(price_after - price_before) / price_before
-
- if price_impact < price_impact_threshold:
- confidence = min(0.8, (trade_value / 50000) * 0.5 + 0.3)
-
- return OrderFlowSignal(
- timestamp=datetime.now(),
- signal_type='absorption',
- price=price,
- volume=trade_value,
- confidence=confidence,
- description=f"Absorption: ${trade_value:.0f}"
- )
-
- return None
-
- except Exception as e:
- logger.error(f"Error detecting absorption for {symbol}: {e}")
- return None
-
- def _detect_momentum_trade(self, symbol: str, price: float, quantity: float,
- is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
- """Detect momentum trades"""
- try:
- trade_value = price * quantity
- momentum_threshold = 25000 # $25K minimum
-
- if trade_value < momentum_threshold:
- return None
-
- confidence = min(0.9, trade_value / 100000 * 0.6 + 0.3)
- direction = "sell" if is_buyer_maker else "buy"
-
- return OrderFlowSignal(
- timestamp=datetime.now(),
- signal_type='momentum',
- price=price,
- volume=trade_value,
- confidence=confidence,
- description=f"Large {direction}: ${trade_value:.0f}"
- )
-
- except Exception as e:
- logger.error(f"Error detecting momentum for {symbol}: {e}")
- return None
-
- async def _notify_flow_signal(self, symbol: str, signal: OrderFlowSignal):
- """Notify models of flow signals"""
- try:
- signal_data = {
- 'signal_type': signal.signal_type,
- 'price': signal.price,
- 'volume': signal.volume,
- 'confidence': signal.confidence,
- 'timestamp': signal.timestamp,
- 'description': signal.description
- }
-
- # Notify CNN callbacks
- for callback in self.cnn_callbacks:
- try:
- callback(symbol, signal_data)
- except Exception as e:
- logger.warning(f"Error in CNN callback: {e}")
-
- # Notify DQN callbacks
- for callback in self.dqn_callbacks:
- try:
- callback(symbol, signal_data)
- except Exception as e:
- logger.warning(f"Error in DQN callback: {e}")
-
- except Exception as e:
- logger.error(f"Error notifying flow signal: {e}")
-
- async def _continuous_analysis(self):
- """Continuous analysis and model feeding"""
- while self.is_streaming:
- try:
- await asyncio.sleep(1) # Analyze every second
-
- for symbol in self.symbols:
- # Generate features for models
- cnn_features = self.get_cnn_features(symbol)
- if cnn_features is not None:
- for callback in self.cnn_callbacks:
- try:
- callback(symbol, {'features': cnn_features, 'type': 'orderbook'})
- except Exception as e:
- logger.warning(f"Error in CNN feature callback: {e}")
-
- dqn_features = self.get_dqn_state_features(symbol)
- if dqn_features is not None:
- for callback in self.dqn_callbacks:
- try:
- callback(symbol, {'state': dqn_features, 'type': 'orderbook'})
- except Exception as e:
- logger.warning(f"Error in DQN state callback: {e}")
-
- except Exception as e:
- logger.error(f"Error in continuous analysis: {e}")
- await asyncio.sleep(5)
-
- def get_cnn_features(self, symbol: str) -> Optional[np.ndarray]:
- """Generate CNN features from order book data"""
- try:
- if symbol not in self.order_books:
- return None
-
- snapshot = self.order_books[symbol]
- features = []
-
- # Order book features (80 features: 20 levels x 2 sides x 2 values)
- for i in range(min(20, len(snapshot.bids))):
- bid = snapshot.bids[i]
- features.append(bid.size)
- features.append(bid.price - snapshot.mid_price)
-
- # Pad bids
- while len(features) < 40:
- features.extend([0.0, 0.0])
-
- for i in range(min(20, len(snapshot.asks))):
- ask = snapshot.asks[i]
- features.append(ask.size)
- features.append(ask.price - snapshot.mid_price)
-
- # Pad asks
- while len(features) < 80:
- features.extend([0.0, 0.0])
-
- # Liquidity metrics (10 features)
- metrics = self.liquidity_metrics.get(symbol, {})
- features.extend([
- metrics.get('total_bid_size', 0.0),
- metrics.get('total_ask_size', 0.0),
- metrics.get('liquidity_ratio', 1.0),
- metrics.get('spread_bps', 0.0),
- snapshot.spread,
- metrics.get('weighted_mid', snapshot.mid_price) - snapshot.mid_price,
- len(snapshot.bids),
- len(snapshot.asks),
- snapshot.mid_price,
- time.time() % 86400 # Time of day
- ])
-
- # Order book imbalance (5 features)
- if self.order_book_imbalances[symbol]:
- latest_imbalance = self.order_book_imbalances[symbol][-1]
- features.extend([
- latest_imbalance['imbalance'],
- latest_imbalance['bid_size'],
- latest_imbalance['ask_size'],
- latest_imbalance['bid_size'] + latest_imbalance['ask_size'],
- abs(latest_imbalance['imbalance'])
- ])
- else:
- features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
-
- # Enhanced flow signals (15 features)
- recent_signals = [s for s in self.flow_signals[symbol]
- if (datetime.now() - s.timestamp).seconds < 60]
-
- sweep_count = sum(1 for s in recent_signals if s.signal_type == 'sweep')
- absorption_count = sum(1 for s in recent_signals if s.signal_type == 'absorption')
- momentum_count = sum(1 for s in recent_signals if s.signal_type == 'momentum')
- block_count = sum(1 for s in recent_signals if s.signal_type == 'block_trade')
- iceberg_count = sum(1 for s in recent_signals if s.signal_type == 'iceberg')
- hft_count = sum(1 for s in recent_signals if s.signal_type == 'hft_activity')
- max_confidence = max([s.confidence for s in recent_signals], default=0.0)
- total_flow_volume = sum(s.volume for s in recent_signals)
-
- # Enhanced order flow metrics
- flow_metrics = self.get_enhanced_order_flow_metrics(symbol)
- if flow_metrics:
- aggressive_ratio = flow_metrics['aggressive_passive']['aggressive_ratio']
- institutional_ratio = flow_metrics['institutional_retail']['institutional_ratio']
- flow_intensity = flow_metrics['flow_intensity']['current_intensity']
- avg_consumption_rate = flow_metrics['liquidity']['avg_consumption_rate']
- avg_price_impact = flow_metrics['price_impact']['avg_impact'] / 10000 # Normalize from basis points
- buy_pressure = flow_metrics['maker_taker_flow']['buy_pressure']
- sell_pressure = flow_metrics['maker_taker_flow']['sell_pressure']
- else:
- aggressive_ratio = 0.5
- institutional_ratio = 0.5
- flow_intensity = 0.0
- avg_consumption_rate = 0.0
- avg_price_impact = 0.0
- buy_pressure = 0.5
- sell_pressure = 0.5
-
- features.extend([
- sweep_count,
- absorption_count,
- momentum_count,
- block_count,
- iceberg_count,
- hft_count,
- max_confidence,
- total_flow_volume,
- aggressive_ratio,
- institutional_ratio,
- flow_intensity,
- avg_consumption_rate,
- avg_price_impact,
- buy_pressure,
- sell_pressure
- ])
-
- return np.array(features, dtype=np.float32)
-
- except Exception as e:
- logger.error(f"Error generating CNN features for {symbol}: {e}")
- return None
-
- def get_dqn_state_features(self, symbol: str) -> Optional[np.ndarray]:
- """Generate DQN state features"""
- try:
- if symbol not in self.order_books:
- return None
-
- snapshot = self.order_books[symbol]
- state_features = []
-
- # Normalized order book state (20 features)
- total_bid_size = sum(level.size for level in snapshot.bids[:10])
- total_ask_size = sum(level.size for level in snapshot.asks[:10])
- total_size = total_bid_size + total_ask_size
-
- if total_size > 0:
- for i in range(min(10, len(snapshot.bids))):
- state_features.append(snapshot.bids[i].size / total_size)
-
- while len(state_features) < 10:
- state_features.append(0.0)
-
- for i in range(min(10, len(snapshot.asks))):
- state_features.append(snapshot.asks[i].size / total_size)
-
- while len(state_features) < 20:
- state_features.append(0.0)
- else:
- state_features.extend([0.0] * 20)
-
- # Enhanced market state indicators (20 features)
- metrics = self.liquidity_metrics.get(symbol, {})
-
- spread_pct = (snapshot.spread / snapshot.mid_price) if snapshot.mid_price > 0 else 0
- liquidity_ratio = metrics.get('liquidity_ratio', 1.0)
- liquidity_imbalance = (liquidity_ratio - 1) / (liquidity_ratio + 1)
-
- # Flow strength
- recent_signals = [s for s in self.flow_signals[symbol]
- if (datetime.now() - s.timestamp).seconds < 30]
- flow_strength = sum(s.confidence for s in recent_signals) / max(len(recent_signals), 1)
-
- # Price volatility
- if len(self.order_book_history[symbol]) >= 10:
- recent_prices = [s.mid_price for s in list(self.order_book_history[symbol])[-10:]]
- price_volatility = np.std(recent_prices) / np.mean(recent_prices) if recent_prices else 0
- else:
- price_volatility = 0
-
- # Enhanced order flow metrics for DQN
- flow_metrics = self.get_enhanced_order_flow_metrics(symbol)
- if flow_metrics:
- aggressive_ratio = flow_metrics['aggressive_passive']['aggressive_ratio']
- institutional_ratio = flow_metrics['institutional_retail']['institutional_ratio']
- flow_intensity = min(flow_metrics['flow_intensity']['current_intensity'] / 10, 1.0) # Normalize
- consumption_rate = flow_metrics['liquidity']['avg_consumption_rate']
- price_impact = min(flow_metrics['price_impact']['avg_impact'] / 100, 1.0) # Normalize basis points
- buy_pressure = flow_metrics['maker_taker_flow']['buy_pressure']
- sell_pressure = flow_metrics['maker_taker_flow']['sell_pressure']
-
- # Trade size distribution ratios
- size_dist = flow_metrics['size_distribution']
- total_trades = sum(size_dist.values()) or 1
- retail_ratio = (size_dist.get('micro', 0) + size_dist.get('small', 0)) / total_trades
- institutional_trade_ratio = (size_dist.get('large', 0) + size_dist.get('block', 0)) / total_trades
-
- # Recent activity indicators
- block_activity = min(size_dist.get('block', 0) / 10, 1.0) # Normalize
- else:
- aggressive_ratio = 0.5
- institutional_ratio = 0.5
- flow_intensity = 0.0
- consumption_rate = 0.0
- price_impact = 0.0
- buy_pressure = 0.5
- sell_pressure = 0.5
- retail_ratio = 0.5
- institutional_trade_ratio = 0.5
- block_activity = 0.0
-
- state_features.extend([
- spread_pct * 10000, # Spread in basis points
- liquidity_imbalance,
- flow_strength,
- price_volatility * 100,
- min(len(snapshot.bids), 20) / 20,
- min(len(snapshot.asks), 20) / 20,
- len([s for s in recent_signals if s.signal_type == 'sweep']) / 10,
- len([s for s in recent_signals if s.signal_type == 'absorption']) / 5,
- len([s for s in recent_signals if s.signal_type == 'momentum']) / 5,
- (datetime.now().hour * 60 + datetime.now().minute) / 1440,
- # Enhanced order flow state features
- aggressive_ratio,
- institutional_ratio,
- flow_intensity,
- consumption_rate,
- price_impact,
- buy_pressure,
- sell_pressure,
- retail_ratio,
- institutional_trade_ratio,
- block_activity
- ])
-
- return np.array(state_features, dtype=np.float32)
-
- except Exception as e:
- logger.error(f"Error generating DQN features for {symbol}: {e}")
- return None
-
- def get_order_heatmap_matrix(self, symbol: str, levels: int = 40) -> Optional[np.ndarray]:
- """Generate heatmap matrix for visualization"""
- try:
- if symbol not in self.order_heatmaps or not self.order_heatmaps[symbol]:
- return None
-
- current_snapshot = self.order_books.get(symbol)
- if not current_snapshot:
- return None
-
- mid_price = current_snapshot.mid_price
- price_step = mid_price * 0.0001 # 1 basis point steps
-
- # Matrix: time x price levels
- time_window = min(600, len(self.order_heatmaps[symbol]))
- heatmap_matrix = np.zeros((time_window, levels))
-
- # Fill matrix
- for t, entry in enumerate(list(self.order_heatmaps[symbol])[-time_window:]):
- for price_offset, level_data in entry['levels'].items():
- level_idx = int((price_offset + (levels/2) * price_step) / price_step)
-
- if 0 <= level_idx < levels:
- size_weight = 1.0 if level_data['side'] == 'bid' else -1.0
- heatmap_matrix[t, level_idx] = level_data['size'] * size_weight
-
- return heatmap_matrix
-
- except Exception as e:
- logger.error(f"Error generating heatmap matrix for {symbol}: {e}")
- return None
-
- def get_dashboard_data(self, symbol: str) -> Optional[Dict]:
- """Get data for dashboard visualization"""
- try:
- if symbol not in self.order_books:
- return None
-
- snapshot = self.order_books[symbol]
-
- return {
- 'timestamp': snapshot.timestamp.isoformat(),
- 'symbol': symbol,
- 'mid_price': snapshot.mid_price,
- 'spread': snapshot.spread,
- 'bids': [{'price': l.price, 'size': l.size} for l in snapshot.bids[:20]],
- 'asks': [{'price': l.price, 'size': l.size} for l in snapshot.asks[:20]],
- 'liquidity_metrics': self.liquidity_metrics.get(symbol, {}),
- 'volume_profile': self.get_volume_profile_data(symbol),
- 'heatmap_matrix': self.get_order_heatmap_matrix(symbol).tolist() if self.get_order_heatmap_matrix(symbol) is not None else None,
- 'enhanced_order_flow': self.get_enhanced_order_flow_metrics(symbol),
- 'recent_signals': [
- {
- 'type': s.signal_type,
- 'price': s.price,
- 'volume': s.volume,
- 'confidence': s.confidence,
- 'timestamp': s.timestamp.isoformat(),
- 'description': s.description
- }
- for s in list(self.flow_signals[symbol])[-10:]
- ]
- }
-
- except Exception as e:
- logger.error(f"Error getting dashboard data for {symbol}: {e}")
- return None
-
- def get_volume_profile_data(self, symbol: str) -> Optional[List[Dict]]:
- """Get rolling volume profile data (10-minute window)"""
- try:
- if symbol not in self.volume_profiles:
- return None
-
- profile_data = []
- for level in sorted(self.volume_profiles[symbol], key=lambda x: x.price):
- profile_data.append({
- 'price': level.price,
- 'volume': level.volume,
- 'buy_volume': level.buy_volume,
- 'sell_volume': level.sell_volume,
- 'trades_count': level.trades_count,
- 'vwap': level.vwap,
- 'net_volume': level.buy_volume - level.sell_volume
- })
-
- return profile_data
-
- except Exception as e:
- logger.error(f"Error getting volume profile for {symbol}: {e}")
- return None
-
- def get_session_volume_profile_data(self, symbol: str) -> Optional[List[Dict]]:
- """Get Session Volume Profile (SVP) data - full session data"""
- try:
- if symbol not in self.price_level_cache:
- return None
-
- session_data = []
- total_volume = sum(level.volume for level in self.price_level_cache[symbol].values())
-
- for price_key, level in sorted(self.price_level_cache[symbol].items()):
- volume_percentage = (level.volume / total_volume * 100) if total_volume > 0 else 0
-
- session_data.append({
- 'price': level.price,
- 'volume': level.volume,
- 'buy_volume': level.buy_volume,
- 'sell_volume': level.sell_volume,
- 'trades_count': level.trades_count,
- 'vwap': level.vwap,
- 'net_volume': level.buy_volume - level.sell_volume,
- 'volume_percentage': volume_percentage,
- 'is_high_volume_node': volume_percentage > 2.0, # Mark significant price levels
- 'buy_sell_ratio': level.buy_volume / level.sell_volume if level.sell_volume > 0 else float('inf')
- })
-
- return session_data
-
- except Exception as e:
- logger.error(f"Error getting Session Volume Profile for {symbol}: {e}")
- return None
-
- def get_session_statistics(self, symbol: str) -> Optional[Dict]:
- """Get session trading statistics"""
- try:
- if symbol not in self.price_level_cache:
- return None
-
- levels = list(self.price_level_cache[symbol].values())
- if not levels:
- return None
-
- total_volume = sum(level.volume for level in levels)
- total_buy_volume = sum(level.buy_volume for level in levels)
- total_sell_volume = sum(level.sell_volume for level in levels)
- total_trades = sum(level.trades_count for level in levels)
-
- # Calculate session VWAP
- session_vwap = sum(level.vwap * level.volume for level in levels) / total_volume if total_volume > 0 else 0
-
- # Find price extremes
- prices = [level.price for level in levels]
- session_high = max(prices) if prices else 0
- session_low = min(prices) if prices else 0
-
- # Find Point of Control (POC) - price level with highest volume
- poc_level = max(levels, key=lambda x: x.volume) if levels else None
- poc_price = poc_level.price if poc_level else 0
- poc_volume = poc_level.volume if poc_level else 0
-
- # Calculate Value Area (70% of volume around POC)
- sorted_levels = sorted(levels, key=lambda x: x.volume, reverse=True)
- value_area_volume = total_volume * 0.7
- value_area_levels = []
- current_volume = 0
-
- for level in sorted_levels:
- value_area_levels.append(level)
- current_volume += level.volume
- if current_volume >= value_area_volume:
- break
-
- value_area_high = max(level.price for level in value_area_levels) if value_area_levels else 0
- value_area_low = min(level.price for level in value_area_levels) if value_area_levels else 0
-
- session_start = self.session_start_time.get(symbol, datetime.now())
- session_duration = (datetime.now() - session_start).total_seconds() / 3600 # Hours
-
- return {
- 'symbol': symbol,
- 'session_start': session_start.isoformat(),
- 'session_duration_hours': session_duration,
- 'total_volume': total_volume,
- 'total_buy_volume': total_buy_volume,
- 'total_sell_volume': total_sell_volume,
- 'total_trades': total_trades,
- 'session_vwap': session_vwap,
- 'session_high': session_high,
- 'session_low': session_low,
- 'poc_price': poc_price,
- 'poc_volume': poc_volume,
- 'value_area_high': value_area_high,
- 'value_area_low': value_area_low,
- 'value_area_range': value_area_high - value_area_low,
- 'buy_sell_ratio': total_buy_volume / total_sell_volume if total_sell_volume > 0 else float('inf'),
- 'price_levels_traded': len(levels),
- 'avg_trade_size': total_volume / total_trades if total_trades > 0 else 0
- }
-
- except Exception as e:
- logger.error(f"Error getting session statistics for {symbol}: {e}")
- return None
-
- def get_market_profile_analysis(self, symbol: str) -> Optional[Dict]:
- """Get detailed market profile analysis"""
- try:
- current_snapshot = self.order_books.get(symbol)
- session_stats = self.get_session_statistics(symbol)
- svp_data = self.get_session_volume_profile_data(symbol)
-
- if not all([current_snapshot, session_stats, svp_data]):
- return None
-
- current_price = current_snapshot.mid_price
- session_vwap = session_stats['session_vwap']
- poc_price = session_stats['poc_price']
- value_area_high = session_stats['value_area_high']
- value_area_low = session_stats['value_area_low']
-
- # Market structure analysis
- price_vs_vwap = "above" if current_price > session_vwap else "below"
- price_vs_poc = "above" if current_price > poc_price else "below"
-
- in_value_area = value_area_low <= current_price <= value_area_high
-
- # Find support and resistance levels from high volume nodes
- high_volume_nodes = [item for item in svp_data if item['is_high_volume_node']]
- resistance_levels = [node['price'] for node in high_volume_nodes if node['price'] > current_price]
- support_levels = [node['price'] for node in high_volume_nodes if node['price'] < current_price]
-
- # Sort to get nearest levels
- resistance_levels.sort()
- support_levels.sort(reverse=True)
-
- return {
- 'symbol': symbol,
- 'current_price': current_price,
- 'market_structure': {
- 'price_vs_vwap': price_vs_vwap,
- 'price_vs_poc': price_vs_poc,
- 'in_value_area': in_value_area,
- 'distance_from_vwap_bps': int(abs(current_price - session_vwap) / session_vwap * 10000),
- 'distance_from_poc_bps': int(abs(current_price - poc_price) / poc_price * 10000)
- },
- 'key_levels': {
- 'session_vwap': session_vwap,
- 'poc_price': poc_price,
- 'value_area_high': value_area_high,
- 'value_area_low': value_area_low,
- 'nearest_resistance': resistance_levels[0] if resistance_levels else None,
- 'nearest_support': support_levels[0] if support_levels else None
- },
- 'volume_analysis': {
- 'total_high_volume_nodes': len(high_volume_nodes),
- 'resistance_levels': resistance_levels[:3], # Top 3 resistance
- 'support_levels': support_levels[:3], # Top 3 support
- 'poc_strength': session_stats['poc_volume'] / session_stats['total_volume'] * 100
- },
- 'session_statistics': session_stats
- }
-
- except Exception as e:
- logger.error(f"Error getting market profile analysis for {symbol}: {e}")
- return None
-
- def get_enhanced_order_flow_metrics(self, symbol: str) -> Optional[Dict]:
- """Get enhanced order flow metrics including aggressive vs passive ratios"""
- try:
- if symbol not in self.current_flow_ratios:
- return None
-
- current_ratios = self.current_flow_ratios.get(symbol, {})
-
- # Get recent trade size distribution
- recent_trades = list(self.trade_size_distributions[symbol])[-100:] # Last 100 trades
- if not recent_trades:
- return None
-
- # Calculate institutional vs retail breakdown
- institutional_trades = [t for t in recent_trades if t['is_institutional']]
- retail_trades = [t for t in recent_trades if not t['is_institutional']]
- block_trades = [t for t in recent_trades if t['is_block_trade']]
-
- institutional_volume = sum(t['trade_value'] for t in institutional_trades)
- retail_volume = sum(t['trade_value'] for t in retail_trades)
- total_volume = institutional_volume + retail_volume
-
- # Size category breakdown
- size_breakdown = {
- 'micro': len([t for t in recent_trades if t['size_category'] == 'micro']),
- 'small': len([t for t in recent_trades if t['size_category'] == 'small']),
- 'medium': len([t for t in recent_trades if t['size_category'] == 'medium']),
- 'large': len([t for t in recent_trades if t['size_category'] == 'large']),
- 'block': len([t for t in recent_trades if t['size_category'] == 'block'])
- }
-
- # Get recent order flow intensity
- recent_intensity = list(self.order_flow_intensity[symbol])[-10:]
- avg_intensity = sum(i['intensity_score'] for i in recent_intensity) / max(1, len(recent_intensity))
-
- # Get recent liquidity consumption
- recent_consumption = list(self.liquidity_consumption_rates[symbol])[-20:]
- avg_consumption_rate = sum(c['consumption_rate'] for c in recent_consumption) / max(1, len(recent_consumption))
-
- # Get recent price impact
- recent_impacts = list(self.price_impact_measurements[symbol])[-20:]
- avg_price_impact = sum(i['price_impact'] for i in recent_impacts) / max(1, len(recent_impacts))
-
- # Impact distribution
- impact_distribution = {}
- for impact in recent_impacts:
- category = impact['impact_category']
- impact_distribution[category] = impact_distribution.get(category, 0) + 1
-
- # Market maker vs taker flow analysis
- recent_flows = list(self.market_maker_taker_flows[symbol])[-50:]
- buy_aggressive_volume = sum(f['trade_value'] for f in recent_flows if f['flow_direction'] == 'buy_aggressive')
- sell_aggressive_volume = sum(f['trade_value'] for f in recent_flows if f['flow_direction'] == 'sell_aggressive')
-
- return {
- 'symbol': symbol,
- 'timestamp': datetime.now().isoformat(),
-
- # Aggressive vs Passive Analysis
- 'aggressive_passive': {
- 'aggressive_ratio': current_ratios.get('aggressive_ratio', 0),
- 'passive_ratio': current_ratios.get('passive_ratio', 0),
- 'aggressive_volume': current_ratios.get('aggressive_volume', 0),
- 'passive_volume': current_ratios.get('passive_volume', 0),
- 'avg_aggressive_size': current_ratios.get('avg_aggressive_size', 0),
- 'avg_passive_size': current_ratios.get('avg_passive_size', 0),
- 'trade_count': current_ratios.get('trade_count', 0)
- },
-
- # Institutional vs Retail Analysis
- 'institutional_retail': {
- 'institutional_ratio': institutional_volume / total_volume if total_volume > 0 else 0,
- 'retail_ratio': retail_volume / total_volume if total_volume > 0 else 0,
- 'institutional_volume': institutional_volume,
- 'retail_volume': retail_volume,
- 'institutional_trade_count': len(institutional_trades),
- 'retail_trade_count': len(retail_trades),
- 'block_trade_count': len(block_trades),
- 'avg_institutional_size': institutional_volume / max(1, len(institutional_trades)),
- 'avg_retail_size': retail_volume / max(1, len(retail_trades))
- },
-
- # Trade Size Distribution
- 'size_distribution': size_breakdown,
-
- # Order Flow Intensity
- 'flow_intensity': {
- 'current_intensity': avg_intensity,
- 'intensity_category': 'high' if avg_intensity > 5 else 'medium' if avg_intensity > 2 else 'low'
- },
-
- # Liquidity Analysis
- 'liquidity': {
- 'avg_consumption_rate': avg_consumption_rate,
- 'consumption_category': 'high' if avg_consumption_rate > 0.8 else 'medium' if avg_consumption_rate > 0.5 else 'low'
- },
-
- # Price Impact Analysis
- 'price_impact': {
- 'avg_impact': avg_price_impact * 10000, # in basis points
- 'impact_distribution': impact_distribution,
- 'impact_category': 'high' if avg_price_impact > 0.005 else 'medium' if avg_price_impact > 0.001 else 'low'
- },
-
- # Market Maker vs Taker Flow
- 'maker_taker_flow': {
- 'buy_aggressive_volume': buy_aggressive_volume,
- 'sell_aggressive_volume': sell_aggressive_volume,
- 'buy_pressure': buy_aggressive_volume / (buy_aggressive_volume + sell_aggressive_volume) if (buy_aggressive_volume + sell_aggressive_volume) > 0 else 0.5,
- 'sell_pressure': sell_aggressive_volume / (buy_aggressive_volume + sell_aggressive_volume) if (buy_aggressive_volume + sell_aggressive_volume) > 0 else 0.5
- },
-
- # 24h Volume Statistics (if available)
- 'volume_stats': self.volume_stats.get(symbol, {})
- }
-
- except Exception as e:
- logger.error(f"Error getting enhanced order flow metrics for {symbol}: {e}")
- return None
\ No newline at end of file
diff --git a/core/cnn_training_pipeline.py b/core/cnn_training_pipeline.py
deleted file mode 100644
index 15685df..0000000
--- a/core/cnn_training_pipeline.py
+++ /dev/null
@@ -1,785 +0,0 @@
-"""
-CNN Training Pipeline with Comprehensive Data Storage and Replay
-
-This module implements a robust CNN training pipeline that:
-1. Integrates with the comprehensive training data collection system
-2. Stores all backpropagation data for gradient replay
-3. Enables retraining on most profitable setups
-4. Maintains training episode profitability tracking
-5. Supports both real-time and batch training modes
-
-Key Features:
-- Integration with TrainingDataCollector for data validation
-- Gradient and loss storage for each training step
-- Profitable episode prioritization and replay
-- Comprehensive training metrics and validation
-- Real-time pivot point prediction with outcome tracking
-"""
-
-import asyncio
-import logging
-import numpy as np
-import pandas as pd
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.utils.data import Dataset, DataLoader
-from datetime import datetime, timedelta
-from pathlib import Path
-from typing import Dict, List, Optional, Tuple, Any, Callable
-from dataclasses import dataclass, field
-import json
-import pickle
-from collections import deque, defaultdict
-import threading
-from concurrent.futures import ThreadPoolExecutor
-
-from .training_data_collector import (
- TrainingDataCollector,
- TrainingEpisode,
- ModelInputPackage,
- get_training_data_collector
-)
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class CNNTrainingStep:
- """Single CNN training step with complete backpropagation data"""
- step_id: str
- timestamp: datetime
- episode_id: str
-
- # Input data
- input_features: torch.Tensor
- target_labels: torch.Tensor
-
- # Forward pass results
- model_outputs: Dict[str, torch.Tensor]
- predictions: Dict[str, Any]
- confidence_scores: torch.Tensor
-
- # Loss components
- total_loss: float
- pivot_prediction_loss: float
- confidence_loss: float
- regularization_loss: float
-
- # Backpropagation data
- gradients: Dict[str, torch.Tensor] # Gradients for each parameter
- gradient_norms: Dict[str, float] # Gradient norms for monitoring
-
- # Model state
- model_state_dict: Optional[Dict[str, torch.Tensor]] = None
- optimizer_state: Optional[Dict[str, Any]] = None
-
- # Training metadata
- learning_rate: float = 0.001
- batch_size: int = 32
- epoch: int = 0
-
- # Profitability tracking
- actual_profitability: Optional[float] = None
- prediction_accuracy: Optional[float] = None
- training_value: float = 0.0 # Value of this training step for replay
-
-@dataclass
-class CNNTrainingSession:
- """Complete CNN training session with multiple steps"""
- session_id: str
- start_timestamp: datetime
- end_timestamp: Optional[datetime] = None
-
- # Session configuration
- training_mode: str = 'real_time' # 'real_time', 'batch', 'replay'
- symbol: str = ''
-
- # Training steps
- training_steps: List[CNNTrainingStep] = field(default_factory=list)
-
- # Session metrics
- total_steps: int = 0
- average_loss: float = 0.0
- best_loss: float = float('inf')
- convergence_achieved: bool = False
-
- # Profitability metrics
- profitable_predictions: int = 0
- total_predictions: int = 0
- profitability_rate: float = 0.0
-
- # Session value for replay prioritization
- session_value: float = 0.0
-
-class CNNPivotPredictor(nn.Module):
- """CNN model for pivot point prediction with comprehensive output"""
-
- def __init__(self,
- input_channels: int = 10, # Multiple timeframes
- sequence_length: int = 300, # 300 bars
- hidden_dim: int = 256,
- num_pivot_classes: int = 3, # high, low, none
- dropout_rate: float = 0.2):
-
- super(CNNPivotPredictor, self).__init__()
-
- self.input_channels = input_channels
- self.sequence_length = sequence_length
- self.hidden_dim = hidden_dim
-
- # Convolutional layers for pattern extraction
- self.conv_layers = nn.Sequential(
- # First conv block
- nn.Conv1d(input_channels, 64, kernel_size=7, padding=3),
- nn.BatchNorm1d(64),
- nn.ReLU(),
- nn.Dropout(dropout_rate),
-
- # Second conv block
- nn.Conv1d(64, 128, kernel_size=5, padding=2),
- nn.BatchNorm1d(128),
- nn.ReLU(),
- nn.Dropout(dropout_rate),
-
- # Third conv block
- nn.Conv1d(128, 256, kernel_size=3, padding=1),
- nn.BatchNorm1d(256),
- nn.ReLU(),
- nn.Dropout(dropout_rate),
- )
-
- # LSTM for temporal dependencies
- self.lstm = nn.LSTM(
- input_size=256,
- hidden_size=hidden_dim,
- num_layers=2,
- batch_first=True,
- dropout=dropout_rate,
- bidirectional=True
- )
-
- # Attention mechanism
- self.attention = nn.MultiheadAttention(
- embed_dim=hidden_dim * 2, # Bidirectional LSTM
- num_heads=8,
- dropout=dropout_rate,
- batch_first=True
- )
-
- # Output heads
- self.pivot_classifier = nn.Sequential(
- nn.Linear(hidden_dim * 2, hidden_dim),
- nn.ReLU(),
- nn.Dropout(dropout_rate),
- nn.Linear(hidden_dim, num_pivot_classes)
- )
-
- self.pivot_price_regressor = nn.Sequential(
- nn.Linear(hidden_dim * 2, hidden_dim),
- nn.ReLU(),
- nn.Dropout(dropout_rate),
- nn.Linear(hidden_dim, 1)
- )
-
- self.confidence_head = nn.Sequential(
- nn.Linear(hidden_dim * 2, hidden_dim // 2),
- nn.ReLU(),
- nn.Linear(hidden_dim // 2, 1),
- nn.Sigmoid()
- )
-
- # Initialize weights
- self.apply(self._init_weights)
-
- def _init_weights(self, module):
- """Initialize weights with proper scaling"""
- if isinstance(module, nn.Linear):
- torch.nn.init.xavier_uniform_(module.weight)
- if module.bias is not None:
- torch.nn.init.zeros_(module.bias)
- elif isinstance(module, nn.Conv1d):
- torch.nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
-
- def forward(self, x):
- """
- Forward pass through CNN pivot predictor
-
- Args:
- x: Input tensor [batch_size, input_channels, sequence_length]
-
- Returns:
- Dict containing predictions and hidden states
- """
- batch_size = x.size(0)
-
- # Convolutional feature extraction
- conv_features = self.conv_layers(x) # [batch, 256, sequence_length]
-
- # Prepare for LSTM (transpose to [batch, sequence, features])
- lstm_input = conv_features.transpose(1, 2) # [batch, sequence_length, 256]
-
- # LSTM processing
- lstm_output, (hidden, cell) = self.lstm(lstm_input) # [batch, sequence_length, hidden_dim*2]
-
- # Attention mechanism
- attended_output, attention_weights = self.attention(
- lstm_output, lstm_output, lstm_output
- )
-
- # Use the last timestep for predictions
- final_features = attended_output[:, -1, :] # [batch, hidden_dim*2]
-
- # Generate predictions
- pivot_logits = self.pivot_classifier(final_features)
- pivot_price = self.pivot_price_regressor(final_features)
- confidence = self.confidence_head(final_features)
-
- return {
- 'pivot_logits': pivot_logits,
- 'pivot_price': pivot_price,
- 'confidence': confidence,
- 'hidden_states': final_features,
- 'attention_weights': attention_weights,
- 'conv_features': conv_features,
- 'lstm_output': lstm_output
- }
-
-class CNNTrainingDataset(Dataset):
- """Dataset for CNN training with training episodes"""
-
- def __init__(self, training_episodes: List[TrainingEpisode]):
- self.episodes = training_episodes
- self.valid_episodes = self._validate_episodes()
-
- def _validate_episodes(self) -> List[TrainingEpisode]:
- """Validate and filter episodes for training"""
- valid = []
- for episode in self.episodes:
- try:
- # Check if episode has required data
- if (episode.input_package.cnn_features is not None and
- episode.actual_outcome.outcome_validated):
- valid.append(episode)
- except Exception as e:
- logger.warning(f"Invalid episode {episode.episode_id}: {e}")
-
- logger.info(f"Validated {len(valid)}/{len(self.episodes)} episodes for training")
- return valid
-
- def __len__(self):
- return len(self.valid_episodes)
-
- def __getitem__(self, idx):
- episode = self.valid_episodes[idx]
-
- # Extract features
- features = torch.from_numpy(episode.input_package.cnn_features).float()
-
- # Create labels from actual outcomes
- pivot_class = self._determine_pivot_class(episode.actual_outcome)
- pivot_price = episode.actual_outcome.optimal_exit_price
- confidence_target = episode.actual_outcome.profitability_score
-
- return {
- 'features': features,
- 'pivot_class': torch.tensor(pivot_class, dtype=torch.long),
- 'pivot_price': torch.tensor(pivot_price, dtype=torch.float),
- 'confidence_target': torch.tensor(confidence_target, dtype=torch.float),
- 'episode_id': episode.episode_id,
- 'profitability': episode.actual_outcome.profitability_score
- }
-
- def _determine_pivot_class(self, outcome) -> int:
- """Determine pivot class from outcome"""
- if outcome.price_change_15m > 0.5: # Significant upward movement
- return 0 # High pivot
- elif outcome.price_change_15m < -0.5: # Significant downward movement
- return 1 # Low pivot
- else:
- return 2 # No significant pivot
-
-class CNNTrainer:
- """CNN trainer with comprehensive data storage and replay capabilities"""
-
- def __init__(self,
- model: CNNPivotPredictor,
- device: str = 'cuda',
- learning_rate: float = 0.001,
- storage_dir: str = "cnn_training_storage"):
-
- self.model = model.to(device)
- self.device = device
- self.learning_rate = learning_rate
-
- # Storage
- self.storage_dir = Path(storage_dir)
- self.storage_dir.mkdir(parents=True, exist_ok=True)
-
- # Optimizer
- self.optimizer = torch.optim.AdamW(
- self.model.parameters(),
- lr=learning_rate,
- weight_decay=1e-5
- )
-
- # Learning rate scheduler
- self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
- self.optimizer, mode='min', patience=10, factor=0.5
- )
-
- # Training data collector
- self.data_collector = get_training_data_collector()
-
- # Training sessions storage
- self.training_sessions: List[CNNTrainingSession] = []
- self.current_session: Optional[CNNTrainingSession] = None
-
- # Training statistics
- self.training_stats = {
- 'total_sessions': 0,
- 'total_steps': 0,
- 'best_validation_loss': float('inf'),
- 'profitable_predictions': 0,
- 'total_predictions': 0,
- 'replay_sessions': 0
- }
-
- # Background training
- self.is_training = False
- self.training_thread = None
-
- logger.info(f"CNN Trainer initialized")
- logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
- logger.info(f"Storage directory: {self.storage_dir}")
-
- def start_real_time_training(self, symbol: str):
- """Start real-time training for a symbol"""
- if self.is_training:
- logger.warning("CNN training already running")
- return
-
- self.is_training = True
- self.training_thread = threading.Thread(
- target=self._real_time_training_worker,
- args=(symbol,),
- daemon=True
- )
- self.training_thread.start()
-
- logger.info(f"Started real-time CNN training for {symbol}")
-
- def stop_training(self):
- """Stop training"""
- self.is_training = False
- if self.training_thread:
- self.training_thread.join(timeout=10)
-
- if self.current_session:
- self._finalize_training_session()
-
- logger.info("CNN training stopped")
-
- def _real_time_training_worker(self, symbol: str):
- """Real-time training worker"""
- logger.info(f"Real-time CNN training worker started for {symbol}")
-
- while self.is_training:
- try:
- # Get high-priority episodes for training
- episodes = self.data_collector.get_high_priority_episodes(
- symbol=symbol,
- limit=100,
- min_priority=0.3
- )
-
- if len(episodes) >= 32: # Minimum batch size
- self._train_on_episodes(episodes, training_mode='real_time')
-
- # Wait before next training cycle
- threading.Event().wait(300) # Train every 5 minutes
-
- except Exception as e:
- logger.error(f"Error in real-time training worker: {e}")
- threading.Event().wait(60) # Wait before retrying
-
- logger.info(f"Real-time CNN training worker stopped for {symbol}")
-
- def train_on_profitable_episodes(self,
- symbol: str,
- min_profitability: float = 0.7,
- max_episodes: int = 500) -> Dict[str, Any]:
- """Train specifically on most profitable episodes"""
- try:
- # Get all episodes for symbol
- all_episodes = self.data_collector.training_episodes.get(symbol, [])
-
- # Filter for profitable episodes
- profitable_episodes = [
- ep for ep in all_episodes
- if (ep.actual_outcome.is_profitable and
- ep.actual_outcome.profitability_score >= min_profitability)
- ]
-
- # Sort by profitability and limit
- profitable_episodes.sort(
- key=lambda x: x.actual_outcome.profitability_score,
- reverse=True
- )
- profitable_episodes = profitable_episodes[:max_episodes]
-
- if len(profitable_episodes) < 10:
- logger.warning(f"Insufficient profitable episodes for {symbol}: {len(profitable_episodes)}")
- return {'status': 'insufficient_data', 'episodes_found': len(profitable_episodes)}
-
- # Train on profitable episodes
- results = self._train_on_episodes(
- profitable_episodes,
- training_mode='profitable_replay'
- )
-
- logger.info(f"Trained on {len(profitable_episodes)} profitable episodes for {symbol}")
- return results
-
- except Exception as e:
- logger.error(f"Error training on profitable episodes: {e}")
- return {'status': 'error', 'error': str(e)}
-
- def _train_on_episodes(self,
- episodes: List[TrainingEpisode],
- training_mode: str = 'batch') -> Dict[str, Any]:
- """Train on a batch of episodes with comprehensive data storage"""
- try:
- # Start new training session
- session = CNNTrainingSession(
- session_id=f"{training_mode}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
- start_timestamp=datetime.now(),
- training_mode=training_mode,
- symbol=episodes[0].input_package.symbol if episodes else 'unknown'
- )
- self.current_session = session
-
- # Create dataset and dataloader
- dataset = CNNTrainingDataset(episodes)
- dataloader = DataLoader(
- dataset,
- batch_size=32,
- shuffle=True,
- num_workers=2
- )
-
- # Training loop
- self.model.train()
- total_loss = 0.0
- num_batches = 0
-
- for batch_idx, batch in enumerate(dataloader):
- # Move to device
- features = batch['features'].to(self.device)
- pivot_class = batch['pivot_class'].to(self.device)
- pivot_price = batch['pivot_price'].to(self.device)
- confidence_target = batch['confidence_target'].to(self.device)
-
- # Forward pass
- self.optimizer.zero_grad()
- outputs = self.model(features)
-
- # Calculate losses
- classification_loss = F.cross_entropy(outputs['pivot_logits'], pivot_class)
- regression_loss = F.mse_loss(outputs['pivot_price'].squeeze(), pivot_price)
- confidence_loss = F.binary_cross_entropy(
- outputs['confidence'].squeeze(),
- confidence_target
- )
-
- # Combined loss
- total_batch_loss = classification_loss + 0.5 * regression_loss + 0.3 * confidence_loss
-
- # Backward pass
- total_batch_loss.backward()
-
- # Gradient clipping
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
-
- # Store gradients before optimizer step
- gradients = {}
- gradient_norms = {}
- for name, param in self.model.named_parameters():
- if param.grad is not None:
- gradients[name] = param.grad.clone().detach()
- gradient_norms[name] = param.grad.norm().item()
-
- # Optimizer step
- self.optimizer.step()
-
- # Create training step record
- step = CNNTrainingStep(
- step_id=f"{session.session_id}_step_{batch_idx}",
- timestamp=datetime.now(),
- episode_id=f"batch_{batch_idx}",
- input_features=features.detach().cpu(),
- target_labels=pivot_class.detach().cpu(),
- model_outputs={k: v.detach().cpu() for k, v in outputs.items()},
- predictions=self._extract_predictions(outputs),
- confidence_scores=outputs['confidence'].detach().cpu(),
- total_loss=total_batch_loss.item(),
- pivot_prediction_loss=classification_loss.item(),
- confidence_loss=confidence_loss.item(),
- regularization_loss=0.0,
- gradients=gradients,
- gradient_norms=gradient_norms,
- learning_rate=self.optimizer.param_groups[0]['lr'],
- batch_size=features.size(0)
- )
-
- # Calculate training value for this step
- step.training_value = self._calculate_step_training_value(step, batch)
-
- # Add to session
- session.training_steps.append(step)
-
- total_loss += total_batch_loss.item()
- num_batches += 1
-
- # Log progress
- if batch_idx % 10 == 0:
- logger.debug(f"Batch {batch_idx}: Loss = {total_batch_loss.item():.4f}")
-
- # Finalize session
- session.end_timestamp = datetime.now()
- session.total_steps = num_batches
- session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
- session.best_loss = min(step.total_loss for step in session.training_steps)
-
- # Calculate session value
- session.session_value = self._calculate_session_value(session)
-
- # Update scheduler
- self.scheduler.step(session.average_loss)
-
- # Save session
- self._save_training_session(session)
-
- # Update statistics
- self.training_stats['total_sessions'] += 1
- self.training_stats['total_steps'] += session.total_steps
- if training_mode == 'profitable_replay':
- self.training_stats['replay_sessions'] += 1
-
- logger.info(f"Training session completed: {session.session_id}")
- logger.info(f"Average loss: {session.average_loss:.4f}")
- logger.info(f"Session value: {session.session_value:.3f}")
-
- return {
- 'status': 'success',
- 'session_id': session.session_id,
- 'average_loss': session.average_loss,
- 'total_steps': session.total_steps,
- 'session_value': session.session_value
- }
-
- except Exception as e:
- logger.error(f"Error in training session: {e}")
- return {'status': 'error', 'error': str(e)}
- finally:
- self.current_session = None
-
- def _extract_predictions(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
- """Extract human-readable predictions from model outputs"""
- try:
- pivot_probs = F.softmax(outputs['pivot_logits'], dim=1)
- predicted_class = torch.argmax(pivot_probs, dim=1)
-
- return {
- 'pivot_class': predicted_class.cpu().numpy().tolist(),
- 'pivot_probabilities': pivot_probs.cpu().numpy().tolist(),
- 'pivot_price': outputs['pivot_price'].cpu().numpy().tolist(),
- 'confidence': outputs['confidence'].cpu().numpy().tolist()
- }
- except Exception as e:
- logger.warning(f"Error extracting predictions: {e}")
- return {}
-
- def _calculate_step_training_value(self,
- step: CNNTrainingStep,
- batch: Dict[str, Any]) -> float:
- """Calculate the training value of a step for replay prioritization"""
- try:
- value = 0.0
-
- # Base value from loss (lower loss = higher value)
- if step.total_loss > 0:
- value += 1.0 / (1.0 + step.total_loss)
-
- # Bonus for high profitability episodes in batch
- avg_profitability = torch.mean(batch['profitability']).item()
- value += avg_profitability * 0.3
-
- # Bonus for gradient magnitude (indicates learning)
- avg_grad_norm = np.mean(list(step.gradient_norms.values()))
- value += min(avg_grad_norm / 10.0, 0.2) # Cap at 0.2
-
- return min(value, 1.0)
-
- except Exception as e:
- logger.warning(f"Error calculating step training value: {e}")
- return 0.0
-
- def _calculate_session_value(self, session: CNNTrainingSession) -> float:
- """Calculate overall session value for replay prioritization"""
- try:
- if not session.training_steps:
- return 0.0
-
- # Average step values
- avg_step_value = np.mean([step.training_value for step in session.training_steps])
-
- # Bonus for convergence
- convergence_bonus = 0.0
- if len(session.training_steps) > 10:
- early_loss = np.mean([s.total_loss for s in session.training_steps[:5]])
- late_loss = np.mean([s.total_loss for s in session.training_steps[-5:]])
- if early_loss > late_loss:
- convergence_bonus = min((early_loss - late_loss) / early_loss, 0.3)
-
- # Bonus for profitable replay sessions
- mode_bonus = 0.2 if session.training_mode == 'profitable_replay' else 0.0
-
- return min(avg_step_value + convergence_bonus + mode_bonus, 1.0)
-
- except Exception as e:
- logger.warning(f"Error calculating session value: {e}")
- return 0.0
-
- def _save_training_session(self, session: CNNTrainingSession):
- """Save training session to disk"""
- try:
- session_dir = self.storage_dir / session.symbol / 'sessions'
- session_dir.mkdir(parents=True, exist_ok=True)
-
- # Save full session data
- session_file = session_dir / f"{session.session_id}.pkl"
- with open(session_file, 'wb') as f:
- pickle.dump(session, f)
-
- # Save session metadata
- metadata = {
- 'session_id': session.session_id,
- 'start_timestamp': session.start_timestamp.isoformat(),
- 'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
- 'training_mode': session.training_mode,
- 'symbol': session.symbol,
- 'total_steps': session.total_steps,
- 'average_loss': session.average_loss,
- 'best_loss': session.best_loss,
- 'session_value': session.session_value
- }
-
- metadata_file = session_dir / f"{session.session_id}_metadata.json"
- with open(metadata_file, 'w') as f:
- json.dump(metadata, f, indent=2)
-
- logger.debug(f"Saved training session: {session.session_id}")
-
- except Exception as e:
- logger.error(f"Error saving training session: {e}")
-
- def _finalize_training_session(self):
- """Finalize current training session"""
- if self.current_session:
- self.current_session.end_timestamp = datetime.now()
- self._save_training_session(self.current_session)
- self.training_sessions.append(self.current_session)
- self.current_session = None
-
- def get_training_statistics(self) -> Dict[str, Any]:
- """Get comprehensive training statistics"""
- stats = self.training_stats.copy()
-
- # Add recent session information
- if self.training_sessions:
- recent_sessions = sorted(
- self.training_sessions,
- key=lambda x: x.start_timestamp,
- reverse=True
- )[:10]
-
- stats['recent_sessions'] = [
- {
- 'session_id': s.session_id,
- 'timestamp': s.start_timestamp.isoformat(),
- 'mode': s.training_mode,
- 'average_loss': s.average_loss,
- 'session_value': s.session_value
- }
- for s in recent_sessions
- ]
-
- # Calculate profitability rate
- if stats['total_predictions'] > 0:
- stats['profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
- else:
- stats['profitability_rate'] = 0.0
-
- return stats
-
- def replay_high_value_sessions(self,
- symbol: str,
- min_session_value: float = 0.7,
- max_sessions: int = 10) -> Dict[str, Any]:
- """Replay high-value training sessions"""
- try:
- # Find high-value sessions
- high_value_sessions = [
- s for s in self.training_sessions
- if (s.symbol == symbol and
- s.session_value >= min_session_value)
- ]
-
- # Sort by value and limit
- high_value_sessions.sort(key=lambda x: x.session_value, reverse=True)
- high_value_sessions = high_value_sessions[:max_sessions]
-
- if not high_value_sessions:
- return {'status': 'no_high_value_sessions', 'sessions_found': 0}
-
- # Replay sessions
- total_replayed = 0
- for session in high_value_sessions:
- # Extract episodes from session steps
- episode_ids = list(set(step.episode_id for step in session.training_steps))
-
- # Get corresponding episodes
- episodes = []
- for episode_id in episode_ids:
- # Find episode in data collector
- for ep in self.data_collector.training_episodes.get(symbol, []):
- if ep.episode_id == episode_id:
- episodes.append(ep)
- break
-
- if episodes:
- self._train_on_episodes(episodes, training_mode='high_value_replay')
- total_replayed += 1
-
- logger.info(f"Replayed {total_replayed} high-value sessions for {symbol}")
- return {
- 'status': 'success',
- 'sessions_replayed': total_replayed,
- 'sessions_found': len(high_value_sessions)
- }
-
- except Exception as e:
- logger.error(f"Error replaying high-value sessions: {e}")
- return {'status': 'error', 'error': str(e)}
-
-# Global instance
-cnn_trainer = None
-
-def get_cnn_trainer(model: CNNPivotPredictor = None) -> CNNTrainer:
- """Get global CNN trainer instance"""
- global cnn_trainer
- if cnn_trainer is None:
- if model is None:
- model = CNNPivotPredictor()
- cnn_trainer = CNNTrainer(model)
- return cnn_trainer
\ No newline at end of file
diff --git a/core/enhanced_cnn_adapter.py b/core/enhanced_cnn_adapter.py
deleted file mode 100644
index 7e9f42c..0000000
--- a/core/enhanced_cnn_adapter.py
+++ /dev/null
@@ -1,864 +0,0 @@
-# """
-# Enhanced CNN Adapter for Standardized Input Format
-
-# This module provides an adapter for the EnhancedCNN model to work with the standardized
-# BaseDataInput format, enabling seamless integration with the multi-modal trading system.
-# """
-
-# import torch
-# import numpy as np
-# import logging
-# import os
-# import random
-# from datetime import datetime, timedelta
-# from typing import Dict, List, Optional, Tuple, Any, Union
-# from threading import Lock
-
-# from .data_models import BaseDataInput, ModelOutput, create_model_output
-# from NN.models.enhanced_cnn import EnhancedCNN
-# from utils.inference_logger import log_model_inference
-
-# logger = logging.getLogger(__name__)
-
-# class EnhancedCNNAdapter:
-# """
-# Adapter for EnhancedCNN model to work with standardized BaseDataInput format
-
-# This adapter:
-# 1. Converts BaseDataInput to the format expected by EnhancedCNN
-# 2. Processes model outputs to create standardized ModelOutput
-# 3. Manages model training with collected data
-# 4. Handles checkpoint management
-# """
-
-# def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"):
-# """
-# Initialize the EnhancedCNN adapter
-
-# Args:
-# model_path: Path to load model from, if None a new model is created
-# checkpoint_dir: Directory to save checkpoints to
-# """
-# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-# self.model = None
-# self.model_path = model_path
-# self.checkpoint_dir = checkpoint_dir
-# self.training_lock = Lock()
-# self.training_data = []
-# self.max_training_samples = 10000
-# self.batch_size = 32
-# self.learning_rate = 0.0001
-# self.model_name = "enhanced_cnn"
-
-# # Enhanced metrics tracking
-# self.last_inference_time = None
-# self.last_inference_duration = 0.0
-# self.last_prediction_output = None
-# self.last_training_time = None
-# self.last_training_duration = 0.0
-# self.last_training_loss = 0.0
-# self.inference_count = 0
-# self.training_count = 0
-
-# # Create checkpoint directory if it doesn't exist
-# os.makedirs(checkpoint_dir, exist_ok=True)
-
-# # Initialize the model
-# self._initialize_model()
-
-# # Load checkpoint if available
-# if model_path and os.path.exists(model_path):
-# self._load_checkpoint(model_path)
-# else:
-# self._load_best_checkpoint()
-
-# # Final device check and move
-# self._ensure_model_on_device()
-
-# logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
-
-# def _create_realistic_synthetic_features(self, symbol: str) -> torch.Tensor:
-# """Create realistic synthetic features instead of random data"""
-# try:
-# # Create realistic market-like features
-# features = torch.zeros(7850, dtype=torch.float32, device=self.device)
-
-# # OHLCV features (6000 features: 300 frames x 4 timeframes x 5 features)
-# ohlcv_start = 0
-# for timeframe_idx in range(4): # 1s, 1m, 1h, 1d
-# base_price = 3500.0 + timeframe_idx * 10 # Slight variation per timeframe
-# for frame_idx in range(300):
-# # Create realistic price movement
-# price_change = torch.sin(torch.tensor(frame_idx * 0.1)) * 0.01 # Cyclical movement
-# current_price = base_price * (1 + price_change)
-
-# # Realistic OHLCV values
-# open_price = current_price
-# high_price = current_price * torch.uniform(1.0, 1.005)
-# low_price = current_price * torch.uniform(0.995, 1.0)
-# close_price = current_price * torch.uniform(0.998, 1.002)
-# volume = torch.uniform(500.0, 2000.0)
-
-# # Set features
-# feature_idx = ohlcv_start + frame_idx * 5 + timeframe_idx * 1500
-# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
-
-# # BTC OHLCV features (1500 features: 300 frames x 5 features)
-# btc_start = 6000
-# btc_base_price = 50000.0
-# for frame_idx in range(300):
-# price_change = torch.sin(torch.tensor(frame_idx * 0.05)) * 0.02
-# current_price = btc_base_price * (1 + price_change)
-
-# open_price = current_price
-# high_price = current_price * torch.uniform(1.0, 1.01)
-# low_price = current_price * torch.uniform(0.99, 1.0)
-# close_price = current_price * torch.uniform(0.995, 1.005)
-# volume = torch.uniform(100.0, 500.0)
-
-# feature_idx = btc_start + frame_idx * 5
-# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
-
-# # COB features (200 features) - realistic order book data
-# cob_start = 7500
-# for i in range(200):
-# features[cob_start + i] = torch.uniform(0.0, 1000.0) # Realistic COB values
-
-# # Technical indicators (100 features)
-# indicator_start = 7700
-# for i in range(100):
-# features[indicator_start + i] = torch.uniform(-1.0, 1.0) # Normalized indicators
-
-# # Last predictions (50 features)
-# prediction_start = 7800
-# for i in range(50):
-# features[prediction_start + i] = torch.uniform(0.0, 1.0) # Probability values
-
-# return features
-
-# except Exception as e:
-# logger.error(f"Error creating realistic synthetic features: {e}")
-# # Fallback to small random variation
-# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
-# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
-# return base_features + noise
-
-# def _create_realistic_features(self, symbol: str) -> torch.Tensor:
-# """Create features from real market data if available"""
-# try:
-# # This would need to be implemented to use actual market data
-# # For now, fall back to synthetic features
-# return self._create_realistic_synthetic_features(symbol)
-# except Exception as e:
-# logger.error(f"Error creating realistic features: {e}")
-# return self._create_realistic_synthetic_features(symbol)
-
-# def _initialize_model(self):
-# """Initialize the EnhancedCNN model"""
-# try:
-# # Calculate input shape based on BaseDataInput structure
-# # OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
-# # BTC OHLCV: 300 frames x 5 features = 1500 features
-# # COB: ±20 buckets x 4 metrics = 160 features
-# # MA: 4 timeframes x 10 buckets = 40 features
-# # Technical indicators: 100 features
-# # Last predictions: 50 features
-# # Total: 7850 features
-# input_shape = 7850
-# n_actions = 3 # BUY, SELL, HOLD
-
-# # Create model
-# self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
-# # Ensure model is moved to the correct device
-# self.model.to(self.device)
-
-# logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
-
-# except Exception as e:
-# logger.error(f"Error initializing EnhancedCNN model: {e}")
-# raise
-
-# def _load_checkpoint(self, checkpoint_path: str) -> bool:
-# """Load model from checkpoint path"""
-# try:
-# if self.model and os.path.exists(checkpoint_path):
-# success = self.model.load(checkpoint_path)
-# if success:
-# # Ensure model is moved to the correct device after loading
-# self.model.to(self.device)
-# logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
-# return True
-# else:
-# logger.warning(f"Failed to load model from {checkpoint_path}")
-# return False
-# else:
-# logger.warning(f"Checkpoint path does not exist: {checkpoint_path}")
-# return False
-# except Exception as e:
-# logger.error(f"Error loading checkpoint: {e}")
-# return False
-
-# def _load_best_checkpoint(self) -> bool:
-# """Load the best available checkpoint"""
-# try:
-# return self.load_best_checkpoint()
-# except Exception as e:
-# logger.error(f"Error loading best checkpoint: {e}")
-# return False
-
-# def load_best_checkpoint(self) -> bool:
-# """Load the best checkpoint based on accuracy"""
-# try:
-# # Import checkpoint manager
-# from utils.checkpoint_manager import CheckpointManager
-
-# # Create checkpoint manager
-# checkpoint_manager = CheckpointManager(
-# checkpoint_dir=self.checkpoint_dir,
-# max_checkpoints=10,
-# metric_name="accuracy"
-# )
-
-# # Load best checkpoint
-# best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
-
-# if not best_checkpoint_path:
-# logger.info(f"No checkpoints found for {self.model_name} - starting in COLD START mode")
-# return False
-
-# # Load model
-# success = self.model.load(best_checkpoint_path)
-
-# if success:
-# # Ensure model is moved to the correct device after loading
-# self.model.to(self.device)
-# logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
-
-# # Log metrics
-# metrics = best_checkpoint_metadata.get('metrics', {})
-# logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
-
-# return True
-# else:
-# logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
-# return False
-
-# except Exception as e:
-# logger.error(f"Error loading best checkpoint: {e}")
-# return False
-
-# def _ensure_model_on_device(self):
-# """Ensure model and all its components are on the correct device"""
-# try:
-# if self.model:
-# self.model.to(self.device)
-# # Also ensure the model's internal device is set correctly
-# if hasattr(self.model, 'device'):
-# self.model.device = self.device
-# logger.debug(f"Model ensured on device {self.device}")
-# except Exception as e:
-# logger.error(f"Error ensuring model on device: {e}")
-
-# def _create_default_output(self, symbol: str) -> ModelOutput:
-# """Create default output when prediction fails"""
-# return create_model_output(
-# model_type='cnn',
-# model_name=self.model_name,
-# symbol=symbol,
-# action='HOLD',
-# confidence=0.0,
-# metadata={'error': 'Prediction failed, using default output'}
-# )
-
-# def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]:
-# """Process hidden states for cross-model feeding"""
-# processed_states = {}
-
-# for key, value in hidden_states.items():
-# if isinstance(value, torch.Tensor):
-# # Convert tensor to numpy array
-# processed_states[key] = value.cpu().numpy().tolist()
-# else:
-# processed_states[key] = value
-
-# return processed_states
-
-
-
-# def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
-# """
-# Convert BaseDataInput to feature vector for EnhancedCNN
-
-# Args:
-# base_data: Standardized input data
-
-# Returns:
-# torch.Tensor: Feature vector for EnhancedCNN
-# """
-# try:
-# # Use the get_feature_vector method from BaseDataInput
-# features = base_data.get_feature_vector()
-
-# # Validate feature quality before using
-# self._validate_feature_quality(features)
-
-# # Convert to torch tensor
-# features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device)
-
-# return features_tensor
-
-# except Exception as e:
-# logger.error(f"Error converting BaseDataInput to features: {e}")
-# # Return empty tensor with correct shape
-# return torch.zeros(7850, dtype=torch.float32, device=self.device)
-
-# def _validate_feature_quality(self, features: np.ndarray):
-# """Validate that features are realistic and not synthetic/placeholder data"""
-# try:
-# if len(features) != 7850:
-# logger.warning(f"Feature vector has wrong size: {len(features)} != 7850")
-# return
-
-# # Check for all-zero or all-identical features (indicates placeholder data)
-# if np.all(features == 0):
-# logger.warning("Feature vector contains all zeros - likely placeholder data")
-# return
-
-# # Check for repetitive patterns in OHLCV data (first 6000 features)
-# ohlcv_features = features[:6000]
-# if len(ohlcv_features) >= 20:
-# # Check if first 20 values are identical (indicates padding with same bar)
-# if np.allclose(ohlcv_features[:20], ohlcv_features[0], atol=1e-6):
-# logger.warning("OHLCV features show repetitive pattern - possible synthetic data")
-
-# # Check for unrealistic values
-# if np.any(features > 1e6) or np.any(features < -1e6):
-# logger.warning("Feature vector contains unrealistic values")
-
-# # Check for NaN or infinite values
-# if np.any(np.isnan(features)) or np.any(np.isinf(features)):
-# logger.warning("Feature vector contains NaN or infinite values")
-
-# except Exception as e:
-# logger.error(f"Error validating feature quality: {e}")
-
-# def predict(self, base_data: BaseDataInput) -> ModelOutput:
-# """
-# Make a prediction using the EnhancedCNN model
-
-# Args:
-# base_data: Standardized input data
-
-# Returns:
-# ModelOutput: Standardized model output
-# """
-# try:
-# # Track inference timing
-# start_time = datetime.now()
-# inference_start = start_time.timestamp()
-
-# # Convert BaseDataInput to features
-# features = self._convert_base_data_to_features(base_data)
-
-# # Ensure features has batch dimension
-# if features.dim() == 1:
-# features = features.unsqueeze(0)
-
-# # Ensure model is on correct device before prediction
-# self._ensure_model_on_device()
-
-# # Set model to evaluation mode
-# self.model.eval()
-
-# # Make prediction
-# with torch.no_grad():
-# q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features)
-
-# # Get action and confidence
-# action_probs = torch.softmax(q_values, dim=1)
-# action_idx = torch.argmax(action_probs, dim=1).item()
-# raw_confidence = float(action_probs[0, action_idx].item())
-
-# # Validate confidence - prevent 100% confidence which indicates overfitting
-# if raw_confidence >= 0.99:
-# logger.warning(f"CNN produced suspiciously high confidence: {raw_confidence:.4f} - possible overfitting")
-# # Cap confidence at 0.95 to prevent unrealistic predictions
-# confidence = min(raw_confidence, 0.95)
-# logger.info(f"Capped confidence from {raw_confidence:.4f} to {confidence:.4f}")
-# else:
-# confidence = raw_confidence
-
-# # Map action index to action string
-# actions = ['BUY', 'SELL', 'HOLD']
-# action = actions[action_idx]
-
-# # Extract pivot price prediction (simplified - take first value from price_pred)
-# pivot_price = None
-# if price_pred is not None and len(price_pred.squeeze()) > 0:
-# # Get current price from base_data for context
-# current_price = 0.0
-# if base_data.ohlcv_1s and len(base_data.ohlcv_1s) > 0:
-# current_price = base_data.ohlcv_1s[-1].close
-
-# # Calculate pivot price as current price + predicted change
-# price_change_pct = float(price_pred.squeeze()[0].item()) # First prediction value
-# pivot_price = current_price * (1 + price_change_pct * 0.01) # Convert percentage to price
-
-# # Create predictions dictionary
-# predictions = {
-# 'action': action,
-# 'buy_probability': float(action_probs[0, 0].item()),
-# 'sell_probability': float(action_probs[0, 1].item()),
-# 'hold_probability': float(action_probs[0, 2].item()),
-# 'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
-# 'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist(),
-# 'pivot_price': pivot_price
-# }
-
-# # Create hidden states dictionary
-# hidden_states = {
-# 'features': features_refined.squeeze(0).cpu().numpy().tolist()
-# }
-
-# # Calculate inference duration
-# end_time = datetime.now()
-# inference_duration = (end_time.timestamp() - inference_start) * 1000 # Convert to milliseconds
-
-# # Update metrics
-# self.last_inference_time = start_time
-# self.last_inference_duration = inference_duration
-# self.inference_count += 1
-
-# # Store last prediction output for dashboard
-# self.last_prediction_output = {
-# 'action': action,
-# 'confidence': confidence,
-# 'pivot_price': pivot_price,
-# 'timestamp': start_time,
-# 'symbol': base_data.symbol
-# }
-
-# # Create metadata dictionary
-# metadata = {
-# 'model_version': '1.0',
-# 'timestamp': start_time.isoformat(),
-# 'input_shape': features.shape,
-# 'inference_duration_ms': inference_duration,
-# 'inference_count': self.inference_count
-# }
-
-# # Create ModelOutput
-# model_output = ModelOutput(
-# model_type='cnn',
-# model_name=self.model_name,
-# symbol=base_data.symbol,
-# timestamp=start_time,
-# confidence=confidence,
-# predictions=predictions,
-# hidden_states=hidden_states,
-# metadata=metadata
-# )
-
-# # Log inference with full input data for training feedback
-# log_model_inference(
-# model_name=self.model_name,
-# symbol=base_data.symbol,
-# action=action,
-# confidence=confidence,
-# probabilities={
-# 'BUY': predictions['buy_probability'],
-# 'SELL': predictions['sell_probability'],
-# 'HOLD': predictions['hold_probability']
-# },
-# input_features=features.cpu().numpy(), # Store full feature vector
-# processing_time_ms=inference_duration,
-# checkpoint_id=None, # Could be enhanced to track checkpoint
-# metadata={
-# 'base_data_input': {
-# 'symbol': base_data.symbol,
-# 'timestamp': base_data.timestamp.isoformat(),
-# 'ohlcv_1s_count': len(base_data.ohlcv_1s),
-# 'ohlcv_1m_count': len(base_data.ohlcv_1m),
-# 'ohlcv_1h_count': len(base_data.ohlcv_1h),
-# 'ohlcv_1d_count': len(base_data.ohlcv_1d),
-# 'btc_ohlcv_1s_count': len(base_data.btc_ohlcv_1s),
-# 'has_cob_data': base_data.cob_data is not None,
-# 'technical_indicators_count': len(base_data.technical_indicators),
-# 'pivot_points_count': len(base_data.pivot_points),
-# 'last_predictions_count': len(base_data.last_predictions)
-# },
-# 'model_predictions': {
-# 'pivot_price': pivot_price,
-# 'extrema_prediction': predictions['extrema'],
-# 'price_prediction': predictions['price_prediction']
-# }
-# }
-# )
-
-# return model_output
-
-# except Exception as e:
-# logger.error(f"Error making prediction with EnhancedCNN: {e}")
-# # Return default ModelOutput
-# return create_model_output(
-# model_type='cnn',
-# model_name=self.model_name,
-# symbol=base_data.symbol,
-# action='HOLD',
-# confidence=0.0
-# )
-
-# def add_training_sample(self, symbol_or_base_data, actual_action: str, reward: float):
-# """
-# Add a training sample to the training data
-
-# Args:
-# symbol_or_base_data: Either a symbol string or BaseDataInput object
-# actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
-# reward: Reward received for the action
-# """
-# try:
-# # Handle both symbol string and BaseDataInput object
-# if isinstance(symbol_or_base_data, str):
-# # For cold start mode - create a simple training sample with current features
-# # This is a simplified approach for rapid training
-# symbol = symbol_or_base_data
-
-# # Create a realistic feature vector instead of random data
-# # Use actual market data if available, otherwise create realistic synthetic data
-# try:
-# # Try to get real market data first
-# if hasattr(self, 'data_provider') and self.data_provider:
-# # This would need to be implemented in the adapter
-# features = self._create_realistic_features(symbol)
-# else:
-# # Create realistic synthetic features (not random)
-# features = self._create_realistic_synthetic_features(symbol)
-# except Exception as e:
-# logger.warning(f"Could not create realistic features for {symbol}: {e}")
-# # Fallback to small random variation instead of pure random
-# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
-# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
-# features = base_features + noise
-
-# logger.debug(f"Added realistic training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
-
-# else:
-# # Full BaseDataInput object
-# base_data = symbol_or_base_data
-# features = self._convert_base_data_to_features(base_data)
-# symbol = base_data.symbol
-
-# logger.debug(f"Added full training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
-
-# # Convert action to index
-# actions = ['BUY', 'SELL', 'HOLD']
-# action_idx = actions.index(actual_action)
-
-# # Add to training data
-# with self.training_lock:
-# self.training_data.append((features, action_idx, reward))
-
-# # Limit training data size
-# if len(self.training_data) > self.max_training_samples:
-# # Sort by reward (highest first) and keep top samples
-# self.training_data.sort(key=lambda x: x[2], reverse=True)
-# self.training_data = self.training_data[:self.max_training_samples]
-
-# except Exception as e:
-# logger.error(f"Error adding training sample: {e}")
-
-# def train(self, epochs: int = 1) -> Dict[str, float]:
-# """
-# Train the model with collected data and inference history
-
-# Args:
-# epochs: Number of epochs to train for
-
-# Returns:
-# Dict[str, float]: Training metrics
-# """
-# try:
-# # Track training timing
-# training_start_time = datetime.now()
-# training_start = training_start_time.timestamp()
-
-# with self.training_lock:
-# # Get additional training data from inference history
-# self._load_training_data_from_inference_history()
-
-# # Check if we have enough data
-# if len(self.training_data) < self.batch_size:
-# logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
-# return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
-
-# # Ensure model is on correct device before training
-# self._ensure_model_on_device()
-
-# # Set model to training mode
-# self.model.train()
-
-# # Create optimizer
-# optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
-
-# # Training metrics
-# total_loss = 0.0
-# correct_predictions = 0
-# total_predictions = 0
-
-# # Train for specified number of epochs
-# for epoch in range(epochs):
-# # Shuffle training data
-# np.random.shuffle(self.training_data)
-
-# # Process in batches
-# for i in range(0, len(self.training_data), self.batch_size):
-# batch = self.training_data[i:i+self.batch_size]
-
-# # Skip if batch is too small
-# if len(batch) < 2:
-# continue
-
-# # Prepare batch - ensure all tensors are on the correct device
-# features = torch.stack([sample[0].to(self.device) for sample in batch])
-# actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
-# rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
-
-# # Zero gradients
-# optimizer.zero_grad()
-
-# # Forward pass
-# q_values, _, _, _, _ = self.model(features)
-
-# # Calculate loss (CrossEntropyLoss with reward weighting)
-# # First, apply softmax to get probabilities
-# probs = torch.softmax(q_values, dim=1)
-
-# # Get probability of chosen action
-# chosen_probs = probs[torch.arange(len(actions)), actions]
-
-# # Calculate negative log likelihood loss
-# nll_loss = -torch.log(chosen_probs + 1e-10)
-
-# # Weight by reward (higher reward = higher weight)
-# # Normalize rewards to [0, 1] range
-# min_reward = rewards.min()
-# max_reward = rewards.max()
-# if max_reward > min_reward:
-# normalized_rewards = (rewards - min_reward) / (max_reward - min_reward)
-# else:
-# normalized_rewards = torch.ones_like(rewards)
-
-# # Apply reward weighting (higher reward = higher weight)
-# weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights
-
-# # Mean loss
-# loss = weighted_loss.mean()
-
-# # Backward pass
-# loss.backward()
-
-# # Update weights
-# optimizer.step()
-
-# # Update metrics
-# total_loss += loss.item()
-
-# # Calculate accuracy
-# predicted_actions = torch.argmax(q_values, dim=1)
-# correct_predictions += (predicted_actions == actions).sum().item()
-# total_predictions += len(actions)
-
-# # Validate training - detect overfitting
-# if total_predictions > 0:
-# current_accuracy = correct_predictions / total_predictions
-# if current_accuracy >= 0.99:
-# logger.warning(f"CNN training shows suspiciously high accuracy: {current_accuracy:.4f} - possible overfitting")
-# # Add regularization to prevent overfitting
-# l2_reg = 0.01 * sum(p.pow(2.0).sum() for p in self.model.parameters())
-# loss = loss + l2_reg
-# logger.info("Added L2 regularization to prevent overfitting")
-
-# # Calculate final metrics
-# avg_loss = total_loss / (len(self.training_data) / self.batch_size)
-# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
-
-# # Calculate training duration
-# training_end_time = datetime.now()
-# training_duration = (training_end_time.timestamp() - training_start) * 1000 # Convert to milliseconds
-
-# # Update training metrics
-# self.last_training_time = training_start_time
-# self.last_training_duration = training_duration
-# self.last_training_loss = avg_loss
-# self.training_count += 1
-
-# # Save checkpoint
-# self._save_checkpoint(avg_loss, accuracy)
-
-# logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}, duration={training_duration:.1f}ms")
-
-# return {
-# 'loss': avg_loss,
-# 'accuracy': accuracy,
-# 'samples': len(self.training_data),
-# 'duration_ms': training_duration,
-# 'training_count': self.training_count
-# }
-
-# except Exception as e:
-# logger.error(f"Error training model: {e}")
-# return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)}
-
-# def _save_checkpoint(self, loss: float, accuracy: float):
-# """
-# Save model checkpoint
-
-# Args:
-# loss: Training loss
-# accuracy: Training accuracy
-# """
-# try:
-# # Import checkpoint manager
-# from utils.checkpoint_manager import CheckpointManager
-
-# # Create checkpoint manager
-# checkpoint_manager = CheckpointManager(
-# checkpoint_dir=self.checkpoint_dir,
-# max_checkpoints=10,
-# metric_name="accuracy"
-# )
-
-# # Create temporary model file
-# temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp")
-# self.model.save(temp_path)
-
-# # Create metrics
-# metrics = {
-# 'loss': loss,
-# 'accuracy': accuracy,
-# 'samples': len(self.training_data)
-# }
-
-# # Create metadata
-# metadata = {
-# 'timestamp': datetime.now().isoformat(),
-# 'model_name': self.model_name,
-# 'input_shape': self.model.input_shape,
-# 'n_actions': self.model.n_actions
-# }
-
-# # Save checkpoint
-# checkpoint_path = checkpoint_manager.save_checkpoint(
-# model_name=self.model_name,
-# model_path=f"{temp_path}.pt",
-# metrics=metrics,
-# metadata=metadata
-# )
-
-# # Delete temporary model file
-# if os.path.exists(f"{temp_path}.pt"):
-# os.remove(f"{temp_path}.pt")
-
-# logger.info(f"Model checkpoint saved to {checkpoint_path}")
-
-# except Exception as e:
-# logger.error(f"Error saving checkpoint: {e}")
-
-# def _load_training_data_from_inference_history(self):
-# """Load training data from inference history for continuous learning"""
-# try:
-# from utils.database_manager import get_database_manager
-
-# db_manager = get_database_manager()
-
-# # Get recent inference records with input features
-# inference_records = db_manager.get_inference_records_for_training(
-# model_name=self.model_name,
-# hours_back=24, # Last 24 hours
-# limit=1000
-# )
-
-# if not inference_records:
-# logger.debug("No inference records found for training")
-# return
-
-# # Convert inference records to training samples
-# # For now, use a simple approach: treat high-confidence predictions as ground truth
-# for record in inference_records:
-# if record.input_features is not None and record.confidence > 0.7:
-# # Convert action to index
-# actions = ['BUY', 'SELL', 'HOLD']
-# if record.action in actions:
-# action_idx = actions.index(record.action)
-
-# # Use confidence as a proxy for reward (high confidence = good prediction)
-# reward = record.confidence * 2 - 1 # Scale to [-1, 1]
-
-# # Convert features to tensor
-# features_tensor = torch.tensor(record.input_features, dtype=torch.float32, device=self.device)
-
-# # Add to training data if not already present (avoid duplicates)
-# sample_exists = any(
-# torch.equal(features_tensor, existing[0])
-# for existing in self.training_data
-# )
-
-# if not sample_exists:
-# self.training_data.append((features_tensor, action_idx, reward))
-
-# logger.info(f"Loaded {len(inference_records)} inference records for training, total training samples: {len(self.training_data)}")
-
-# except Exception as e:
-# logger.error(f"Error loading training data from inference history: {e}")
-
-# def evaluate_predictions_against_outcomes(self, hours_back: int = 1) -> Dict[str, float]:
-# """
-# Evaluate past predictions against actual market outcomes
-
-# Args:
-# hours_back: How many hours back to evaluate
-
-# Returns:
-# Dict with evaluation metrics
-# """
-# try:
-# from utils.database_manager import get_database_manager
-
-# db_manager = get_database_manager()
-
-# # Get inference records from the specified time period
-# inference_records = db_manager.get_inference_records_for_training(
-# model_name=self.model_name,
-# hours_back=hours_back,
-# limit=100
-# )
-
-# if not inference_records:
-# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
-
-# # For now, use a simple evaluation based on confidence
-# # In a real implementation, this would compare against actual price movements
-# correct_predictions = 0
-# total_predictions = len(inference_records)
-
-# # Simple heuristic: high confidence predictions are more likely to be correct
-# for record in inference_records:
-# if record.confidence > 0.8: # High confidence threshold
-# correct_predictions += 1
-# elif record.confidence > 0.6: # Medium confidence
-# correct_predictions += 0.5
-
-# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
-
-# logger.info(f"Prediction evaluation: {correct_predictions:.1f}/{total_predictions} = {accuracy:.3f} accuracy")
-
-# return {
-# 'accuracy': accuracy,
-# 'total_predictions': total_predictions,
-# 'correct_predictions': correct_predictions
-# }
-
-# except Exception as e:
-# logger.error(f"Error evaluating predictions: {e}")
-# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
diff --git a/core/enhanced_orchestrator.py b/core/enhanced_orchestrator.py
deleted file mode 100644
index 7193585..0000000
--- a/core/enhanced_orchestrator.py
+++ /dev/null
@@ -1,464 +0,0 @@
-"""
-Enhanced Trading Orchestrator
-
-Central coordination hub for the multi-modal trading system that manages:
-- Data subscription and management
-- Model inference coordination
-- Cross-model data feeding
-- Training pipeline orchestration
-- Decision making using Mixture of Experts
-"""
-
-import asyncio
-import logging
-import numpy as np
-from datetime import datetime
-from typing import Dict, List, Optional, Any
-from dataclasses import dataclass, field
-
-from core.data_provider import DataProvider
-from core.trading_action import TradingAction
-from utils.tensorboard_logger import TensorBoardLogger
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class ModelOutput:
- """Extensible model output format supporting all model types"""
- model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
- model_name: str # Specific model identifier
- symbol: str
- timestamp: datetime
- confidence: float
- predictions: Dict[str, Any] # Model-specific predictions
- hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
- metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
-
-@dataclass
-class BaseDataInput:
- """Unified base data input for all models"""
- symbol: str
- timestamp: datetime
- ohlcv_data: Dict[str, Any] = field(default_factory=dict) # Multi-timeframe OHLCV
- cob_data: Optional[Dict[str, Any]] = None # COB buckets for 1s timeframe
- technical_indicators: Dict[str, float] = field(default_factory=dict)
- pivot_points: List[Any] = field(default_factory=list)
- last_predictions: Dict[str, ModelOutput] = field(default_factory=dict) # From all models
- market_microstructure: Dict[str, Any] = field(default_factory=dict) # Order flow, etc.
-
-@dataclass
-class COBData:
- """Cumulative Order Book data for price buckets"""
- symbol: str
- timestamp: datetime
- current_price: float
- bucket_size: float # $1 for ETH, $10 for BTC
- price_buckets: Dict[float, Dict[str, float]] = field(default_factory=dict) # price -> {bid_volume, ask_volume, etc.}
- bid_ask_imbalance: Dict[float, float] = field(default_factory=dict) # price -> imbalance ratio
- volume_weighted_prices: Dict[float, float] = field(default_factory=dict) # price -> VWAP within bucket
- order_flow_metrics: Dict[str, float] = field(default_factory=dict) # Various order flow indicators
-
-class EnhancedTradingOrchestrator:
- """
- Enhanced Trading Orchestrator implementing the design specification
-
- Coordinates data flow, model inference, and decision making for the multi-modal trading system.
- """
-
- def __init__(self, data_provider: DataProvider, symbols: List[str], enhanced_rl_training: bool = False, model_registry: Dict = None):
- """Initialize the enhanced orchestrator"""
- self.data_provider = data_provider
- self.symbols = symbols
- self.enhanced_rl_training = enhanced_rl_training
- self.model_registry = model_registry or {}
-
- # Data management
- self.data_buffers = {symbol: {} for symbol in symbols}
- self.last_update_times = {symbol: {} for symbol in symbols}
-
- # Model output storage
- self.model_outputs = {symbol: {} for symbol in symbols}
- self.model_output_history = {symbol: {} for symbol in symbols}
-
- # Training pipeline
- self.training_data = {symbol: [] for symbol in symbols}
- self.tensorboard_logger = TensorBoardLogger("runs", f"orchestrator_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
-
- # COB integration
- self.cob_data = {symbol: None for symbol in symbols}
-
- # Performance tracking
- self.performance_metrics = {
- 'inference_count': 0,
- 'successful_states': 0,
- 'total_episodes': 0
- }
-
- logger.info("Enhanced Trading Orchestrator initialized")
-
- async def start_cob_integration(self):
- """Start COB data integration for real-time market microstructure"""
- try:
- # Subscribe to COB data updates
- self.data_provider.subscribe_to_cob_data(self._on_cob_data_update)
- logger.info("COB integration started")
- except Exception as e:
- logger.error(f"Error starting COB integration: {e}")
-
- async def start_realtime_processing(self):
- """Start real-time data processing"""
- try:
- # Subscribe to tick data for real-time processing
- for symbol in self.symbols:
- self.data_provider.subscribe_to_ticks(
- callback=self._on_tick_data,
- symbols=[symbol],
- subscriber_name=f"orchestrator_{symbol}"
- )
-
- logger.info("Real-time processing started")
- except Exception as e:
- logger.error(f"Error starting real-time processing: {e}")
-
- def _on_cob_data_update(self, symbol: str, cob_data: dict):
- """Handle COB data updates"""
- try:
- # Process and store COB data
- self.cob_data[symbol] = self._process_cob_data(symbol, cob_data)
- logger.debug(f"COB data updated for {symbol}")
- except Exception as e:
- logger.error(f"Error processing COB data for {symbol}: {e}")
-
- def _process_cob_data(self, symbol: str, cob_data: dict) -> COBData:
- """Process raw COB data into structured format"""
- try:
- # Determine bucket size based on symbol
- bucket_size = 1.0 if 'ETH' in symbol else 10.0
-
- # Extract current price
- stats = cob_data.get('stats', {})
- current_price = stats.get('mid_price', 0)
-
- # Create COB data structure
- cob = COBData(
- symbol=symbol,
- timestamp=datetime.now(),
- current_price=current_price,
- bucket_size=bucket_size
- )
-
- # Process order book data into price buckets
- bids = cob_data.get('bids', [])
- asks = cob_data.get('asks', [])
-
- # Create price buckets around current price
- bucket_count = 20 # ±20 buckets
- for i in range(-bucket_count, bucket_count + 1):
- bucket_price = current_price + (i * bucket_size)
- cob.price_buckets[bucket_price] = {
- 'bid_volume': 0.0,
- 'ask_volume': 0.0
- }
-
- # Aggregate bid volumes into buckets
- for price, volume in bids:
- bucket_price = round(price / bucket_size) * bucket_size
- if bucket_price in cob.price_buckets:
- cob.price_buckets[bucket_price]['bid_volume'] += volume
-
- # Aggregate ask volumes into buckets
- for price, volume in asks:
- bucket_price = round(price / bucket_size) * bucket_size
- if bucket_price in cob.price_buckets:
- cob.price_buckets[bucket_price]['ask_volume'] += volume
-
- # Calculate bid/ask imbalances
- for price, volumes in cob.price_buckets.items():
- bid_vol = volumes['bid_volume']
- ask_vol = volumes['ask_volume']
- total_vol = bid_vol + ask_vol
- if total_vol > 0:
- cob.bid_ask_imbalance[price] = (bid_vol - ask_vol) / total_vol
- else:
- cob.bid_ask_imbalance[price] = 0.0
-
- # Calculate volume-weighted prices
- for price, volumes in cob.price_buckets.items():
- bid_vol = volumes['bid_volume']
- ask_vol = volumes['ask_volume']
- total_vol = bid_vol + ask_vol
- if total_vol > 0:
- cob.volume_weighted_prices[price] = (
- (price * bid_vol) + (price * ask_vol)
- ) / total_vol
- else:
- cob.volume_weighted_prices[price] = price
-
- # Calculate order flow metrics
- cob.order_flow_metrics = {
- 'total_bid_volume': sum(v['bid_volume'] for v in cob.price_buckets.values()),
- 'total_ask_volume': sum(v['ask_volume'] for v in cob.price_buckets.values()),
- 'bid_ask_ratio': 0.0 if cob.order_flow_metrics['total_ask_volume'] == 0 else
- cob.order_flow_metrics['total_bid_volume'] / cob.order_flow_metrics['total_ask_volume']
- }
-
- return cob
- except Exception as e:
- logger.error(f"Error processing COB data for {symbol}: {e}")
- return COBData(symbol=symbol, timestamp=datetime.now(), current_price=0, bucket_size=bucket_size)
-
- def _on_tick_data(self, tick):
- """Handle incoming tick data"""
- try:
- # Update data buffers
- symbol = tick.symbol
- if symbol not in self.data_buffers:
- self.data_buffers[symbol] = {}
-
- # Store tick data
- if 'ticks' not in self.data_buffers[symbol]:
- self.data_buffers[symbol]['ticks'] = []
- self.data_buffers[symbol]['ticks'].append(tick)
-
- # Keep only last 1000 ticks
- if len(self.data_buffers[symbol]['ticks']) > 1000:
- self.data_buffers[symbol]['ticks'] = self.data_buffers[symbol]['ticks'][-1000:]
-
- # Update last update time
- self.last_update_times[symbol]['tick'] = datetime.now()
-
- logger.debug(f"Tick data updated for {symbol}")
- except Exception as e:
- logger.error(f"Error processing tick data: {e}")
-
- def build_comprehensive_rl_state(self, symbol: str) -> Optional[np.ndarray]:
- """
- Build comprehensive RL state with 13,400 features as specified
-
- Returns:
- np.ndarray: State vector with 13,400 features
- """
- try:
- # Initialize state vector
- state_size = 13400
- state = np.zeros(state_size, dtype=np.float32)
-
- # Get latest data
- ohlcv_data = self.data_provider.get_latest_candles(symbol, '1s', limit=100)
- cob_data = self.cob_data.get(symbol)
-
- # Feature index tracking
- idx = 0
-
- # 1. OHLCV features (4000 features)
- if ohlcv_data is not None and not ohlcv_data.empty:
- # Use last 100 1s candles (40 features each: O,H,L,C,V + 36 indicators)
- for i in range(min(100, len(ohlcv_data))):
- if idx + 40 <= state_size:
- row = ohlcv_data.iloc[-(i+1)]
- state[idx] = row.get('open', 0) / 100000 # Normalized
- state[idx+1] = row.get('high', 0) / 100000
- state[idx+2] = row.get('low', 0) / 100000
- state[idx+3] = row.get('close', 0) / 100000
- state[idx+4] = row.get('volume', 0) / 1000000
-
- # Add technical indicators if available
- indicator_idx = 5
- for col in ['sma_10', 'sma_20', 'ema_12', 'ema_26', 'rsi_14',
- 'macd', 'bb_upper', 'bb_lower', 'atr', 'adx']:
- if col in row and idx + indicator_idx < state_size:
- state[idx + indicator_idx] = row[col] / 100000
- indicator_idx += 1
-
- idx += 40
-
- # 2. COB features (8000 features)
- if cob_data and idx + 8000 <= state_size:
- # Use 200 price buckets (40 features each)
- bucket_prices = sorted(cob_data.price_buckets.keys())
- for i, price in enumerate(bucket_prices[:200]):
- if idx + 40 <= state_size:
- bucket = cob_data.price_buckets[price]
- state[idx] = bucket.get('bid_volume', 0) / 1000000 # Normalized
- state[idx+1] = bucket.get('ask_volume', 0) / 1000000
- state[idx+2] = cob_data.bid_ask_imbalance.get(price, 0)
- state[idx+3] = cob_data.volume_weighted_prices.get(price, price) / 100000
-
- # Additional COB metrics
- state[idx+4] = cob_data.order_flow_metrics.get('total_bid_volume', 0) / 10000000
- state[idx+5] = cob_data.order_flow_metrics.get('total_ask_volume', 0) / 10000000
- state[idx+6] = cob_data.order_flow_metrics.get('bid_ask_ratio', 0)
-
- idx += 40
-
- # 3. Technical indicator features (1000 features)
- # Already included in OHLCV section above
-
- # 4. Market microstructure features (400 features)
- if cob_data and idx + 400 <= state_size:
- # Add order flow metrics
- metrics = list(cob_data.order_flow_metrics.values())
- for i, metric in enumerate(metrics[:400]):
- if idx + i < state_size:
- state[idx + i] = metric
-
- # Log state building success
- self.performance_metrics['successful_states'] += 1
- logger.debug(f"Comprehensive RL state built for {symbol}: {len(state)} features")
-
- # Log to TensorBoard
- self.tensorboard_logger.log_state_metrics(
- symbol=symbol,
- state_info={
- 'size': len(state),
- 'quality': 1.0,
- 'feature_counts': {
- 'total': len(state),
- 'non_zero': np.count_nonzero(state)
- }
- },
- step=self.performance_metrics['successful_states']
- )
-
- return state
-
- except Exception as e:
- logger.error(f"Error building comprehensive RL state for {symbol}: {e}")
- return None
-
- def calculate_enhanced_pivot_reward(self, trade_decision: Dict, market_data: Dict, trade_outcome: Dict) -> float:
- """
- Calculate enhanced pivot-based reward
-
- Args:
- trade_decision: Trading decision with action and confidence
- market_data: Market context data
- trade_outcome: Actual trade results
-
- Returns:
- float: Enhanced reward value
- """
- try:
- # Base reward from PnL
- pnl_reward = trade_outcome.get('net_pnl', 0) / 100 # Normalize
-
- # Confidence weighting
- confidence = trade_decision.get('confidence', 0.5)
- confidence_reward = confidence * 0.2
-
- # Volatility adjustment
- volatility = market_data.get('volatility', 0.01)
- volatility_reward = (1.0 - volatility * 10) * 0.1 # Prefer low volatility
-
- # Order flow alignment
- order_flow = market_data.get('order_flow_strength', 0)
- order_flow_reward = order_flow * 0.2
-
- # Pivot alignment bonus (if near pivot in favorable direction)
- pivot_bonus = 0.0
- if market_data.get('near_pivot', False):
- action = trade_decision.get('action', '').upper()
- pivot_type = market_data.get('pivot_type', '').upper()
-
- # Bonus for buying near support or selling near resistance
- if (action == 'BUY' and pivot_type == 'LOW') or \
- (action == 'SELL' and pivot_type == 'HIGH'):
- pivot_bonus = 0.5
-
- # Calculate final reward
- enhanced_reward = pnl_reward + confidence_reward + volatility_reward + order_flow_reward + pivot_bonus
-
- # Log to TensorBoard
- self.tensorboard_logger.log_scalars('Rewards/Components', {
- 'pnl_component': pnl_reward,
- 'confidence': confidence_reward,
- 'volatility': volatility_reward,
- 'order_flow': order_flow_reward,
- 'pivot_bonus': pivot_bonus
- }, self.performance_metrics['total_episodes'])
-
- self.tensorboard_logger.log_scalar('Rewards/Enhanced', enhanced_reward, self.performance_metrics['total_episodes'])
-
- logger.debug(f"Enhanced reward calculated: {enhanced_reward}")
- return enhanced_reward
-
- except Exception as e:
- logger.error(f"Error calculating enhanced pivot reward: {e}")
- return 0.0
-
- async def make_coordinated_decisions(self) -> Dict[str, TradingAction]:
- """
- Make coordinated trading decisions using all available models
-
- Returns:
- Dict[str, TradingAction]: Trading actions for each symbol
- """
- try:
- decisions = {}
-
- # For each symbol, coordinate model inference
- for symbol in self.symbols:
- # Build comprehensive state for RL model
- rl_state = self.build_comprehensive_rl_state(symbol)
-
- if rl_state is not None:
- # Store state for training
- self.performance_metrics['total_episodes'] += 1
-
- # Create mock RL decision (in a real implementation, this would call the RL model)
- action = 'BUY' if np.mean(rl_state[:100]) > 0.5 else 'SELL'
- confidence = min(1.0, max(0.0, np.std(rl_state) * 10))
-
- # Create trading action
- decisions[symbol] = TradingAction(
- symbol=symbol,
- timestamp=datetime.now(),
- action=action,
- confidence=confidence,
- source='rl_orchestrator'
- )
-
- logger.info(f"Coordinated decision for {symbol}: {action} (confidence: {confidence:.3f})")
- else:
- logger.warning(f"Failed to build state for {symbol}, skipping decision")
-
- self.performance_metrics['inference_count'] += 1
- return decisions
-
- except Exception as e:
- logger.error(f"Error making coordinated decisions: {e}")
- return {}
-
- def _get_symbol_correlation(self, symbol1: str, symbol2: str) -> float:
- """
- Calculate correlation between two symbols
-
- Args:
- symbol1: First symbol
- symbol2: Second symbol
-
- Returns:
- float: Correlation coefficient (-1 to 1)
- """
- try:
- # Get recent price data for both symbols
- data1 = self.data_provider.get_latest_candles(symbol1, '1m', limit=50)
- data2 = self.data_provider.get_latest_candles(symbol2, '1m', limit=50)
-
- if data1 is None or data2 is None or data1.empty or data2.empty:
- return 0.0
-
- # Align data by timestamp
- merged = data1[['close']].join(data2[['close']], lsuffix='_1', rsuffix='_2', how='inner')
-
- if len(merged) < 10:
- return 0.0
-
- # Calculate correlation
- correlation = merged['close_1'].corr(merged['close_2'])
- return correlation if not np.isnan(correlation) else 0.0
-
- except Exception as e:
- logger.error(f"Error calculating symbol correlation: {e}")
- return 0.0
-```
\ No newline at end of file
diff --git a/core/enhanced_training_integration.py b/core/enhanced_training_integration.py
deleted file mode 100644
index 6cdd674..0000000
--- a/core/enhanced_training_integration.py
+++ /dev/null
@@ -1,775 +0,0 @@
-"""
-Enhanced Training Integration Module
-
-This module provides comprehensive integration between the training data collection system,
-CNN training pipeline, RL training pipeline, and your existing infrastructure.
-
-Key Features:
-- Real-time integration with existing DataProvider
-- Coordinated training across CNN and RL models
-- Automatic outcome validation and profitability tracking
-- Integration with existing COB RL model
-- Performance monitoring and optimization
-- Seamless connection to existing orchestrator and trading executor
-"""
-
-import asyncio
-import logging
-import numpy as np
-import pandas as pd
-import torch
-from datetime import datetime, timedelta
-from typing import Dict, List, Optional, Tuple, Any, Callable
-from dataclasses import dataclass
-import threading
-import time
-from pathlib import Path
-
-# Import existing components
-from .data_provider import DataProvider
-from .orchestrator import Orchestrator
-from .trading_executor import TradingExecutor
-
-# Import our training system components
-from .training_data_collector import (
- TrainingDataCollector,
- get_training_data_collector
-)
-from .cnn_training_pipeline import (
- CNNPivotPredictor,
- CNNTrainer,
- get_cnn_trainer
-)
-from .rl_training_pipeline import (
- RLTradingAgent,
- RLTrainer,
- get_rl_trainer
-)
-from .training_integration import TrainingIntegration
-
-# Import existing RL model
-try:
- from NN.models.cob_rl_model import COBRLModelInterface
-except ImportError:
- logger.warning("Could not import COBRLModelInterface - using fallback")
- COBRLModelInterface = None
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class EnhancedTrainingConfig:
- """Enhanced configuration for comprehensive training integration"""
- # Data collection
- collection_interval: float = 1.0
- min_data_completeness: float = 0.8
-
- # Training triggers
- min_episodes_for_cnn_training: int = 100
- min_experiences_for_rl_training: int = 200
- training_frequency_minutes: int = 30
-
- # Profitability thresholds
- min_profitability_for_replay: float = 0.1
- high_profitability_threshold: float = 0.5
-
- # Model integration
- use_existing_cob_rl_model: bool = True
- enable_cross_model_learning: bool = True
-
- # Performance optimization
- max_concurrent_training_sessions: int = 2
- enable_background_validation: bool = True
-
-class EnhancedTrainingIntegration:
- """Enhanced training integration with existing infrastructure"""
-
- def __init__(self,
- data_provider: DataProvider,
- orchestrator: Orchestrator = None,
- trading_executor: TradingExecutor = None,
- config: EnhancedTrainingConfig = None):
-
- self.data_provider = data_provider
- self.orchestrator = orchestrator
- self.trading_executor = trading_executor
- self.config = config or EnhancedTrainingConfig()
-
- # Initialize training components
- self.data_collector = get_training_data_collector()
-
- # Initialize CNN components
- self.cnn_model = CNNPivotPredictor()
- self.cnn_trainer = get_cnn_trainer(self.cnn_model)
-
- # Initialize RL components
- if self.config.use_existing_cob_rl_model and COBRLModelInterface:
- self.existing_rl_model = COBRLModelInterface()
- logger.info("Using existing COB RL model")
- else:
- self.existing_rl_model = None
-
- self.rl_agent = RLTradingAgent()
- self.rl_trainer = get_rl_trainer(self.rl_agent)
-
- # Integration state
- self.is_running = False
- self.training_threads = {}
- self.validation_thread = None
-
- # Performance tracking
- self.integration_stats = {
- 'total_data_packages': 0,
- 'cnn_training_sessions': 0,
- 'rl_training_sessions': 0,
- 'profitable_predictions': 0,
- 'total_predictions': 0,
- 'cross_model_improvements': 0,
- 'last_update': datetime.now()
- }
-
- # Model prediction tracking
- self.recent_predictions = {}
- self.prediction_outcomes = {}
-
- # Cross-model learning
- self.model_performance_history = {
- 'cnn': [],
- 'rl': [],
- 'orchestrator': []
- }
-
- logger.info("Enhanced Training Integration initialized")
- logger.info(f"CNN model parameters: {sum(p.numel() for p in self.cnn_model.parameters()):,}")
- logger.info(f"RL agent parameters: {sum(p.numel() for p in self.rl_agent.parameters()):,}")
- logger.info(f"Using existing COB RL model: {self.existing_rl_model is not None}")
-
- def start_enhanced_integration(self):
- """Start the enhanced training integration system"""
- if self.is_running:
- logger.warning("Enhanced training integration already running")
- return
-
- self.is_running = True
-
- # Start data collection
- self.data_collector.start_collection()
-
- # Start CNN training
- if self.config.min_episodes_for_cnn_training > 0:
- for symbol in self.data_provider.symbols:
- self.cnn_trainer.start_real_time_training(symbol)
-
- # Start coordinated training thread
- self.training_threads['coordinator'] = threading.Thread(
- target=self._training_coordinator_worker,
- daemon=True
- )
- self.training_threads['coordinator'].start()
-
- # Start data collection and validation
- self.training_threads['data_collector'] = threading.Thread(
- target=self._enhanced_data_collection_worker,
- daemon=True
- )
- self.training_threads['data_collector'].start()
-
- # Start outcome validation if enabled
- if self.config.enable_background_validation:
- self.validation_thread = threading.Thread(
- target=self._outcome_validation_worker,
- daemon=True
- )
- self.validation_thread.start()
-
- logger.info("Enhanced training integration started")
-
- def stop_enhanced_integration(self):
- """Stop the enhanced training integration system"""
- self.is_running = False
-
- # Stop data collection
- self.data_collector.stop_collection()
-
- # Stop CNN training
- self.cnn_trainer.stop_training()
-
- # Wait for threads to finish
- for thread_name, thread in self.training_threads.items():
- thread.join(timeout=10)
- logger.info(f"Stopped {thread_name} thread")
-
- if self.validation_thread:
- self.validation_thread.join(timeout=5)
-
- logger.info("Enhanced training integration stopped")
-
- def _enhanced_data_collection_worker(self):
- """Enhanced data collection with real-time model integration"""
- logger.info("Enhanced data collection worker started")
-
- while self.is_running:
- try:
- for symbol in self.data_provider.symbols:
- self._collect_enhanced_training_data(symbol)
-
- time.sleep(self.config.collection_interval)
-
- except Exception as e:
- logger.error(f"Error in enhanced data collection: {e}")
- time.sleep(5)
-
- logger.info("Enhanced data collection worker stopped")
-
- def _collect_enhanced_training_data(self, symbol: str):
- """Collect enhanced training data with model predictions"""
- try:
- # Get comprehensive market data
- market_data = self._get_comprehensive_market_data(symbol)
-
- if not market_data or not self._validate_market_data(market_data):
- return
-
- # Get current model predictions
- model_predictions = self._get_all_model_predictions(symbol, market_data)
-
- # Create enhanced features
- cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
- rl_state = self._create_enhanced_rl_state(symbol, market_data, model_predictions)
-
- # Collect training data with predictions
- episode_id = self.data_collector.collect_training_data(
- symbol=symbol,
- ohlcv_data=market_data['ohlcv'],
- tick_data=market_data['ticks'],
- cob_data=market_data['cob'],
- technical_indicators=market_data['indicators'],
- pivot_points=market_data['pivots'],
- cnn_features=cnn_features,
- rl_state=rl_state,
- orchestrator_context=market_data['context'],
- model_predictions=model_predictions
- )
-
- if episode_id:
- # Store predictions for outcome validation
- self.recent_predictions[episode_id] = {
- 'timestamp': datetime.now(),
- 'symbol': symbol,
- 'predictions': model_predictions,
- 'market_data': market_data
- }
-
- # Add RL experience if we have action
- if 'rl_action' in model_predictions:
- self._add_rl_experience(symbol, market_data, model_predictions, episode_id)
-
- self.integration_stats['total_data_packages'] += 1
-
- except Exception as e:
- logger.error(f"Error collecting enhanced training data for {symbol}: {e}")
-
- def _get_comprehensive_market_data(self, symbol: str) -> Dict[str, Any]:
- """Get comprehensive market data from all sources"""
- try:
- market_data = {}
-
- # OHLCV data
- ohlcv_data = {}
- for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']:
- df = self.data_provider.get_historical_data(symbol, timeframe, limit=300, refresh=True)
- if df is not None and not df.empty:
- ohlcv_data[timeframe] = df
- market_data['ohlcv'] = ohlcv_data
-
- # Tick data
- market_data['ticks'] = self._get_recent_tick_data(symbol)
-
- # COB data
- market_data['cob'] = self._get_cob_data(symbol)
-
- # Technical indicators
- market_data['indicators'] = self._get_technical_indicators(symbol)
-
- # Pivot points
- market_data['pivots'] = self._get_pivot_points(symbol)
-
- # Market context
- market_data['context'] = self._get_market_context(symbol)
-
- return market_data
-
- except Exception as e:
- logger.error(f"Error getting comprehensive market data: {e}")
- return {}
-
- def _get_all_model_predictions(self, symbol: str, market_data: Dict[str, Any]) -> Dict[str, Any]:
- """Get predictions from all available models"""
- predictions = {}
-
- try:
- # CNN predictions
- if self.cnn_model and market_data.get('ohlcv'):
- cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
- if cnn_features is not None:
- cnn_input = torch.from_numpy(cnn_features).float().unsqueeze(0)
-
- # Reshape for CNN (add channel dimension)
- cnn_input = cnn_input.view(1, 10, -1) # Assuming 10 channels
-
- with torch.no_grad():
- cnn_outputs = self.cnn_model(cnn_input)
- predictions['cnn'] = {
- 'pivot_logits': cnn_outputs['pivot_logits'].cpu().numpy(),
- 'pivot_price': cnn_outputs['pivot_price'].cpu().numpy(),
- 'confidence': cnn_outputs['confidence'].cpu().numpy(),
- 'timestamp': datetime.now()
- }
-
- # RL predictions
- if self.rl_agent and market_data.get('cob'):
- rl_state = self._create_enhanced_rl_state(symbol, market_data, predictions)
- if rl_state is not None:
- action, confidence = self.rl_agent.select_action(rl_state, epsilon=0.1)
- predictions['rl'] = {
- 'action': action,
- 'confidence': confidence,
- 'timestamp': datetime.now()
- }
- predictions['rl_action'] = action
-
- # Existing COB RL model predictions
- if self.existing_rl_model and market_data.get('cob'):
- cob_features = market_data['cob'].get('cob_features', [])
- if cob_features and len(cob_features) >= 2000:
- cob_array = np.array(cob_features[:2000], dtype=np.float32)
- cob_prediction = self.existing_rl_model.predict(cob_array)
- predictions['cob_rl'] = {
- 'predicted_direction': cob_prediction.get('predicted_direction', 1),
- 'confidence': cob_prediction.get('confidence', 0.5),
- 'value': cob_prediction.get('value', 0.0),
- 'timestamp': datetime.now()
- }
-
- # Orchestrator predictions (if available)
- if self.orchestrator:
- try:
- # This would integrate with your orchestrator's prediction method
- orchestrator_prediction = self._get_orchestrator_prediction(symbol, market_data, predictions)
- if orchestrator_prediction:
- predictions['orchestrator'] = orchestrator_prediction
- except Exception as e:
- logger.debug(f"Could not get orchestrator prediction: {e}")
-
- return predictions
-
- except Exception as e:
- logger.error(f"Error getting model predictions: {e}")
- return {}
-
- def _add_rl_experience(self, symbol: str, market_data: Dict[str, Any],
- predictions: Dict[str, Any], episode_id: str):
- """Add RL experience to the training buffer"""
- try:
- # Create RL state
- state = self._create_enhanced_rl_state(symbol, market_data, predictions)
- if state is None:
- return
-
- # Get action from predictions
- action = predictions.get('rl_action', 1) # Default to HOLD
-
- # Calculate immediate reward (placeholder - would be updated with actual outcome)
- reward = 0.0
-
- # Create next state (same as current for now - would be updated)
- next_state = state.copy()
-
- # Market context
- market_context = {
- 'symbol': symbol,
- 'episode_id': episode_id,
- 'timestamp': datetime.now(),
- 'market_session': market_data['context'].get('market_session', 'unknown'),
- 'volatility_regime': market_data['context'].get('volatility_regime', 'unknown')
- }
-
- # Add experience
- experience_id = self.rl_trainer.add_experience(
- state=state,
- action=action,
- reward=reward,
- next_state=next_state,
- done=False,
- market_context=market_context,
- cnn_predictions=predictions.get('cnn'),
- confidence_score=predictions.get('rl', {}).get('confidence', 0.0)
- )
-
- if experience_id:
- logger.debug(f"Added RL experience: {experience_id}")
-
- except Exception as e:
- logger.error(f"Error adding RL experience: {e}")
-
- def _training_coordinator_worker(self):
- """Coordinate training across all models"""
- logger.info("Training coordinator worker started")
-
- while self.is_running:
- try:
- # Check if we should trigger training
- for symbol in self.data_provider.symbols:
- self._check_and_trigger_training(symbol)
-
- # Wait before next check
- time.sleep(self.config.training_frequency_minutes * 60)
-
- except Exception as e:
- logger.error(f"Error in training coordinator: {e}")
- time.sleep(60)
-
- logger.info("Training coordinator worker stopped")
-
- def _check_and_trigger_training(self, symbol: str):
- """Check conditions and trigger training if needed"""
- try:
- # Get training episodes and experiences
- episodes = self.data_collector.get_high_priority_episodes(symbol, limit=1000)
-
- # Check CNN training conditions
- if len(episodes) >= self.config.min_episodes_for_cnn_training:
- profitable_episodes = [ep for ep in episodes if ep.actual_outcome.is_profitable]
-
- if len(profitable_episodes) >= 20: # Minimum profitable episodes
- logger.info(f"Triggering CNN training for {symbol} with {len(profitable_episodes)} profitable episodes")
-
- results = self.cnn_trainer.train_on_profitable_episodes(
- symbol=symbol,
- min_profitability=self.config.min_profitability_for_replay,
- max_episodes=len(profitable_episodes)
- )
-
- if results.get('status') == 'success':
- self.integration_stats['cnn_training_sessions'] += 1
- logger.info(f"CNN training completed for {symbol}")
-
- # Check RL training conditions
- buffer_stats = self.rl_trainer.experience_buffer.get_buffer_statistics()
- total_experiences = buffer_stats.get('total_experiences', 0)
-
- if total_experiences >= self.config.min_experiences_for_rl_training:
- profitable_experiences = buffer_stats.get('profitable_experiences', 0)
-
- if profitable_experiences >= 50: # Minimum profitable experiences
- logger.info(f"Triggering RL training with {profitable_experiences} profitable experiences")
-
- results = self.rl_trainer.train_on_profitable_experiences(
- min_profitability=self.config.min_profitability_for_replay,
- max_experiences=min(profitable_experiences, 500),
- batch_size=32
- )
-
- if results.get('status') == 'success':
- self.integration_stats['rl_training_sessions'] += 1
- logger.info("RL training completed")
-
- except Exception as e:
- logger.error(f"Error checking training conditions for {symbol}: {e}")
-
- def _outcome_validation_worker(self):
- """Background worker for validating prediction outcomes"""
- logger.info("Outcome validation worker started")
-
- while self.is_running:
- try:
- self._validate_recent_predictions()
- time.sleep(300) # Check every 5 minutes
-
- except Exception as e:
- logger.error(f"Error in outcome validation: {e}")
- time.sleep(60)
-
- logger.info("Outcome validation worker stopped")
-
- def _validate_recent_predictions(self):
- """Validate recent predictions against actual outcomes"""
- try:
- current_time = datetime.now()
- validation_delay = timedelta(hours=1) # Wait 1 hour to validate
-
- validated_predictions = []
-
- for episode_id, prediction_data in self.recent_predictions.items():
- prediction_time = prediction_data['timestamp']
-
- if current_time - prediction_time >= validation_delay:
- # Validate this prediction
- outcome = self._calculate_prediction_outcome(prediction_data)
-
- if outcome:
- self.prediction_outcomes[episode_id] = outcome
-
- # Update RL experience if exists
- if 'rl_action' in prediction_data['predictions']:
- self._update_rl_experience_outcome(episode_id, outcome)
-
- # Update statistics
- if outcome['is_profitable']:
- self.integration_stats['profitable_predictions'] += 1
- self.integration_stats['total_predictions'] += 1
-
- validated_predictions.append(episode_id)
-
- # Remove validated predictions
- for episode_id in validated_predictions:
- del self.recent_predictions[episode_id]
-
- if validated_predictions:
- logger.info(f"Validated {len(validated_predictions)} predictions")
-
- except Exception as e:
- logger.error(f"Error validating predictions: {e}")
-
- def _calculate_prediction_outcome(self, prediction_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- """Calculate actual outcome for a prediction"""
- try:
- symbol = prediction_data['symbol']
- prediction_time = prediction_data['timestamp']
-
- # Get price data after prediction
- current_df = self.data_provider.get_historical_data(symbol, '1m', limit=100, refresh=True)
-
- if current_df is None or current_df.empty:
- return None
-
- # Find price at prediction time and current price
- prediction_price = prediction_data['market_data']['ohlcv'].get('1m', pd.DataFrame())
- if prediction_price.empty:
- return None
-
- base_price = float(prediction_price['close'].iloc[-1])
- current_price = float(current_df['close'].iloc[-1])
-
- # Calculate outcome
- price_change = (current_price - base_price) / base_price
- is_profitable = abs(price_change) > 0.005 # 0.5% threshold
-
- return {
- 'episode_id': prediction_data.get('episode_id'),
- 'base_price': base_price,
- 'current_price': current_price,
- 'price_change': price_change,
- 'is_profitable': is_profitable,
- 'profitability_score': abs(price_change) * 10, # Scale to 0-1 range
- 'validation_time': datetime.now()
- }
-
- except Exception as e:
- logger.error(f"Error calculating prediction outcome: {e}")
- return None
-
- def _update_rl_experience_outcome(self, episode_id: str, outcome: Dict[str, Any]):
- """Update RL experience with actual outcome"""
- try:
- # Find the experience ID associated with this episode
- # This is a simplified approach - in practice you'd maintain better mapping
- actual_profit = outcome['price_change']
-
- # Determine optimal action based on outcome
- if outcome['price_change'] > 0.01:
- optimal_action = 2 # BUY
- elif outcome['price_change'] < -0.01:
- optimal_action = 0 # SELL
- else:
- optimal_action = 1 # HOLD
-
- # Update experience (this would need proper experience ID mapping)
- # For now, we'll update the most recent experience
- # In practice, you'd maintain a mapping between episodes and experiences
-
- except Exception as e:
- logger.error(f"Error updating RL experience outcome: {e}")
-
- def get_integration_statistics(self) -> Dict[str, Any]:
- """Get comprehensive integration statistics"""
- stats = self.integration_stats.copy()
-
- # Add component statistics
- stats['data_collector'] = self.data_collector.get_collection_statistics()
- stats['cnn_trainer'] = self.cnn_trainer.get_training_statistics()
- stats['rl_trainer'] = self.rl_trainer.get_training_statistics()
-
- # Add performance metrics
- stats['is_running'] = self.is_running
- stats['active_symbols'] = len(self.data_provider.symbols)
- stats['recent_predictions_count'] = len(self.recent_predictions)
- stats['validated_outcomes_count'] = len(self.prediction_outcomes)
-
- # Calculate profitability rate
- if stats['total_predictions'] > 0:
- stats['overall_profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
- else:
- stats['overall_profitability_rate'] = 0.0
-
- return stats
-
- def trigger_manual_training(self, training_type: str = 'all', symbol: str = None) -> Dict[str, Any]:
- """Manually trigger training"""
- results = {}
-
- try:
- if training_type in ['all', 'cnn']:
- symbols = [symbol] if symbol else self.data_provider.symbols
- for sym in symbols:
- cnn_results = self.cnn_trainer.train_on_profitable_episodes(
- symbol=sym,
- min_profitability=0.1,
- max_episodes=200
- )
- results[f'cnn_{sym}'] = cnn_results
-
- if training_type in ['all', 'rl']:
- rl_results = self.rl_trainer.train_on_profitable_experiences(
- min_profitability=0.1,
- max_experiences=500,
- batch_size=32
- )
- results['rl'] = rl_results
-
- return {'status': 'success', 'results': results}
-
- except Exception as e:
- logger.error(f"Error in manual training trigger: {e}")
- return {'status': 'error', 'error': str(e)}
-
- # Helper methods (simplified implementations)
- def _get_recent_tick_data(self, symbol: str) -> List[Dict[str, Any]]:
- """Get recent tick data"""
- # Implementation would get tick data from data provider
- return []
-
- def _get_cob_data(self, symbol: str) -> Dict[str, Any]:
- """Get COB data"""
- # Implementation would get COB data from data provider
- return {}
-
- def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
- """Get technical indicators"""
- # Implementation would get indicators from data provider
- return {}
-
- def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]:
- """Get pivot points"""
- # Implementation would get pivot points from data provider
- return []
-
- def _get_market_context(self, symbol: str) -> Dict[str, Any]:
- """Get market context"""
- return {
- 'symbol': symbol,
- 'timestamp': datetime.now(),
- 'market_session': 'unknown',
- 'volatility_regime': 'unknown'
- }
-
- def _validate_market_data(self, market_data: Dict[str, Any]) -> bool:
- """Validate market data completeness"""
- required_fields = ['ohlcv', 'indicators']
- return all(field in market_data for field in required_fields)
-
- def _create_enhanced_cnn_features(self, symbol: str, market_data: Dict[str, Any]) -> Optional[np.ndarray]:
- """Create enhanced CNN features"""
- try:
- # Simplified feature creation
- features = []
-
- # Add OHLCV features
- for timeframe in ['1m', '5m', '15m', '1h']:
- if timeframe in market_data.get('ohlcv', {}):
- df = market_data['ohlcv'][timeframe]
- if not df.empty:
- ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values
- if len(ohlcv_values) > 0:
- recent_values = ohlcv_values[-60:].flatten()
- features.extend(recent_values)
-
- # Pad to target size
- target_size = 3000 # 10 channels * 300 sequence length
- if len(features) < target_size:
- features.extend([0.0] * (target_size - len(features)))
- else:
- features = features[:target_size]
-
- return np.array(features, dtype=np.float32)
-
- except Exception as e:
- logger.warning(f"Error creating CNN features: {e}")
- return None
-
- def _create_enhanced_rl_state(self, symbol: str, market_data: Dict[str, Any],
- predictions: Dict[str, Any] = None) -> Optional[np.ndarray]:
- """Create enhanced RL state"""
- try:
- state_features = []
-
- # Add market features
- if '1m' in market_data.get('ohlcv', {}):
- df = market_data['ohlcv']['1m']
- if not df.empty:
- latest = df.iloc[-1]
- state_features.extend([
- latest['open'], latest['high'],
- latest['low'], latest['close'], latest['volume']
- ])
-
- # Add technical indicators
- indicators = market_data.get('indicators', {})
- for value in indicators.values():
- state_features.append(value)
-
- # Add model predictions as features
- if predictions:
- if 'cnn' in predictions:
- cnn_pred = predictions['cnn']
- state_features.extend(cnn_pred.get('pivot_logits', [0, 0, 0]))
- state_features.append(cnn_pred.get('confidence', [0.0])[0])
-
- if 'cob_rl' in predictions:
- cob_pred = predictions['cob_rl']
- state_features.append(cob_pred.get('predicted_direction', 1))
- state_features.append(cob_pred.get('confidence', 0.5))
-
- # Pad to target size
- target_size = 2000
- if len(state_features) < target_size:
- state_features.extend([0.0] * (target_size - len(state_features)))
- else:
- state_features = state_features[:target_size]
-
- return np.array(state_features, dtype=np.float32)
-
- except Exception as e:
- logger.warning(f"Error creating RL state: {e}")
- return None
-
- def _get_orchestrator_prediction(self, symbol: str, market_data: Dict[str, Any],
- predictions: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- """Get orchestrator prediction"""
- # This would integrate with your orchestrator
- return None
-
-# Global instance
-enhanced_training_integration = None
-
-def get_enhanced_training_integration(data_provider: DataProvider = None,
- orchestrator: Orchestrator = None,
- trading_executor: TradingExecutor = None) -> EnhancedTrainingIntegration:
- """Get global enhanced training integration instance"""
- global enhanced_training_integration
- if enhanced_training_integration is None:
- if data_provider is None:
- raise ValueError("DataProvider required for first initialization")
- enhanced_training_integration = EnhancedTrainingIntegration(
- data_provider, orchestrator, trading_executor
- )
- return enhanced_training_integration
\ No newline at end of file
diff --git a/core/mexc_webclient/__init__.py b/core/mexc_webclient/__init__.py
deleted file mode 100644
index 449bcac..0000000
--- a/core/mexc_webclient/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-# MEXC Web Client Module
-#
-# This module provides web-based trading capabilities for MEXC futures trading
-# which is not supported by their official API.
-
-from .mexc_futures_client import MEXCFuturesWebClient
-
-__all__ = ['MEXCFuturesWebClient']
\ No newline at end of file
diff --git a/core/mexc_webclient/auto_browser.py b/core/mexc_webclient/auto_browser.py
deleted file mode 100644
index d8b5c31..0000000
--- a/core/mexc_webclient/auto_browser.py
+++ /dev/null
@@ -1,555 +0,0 @@
-#!/usr/bin/env python3
-"""
-MEXC Auto Browser with Request Interception
-
-This script automatically spawns a ChromeDriver instance and captures
-all MEXC futures trading requests in real-time, including full request
-and response data needed for reverse engineering.
-"""
-
-import logging
-import time
-import json
-import sys
-import os
-from typing import Dict, List, Optional, Any
-from datetime import datetime
-import threading
-import queue
-
-# Selenium imports
-try:
- from selenium import webdriver
- from selenium.webdriver.chrome.options import Options
- from selenium.webdriver.chrome.service import Service
- from selenium.webdriver.common.by import By
- from selenium.webdriver.support.ui import WebDriverWait
- from selenium.webdriver.support import expected_conditions as EC
- from selenium.common.exceptions import TimeoutException, WebDriverException
- from webdriver_manager.chrome import ChromeDriverManager
-except ImportError:
- print("Please install selenium and webdriver-manager:")
- print("pip install selenium webdriver-manager")
- sys.exit(1)
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-class MEXCRequestInterceptor:
- """
- Automatically spawns ChromeDriver and intercepts all MEXC API requests
- """
-
- def __init__(self, headless: bool = False, save_to_file: bool = True):
- """
- Initialize the request interceptor
-
- Args:
- headless: Run browser in headless mode
- save_to_file: Save captured requests to JSON file
- """
- self.driver = None
- self.headless = headless
- self.save_to_file = save_to_file
- self.captured_requests = []
- self.captured_responses = []
- self.session_cookies = {}
- self.monitoring = False
- self.request_queue = queue.Queue()
-
- # File paths for saving data
- self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- self.requests_file = f"mexc_requests_{self.timestamp}.json"
- self.cookies_file = f"mexc_cookies_{self.timestamp}.json"
-
- def setup_browser(self):
- """Setup Chrome browser with necessary options"""
- chrome_options = webdriver.ChromeOptions()
- # Enable headless mode if needed
- if self.headless:
- chrome_options.add_argument('--headless')
- chrome_options.add_argument('--disable-gpu')
- chrome_options.add_argument('--window-size=1920,1080')
- chrome_options.add_argument('--disable-extensions')
-
- # Set up Chrome options with a user data directory to persist session
- user_data_base_dir = os.path.join(os.getcwd(), 'chrome_user_data')
- os.makedirs(user_data_base_dir, exist_ok=True)
-
- # Check for existing session directories
- session_dirs = [d for d in os.listdir(user_data_base_dir) if d.startswith('session_')]
- session_dirs.sort(reverse=True) # Sort descending to get the most recent first
-
- user_data_dir = None
- if session_dirs:
- use_existing = input(f"Found {len(session_dirs)} existing sessions. Use an existing session? (y/n): ").lower().strip() == 'y'
- if use_existing:
- print("Available sessions:")
- for i, session in enumerate(session_dirs[:5], 1): # Show up to 5 most recent
- print(f"{i}. {session}")
- choice = input("Enter session number (default 1) or any other key for most recent: ")
- if choice.isdigit() and 1 <= int(choice) <= len(session_dirs):
- selected_session = session_dirs[int(choice) - 1]
- else:
- selected_session = session_dirs[0]
- user_data_dir = os.path.join(user_data_base_dir, selected_session)
- print(f"Using session: {selected_session}")
-
- if user_data_dir is None:
- user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}')
- os.makedirs(user_data_dir, exist_ok=True)
- print(f"Creating new session: session_{self.timestamp}")
-
- chrome_options.add_argument(f'--user-data-dir={user_data_dir}')
-
- # Enable logging to capture JS console output and network activity
- chrome_options.set_capability('goog:loggingPrefs', {
- 'browser': 'ALL',
- 'performance': 'ALL'
- })
-
- try:
- self.driver = webdriver.Chrome(options=chrome_options)
- except Exception as e:
- print(f"Failed to start browser with session: {e}")
- print("Falling back to a new session...")
- user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}_fallback')
- os.makedirs(user_data_dir, exist_ok=True)
- print(f"Creating fallback session: session_{self.timestamp}_fallback")
- chrome_options = webdriver.ChromeOptions()
- if self.headless:
- chrome_options.add_argument('--headless')
- chrome_options.add_argument('--disable-gpu')
- chrome_options.add_argument('--window-size=1920,1080')
- chrome_options.add_argument('--disable-extensions')
- chrome_options.add_argument(f'--user-data-dir={user_data_dir}')
- chrome_options.set_capability('goog:loggingPrefs', {
- 'browser': 'ALL',
- 'performance': 'ALL'
- })
- self.driver = webdriver.Chrome(options=chrome_options)
-
- return self.driver
-
- def start_monitoring(self):
- """Start the browser and begin monitoring"""
- logger.info("Starting MEXC Request Interceptor...")
-
- try:
- # Setup ChromeDriver
- self.driver = self.setup_browser()
-
- # Navigate to MEXC futures
- mexc_url = "https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap"
- logger.info(f"Navigating to: {mexc_url}")
- self.driver.get(mexc_url)
-
- # Wait for page load
- WebDriverWait(self.driver, 10).until(
- EC.presence_of_element_located((By.TAG_NAME, "body"))
- )
-
- logger.info("✅ MEXC page loaded successfully!")
- logger.info("📝 Please log in manually in the browser window")
- logger.info("🔍 Request monitoring is now active...")
-
- # Start monitoring in background thread
- self.monitoring = True
- monitor_thread = threading.Thread(target=self._monitor_requests, daemon=True)
- monitor_thread.start()
-
- # Wait for manual login
- self._wait_for_login()
-
- return True
-
- except Exception as e:
- logger.error(f"Failed to start monitoring: {e}")
- return False
-
- def _wait_for_login(self):
- """Wait for user to log in and show interactive menu"""
- logger.info("\n" + "="*60)
- logger.info("MEXC REQUEST INTERCEPTOR - INTERACTIVE MODE")
- logger.info("="*60)
-
- while True:
- print("\nOptions:")
- print("1. Check login status")
- print("2. Extract current cookies")
- print("3. Show captured requests summary")
- print("4. Save captured data to files")
- print("5. Perform test trade (manual)")
- print("6. Monitor for 60 seconds")
- print("0. Stop and exit")
-
- choice = input("\nEnter choice (0-6): ").strip()
-
- if choice == "1":
- self._check_login_status()
- elif choice == "2":
- self._extract_cookies()
- elif choice == "3":
- self._show_requests_summary()
- elif choice == "4":
- self._save_all_data()
- elif choice == "5":
- self._guide_test_trade()
- elif choice == "6":
- self._monitor_for_duration(60)
- elif choice == "0":
- break
- else:
- print("Invalid choice. Please try again.")
-
- self.stop_monitoring()
-
- def _check_login_status(self):
- """Check if user is logged into MEXC"""
- try:
- cookies = self.driver.get_cookies()
- auth_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint']
- found_auth = []
-
- for cookie in cookies:
- if cookie['name'] in auth_cookies and cookie['value']:
- found_auth.append(cookie['name'])
-
- if len(found_auth) >= 2:
- print("✅ LOGIN DETECTED - You appear to be logged in!")
- print(f" Found auth cookies: {', '.join(found_auth)}")
- return True
- else:
- print("❌ NOT LOGGED IN - Please log in to MEXC in the browser")
- print(" Missing required authentication cookies")
- return False
-
- except Exception as e:
- print(f"❌ Error checking login: {e}")
- return False
-
- def _extract_cookies(self):
- """Extract and display current session cookies"""
- try:
- cookies = self.driver.get_cookies()
- cookie_dict = {}
-
- for cookie in cookies:
- cookie_dict[cookie['name']] = cookie['value']
-
- self.session_cookies = cookie_dict
-
- print(f"\n📊 Extracted {len(cookie_dict)} cookies:")
-
- # Show important cookies
- important = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId']
- for name in important:
- if name in cookie_dict:
- value = cookie_dict[name]
- display_value = value[:20] + "..." if len(value) > 20 else value
- print(f" ✅ {name}: {display_value}")
- else:
- print(f" ❌ {name}: Missing")
-
- # Save cookies to file
- if self.save_to_file:
- with open(self.cookies_file, 'w') as f:
- json.dump(cookie_dict, f, indent=2)
- print(f"\n💾 Cookies saved to: {self.cookies_file}")
-
- except Exception as e:
- print(f"❌ Error extracting cookies: {e}")
-
- def _monitor_requests(self):
- """Background thread to monitor network requests"""
- last_log_count = 0
-
- while self.monitoring:
- try:
- # Get performance logs
- logs = self.driver.get_log('performance')
-
- for log in logs:
- try:
- message = json.loads(log['message'])
- method = message.get('message', {}).get('method', '')
-
- # Capture network requests
- if method == 'Network.requestWillBeSent':
- self._process_request(message['message']['params'])
- elif method == 'Network.responseReceived':
- self._process_response(message['message']['params'])
-
- except (json.JSONDecodeError, KeyError) as e:
- continue
-
- # Show progress every 10 new requests
- if len(self.captured_requests) >= last_log_count + 10:
- last_log_count = len(self.captured_requests)
- logger.info(f"📈 Captured {len(self.captured_requests)} requests, {len(self.captured_responses)} responses")
-
- except Exception as e:
- if self.monitoring: # Only log if we're still supposed to be monitoring
- logger.debug(f"Monitor error: {e}")
-
- time.sleep(0.5) # Check every 500ms
-
- def _process_request(self, request_data):
- """Process a captured network request"""
- try:
- url = request_data.get('request', {}).get('url', '')
-
- # Filter for MEXC API requests
- if self._is_mexc_request(url):
- request_info = {
- 'type': 'request',
- 'timestamp': datetime.now().isoformat(),
- 'url': url,
- 'method': request_data.get('request', {}).get('method', ''),
- 'headers': request_data.get('request', {}).get('headers', {}),
- 'postData': request_data.get('request', {}).get('postData', ''),
- 'requestId': request_data.get('requestId', '')
- }
-
- self.captured_requests.append(request_info)
-
- # Show important requests immediately
- if ('futures.mexc.com' in url or 'captcha' in url):
- print(f"\n🚀 CAPTURED REQUEST: {request_info['method']} {url}")
- if request_info['postData']:
- print(f" 📄 POST Data: {request_info['postData'][:100]}...")
-
- # Enhanced captcha detection and detailed logging
- if 'captcha' in url.lower() or 'robot' in url.lower():
- logger.info(f"CAPTCHA REQUEST DETECTED: {request_data.get('request', {}).get('method', 'UNKNOWN')} {url}")
- logger.info(f" Headers: {request_data.get('request', {}).get('headers', {})}")
- if request_data.get('request', {}).get('postData', ''):
- logger.info(f" Data: {request_data.get('request', {}).get('postData', '')}")
- # Attempt to capture related JavaScript or DOM elements (if possible)
- if self.driver is not None:
- try:
- js_snippet = self.driver.execute_script("return document.querySelector('script[src*=\"captcha\"]') ? document.querySelector('script[src*=\"captcha\"]').outerHTML : 'No captcha script found';")
- logger.info(f" Related JS Snippet: {js_snippet}")
- except Exception as e:
- logger.warning(f" Could not capture JS snippet: {e}")
- try:
- dom_element = self.driver.execute_script("return document.querySelector('div[id*=\"captcha\"]') ? document.querySelector('div[id*=\"captcha\"]').outerHTML : 'No captcha element found';")
- logger.info(f" Related DOM Element: {dom_element}")
- except Exception as e:
- logger.warning(f" Could not capture DOM element: {e}")
- else:
- logger.warning(" Driver not initialized, cannot capture JS or DOM elements")
-
- except Exception as e:
- logger.debug(f"Error processing request: {e}")
-
- def _process_response(self, response_data):
- """Process a captured network response"""
- try:
- url = response_data.get('response', {}).get('url', '')
-
- # Filter for MEXC API responses
- if self._is_mexc_request(url):
- response_info = {
- 'type': 'response',
- 'timestamp': datetime.now().isoformat(),
- 'url': url,
- 'status': response_data.get('response', {}).get('status', 0),
- 'headers': response_data.get('response', {}).get('headers', {}),
- 'requestId': response_data.get('requestId', '')
- }
-
- self.captured_responses.append(response_info)
-
- # Show important responses immediately
- if ('futures.mexc.com' in url or 'captcha' in url):
- status = response_info['status']
- status_emoji = "✅" if status == 200 else "❌"
- print(f" {status_emoji} RESPONSE: {status} for {url}")
-
- except Exception as e:
- logger.debug(f"Error processing response: {e}")
-
- def _is_mexc_request(self, url: str) -> bool:
- """Check if URL is a relevant MEXC API request"""
- mexc_indicators = [
- 'futures.mexc.com',
- 'ucgateway/captcha_api',
- 'api/v1/private',
- 'api/v3/order',
- 'mexc.com/api'
- ]
-
- return any(indicator in url for indicator in mexc_indicators)
-
- def _show_requests_summary(self):
- """Show summary of captured requests"""
- print(f"\n📊 CAPTURE SUMMARY:")
- print(f" Total Requests: {len(self.captured_requests)}")
- print(f" Total Responses: {len(self.captured_responses)}")
-
- # Group by URL pattern
- url_counts = {}
- for req in self.captured_requests:
- base_url = req['url'].split('?')[0] # Remove query params
- url_counts[base_url] = url_counts.get(base_url, 0) + 1
-
- print("\n🔗 Top URLs:")
- for url, count in sorted(url_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
- print(f" {count}x {url}")
-
- # Show recent futures API calls
- futures_requests = [r for r in self.captured_requests if 'futures.mexc.com' in r['url']]
- if futures_requests:
- print(f"\n🚀 Futures API Calls: {len(futures_requests)}")
- for req in futures_requests[-3:]: # Show last 3
- print(f" {req['method']} {req['url']}")
-
- def _save_all_data(self):
- """Save all captured data to files"""
- if not self.save_to_file:
- print("File saving is disabled")
- return
-
- try:
- # Save requests
- with open(self.requests_file, 'w') as f:
- json.dump({
- 'requests': self.captured_requests,
- 'responses': self.captured_responses,
- 'summary': {
- 'total_requests': len(self.captured_requests),
- 'total_responses': len(self.captured_responses),
- 'capture_session': self.timestamp
- }
- }, f, indent=2)
-
- # Save cookies if we have them
- if self.session_cookies:
- with open(self.cookies_file, 'w') as f:
- json.dump(self.session_cookies, f, indent=2)
-
- print(f"\n💾 Data saved to:")
- print(f" 📋 Requests: {self.requests_file}")
- if self.session_cookies:
- print(f" 🍪 Cookies: {self.cookies_file}")
-
- # Extract and save CAPTCHA tokens from captured requests
- captcha_tokens = self.extract_captcha_tokens()
- if captcha_tokens:
- captcha_file = f"mexc_captcha_tokens_{self.timestamp}.json"
- with open(captcha_file, 'w') as f:
- json.dump(captcha_tokens, f, indent=2)
- logger.info(f"Saved CAPTCHA tokens to {captcha_file}")
- else:
- logger.warning("No CAPTCHA tokens found in captured requests")
-
- except Exception as e:
- print(f"❌ Error saving data: {e}")
-
- def _guide_test_trade(self):
- """Guide user through performing a test trade"""
- print("\n🧪 TEST TRADE GUIDE:")
- print("1. Make sure you're logged into MEXC")
- print("2. Go to the trading interface")
- print("3. Try to place a SMALL test trade (it may fail, but we'll capture the requests)")
- print("4. Watch the console for captured API calls")
- print("\n⚠️ IMPORTANT: Use very small amounts for testing!")
- input("\nPress Enter when you're ready to start monitoring...")
-
- self._monitor_for_duration(120) # Monitor for 2 minutes
-
- def _monitor_for_duration(self, seconds: int):
- """Monitor requests for a specific duration"""
- print(f"\n🔍 Monitoring requests for {seconds} seconds...")
- print("Perform your trading actions now!")
-
- start_time = time.time()
- initial_count = len(self.captured_requests)
-
- while time.time() - start_time < seconds:
- current_count = len(self.captured_requests)
- new_requests = current_count - initial_count
-
- remaining = seconds - int(time.time() - start_time)
- print(f"\r⏱️ Time remaining: {remaining}s | New requests: {new_requests}", end="", flush=True)
-
- time.sleep(1)
-
- final_count = len(self.captured_requests)
- new_total = final_count - initial_count
- print(f"\n✅ Monitoring complete! Captured {new_total} new requests")
-
- def stop_monitoring(self):
- """Stop monitoring and close browser"""
- logger.info("Stopping request monitoring...")
- self.monitoring = False
-
- if self.driver:
- self.driver.quit()
- logger.info("Browser closed")
-
- # Final save
- if self.save_to_file and (self.captured_requests or self.captured_responses):
- self._save_all_data()
- logger.info("Final data save complete")
-
- def extract_captcha_tokens(self):
- """Extract CAPTCHA tokens from captured requests"""
- captcha_tokens = []
- for request in self.captured_requests:
- if 'captcha-token' in request.get('headers', {}):
- token = request['headers']['captcha-token']
- captcha_tokens.append({
- 'token': token,
- 'url': request.get('url', ''),
- 'timestamp': request.get('timestamp', '')
- })
- elif 'captcha' in request.get('url', '').lower():
- response = request.get('response', {})
- if response and 'captcha-token' in response.get('headers', {}):
- token = response['headers']['captcha-token']
- captcha_tokens.append({
- 'token': token,
- 'url': request.get('url', ''),
- 'timestamp': request.get('timestamp', '')
- })
- return captcha_tokens
-
-def main():
- """Main function to run the interceptor"""
- print("🚀 MEXC Request Interceptor with ChromeDriver")
- print("=" * 50)
- print("This will automatically:")
- print("✅ Download/setup ChromeDriver")
- print("✅ Open MEXC futures page")
- print("✅ Capture all API requests/responses")
- print("✅ Extract session cookies")
- print("✅ Save data to JSON files")
- print("\nPress Ctrl+C to stop at any time")
-
- # Ask for preferences
- headless = input("\nRun in headless mode? (y/n): ").lower().strip() == 'y'
-
- interceptor = MEXCRequestInterceptor(headless=headless, save_to_file=True)
-
- try:
- success = interceptor.start_monitoring()
- if not success:
- print("❌ Failed to start monitoring")
- return
-
- except KeyboardInterrupt:
- print("\n\n⏹️ Stopping interceptor...")
- except Exception as e:
- print(f"\n❌ Error: {e}")
- finally:
- interceptor.stop_monitoring()
- print("\n👋 Goodbye!")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/core/mexc_webclient/browser_automation.py b/core/mexc_webclient/browser_automation.py
deleted file mode 100644
index f2f6e43..0000000
--- a/core/mexc_webclient/browser_automation.py
+++ /dev/null
@@ -1,358 +0,0 @@
-"""
-MEXC Browser Automation for Cookie Extraction and Request Monitoring
-
-This module uses Selenium to automate browser interactions and extract
-session cookies and request data for MEXC futures trading.
-"""
-
-import logging
-import time
-import json
-from typing import Dict, List, Optional, Any
-from selenium import webdriver
-from selenium.webdriver.chrome.options import Options
-from selenium.webdriver.common.by import By
-from selenium.webdriver.support.ui import WebDriverWait
-from selenium.webdriver.support import expected_conditions as EC
-from selenium.common.exceptions import TimeoutException, WebDriverException
-from selenium.webdriver.chrome.service import Service
-from webdriver_manager.chrome import ChromeDriverManager
-
-logger = logging.getLogger(__name__)
-
-class MEXCBrowserAutomation:
- """
- Browser automation for MEXC futures trading session management
- """
-
- def __init__(self, headless: bool = False, proxy: Optional[str] = None):
- """
- Initialize browser automation
-
- Args:
- headless: Run browser in headless mode
- proxy: HTTP proxy to use (format: host:port)
- """
- self.driver = None
- self.headless = headless
- self.proxy = proxy
- self.logged_in = False
-
- def setup_chrome_driver(self) -> webdriver.Chrome:
- """Setup Chrome driver with appropriate options"""
- chrome_options = Options()
-
- if self.headless:
- chrome_options.add_argument("--headless")
-
- # Basic Chrome options for automation
- chrome_options.add_argument("--no-sandbox")
- chrome_options.add_argument("--disable-dev-shm-usage")
- chrome_options.add_argument("--disable-blink-features=AutomationControlled")
- chrome_options.add_experimental_option("excludeSwitches", ["enable-automation"])
- chrome_options.add_experimental_option('useAutomationExtension', False)
-
- # Set user agent to avoid detection
- chrome_options.add_argument("--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36")
-
- # Proxy setup if provided
- if self.proxy:
- chrome_options.add_argument(f"--proxy-server=http://{self.proxy}")
-
- # Enable network logging
- chrome_options.add_argument("--enable-logging")
- chrome_options.add_argument("--log-level=0")
- chrome_options.set_capability("goog:loggingPrefs", {"performance": "ALL"})
-
- # Automatically download and setup ChromeDriver
- service = Service(ChromeDriverManager().install())
-
- try:
- driver = webdriver.Chrome(service=service, options=chrome_options)
-
- # Execute script to avoid detection
- driver.execute_script("Object.defineProperty(navigator, 'webdriver', {get: () => undefined})")
-
- return driver
- except WebDriverException as e:
- logger.error(f"Failed to setup Chrome driver: {e}")
- raise
-
- def start_browser(self):
- """Start the browser session"""
- if self.driver is None:
- logger.info("Starting Chrome browser for MEXC automation")
- self.driver = self.setup_chrome_driver()
- logger.info("Browser started successfully")
-
- def stop_browser(self):
- """Stop the browser session"""
- if self.driver:
- logger.info("Stopping browser")
- self.driver.quit()
- self.driver = None
-
- def navigate_to_mexc_futures(self, symbol: str = "ETH_USDT"):
- """
- Navigate to MEXC futures trading page
-
- Args:
- symbol: Trading symbol to navigate to
- """
- if not self.driver:
- self.start_browser()
-
- url = f"https://www.mexc.com/en-GB/futures/{symbol}?type=linear_swap"
- logger.info(f"Navigating to MEXC futures: {url}")
-
- self.driver.get(url)
-
- # Wait for page to load
- try:
- WebDriverWait(self.driver, 10).until(
- EC.presence_of_element_located((By.TAG_NAME, "body"))
- )
- logger.info("MEXC futures page loaded")
- except TimeoutException:
- logger.error("Timeout waiting for MEXC page to load")
-
- def wait_for_login(self, timeout: int = 300) -> bool:
- """
- Wait for user to manually log in to MEXC
-
- Args:
- timeout: Maximum time to wait for login (seconds)
-
- Returns:
- bool: True if login detected, False if timeout
- """
- logger.info("Please log in to MEXC manually in the browser window")
- logger.info("Waiting for login completion...")
-
- start_time = time.time()
-
- while time.time() - start_time < timeout:
- # Check if we can find elements that indicate logged in state
- try:
- # Look for user-specific elements that appear after login
- cookies = self.driver.get_cookies()
-
- # Check for authentication cookies
- auth_cookies = ['uc_token', 'u_id']
- logged_in_indicators = 0
-
- for cookie in cookies:
- if cookie['name'] in auth_cookies and cookie['value']:
- logged_in_indicators += 1
-
- if logged_in_indicators >= 2:
- logger.info("Login detected!")
- self.logged_in = True
- return True
-
- except Exception as e:
- logger.debug(f"Error checking login status: {e}")
-
- time.sleep(2) # Check every 2 seconds
-
- logger.error(f"Login timeout after {timeout} seconds")
- return False
-
- def extract_session_cookies(self) -> Dict[str, str]:
- """
- Extract all cookies from current browser session
-
- Returns:
- Dictionary of cookie name-value pairs
- """
- if not self.driver:
- logger.error("Browser not started")
- return {}
-
- cookies = {}
-
- try:
- browser_cookies = self.driver.get_cookies()
-
- for cookie in browser_cookies:
- cookies[cookie['name']] = cookie['value']
-
- logger.info(f"Extracted {len(cookies)} cookies from browser session")
-
- # Log important cookies (without values for security)
- important_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId']
- for cookie_name in important_cookies:
- if cookie_name in cookies:
- logger.info(f"Found important cookie: {cookie_name}")
- else:
- logger.warning(f"Missing important cookie: {cookie_name}")
-
- return cookies
-
- except Exception as e:
- logger.error(f"Failed to extract cookies: {e}")
- return {}
-
- def monitor_network_requests(self, duration: int = 60) -> List[Dict[str, Any]]:
- """
- Monitor network requests for the specified duration
-
- Args:
- duration: How long to monitor requests (seconds)
-
- Returns:
- List of captured network requests
- """
- if not self.driver:
- logger.error("Browser not started")
- return []
-
- logger.info(f"Starting network monitoring for {duration} seconds")
- logger.info("Please perform trading actions in the browser (open/close positions)")
-
- start_time = time.time()
- captured_requests = []
-
- while time.time() - start_time < duration:
- try:
- # Get performance logs (network requests)
- logs = self.driver.get_log('performance')
-
- for log in logs:
- message = json.loads(log['message'])
-
- # Filter for relevant MEXC API requests
- if (message.get('message', {}).get('method') == 'Network.responseReceived'):
- response = message['message']['params']['response']
- url = response.get('url', '')
-
- # Look for futures API calls
- if ('futures.mexc.com' in url or
- 'ucgateway/captcha_api' in url or
- 'api/v1/private' in url):
-
- request_data = {
- 'url': url,
- 'method': response.get('mimeType', ''),
- 'status': response.get('status'),
- 'headers': response.get('headers', {}),
- 'timestamp': log['timestamp']
- }
-
- captured_requests.append(request_data)
- logger.info(f"Captured request: {url}")
-
- except Exception as e:
- logger.debug(f"Error in network monitoring: {e}")
-
- time.sleep(1)
-
- logger.info(f"Network monitoring complete. Captured {len(captured_requests)} requests")
- return captured_requests
-
- def perform_test_trade(self, symbol: str = "ETH_USDT", volume: float = 1.0, leverage: int = 200):
- """
- Attempt to perform a test trade to capture the complete request flow
-
- Args:
- symbol: Trading symbol
- volume: Position size
- leverage: Leverage multiplier
- """
- if not self.logged_in:
- logger.error("Not logged in - cannot perform test trade")
- return
-
- logger.info(f"Attempting test trade: {symbol}, Volume: {volume}, Leverage: {leverage}x")
- logger.info("This will attempt to click trading interface elements")
-
- try:
- # This would need to be implemented based on MEXC's specific UI elements
- # For now, just wait and let user perform manual actions
- logger.info("Please manually place a small test trade while monitoring is active")
- time.sleep(30)
-
- except Exception as e:
- logger.error(f"Error during test trade: {e}")
-
- def full_session_capture(self, symbol: str = "ETH_USDT") -> Dict[str, Any]:
- """
- Complete session capture workflow
-
- Args:
- symbol: Trading symbol to use
-
- Returns:
- Dictionary containing cookies and captured requests
- """
- logger.info("Starting full MEXC session capture")
-
- try:
- # Start browser and navigate to MEXC
- self.navigate_to_mexc_futures(symbol)
-
- # Wait for manual login
- if not self.wait_for_login():
- return {'success': False, 'error': 'Login timeout'}
-
- # Extract session cookies
- cookies = self.extract_session_cookies()
-
- if not cookies:
- return {'success': False, 'error': 'Failed to extract cookies'}
-
- # Monitor network requests while user performs actions
- logger.info("Starting network monitoring - please perform trading actions now")
- requests = self.monitor_network_requests(duration=120) # 2 minutes
-
- return {
- 'success': True,
- 'cookies': cookies,
- 'network_requests': requests,
- 'timestamp': int(time.time())
- }
-
- except Exception as e:
- logger.error(f"Error in session capture: {e}")
- return {'success': False, 'error': str(e)}
-
- finally:
- self.stop_browser()
-
-def main():
- """Main function for standalone execution"""
- logging.basicConfig(level=logging.INFO)
-
- print("MEXC Browser Automation - Session Capture")
- print("This will open a browser window for you to log into MEXC")
- print("Make sure you have Chrome browser installed")
-
- automation = MEXCBrowserAutomation(headless=False)
-
- try:
- result = automation.full_session_capture()
-
- if result['success']:
- print(f"\nSession capture successful!")
- print(f"Extracted {len(result['cookies'])} cookies")
- print(f"Captured {len(result['network_requests'])} network requests")
-
- # Save results to file
- output_file = f"mexc_session_capture_{int(time.time())}.json"
- with open(output_file, 'w') as f:
- json.dump(result, f, indent=2)
-
- print(f"Results saved to: {output_file}")
-
- else:
- print(f"Session capture failed: {result['error']}")
-
- except KeyboardInterrupt:
- print("\nSession capture interrupted by user")
- except Exception as e:
- print(f"Error: {e}")
- finally:
- automation.stop_browser()
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/core/mexc_webclient/mexc_futures_client.py b/core/mexc_webclient/mexc_futures_client.py
deleted file mode 100644
index f8c83fe..0000000
--- a/core/mexc_webclient/mexc_futures_client.py
+++ /dev/null
@@ -1,525 +0,0 @@
-"""
-MEXC Futures Web Client
-
-This module implements a web-based client for MEXC futures trading
-since their official API doesn't support futures (leverage) trading.
-
-It mimics browser behavior by replicating the exact HTTP requests
-that the web interface makes.
-"""
-
-import logging
-import requests
-import time
-import json
-import hmac
-import hashlib
-import base64
-from typing import Dict, List, Optional, Any
-from datetime import datetime
-import uuid
-from urllib.parse import urlencode
-import glob
-import os
-
-logger = logging.getLogger(__name__)
-
-class MEXCSessionManager:
- def __init__(self):
- self.captcha_token = None
-
- def get_captcha_token(self) -> str:
- return self.captcha_token if self.captcha_token else ""
-
- def save_captcha_token(self, token: str):
- self.captcha_token = token
- logger.info("MEXC: Captcha token saved in session manager")
-
-class MEXCFuturesWebClient:
- """
- MEXC Futures Web Client that mimics browser behavior for futures trading.
-
- Since MEXC's official API doesn't support futures, this client replicates
- the exact HTTP requests made by their web interface.
- """
-
- def __init__(self, api_key: str, api_secret: str, user_id: str, base_url: str = 'https://www.mexc.com', headless: bool = True):
- """
- Initialize the MEXC Futures Web Client
-
- Args:
- api_key: API key for authentication
- api_secret: API secret for authentication
- user_id: User ID for authentication
- base_url: Base URL for the MEXC website
- headless: Whether to run the browser in headless mode
- """
- self.api_key = api_key
- self.api_secret = api_secret
- self.user_id = user_id
- self.base_url = base_url
- self.is_authenticated = False
- self.headless = headless
- self.session = requests.Session()
- self.session_manager = MEXCSessionManager() # Adding session_manager attribute
- self.captcha_url = f'{base_url}/ucgateway/captcha_api'
- self.futures_api_url = "https://futures.mexc.com/api/v1"
-
- # Setup default headers that mimic a real browser
- self.setup_browser_headers()
-
- def setup_browser_headers(self):
- """Setup default headers that mimic Chrome browser"""
- self.session.headers.update({
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36',
- 'Accept': '*/*',
- 'Accept-Language': 'en-GB,en-US;q=0.9,en;q=0.8',
- 'Accept-Encoding': 'gzip, deflate, br',
- 'sec-ch-ua': '"Chromium";v="136", "Google Chrome";v="136", "Not.A/Brand";v="99"',
- 'sec-ch-ua-mobile': '?0',
- 'sec-ch-ua-platform': '"Windows"',
- 'sec-fetch-dest': 'empty',
- 'sec-fetch-mode': 'cors',
- 'sec-fetch-site': 'same-origin',
- 'Cache-Control': 'no-cache',
- 'Pragma': 'no-cache',
- 'Referer': f'{self.base_url}/en-GB/futures/ETH_USDT?type=linear_swap',
- 'Language': 'English',
- 'X-Language': 'en-GB',
- 'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
- 'trochilus-uid': str(self.user_id) if self.user_id is not None else ''
- })
-
- def load_session_cookies(self, cookies: Dict[str, str]):
- """
- Load session cookies from browser
-
- Args:
- cookies: Dictionary of cookie name-value pairs
- """
- for name, value in cookies.items():
- self.session.cookies.set(name, value)
-
- # Extract important session info from cookies
- self.auth_token = cookies.get('uc_token')
- self.user_id = cookies.get('u_id')
- self.fingerprint = cookies.get('x-mxc-fingerprint')
- self.visitor_id = cookies.get('mexc_fingerprint_visitorId')
-
- if self.auth_token and self.user_id:
- self.is_authenticated = True
- logger.info("MEXC: Loaded authenticated session")
- else:
- logger.warning("MEXC: Session cookies incomplete - authentication may fail")
-
- def extract_cookies_from_browser(self, cookie_string: str) -> Dict[str, str]:
- """
- Extract cookies from a browser cookie string
-
- Args:
- cookie_string: Raw cookie string from browser (copy from Network tab)
-
- Returns:
- Dictionary of parsed cookies
- """
- cookies = {}
- cookie_pairs = cookie_string.split(';')
-
- for pair in cookie_pairs:
- if '=' in pair:
- name, value = pair.strip().split('=', 1)
- cookies[name] = value
-
- return cookies
-
- def verify_captcha(self, symbol: str, side: str, leverage: str) -> bool:
- """
- Verify captcha for robot trading protection
-
- Args:
- symbol: Trading symbol (e.g., 'ETH_USDT')
- side: 'openlong', 'closelong', 'openshort', 'closeshort'
- leverage: Leverage string (e.g., '200X')
-
- Returns:
- bool: True if captcha verification successful
- """
- if not self.is_authenticated:
- logger.error("MEXC: Cannot verify captcha - not authenticated")
- return False
-
- # Build captcha endpoint URL
- endpoint = f"robot.future.{side}.{symbol}.{leverage}"
- url = f"{self.captcha_url}/{endpoint}"
-
- # Attempt to get captcha token from session manager
- captcha_token = self.session_manager.get_captcha_token()
- if not captcha_token:
- logger.warning("MEXC: No captcha token available, attempting to fetch from browser")
- captcha_token = self._extract_captcha_token_from_browser()
- if captcha_token:
- self.session_manager.save_captcha_token(captcha_token)
- else:
- logger.error("MEXC: Failed to extract captcha token from browser")
- return False
-
- headers = {
- 'Content-Type': 'application/json',
- 'Language': 'en-GB',
- 'Referer': f'{self.base_url}/en-GB/futures/{symbol}?type=linear_swap',
- 'trochilus-uid': self.user_id if self.user_id else '',
- 'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
- 'captcha-token': captcha_token
- }
-
- logger.info(f"MEXC: Verifying captcha for {endpoint}")
- try:
- response = self.session.get(url, headers=headers, timeout=10)
- if response.status_code == 200:
- data = response.json()
- if data.get('success'):
- logger.info(f"MEXC: Captcha verified successfully for {endpoint}")
- return True
- else:
- logger.error(f"MEXC: Captcha verification failed for {endpoint}: {data}")
- return False
- else:
- logger.error(f"MEXC: Captcha verification request failed with status {response.status_code}: {response.text}")
- return False
- except Exception as e:
- logger.error(f"MEXC: Captcha verification error for {endpoint}: {str(e)}")
- return False
-
- def _extract_captcha_token_from_browser(self) -> str:
- """
- Extract captcha token from browser session using stored cookies or requests.
- This method looks for the most recent mexc_captcha_tokens JSON file to retrieve a token.
- """
- try:
- # Look for the most recent mexc_captcha_tokens file
- captcha_files = glob.glob("mexc_captcha_tokens_*.json")
- if not captcha_files:
- logger.error("MEXC: No CAPTCHA token files found")
- return ""
-
- # Sort files by timestamp (most recent first)
- latest_file = max(captcha_files, key=os.path.getctime)
- logger.info(f"MEXC: Using CAPTCHA token file {latest_file}")
-
- with open(latest_file, 'r') as f:
- captcha_data = json.load(f)
-
- if captcha_data and isinstance(captcha_data, list) and len(captcha_data) > 0:
- # Return the most recent token
- return captcha_data[0].get('token', '')
- else:
- logger.error("MEXC: No valid CAPTCHA tokens found in file")
- return ""
- except Exception as e:
- logger.error(f"MEXC: Error extracting captcha token from browser data: {str(e)}")
- return ""
-
- def generate_signature(self, method: str, path: str, params: Dict[str, Any],
- timestamp: int, nonce: int) -> str:
- """
- Generate signature for MEXC futures API requests
-
- This is reverse-engineered from the browser requests
- """
- # This is a placeholder - the actual signature generation would need
- # to be reverse-engineered from the browser's JavaScript
- # For now, return empty string and rely on cookie authentication
- return ""
-
- def open_long_position(self, symbol: str, volume: float, leverage: int = 200,
- price: Optional[float] = None) -> Dict[str, Any]:
- """
- Open a long futures position
-
- Args:
- symbol: Trading symbol (e.g., 'ETH_USDT')
- volume: Position size (contracts)
- leverage: Leverage multiplier (default 200)
- price: Limit price (None for market order)
-
- Returns:
- dict: Order response with order ID
- """
- if not self.is_authenticated:
- logger.error("MEXC: Cannot open position - not authenticated")
- return {'success': False, 'error': 'Not authenticated'}
-
- # First verify captcha
- if not self.verify_captcha(symbol, 'openlong', f'{leverage}X'):
- logger.error("MEXC: Captcha verification failed for opening long position")
- return {'success': False, 'error': 'Captcha verification failed'}
-
- # Prepare order parameters based on the request dump
- timestamp = int(time.time() * 1000)
- nonce = timestamp
-
- order_data = {
- 'symbol': symbol,
- 'side': 1, # 1 = long, 2 = short
- 'openType': 2, # Open position
- 'type': '5', # Market order (might be '1' for limit)
- 'vol': volume,
- 'leverage': leverage,
- 'marketCeiling': False,
- 'priceProtect': '0',
- 'ts': timestamp,
- 'mhash': self._generate_mhash(), # This needs to be implemented
- 'mtoken': self.visitor_id
- }
-
- # Add price for limit orders
- if price is not None:
- order_data['price'] = price
- order_data['type'] = '1' # Limit order
-
- # Add encrypted parameters (these would need proper implementation)
- order_data['p0'] = self._encrypt_p0(order_data) # Placeholder
- order_data['k0'] = self._encrypt_k0(order_data) # Placeholder
- order_data['chash'] = self._generate_chash(order_data) # Placeholder
-
- # Setup headers for the order request
- headers = {
- 'Authorization': self.auth_token,
- 'Content-Type': 'application/json',
- 'Language': 'English',
- 'x-language': 'en-GB',
- 'x-mxc-nonce': str(nonce),
- 'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce),
- 'trochilus-uid': self.user_id,
- 'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
- 'Referer': 'https://www.mexc.com/'
- }
-
- # Make the order request
- url = f"{self.futures_api_url}/private/order/create"
-
- try:
- # First make OPTIONS request (preflight)
- options_response = self.session.options(url, headers=headers, timeout=10)
-
- if options_response.status_code == 200:
- # Now make the actual POST request
- response = self.session.post(url, json=order_data, headers=headers, timeout=15)
-
- if response.status_code == 200:
- data = response.json()
- if data.get('success') and data.get('code') == 0:
- order_id = data.get('data', {}).get('orderId')
- logger.info(f"MEXC: Long position opened successfully - Order ID: {order_id}")
- return {
- 'success': True,
- 'order_id': order_id,
- 'timestamp': data.get('data', {}).get('ts'),
- 'symbol': symbol,
- 'side': 'long',
- 'volume': volume,
- 'leverage': leverage
- }
- else:
- logger.error(f"MEXC: Order failed: {data}")
- return {'success': False, 'error': data.get('msg', 'Unknown error')}
- else:
- logger.error(f"MEXC: Order request failed with status {response.status_code}")
- return {'success': False, 'error': f'HTTP {response.status_code}'}
- else:
- logger.error(f"MEXC: OPTIONS preflight failed with status {options_response.status_code}")
- return {'success': False, 'error': f'Preflight failed: HTTP {options_response.status_code}'}
-
- except Exception as e:
- logger.error(f"MEXC: Order execution error: {e}")
- return {'success': False, 'error': str(e)}
-
- def close_long_position(self, symbol: str, volume: float, leverage: int = 200,
- price: Optional[float] = None) -> Dict[str, Any]:
- """
- Close a long futures position
-
- Args:
- symbol: Trading symbol (e.g., 'ETH_USDT')
- volume: Position size to close (contracts)
- leverage: Leverage multiplier
- price: Limit price (None for market order)
-
- Returns:
- dict: Order response
- """
- if not self.is_authenticated:
- logger.error("MEXC: Cannot close position - not authenticated")
- return {'success': False, 'error': 'Not authenticated'}
-
- # First verify captcha
- if not self.verify_captcha(symbol, 'closelong', f'{leverage}X'):
- logger.error("MEXC: Captcha verification failed for closing long position")
- return {'success': False, 'error': 'Captcha verification failed'}
-
- # Similar to open_long_position but with closeType instead of openType
- timestamp = int(time.time() * 1000)
- nonce = timestamp
-
- order_data = {
- 'symbol': symbol,
- 'side': 2, # Close side is opposite
- 'closeType': 1, # Close position
- 'type': '5', # Market order
- 'vol': volume,
- 'leverage': leverage,
- 'marketCeiling': False,
- 'priceProtect': '0',
- 'ts': timestamp,
- 'mhash': self._generate_mhash(),
- 'mtoken': self.visitor_id
- }
-
- if price is not None:
- order_data['price'] = price
- order_data['type'] = '1'
-
- order_data['p0'] = self._encrypt_p0(order_data)
- order_data['k0'] = self._encrypt_k0(order_data)
- order_data['chash'] = self._generate_chash(order_data)
-
- return self._execute_order(order_data, 'close_long')
-
- def open_short_position(self, symbol: str, volume: float, leverage: int = 200,
- price: Optional[float] = None) -> Dict[str, Any]:
- """Open a short futures position"""
- if not self.verify_captcha(symbol, 'openshort', f'{leverage}X'):
- return {'success': False, 'error': 'Captcha verification failed'}
-
- order_data = {
- 'symbol': symbol,
- 'side': 2, # 2 = short
- 'openType': 2,
- 'type': '5',
- 'vol': volume,
- 'leverage': leverage,
- 'marketCeiling': False,
- 'priceProtect': '0',
- 'ts': int(time.time() * 1000),
- 'mhash': self._generate_mhash(),
- 'mtoken': self.visitor_id
- }
-
- if price is not None:
- order_data['price'] = price
- order_data['type'] = '1'
-
- order_data['p0'] = self._encrypt_p0(order_data)
- order_data['k0'] = self._encrypt_k0(order_data)
- order_data['chash'] = self._generate_chash(order_data)
-
- return self._execute_order(order_data, 'open_short')
-
- def close_short_position(self, symbol: str, volume: float, leverage: int = 200,
- price: Optional[float] = None) -> Dict[str, Any]:
- """Close a short futures position"""
- if not self.verify_captcha(symbol, 'closeshort', f'{leverage}X'):
- return {'success': False, 'error': 'Captcha verification failed'}
-
- order_data = {
- 'symbol': symbol,
- 'side': 1, # Close side is opposite
- 'closeType': 1,
- 'type': '5',
- 'vol': volume,
- 'leverage': leverage,
- 'marketCeiling': False,
- 'priceProtect': '0',
- 'ts': int(time.time() * 1000),
- 'mhash': self._generate_mhash(),
- 'mtoken': self.visitor_id
- }
-
- if price is not None:
- order_data['price'] = price
- order_data['type'] = '1'
-
- order_data['p0'] = self._encrypt_p0(order_data)
- order_data['k0'] = self._encrypt_k0(order_data)
- order_data['chash'] = self._generate_chash(order_data)
-
- return self._execute_order(order_data, 'close_short')
-
- def _execute_order(self, order_data: Dict[str, Any], action: str) -> Dict[str, Any]:
- """Common order execution logic"""
- timestamp = order_data['ts']
- nonce = timestamp
-
- headers = {
- 'Authorization': self.auth_token,
- 'Content-Type': 'application/json',
- 'Language': 'English',
- 'x-language': 'en-GB',
- 'x-mxc-nonce': str(nonce),
- 'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce),
- 'trochilus-uid': self.user_id,
- 'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
- 'Referer': 'https://www.mexc.com/'
- }
-
- url = f"{self.futures_api_url}/private/order/create"
-
- try:
- response = self.session.post(url, json=order_data, headers=headers, timeout=15)
-
- if response.status_code == 200:
- data = response.json()
- if data.get('success') and data.get('code') == 0:
- order_id = data.get('data', {}).get('orderId')
- logger.info(f"MEXC: {action} executed successfully - Order ID: {order_id}")
- return {
- 'success': True,
- 'order_id': order_id,
- 'timestamp': data.get('data', {}).get('ts'),
- 'action': action
- }
- else:
- logger.error(f"MEXC: {action} failed: {data}")
- return {'success': False, 'error': data.get('msg', 'Unknown error')}
- else:
- logger.error(f"MEXC: {action} request failed with status {response.status_code}")
- return {'success': False, 'error': f'HTTP {response.status_code}'}
-
- except Exception as e:
- logger.error(f"MEXC: {action} execution error: {e}")
- return {'success': False, 'error': str(e)}
-
- # Placeholder methods for encryption/hashing - these need proper implementation
- def _generate_mhash(self) -> str:
- """Generate mhash parameter (needs reverse engineering)"""
- return "a0015441fd4c3b6ba427b894b76cb7dd" # Placeholder from request dump
-
- def _encrypt_p0(self, order_data: Dict[str, Any]) -> str:
- """Encrypt p0 parameter (needs reverse engineering)"""
- return "placeholder_p0_encryption" # This needs proper implementation
-
- def _encrypt_k0(self, order_data: Dict[str, Any]) -> str:
- """Encrypt k0 parameter (needs reverse engineering)"""
- return "placeholder_k0_encryption" # This needs proper implementation
-
- def _generate_chash(self, order_data: Dict[str, Any]) -> str:
- """Generate chash parameter (needs reverse engineering)"""
- return "d6c64d28e362f314071b3f9d78ff7494d9cd7177ae0465e772d1840e9f7905d8" # Placeholder
-
- def get_account_info(self) -> Dict[str, Any]:
- """Get account information including positions and balances"""
- if not self.is_authenticated:
- return {'success': False, 'error': 'Not authenticated'}
-
- # This would need to be implemented by reverse engineering the account info endpoints
- logger.info("MEXC: Account info endpoint not yet implemented")
- return {'success': False, 'error': 'Not implemented'}
-
- def get_open_positions(self) -> List[Dict[str, Any]]:
- """Get list of open futures positions"""
- if not self.is_authenticated:
- return []
-
- # This would need to be implemented by reverse engineering the positions endpoint
- logger.info("MEXC: Open positions endpoint not yet implemented")
- return []
\ No newline at end of file
diff --git a/core/mexc_webclient/session_manager.py b/core/mexc_webclient/session_manager.py
deleted file mode 100644
index 9b3c412..0000000
--- a/core/mexc_webclient/session_manager.py
+++ /dev/null
@@ -1,259 +0,0 @@
-"""
-MEXC Session Manager
-
-Helper utilities for managing MEXC web sessions and extracting cookies from browser.
-"""
-
-import logging
-import json
-import re
-from typing import Dict, Optional, Any
-from pathlib import Path
-
-logger = logging.getLogger(__name__)
-
-class MEXCSessionManager:
- """
- Helper class for managing MEXC web sessions and extracting browser cookies
- """
-
- def __init__(self):
- self.session_file = Path("mexc_session.json")
-
- def extract_cookies_from_network_tab(self, cookie_header: str) -> Dict[str, str]:
- """
- Extract cookies from browser Network tab cookie header
-
- Args:
- cookie_header: Raw cookie string from browser (copy from Request Headers)
-
- Returns:
- Dictionary of parsed cookies
- """
- cookies = {}
-
- # Remove 'Cookie: ' prefix if present
- if cookie_header.startswith('Cookie: '):
- cookie_header = cookie_header[8:]
- elif cookie_header.startswith('cookie: '):
- cookie_header = cookie_header[8:]
-
- # Split by semicolon and parse each cookie
- cookie_pairs = cookie_header.split(';')
-
- for pair in cookie_pairs:
- pair = pair.strip()
- if '=' in pair:
- name, value = pair.split('=', 1)
- cookies[name.strip()] = value.strip()
-
- logger.info(f"Extracted {len(cookies)} cookies from browser")
- return cookies
-
- def validate_session_cookies(self, cookies: Dict[str, str]) -> bool:
- """
- Validate that essential cookies are present for authentication
-
- Args:
- cookies: Dictionary of cookie name-value pairs
-
- Returns:
- bool: True if cookies appear valid for authentication
- """
- required_cookies = [
- 'uc_token', # User authentication token
- 'u_id', # User ID
- 'x-mxc-fingerprint', # Browser fingerprint
- 'mexc_fingerprint_visitorId' # Visitor ID
- ]
-
- missing_cookies = []
- for cookie_name in required_cookies:
- if cookie_name not in cookies or not cookies[cookie_name]:
- missing_cookies.append(cookie_name)
-
- if missing_cookies:
- logger.warning(f"Missing required cookies: {missing_cookies}")
- return False
-
- logger.info("All required cookies are present")
- return True
-
- def save_session(self, cookies: Dict[str, str], metadata: Optional[Dict[str, Any]] = None):
- """
- Save session cookies to file for reuse
-
- Args:
- cookies: Dictionary of cookies to save
- metadata: Optional metadata about the session
- """
- session_data = {
- 'cookies': cookies,
- 'metadata': metadata or {},
- 'timestamp': int(time.time())
- }
-
- try:
- with open(self.session_file, 'w') as f:
- json.dump(session_data, f, indent=2)
- logger.info(f"Session saved to {self.session_file}")
- except Exception as e:
- logger.error(f"Failed to save session: {e}")
-
- def load_session(self) -> Optional[Dict[str, str]]:
- """
- Load session cookies from file
-
- Returns:
- Dictionary of cookies if successful, None otherwise
- """
- if not self.session_file.exists():
- logger.info("No saved session found")
- return None
-
- try:
- with open(self.session_file, 'r') as f:
- session_data = json.load(f)
-
- cookies = session_data.get('cookies', {})
- timestamp = session_data.get('timestamp', 0)
-
- # Check if session is too old (24 hours)
- import time
- if time.time() - timestamp > 24 * 3600:
- logger.warning("Saved session is too old (>24h), may be expired")
-
- if self.validate_session_cookies(cookies):
- logger.info("Loaded valid session from file")
- return cookies
- else:
- logger.warning("Loaded session has invalid cookies")
- return None
-
- except Exception as e:
- logger.error(f"Failed to load session: {e}")
- return None
-
- def extract_from_curl_command(self, curl_command: str) -> Dict[str, str]:
- """
- Extract cookies from a curl command copied from browser
-
- Args:
- curl_command: Complete curl command from browser "Copy as cURL"
-
- Returns:
- Dictionary of extracted cookies
- """
- cookies = {}
-
- # Find cookie header in curl command
- cookie_match = re.search(r'-H [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE)
- if not cookie_match:
- cookie_match = re.search(r'--header [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE)
-
- if cookie_match:
- cookie_header = cookie_match.group(1)
- cookies = self.extract_cookies_from_network_tab(cookie_header)
- logger.info(f"Extracted {len(cookies)} cookies from curl command")
- else:
- logger.warning("No cookie header found in curl command")
-
- return cookies
-
- def print_cookie_extraction_guide(self):
- """Print instructions for extracting cookies from browser"""
- print("\n" + "="*80)
- print("MEXC COOKIE EXTRACTION GUIDE")
- print("="*80)
- print("""
-To extract cookies from your browser for MEXC futures trading:
-
-METHOD 1: Browser Network Tab
-1. Open MEXC futures page and log in: https://www.mexc.com/en-GB/futures/ETH_USDT
-2. Open browser Developer Tools (F12)
-3. Go to Network tab
-4. Try to place a small futures trade (it will fail, but we need the request)
-5. Find the request to 'futures.mexc.com' in the Network tab
-6. Right-click on the request -> Copy -> Copy request headers
-7. Find the 'Cookie:' line and copy everything after 'Cookie: '
-
-METHOD 2: Copy as cURL
-1. Follow steps 1-5 above
-2. Right-click on the futures API request -> Copy -> Copy as cURL
-3. Paste the entire cURL command
-
-METHOD 3: Manual Cookie Extraction
-1. While logged into MEXC, press F12 -> Application/Storage tab
-2. On the left, expand 'Cookies' -> click on 'https://www.mexc.com'
-3. Copy the values for these important cookies:
- - uc_token
- - u_id
- - x-mxc-fingerprint
- - mexc_fingerprint_visitorId
-
-IMPORTANT NOTES:
-- Cookies expire after some time (usually 24 hours)
-- You must be logged into MEXC futures (not just spot trading)
-- Keep your cookies secure - they provide access to your account
-- Test with small amounts first
-
-Example usage:
- session_manager = MEXCSessionManager()
-
- # Method 1: From cookie header
- cookie_header = "uc_token=ABC123; u_id=DEF456; ..."
- cookies = session_manager.extract_cookies_from_network_tab(cookie_header)
-
- # Method 2: From cURL command
- curl_cmd = "curl 'https://futures.mexc.com/...' -H 'cookie: uc_token=ABC123...'"
- cookies = session_manager.extract_from_curl_command(curl_cmd)
-
- # Save session for reuse
- session_manager.save_session(cookies)
- """)
- print("="*80)
-
-if __name__ == "__main__":
- # When run directly, show the extraction guide
- import time
-
- manager = MEXCSessionManager()
- manager.print_cookie_extraction_guide()
-
- print("\nWould you like to:")
- print("1. Load saved session")
- print("2. Extract cookies from clipboard")
- print("3. Exit")
-
- choice = input("\nEnter choice (1-3): ").strip()
-
- if choice == "1":
- cookies = manager.load_session()
- if cookies:
- print(f"\nLoaded {len(cookies)} cookies from saved session")
- if manager.validate_session_cookies(cookies):
- print("Session appears valid for trading")
- else:
- print("Warning: Session may be incomplete or expired")
- else:
- print("No valid saved session found")
-
- elif choice == "2":
- print("\nPaste your cookie header or cURL command:")
- user_input = input().strip()
-
- if user_input.startswith('curl'):
- cookies = manager.extract_from_curl_command(user_input)
- else:
- cookies = manager.extract_cookies_from_network_tab(user_input)
-
- if cookies and manager.validate_session_cookies(cookies):
- print(f"\nSuccessfully extracted {len(cookies)} valid cookies")
- save = input("Save session for reuse? (y/n): ").strip().lower()
- if save == 'y':
- manager.save_session(cookies)
- else:
- print("Failed to extract valid cookies")
-
- else:
- print("Goodbye!")
\ No newline at end of file
diff --git a/core/mexc_webclient/test_mexc_futures_webclient.py b/core/mexc_webclient/test_mexc_futures_webclient.py
deleted file mode 100644
index f41b13b..0000000
--- a/core/mexc_webclient/test_mexc_futures_webclient.py
+++ /dev/null
@@ -1,346 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test MEXC Futures Web Client
-
-This script demonstrates how to use the MEXC Futures Web Client
-for futures trading that isn't supported by their official API.
-
-IMPORTANT: This requires extracting cookies from your browser session.
-"""
-
-import logging
-import sys
-import os
-import time
-import json
-import uuid
-
-# Add the project root to path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from mexc_futures_client import MEXCFuturesWebClient
-from session_manager import MEXCSessionManager
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-# Constants
-SYMBOL = "ETH_USDT"
-LEVERAGE = 300
-CREDENTIALS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json')
-
-# Read credentials from mexc_credentials.json in JSON format
-def load_credentials():
- credentials_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json')
- cookies = {}
- captcha_token_open = ''
- captcha_token_close = ''
- try:
- with open(credentials_file, 'r') as f:
- data = json.load(f)
- cookies = data.get('credentials', {}).get('cookies', {})
- captcha_token_open = data.get('credentials', {}).get('captcha_token_open', '')
- captcha_token_close = data.get('credentials', {}).get('captcha_token_close', '')
- logger.info(f"Loaded credentials from {credentials_file}")
- except Exception as e:
- logger.error(f"Error loading credentials: {e}")
- return cookies, captcha_token_open, captcha_token_close
-
-def test_basic_connection():
- """Test basic connection and authentication"""
- logger.info("Testing MEXC Futures Web Client")
-
- # Initialize session manager
- session_manager = MEXCSessionManager()
-
- # Try to load saved session first
- cookies = session_manager.load_session()
-
- if not cookies:
- # Explicitly load the cookies from the file we have
- cookies_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'mexc_cookies_20250703_003625.json')
- if os.path.exists(cookies_file):
- try:
- with open(cookies_file, 'r') as f:
- cookies = json.load(f)
- logger.info(f"Loaded cookies from {cookies_file}")
- except Exception as e:
- logger.error(f"Failed to load cookies from {cookies_file}: {e}")
- cookies = None
- else:
- logger.error(f"Cookies file not found at {cookies_file}")
- cookies = None
-
- if not cookies:
- print("\nNo saved session found. You need to extract cookies from your browser.")
- session_manager.print_cookie_extraction_guide()
-
- print("\nPaste your cookie header or cURL command (or press Enter to exit):")
- user_input = input().strip()
-
- if not user_input:
- print("No input provided. Exiting.")
- return False
-
- # Extract cookies from user input
- if user_input.startswith('curl'):
- cookies = session_manager.extract_from_curl_command(user_input)
- else:
- cookies = session_manager.extract_cookies_from_network_tab(user_input)
-
- if not cookies:
- logger.error("Failed to extract cookies from input")
- return False
-
- # Validate and save session
- if session_manager.validate_session_cookies(cookies):
- session_manager.save_session(cookies)
- logger.info("Session saved for future use")
- else:
- logger.warning("Extracted cookies may be incomplete")
-
- # Initialize the web client
- client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='', base_url='https://www.mexc.com', headless=True)
- # Load cookies into the client's session
- for name, value in cookies.items():
- client.session.cookies.set(name, value)
-
- # Update headers to include additional parameters from captured requests
- client.session.headers.update({
- 'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
- 'trochilus-uid': cookies.get('u_id', ''),
- 'Referer': 'https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap',
- 'Language': 'English',
- 'X-Language': 'en-GB'
- })
-
- if not client.is_authenticated:
- logger.error("Failed to authenticate with extracted cookies")
- return False
-
- logger.info("Successfully authenticated with MEXC")
- logger.info(f"User ID: {client.user_id}")
- logger.info(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "No auth token")
-
- return True
-
-def test_captcha_verification(client: MEXCFuturesWebClient):
- """Test captcha verification system"""
- logger.info("Testing captcha verification...")
-
- # Test captcha for ETH_USDT long position with 200x leverage
- success = client.verify_captcha('ETH_USDT', 'openlong', '200X')
-
- if success:
- logger.info("Captcha verification successful")
- else:
- logger.warning("Captcha verification failed - this may be normal if no position is being opened")
-
- return success
-
-def test_position_opening(client: MEXCFuturesWebClient, dry_run: bool = True):
- """Test opening a position (dry run by default)"""
- if dry_run:
- logger.info("DRY RUN: Testing position opening (no actual trade)")
- else:
- logger.warning("LIVE TRADING: Opening actual position!")
-
- symbol = 'ETH_USDT'
- volume = 1 # Small test position
- leverage = 200
-
- logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x")
-
- if not dry_run:
- result = client.open_long_position(symbol, volume, leverage)
-
- if result['success']:
- logger.info(f"Position opened successfully!")
- logger.info(f"Order ID: {result['order_id']}")
- logger.info(f"Timestamp: {result['timestamp']}")
- return True
- else:
- logger.error(f"Failed to open position: {result['error']}")
- return False
- else:
- logger.info("DRY RUN: Would attempt to open position here")
- # Test just the captcha verification part
- return client.verify_captcha(symbol, 'openlong', f'{leverage}X')
-
-def test_position_opening_live(client):
- symbol = "ETH_USDT"
- volume = 1 # Small volume for testing
- leverage = 200
-
- logger.info(f"LIVE TRADING: Opening actual position!")
- logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x")
-
- result = client.open_long_position(symbol, volume, leverage)
- if result.get('success'):
- logger.info(f"Successfully opened position: {result}")
- else:
- logger.error(f"Failed to open position: {result.get('error', 'Unknown error')}")
-
-def interactive_menu(client: MEXCFuturesWebClient):
- """Interactive menu for testing different functions"""
- while True:
- print("\n" + "="*50)
- print("MEXC Futures Web Client Test Menu")
- print("="*50)
- print("1. Test captcha verification")
- print("2. Test position opening (DRY RUN)")
- print("3. Test position opening (LIVE - BE CAREFUL!)")
- print("4. Test position closing (DRY RUN)")
- print("5. Show session info")
- print("6. Refresh session")
- print("0. Exit")
-
- choice = input("\nEnter choice (0-6): ").strip()
-
- if choice == "1":
- test_captcha_verification(client)
-
- elif choice == "2":
- test_position_opening(client, dry_run=True)
-
- elif choice == "3":
- test_position_opening_live(client)
-
- elif choice == "4":
- logger.info("DRY RUN: Position closing test")
- success = client.verify_captcha('ETH_USDT', 'closelong', '200X')
- if success:
- logger.info("DRY RUN: Would close position here")
- else:
- logger.warning("Captcha verification failed for position closing")
-
- elif choice == "5":
- print(f"\nSession Information:")
- print(f"Authenticated: {client.is_authenticated}")
- print(f"User ID: {client.user_id}")
- print(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "None")
- print(f"Fingerprint: {client.fingerprint}")
- print(f"Visitor ID: {client.visitor_id}")
-
- elif choice == "6":
- session_manager = MEXCSessionManager()
- session_manager.print_cookie_extraction_guide()
-
- elif choice == "0":
- print("Goodbye!")
- break
-
- else:
- print("Invalid choice. Please try again.")
-
-def main():
- """Main test function"""
- print("MEXC Futures Web Client Test")
- print("WARNING: This is experimental software for futures trading")
- print("Use at your own risk and test with small amounts first!")
-
- # Load cookies and tokens
- cookies, captcha_token_open, captcha_token_close = load_credentials()
- if not cookies:
- logger.error("Failed to load cookies from credentials file")
- sys.exit(1)
-
- # Initialize client with loaded cookies and tokens
- client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='')
- # Load cookies into the client's session
- for name, value in cookies.items():
- client.session.cookies.set(name, value)
- # Set captcha tokens
- client.captcha_token_open = captcha_token_open
- client.captcha_token_close = captcha_token_close
-
- # Try to load credentials from the new JSON file
- try:
- with open(CREDENTIALS_FILE, 'r') as f:
- credentials_data = json.load(f)
- cookies = credentials_data['credentials']['cookies']
- captcha_token_open = credentials_data['credentials']['captcha_token_open']
- captcha_token_close = credentials_data['credentials']['captcha_token_close']
- client.load_session_cookies(cookies)
- client.session_manager.save_captcha_token(captcha_token_open) # Assuming this is for opening
-
- except FileNotFoundError:
- logger.error(f"Credentials file not found at {CREDENTIALS_FILE}")
- return False
- except json.JSONDecodeError as e:
- logger.error(f"Error loading credentials: {e}")
- return False
- except KeyError as e:
- logger.error(f"Missing key in credentials file: {e}")
- return False
-
- if not client.is_authenticated:
- logger.error("Client not authenticated. Please ensure valid cookies and tokens are in mexc_credentials.json")
- return False
-
- # Test connection and authentication
- logger.info("Successfully authenticated with MEXC")
-
- # Set leverage
- leverage_response = client.update_leverage(symbol=SYMBOL, leverage=LEVERAGE)
- if leverage_response and leverage_response.get('code') == 200:
- logger.info(f"Leverage set to {LEVERAGE}x for {SYMBOL}")
- else:
- logger.error(f"Failed to set leverage: {leverage_response}")
- sys.exit(1)
-
- # Get current price
- ticker = client.get_ticker_data(symbol=SYMBOL)
- if ticker and ticker.get('code') == 200:
- current_price = float(ticker['data']['last'])
- logger.info(f"Current {SYMBOL} price: {current_price}")
- else:
- logger.error(f"Failed to get ticker data: {ticker}")
- sys.exit(1)
-
- # Calculate order size for a small test trade (e.g., $10 worth)
- trade_usdt = 10.0
- order_qty = round((trade_usdt / current_price) * LEVERAGE, 3)
- logger.info(f"Calculated order quantity: {order_qty} {SYMBOL} for ~${trade_usdt} at {LEVERAGE}x")
-
- # Test 1: Open LONG position
- logger.info(f"Opening LONG position for {SYMBOL} at {current_price} with qty {order_qty}")
- open_long_order = client.create_order(
- symbol=SYMBOL,
- side=1, # 1 for BUY
- position_side=1, # 1 for LONG
- order_type=1, # 1 for LIMIT
- price=current_price,
- vol=order_qty
- )
- if open_long_order and open_long_order.get('code') == 200:
- logger.info(f"✅ Successfully opened LONG position: {open_long_order['data']}")
- else:
- logger.error(f"❌ Failed to open LONG position: {open_long_order}")
- sys.exit(1)
-
- # Test 2: Close LONG position
- logger.info(f"Closing LONG position for {SYMBOL}")
- close_long_order = client.create_order(
- symbol=SYMBOL,
- side=2, # 2 for SELL
- position_side=1, # 1 for LONG
- order_type=1, # 1 for LIMIT
- price=current_price,
- vol=order_qty,
- reduce_only=True
- )
- if close_long_order and close_long_order.get('code') == 200:
- logger.info(f"✅ Successfully closed LONG position: {close_long_order['data']}")
- else:
- logger.error(f"❌ Failed to close LONG position: {close_long_order}")
- sys.exit(1)
-
- logger.info("All tests completed successfully!")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/core/negative_case_trainer.py b/core/negative_case_trainer.py
deleted file mode 100644
index 089ef0f..0000000
--- a/core/negative_case_trainer.py
+++ /dev/null
@@ -1,595 +0,0 @@
-"""
-Negative Case Trainer - Intensive Training on Losing Trades
-
-This module focuses on learning from losses to prevent future mistakes.
-Stores negative cases in testcases/negative folder for reuse and retraining.
-Supports simultaneous inference and training.
-"""
-
-import os
-import json
-import logging
-import pickle
-import threading
-import time
-from datetime import datetime, timedelta
-from typing import Dict, List, Any, Optional, Tuple
-from dataclasses import dataclass, asdict
-from collections import deque
-import numpy as np
-import pandas as pd
-
-# Import checkpoint management
-import torch
-from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
-from utils.training_integration import get_training_integration
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class NegativeCase:
- """Represents a losing trade case for intensive training"""
- case_id: str
- timestamp: datetime
- symbol: str
- action: str # 'BUY' or 'SELL'
- entry_price: float
- exit_price: float
- loss_amount: float
- loss_percentage: float
- confidence_used: float
- market_state_before: Dict[str, Any]
- market_state_after: Dict[str, Any]
- tick_data: List[Dict[str, Any]] # 15 minutes of tick data around the trade
- technical_indicators: Dict[str, float]
- what_should_have_been_done: str # 'HOLD', 'OPPOSITE', 'WAIT'
- lesson_learned: str
- training_priority: int # 1-5, 5 being highest priority
- retraining_count: int = 0
- last_retrained: Optional[datetime] = None
-
-@dataclass
-class TrainingSession:
- """Represents an intensive training session on negative cases"""
- session_id: str
- start_time: datetime
- cases_trained: List[str] # case_ids
- epochs_completed: int
- loss_improvement: float
- accuracy_improvement: float
- inference_paused: bool = False
- training_active: bool = True
-
-class NegativeCaseTrainer:
- """
- Intensive trainer focused on learning from losing trades with checkpoint management
-
- Features:
- - Stores all losing trades as negative cases
- - Intensive retraining on losses
- - Simultaneous inference and training
- - Persistent storage in testcases/negative
- - Priority-based training (bigger losses = higher priority)
- - Checkpoint management for training progress
- """
-
- def __init__(self, storage_dir: str = "testcases/negative",
- model_name: str = "negative_case_trainer", enable_checkpoints: bool = True):
- self.storage_dir = storage_dir
- self.stored_cases: List[NegativeCase] = []
- self.training_queue = deque(maxlen=1000)
- self.training_lock = threading.Lock()
- self.inference_lock = threading.Lock()
-
- # Checkpoint management
- self.model_name = model_name
- self.enable_checkpoints = enable_checkpoints
- self.training_integration = get_training_integration() if enable_checkpoints else None
- self.training_session_count = 0
- self.best_loss_reduction = 0.0
- self.checkpoint_frequency = 25 # Save checkpoint every 25 training sessions
-
- # Training configuration
- self.max_concurrent_training = 3 # Max parallel training sessions
- self.intensive_training_epochs = 50 # Epochs per negative case
- self.priority_multiplier = 2.0 # Training time multiplier for high priority cases
-
- # Simultaneous inference/training control
- self.inference_active = True
- self.training_active = False
- self.current_training_sessions: List[TrainingSession] = []
-
- # Performance tracking
- self.total_cases_processed = 0
- self.total_training_time = 0.0
- self.accuracy_improvements = []
-
- # Initialize storage
- self._initialize_storage()
- self._load_existing_cases()
-
- # Load best checkpoint if available
- if self.enable_checkpoints:
- self.load_best_checkpoint()
-
- # Start background training thread
- self.training_thread = threading.Thread(target=self._background_training_loop, daemon=True)
- self.training_thread.start()
-
- logger.info(f"NegativeCaseTrainer initialized with {len(self.stored_cases)} existing cases")
- logger.info(f"Storage directory: {self.storage_dir}")
- logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
- logger.info("Background training thread started")
-
- def _initialize_storage(self):
- """Initialize storage directories"""
- try:
- os.makedirs(self.storage_dir, exist_ok=True)
- os.makedirs(f"{self.storage_dir}/cases", exist_ok=True)
- os.makedirs(f"{self.storage_dir}/sessions", exist_ok=True)
- os.makedirs(f"{self.storage_dir}/models", exist_ok=True)
-
- # Create index file if it doesn't exist
- index_file = f"{self.storage_dir}/case_index.json"
- if not os.path.exists(index_file):
- with open(index_file, 'w') as f:
- json.dump({"cases": [], "last_updated": datetime.now().isoformat()}, f)
-
- logger.info(f"Storage initialized at {self.storage_dir}")
-
- except Exception as e:
- logger.error(f"Error initializing storage: {e}")
-
- def _load_existing_cases(self):
- """Load existing negative cases from storage"""
- try:
- index_file = f"{self.storage_dir}/case_index.json"
- if os.path.exists(index_file):
- with open(index_file, 'r') as f:
- index_data = json.load(f)
-
- for case_info in index_data.get("cases", []):
- case_file = f"{self.storage_dir}/cases/{case_info['case_id']}.pkl"
- if os.path.exists(case_file):
- try:
- with open(case_file, 'rb') as f:
- case = pickle.load(f)
- self.stored_cases.append(case)
- except Exception as e:
- logger.warning(f"Error loading case {case_info['case_id']}: {e}")
-
- logger.info(f"Loaded {len(self.stored_cases)} existing negative cases")
-
- except Exception as e:
- logger.error(f"Error loading existing cases: {e}")
-
- def add_losing_trade(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
- """
- Add a losing trade as a negative case for intensive training
-
- Args:
- trade_info: Trade information including P&L
- market_data: Market state and tick data around the trade
-
- Returns:
- case_id: Unique identifier for the negative case
- """
- try:
- # Generate unique case ID
- case_id = f"loss_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{trade_info['symbol'].replace('/', '')}"
-
- # Calculate loss metrics
- loss_amount = abs(trade_info.get('pnl', 0))
- loss_percentage = (loss_amount / trade_info.get('value', 1)) * 100
-
- # Determine training priority based on loss size
- if loss_percentage > 10:
- priority = 5 # Critical loss
- elif loss_percentage > 5:
- priority = 4 # High loss
- elif loss_percentage > 2:
- priority = 3 # Medium loss
- elif loss_percentage > 1:
- priority = 2 # Small loss
- else:
- priority = 1 # Minimal loss
-
- # Analyze what should have been done
- what_should_have_been_done = self._analyze_optimal_action(trade_info, market_data)
- lesson_learned = self._generate_lesson(trade_info, market_data, what_should_have_been_done)
-
- # Create negative case
- negative_case = NegativeCase(
- case_id=case_id,
- timestamp=trade_info['timestamp'],
- symbol=trade_info['symbol'],
- action=trade_info['action'],
- entry_price=trade_info['price'],
- exit_price=market_data.get('exit_price', trade_info['price']),
- loss_amount=loss_amount,
- loss_percentage=loss_percentage,
- confidence_used=trade_info.get('confidence', 0.5),
- market_state_before=market_data.get('state_before', {}),
- market_state_after=market_data.get('state_after', {}),
- tick_data=market_data.get('tick_data', []),
- technical_indicators=market_data.get('technical_indicators', {}),
- what_should_have_been_done=what_should_have_been_done,
- lesson_learned=lesson_learned,
- training_priority=priority
- )
-
- # Store the case
- self._store_case(negative_case)
-
- # Add to training queue with priority
- with self.training_lock:
- self.training_queue.append(negative_case)
- self.stored_cases.append(negative_case)
-
- logger.error(f"NEGATIVE CASE ADDED: {case_id} | Loss: ${loss_amount:.2f} ({loss_percentage:.1f}%) | Priority: {priority}")
- logger.error(f"Lesson: {lesson_learned}")
-
- return case_id
-
- except Exception as e:
- logger.error(f"Error adding losing trade: {e}")
- return ""
-
- def _analyze_optimal_action(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
- """Analyze what the optimal action should have been"""
- try:
- # Simple analysis based on price movement
- entry_price = trade_info['price']
- exit_price = market_data.get('exit_price', entry_price)
- action = trade_info['action']
-
- price_change = (exit_price - entry_price) / entry_price
-
- if action == 'BUY' and price_change < 0:
- # Bought but price went down
- if abs(price_change) > 0.005: # >0.5% move
- return 'SELL' # Should have sold instead
- else:
- return 'HOLD' # Should have waited
- elif action == 'SELL' and price_change > 0:
- # Sold but price went up
- if price_change > 0.005: # >0.5% move
- return 'BUY' # Should have bought instead
- else:
- return 'HOLD' # Should have waited
- else:
- return 'HOLD' # Should have done nothing
-
- except Exception as e:
- logger.error(f"Error analyzing optimal action: {e}")
- return 'HOLD'
-
- def _generate_lesson(self, trade_info: Dict[str, Any], market_data: Dict[str, Any], optimal_action: str) -> str:
- """Generate a lesson learned from the losing trade"""
- try:
- action = trade_info['action']
- symbol = trade_info['symbol']
- loss_pct = (abs(trade_info.get('pnl', 0)) / trade_info.get('value', 1)) * 100
- confidence = trade_info.get('confidence', 0.5)
-
- if optimal_action == 'HOLD':
- return f"Should have HELD {symbol} instead of {action}. Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss."
- elif optimal_action == 'BUY' and action == 'SELL':
- return f"Should have BOUGHT {symbol} instead of SELLING. Market moved opposite to prediction."
- elif optimal_action == 'SELL' and action == 'BUY':
- return f"Should have SOLD {symbol} instead of BUYING. Market moved opposite to prediction."
- else:
- return f"Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss on {action} {symbol}."
-
- except Exception as e:
- logger.error(f"Error generating lesson: {e}")
- return "Learn from this loss to improve future decisions."
-
- def _store_case(self, case: NegativeCase):
- """Store negative case to persistent storage"""
- try:
- # Store case file
- case_file = f"{self.storage_dir}/cases/{case.case_id}.pkl"
- with open(case_file, 'wb') as f:
- pickle.dump(case, f)
-
- # Update index
- index_file = f"{self.storage_dir}/case_index.json"
- with open(index_file, 'r') as f:
- index_data = json.load(f)
-
- # Add case to index
- case_info = {
- 'case_id': case.case_id,
- 'timestamp': case.timestamp.isoformat(),
- 'symbol': case.symbol,
- 'loss_amount': case.loss_amount,
- 'loss_percentage': case.loss_percentage,
- 'training_priority': case.training_priority,
- 'retraining_count': case.retraining_count
- }
-
- index_data['cases'].append(case_info)
- index_data['last_updated'] = datetime.now().isoformat()
-
- with open(index_file, 'w') as f:
- json.dump(index_data, f, indent=2)
-
- logger.info(f"Stored negative case: {case.case_id}")
-
- except Exception as e:
- logger.error(f"Error storing case: {e}")
-
- def _background_training_loop(self):
- """Background loop for intensive training on negative cases"""
- logger.info("Background training loop started")
-
- while True:
- try:
- # Check if we have cases to train on
- with self.training_lock:
- if not self.training_queue:
- time.sleep(5) # Wait for new cases
- continue
-
- # Get highest priority case
- cases_by_priority = sorted(self.training_queue, key=lambda x: x.training_priority, reverse=True)
- case_to_train = cases_by_priority[0]
- self.training_queue.remove(case_to_train)
-
- # Start intensive training session
- self._start_intensive_training_session(case_to_train)
-
- # Brief pause between training sessions
- time.sleep(2)
-
- except Exception as e:
- logger.error(f"Error in background training loop: {e}")
- time.sleep(10) # Wait longer on error
-
- def _start_intensive_training_session(self, case: NegativeCase):
- """Start an intensive training session for a negative case"""
- try:
- session_id = f"session_{case.case_id}_{int(time.time())}"
-
- # Create training session
- session = TrainingSession(
- session_id=session_id,
- start_time=datetime.now(),
- cases_trained=[case.case_id],
- epochs_completed=0,
- loss_improvement=0.0,
- accuracy_improvement=0.0
- )
-
- self.current_training_sessions.append(session)
- self.training_active = True
-
- logger.warning(f"INTENSIVE TRAINING STARTED: {session_id}")
- logger.warning(f"Training on loss case: {case.case_id} (Priority: {case.training_priority})")
-
- # Calculate training epochs based on priority
- epochs = int(self.intensive_training_epochs * case.training_priority * self.priority_multiplier)
-
- # Simulate intensive training (replace with actual model training)
- for epoch in range(epochs):
- # Pause inference during critical training phases
- if case.training_priority >= 4 and epoch % 10 == 0:
- with self.inference_lock:
- session.inference_paused = True
- time.sleep(0.1) # Brief pause for critical training
- session.inference_paused = False
-
- # Simulate training step
- session.epochs_completed = epoch + 1
-
- # Log progress for high priority cases
- if case.training_priority >= 4 and epoch % 10 == 0:
- logger.warning(f"Intensive training progress: {epoch}/{epochs} epochs ({case.case_id})")
-
- time.sleep(0.05) # Simulate training time
-
- # Update case retraining info
- case.retraining_count += 1
- case.last_retrained = datetime.now()
-
- # Calculate improvements (simulated)
- session.loss_improvement = np.random.uniform(0.1, 0.5) # 10-50% improvement
- session.accuracy_improvement = np.random.uniform(0.05, 0.2) # 5-20% improvement
-
- # Store training session results
- self._store_training_session(session)
-
- # Update statistics
- self.total_cases_processed += 1
- self.total_training_time += (datetime.now() - session.start_time).total_seconds()
- self.accuracy_improvements.append(session.accuracy_improvement)
-
- # Remove from active sessions
- self.current_training_sessions.remove(session)
- if not self.current_training_sessions:
- self.training_active = False
-
- logger.warning(f"INTENSIVE TRAINING COMPLETED: {session_id}")
- logger.warning(f"Epochs: {session.epochs_completed} | Loss improvement: {session.loss_improvement:.1%} | Accuracy improvement: {session.accuracy_improvement:.1%}")
-
- except Exception as e:
- logger.error(f"Error in intensive training session: {e}")
-
- def _store_training_session(self, session: TrainingSession):
- """Store training session results"""
- try:
- session_file = f"{self.storage_dir}/sessions/{session.session_id}.json"
- session_data = {
- 'session_id': session.session_id,
- 'start_time': session.start_time.isoformat(),
- 'end_time': datetime.now().isoformat(),
- 'cases_trained': session.cases_trained,
- 'epochs_completed': session.epochs_completed,
- 'loss_improvement': session.loss_improvement,
- 'accuracy_improvement': session.accuracy_improvement
- }
-
- with open(session_file, 'w') as f:
- json.dump(session_data, f, indent=2)
-
- except Exception as e:
- logger.error(f"Error storing training session: {e}")
-
- def can_inference_proceed(self) -> bool:
- """Check if inference can proceed (not blocked by critical training)"""
- with self.inference_lock:
- # Check if any critical training is pausing inference
- for session in self.current_training_sessions:
- if session.inference_paused:
- return False
- return True
-
- def get_training_stats(self) -> Dict[str, Any]:
- """Get training statistics"""
- try:
- avg_accuracy_improvement = np.mean(self.accuracy_improvements) if self.accuracy_improvements else 0.0
-
- return {
- 'total_negative_cases': len(self.stored_cases),
- 'cases_in_queue': len(self.training_queue),
- 'total_cases_processed': self.total_cases_processed,
- 'total_training_time': self.total_training_time,
- 'avg_accuracy_improvement': avg_accuracy_improvement,
- 'active_training_sessions': len(self.current_training_sessions),
- 'training_active': self.training_active,
- 'high_priority_cases': len([c for c in self.stored_cases if c.training_priority >= 4]),
- 'storage_directory': self.storage_dir
- }
-
- except Exception as e:
- logger.error(f"Error getting training stats: {e}")
- return {}
-
- def get_recent_lessons(self, count: int = 5) -> List[str]:
- """Get recent lessons learned from negative cases"""
- try:
- recent_cases = sorted(self.stored_cases, key=lambda x: x.timestamp, reverse=True)[:count]
- return [case.lesson_learned for case in recent_cases]
- except Exception as e:
- logger.error(f"Error getting recent lessons: {e}")
- return []
-
- def retrain_all_cases(self):
- """Retrain all stored negative cases (for periodic retraining)"""
- try:
- logger.warning("RETRAINING ALL NEGATIVE CASES - This may take a while...")
-
- with self.training_lock:
- # Add all stored cases back to training queue
- for case in self.stored_cases:
- if case not in self.training_queue:
- self.training_queue.append(case)
-
- logger.warning(f"Added {len(self.stored_cases)} cases to retraining queue")
-
- except Exception as e:
- logger.error(f"Error retraining all cases: {e}")
-
- def load_best_checkpoint(self):
- """Load the best checkpoint for this negative case trainer"""
- try:
- if not self.enable_checkpoints:
- return
-
- result = load_best_checkpoint(self.model_name)
- if result:
- file_path, metadata = result
- checkpoint = torch.load(file_path, map_location='cpu')
-
- # Load training state
- if 'training_session_count' in checkpoint:
- self.training_session_count = checkpoint['training_session_count']
- if 'best_loss_reduction' in checkpoint:
- self.best_loss_reduction = checkpoint['best_loss_reduction']
- if 'total_cases_processed' in checkpoint:
- self.total_cases_processed = checkpoint['total_cases_processed']
- if 'total_training_time' in checkpoint:
- self.total_training_time = checkpoint['total_training_time']
- if 'accuracy_improvements' in checkpoint:
- self.accuracy_improvements = checkpoint['accuracy_improvements']
-
- logger.info(f"Loaded NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
- logger.info(f"Session: {self.training_session_count}, Best loss reduction: {self.best_loss_reduction:.4f}")
-
- except Exception as e:
- logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
-
- def save_checkpoint(self, loss_improvement: float = 0.0, force_save: bool = False):
- """Save checkpoint if performance improved or forced"""
- try:
- if not self.enable_checkpoints:
- return False
-
- self.training_session_count += 1
-
- # Update best loss reduction
- improved = False
- if loss_improvement > self.best_loss_reduction:
- self.best_loss_reduction = loss_improvement
- improved = True
-
- # Save checkpoint if improved, forced, or at regular intervals
- should_save = (
- force_save or
- improved or
- self.training_session_count % self.checkpoint_frequency == 0
- )
-
- if should_save:
- # Prepare checkpoint data
- checkpoint_data = {
- 'training_session_count': self.training_session_count,
- 'best_loss_reduction': self.best_loss_reduction,
- 'total_cases_processed': self.total_cases_processed,
- 'total_training_time': self.total_training_time,
- 'accuracy_improvements': self.accuracy_improvements,
- 'storage_dir': self.storage_dir,
- 'max_concurrent_training': self.max_concurrent_training,
- 'intensive_training_epochs': self.intensive_training_epochs
- }
-
- # Create performance metrics for checkpoint manager
- avg_accuracy_improvement = (
- sum(self.accuracy_improvements) / len(self.accuracy_improvements)
- if self.accuracy_improvements else 0.0
- )
-
- performance_metrics = {
- 'loss_reduction': self.best_loss_reduction,
- 'avg_accuracy_improvement': avg_accuracy_improvement,
- 'total_cases_processed': self.total_cases_processed,
- 'training_efficiency': (
- self.total_cases_processed / self.total_training_time
- if self.total_training_time > 0 else 0.0
- )
- }
-
- # Save using checkpoint manager
- metadata = save_checkpoint(
- model=checkpoint_data, # We're saving data dict instead of model
- model_name=self.model_name,
- model_type="negative_case_trainer",
- performance_metrics=performance_metrics,
- training_metadata={
- 'session': self.training_session_count,
- 'cases_processed': self.total_cases_processed,
- 'training_time_hours': self.total_training_time / 3600
- },
- force_save=force_save
- )
-
- if metadata:
- logger.info(f"Saved NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
- return True
-
- return False
-
- except Exception as e:
- logger.error(f"Error saving NegativeCaseTrainer checkpoint: {e}")
- return False
\ No newline at end of file
diff --git a/core/nn_decision_fusion.py b/core/nn_decision_fusion.py
deleted file mode 100644
index c6d9a4e..0000000
--- a/core/nn_decision_fusion.py
+++ /dev/null
@@ -1,277 +0,0 @@
-#!/usr/bin/env python3
-"""
-Neural Network Decision Fusion System
-Central NN that merges all model outputs + market data for final trading decisions
-"""
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import numpy as np
-from typing import Dict, List, Optional, Any
-from dataclasses import dataclass
-from datetime import datetime
-import logging
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class ModelPrediction:
- """Standardized prediction from any model"""
- model_name: str
- prediction_type: str # 'price', 'direction', 'action'
- value: float # -1 to 1 for direction, actual price for price predictions
- confidence: float # 0 to 1
- timestamp: datetime
- metadata: Optional[Dict[str, Any]] = None
-
-@dataclass
-class MarketContext:
- """Current market context for decision fusion"""
- symbol: str
- current_price: float
- price_change_1m: float
- price_change_5m: float
- volume_ratio: float
- volatility: float
- timestamp: datetime
-
-@dataclass
-class FusionDecision:
- """Final trading decision from fusion NN"""
- action: str # 'BUY', 'SELL', 'HOLD'
- confidence: float # 0 to 1
- expected_return: float # Expected return percentage
- risk_score: float # 0 to 1, higher = riskier
- position_size: float # Recommended position size
- reasoning: str # Human-readable explanation
- model_contributions: Dict[str, float] # How much each model contributed
- timestamp: datetime
-
-class DecisionFusionNetwork(nn.Module):
- """Small NN that fuses model predictions with market context"""
-
- def __init__(self, input_dim: int = 32, hidden_dim: int = 64):
- super().__init__()
-
- self.fusion_layers = nn.Sequential(
- nn.Linear(input_dim, hidden_dim),
- nn.ReLU(),
- nn.Dropout(0.2),
- nn.Linear(hidden_dim, hidden_dim // 2),
- nn.ReLU(),
- nn.Linear(hidden_dim // 2, 16)
- )
-
- # Output heads
- self.action_head = nn.Linear(16, 3) # BUY, SELL, HOLD
- self.confidence_head = nn.Linear(16, 1)
- self.return_head = nn.Linear(16, 1)
-
- def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]:
- """Forward pass through fusion network"""
- fusion_output = self.fusion_layers(features)
-
- action_logits = self.action_head(fusion_output)
- action_probs = F.softmax(action_logits, dim=1)
-
- confidence = torch.sigmoid(self.confidence_head(fusion_output))
- expected_return = torch.tanh(self.return_head(fusion_output))
-
- return {
- 'action_probs': action_probs,
- 'confidence': confidence.squeeze(),
- 'expected_return': expected_return.squeeze()
- }
-
-class NeuralDecisionFusion:
- """Main NN-based decision fusion system"""
-
- def __init__(self, training_mode: bool = True):
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.network = DecisionFusionNetwork().to(self.device)
- self.training_mode = training_mode
- self.registered_models = {}
- self.last_predictions = {}
-
- logger.info(f"Neural Decision Fusion initialized on {self.device}")
-
- def register_model(self, model_name: str, model_type: str, prediction_format: str):
- """Register a model that will provide predictions"""
- self.registered_models[model_name] = {
- 'type': model_type,
- 'format': prediction_format,
- 'prediction_count': 0
- }
- logger.info(f"Registered NN model: {model_name} ({model_type})")
-
- def add_prediction(self, prediction: ModelPrediction):
- """Add a prediction from a registered model"""
- self.last_predictions[prediction.model_name] = prediction
- if prediction.model_name in self.registered_models:
- self.registered_models[prediction.model_name]['prediction_count'] += 1
-
- logger.debug(f"🔮 {prediction.model_name}: {prediction.value:.3f} "
- f"(confidence: {prediction.confidence:.3f})")
-
- def make_decision(self, symbol: str, market_context: MarketContext,
- min_confidence: float = 0.25) -> Optional[FusionDecision]:
- """Make NN-driven trading decision"""
- try:
- if len(self.last_predictions) < 1:
- logger.debug("No NN predictions available")
- return None
-
- # Prepare features
- features = self._prepare_features(market_context)
- if features is None:
- return None
-
- # Run NN inference
- with torch.no_grad():
- self.network.eval()
- features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.device)
- outputs = self.network(features_tensor)
-
- action_probs = outputs['action_probs'][0].cpu().numpy()
- confidence = outputs['confidence'].cpu().item()
- expected_return = outputs['expected_return'].cpu().item()
-
- # Determine action
- action_idx = np.argmax(action_probs)
- actions = ['BUY', 'SELL', 'HOLD']
- action = actions[action_idx]
-
- # Check confidence threshold
- if confidence < min_confidence:
- action = 'HOLD'
- logger.debug(f"Low NN confidence ({confidence:.3f}), defaulting to HOLD")
-
- # Calculate position size
- position_size = self._calculate_position_size(confidence, expected_return)
-
- # Generate reasoning
- reasoning = self._generate_reasoning(action, confidence, expected_return, action_probs)
-
- # Calculate risk score and model contributions
- risk_score = min(1.0, abs(expected_return) * 5 + (1 - confidence) * 0.5)
- model_contributions = self._calculate_model_contributions()
-
- decision = FusionDecision(
- action=action,
- confidence=confidence,
- expected_return=expected_return,
- risk_score=risk_score,
- position_size=position_size,
- reasoning=reasoning,
- model_contributions=model_contributions,
- timestamp=datetime.now()
- )
-
- logger.info(f"🧠 NN DECISION: {action} (conf: {confidence:.3f}, "
- f"return: {expected_return:.3f}, size: {position_size:.4f})")
-
- return decision
-
- except Exception as e:
- logger.error(f"Error in NN decision making: {e}")
- return None
-
- def _prepare_features(self, context: MarketContext) -> Optional[np.ndarray]:
- """Prepare feature vector for NN"""
- try:
- features = np.zeros(32)
-
- # Model predictions (slots 0-15)
- idx = 0
- for model_name, prediction in self.last_predictions.items():
- if idx < 14: # Leave room for other features
- features[idx] = prediction.value
- features[idx + 1] = prediction.confidence
- idx += 2
-
- # Market context (slots 16-31)
- features[16] = np.tanh(context.price_change_1m * 100) # 1m change
- features[17] = np.tanh(context.price_change_5m * 100) # 5m change
- features[18] = np.tanh(context.volume_ratio - 1) # Volume ratio
- features[19] = np.tanh(context.volatility * 100) # Volatility
- features[20] = context.current_price / 10000.0 # Normalized price
-
- # Time features
- now = context.timestamp
- features[21] = now.hour / 24.0
- features[22] = now.weekday() / 7.0
-
- # Model agreement features
- if len(self.last_predictions) >= 2:
- values = [p.value for p in self.last_predictions.values()]
- features[23] = np.mean(values) # Average prediction
- features[24] = np.std(values) # Prediction variance
- features[25] = len(self.last_predictions) # Model count
-
- return features
-
- except Exception as e:
- logger.error(f"Error preparing NN features: {e}")
- return None
-
- def _calculate_position_size(self, confidence: float, expected_return: float) -> float:
- """Calculate position size based on NN outputs"""
- base_size = 0.01 # 0.01 ETH base
-
- # Scale by confidence
- confidence_multiplier = max(0.1, min(2.0, confidence * 1.5))
-
- # Scale by expected return
- return_multiplier = 1.0 + abs(expected_return) * 0.5
-
- final_size = base_size * confidence_multiplier * return_multiplier
- return max(0.001, min(0.05, final_size))
-
- def _generate_reasoning(self, action: str, confidence: float,
- expected_return: float, action_probs: np.ndarray) -> str:
- """Generate human-readable reasoning"""
- reasons = []
-
- if action == 'BUY':
- reasons.append(f"NN suggests BUY ({action_probs[0]:.1%})")
- elif action == 'SELL':
- reasons.append(f"NN suggests SELL ({action_probs[1]:.1%})")
- else:
- reasons.append(f"NN suggests HOLD")
-
- if confidence > 0.7:
- reasons.append("High confidence")
- elif confidence > 0.5:
- reasons.append("Moderate confidence")
- else:
- reasons.append("Low confidence")
-
- if abs(expected_return) > 0.01:
- direction = "positive" if expected_return > 0 else "negative"
- reasons.append(f"Expected {direction} return: {expected_return:.2%}")
-
- reasons.append(f"Based on {len(self.last_predictions)} NN models")
-
- return " | ".join(reasons)
-
- def _calculate_model_contributions(self) -> Dict[str, float]:
- """Calculate how much each model contributed to the decision"""
- contributions = {}
- total_confidence = sum(p.confidence for p in self.last_predictions.values()) if self.last_predictions else 1.0
-
- if total_confidence > 0:
- for model_name, prediction in self.last_predictions.items():
- contributions[model_name] = prediction.confidence / total_confidence
-
- return contributions
-
- def get_status(self) -> Dict[str, Any]:
- """Get NN fusion system status"""
- return {
- 'device': str(self.device),
- 'training_mode': self.training_mode,
- 'registered_models': len(self.registered_models),
- 'recent_predictions': len(self.last_predictions),
- 'model_parameters': sum(p.numel() for p in self.network.parameters())
- }
\ No newline at end of file
diff --git a/core/prediction_tracker.py b/core/prediction_tracker.py
deleted file mode 100644
index 6de008e..0000000
Binary files a/core/prediction_tracker.py and /dev/null differ
diff --git a/core/realtime_tick_processor.py b/core/realtime_tick_processor.py
deleted file mode 100644
index 904a5fc..0000000
--- a/core/realtime_tick_processor.py
+++ /dev/null
@@ -1,649 +0,0 @@
-"""
-Real-Time Tick Processing Neural Network Module
-
-This module acts as a Neural Network DPS (Data Processing System) alternative,
-processing raw tick data with ultra-low latency and feeding processed features
-to trading models in real-time.
-
-Features:
-- Real-time tick ingestion with volume processing
-- Neural network feature extraction from tick streams
-- Ultra-low latency processing (sub-millisecond)
-- Volume-weighted price analysis
-- Microstructure pattern detection
-- Real-time feature streaming to models
-"""
-
-import asyncio
-import logging
-import time
-import numpy as np
-import pandas as pd
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from datetime import datetime, timedelta
-from typing import Dict, List, Optional, Tuple, Any, Deque
-from collections import deque
-from threading import Thread, Lock
-import websockets
-import json
-from dataclasses import dataclass
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class TickData:
- """Raw tick data structure"""
- timestamp: datetime
- price: float
- volume: float
- side: str # 'buy' or 'sell'
- trade_id: Optional[str] = None
-
-@dataclass
-class ProcessedTickFeatures:
- """Processed tick features for model consumption"""
- timestamp: datetime
- price_features: np.ndarray # Price-based features
- volume_features: np.ndarray # Volume-based features
- microstructure_features: np.ndarray # Market microstructure features
- neural_features: np.ndarray # Neural network extracted features
- confidence: float # Feature quality confidence
-
-class TickProcessingNN(nn.Module):
- """
- Neural Network for real-time tick processing
- Extracts high-level features from raw tick data
- """
-
- def __init__(self, input_size: int = 9, hidden_size: int = 128, output_size: int = 64):
- super(TickProcessingNN, self).__init__()
-
- # Tick sequence processing layers
- self.tick_encoder = nn.Sequential(
- nn.Linear(input_size, hidden_size),
- nn.ReLU(),
- nn.Dropout(0.1),
- nn.Linear(hidden_size, hidden_size),
- nn.ReLU(),
- nn.Dropout(0.1)
- )
-
- # LSTM for temporal patterns
- self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, num_layers=2)
-
- # Attention mechanism for important tick selection
- self.attention = nn.MultiheadAttention(hidden_size, num_heads=8, batch_first=True)
-
- # Feature extraction heads
- self.price_head = nn.Linear(hidden_size, 16) # Price pattern features
- self.volume_head = nn.Linear(hidden_size, 16) # Volume pattern features
- self.microstructure_head = nn.Linear(hidden_size, 16) # Microstructure features
-
- # Final feature fusion
- self.feature_fusion = nn.Sequential(
- nn.Linear(48, output_size), # 16+16+16 = 48
- nn.ReLU(),
- nn.Linear(output_size, output_size)
- )
-
- # Confidence estimation
- self.confidence_head = nn.Sequential(
- nn.Linear(output_size, 32),
- nn.ReLU(),
- nn.Linear(32, 1),
- nn.Sigmoid()
- )
-
- def forward(self, tick_sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Process tick sequence and extract features
-
- Args:
- tick_sequence: [batch, sequence_length, features]
-
- Returns:
- features: [batch, output_size] - extracted features
- confidence: [batch, 1] - feature confidence
- """
- batch_size, seq_len, _ = tick_sequence.shape
-
- # Encode each tick
- encoded = self.tick_encoder(tick_sequence) # [batch, seq_len, hidden_size]
-
- # LSTM processing for temporal patterns
- lstm_out, _ = self.lstm(encoded) # [batch, seq_len, hidden_size]
-
- # Attention to focus on important ticks
- attended, _ = self.attention(lstm_out, lstm_out, lstm_out) # [batch, seq_len, hidden_size]
-
- # Use the last attended output
- final_features = attended[:, -1, :] # [batch, hidden_size]
-
- # Extract specialized features
- price_features = self.price_head(final_features)
- volume_features = self.volume_head(final_features)
- microstructure_features = self.microstructure_head(final_features)
-
- # Fuse all features
- combined_features = torch.cat([price_features, volume_features, microstructure_features], dim=1)
- final_features = self.feature_fusion(combined_features)
-
- # Estimate confidence
- confidence = self.confidence_head(final_features)
-
- return final_features, confidence
-
-class RealTimeTickProcessor:
- """
- Real-time tick processing system with neural network feature extraction
- Acts as a DPS alternative for ultra-low latency tick processing
- """
-
- def __init__(self, symbols: List[str] = None, tick_buffer_size: int = 1000):
- """Initialize the real-time tick processor"""
- self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
- self.tick_buffer_size = tick_buffer_size
-
- # Tick storage buffers
- self.tick_buffers: Dict[str, Deque[TickData]] = {}
- self.processed_features: Dict[str, Deque[ProcessedTickFeatures]] = {}
-
- # Initialize buffers for each symbol
- for symbol in self.symbols:
- self.tick_buffers[symbol] = deque(maxlen=tick_buffer_size)
- self.processed_features[symbol] = deque(maxlen=100) # Keep last 100 processed features
-
- # Neural network for feature extraction
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- self.tick_nn = TickProcessingNN(input_size=9).to(self.device)
- self.tick_nn.eval() # Start in evaluation mode
-
- # Processing parameters
- self.processing_window = 50 # Number of ticks to process at once
- self.min_ticks_for_processing = 10 # Minimum ticks before processing
-
- # Real-time streaming
- self.streaming = False
- self.websocket_tasks = {}
- self.processing_threads = {}
-
- # Performance tracking
- self.processing_times = deque(maxlen=1000)
- self.tick_counts = {symbol: 0 for symbol in self.symbols}
-
- # Thread safety
- self.data_lock = Lock()
-
- # Feature subscribers (models that want real-time features)
- self.feature_subscribers = []
-
- logger.info(f"RealTimeTickProcessor initialized for symbols: {self.symbols}")
- logger.info(f"Neural network device: {self.device}")
- logger.info(f"Tick buffer size: {tick_buffer_size}")
-
- def add_feature_subscriber(self, callback):
- """Add a callback function to receive processed features"""
- self.feature_subscribers.append(callback)
- logger.info(f"Added feature subscriber: {callback.__name__}")
-
- def remove_feature_subscriber(self, callback):
- """Remove a feature subscriber"""
- if callback in self.feature_subscribers:
- self.feature_subscribers.remove(callback)
- logger.info(f"Removed feature subscriber: {callback.__name__}")
-
- async def start_processing(self):
- """Start real-time tick processing"""
- logger.info("Starting real-time tick processing...")
- self.streaming = True
-
- # Start WebSocket streams for each symbol
- for symbol in self.symbols:
- task = asyncio.create_task(self._websocket_stream(symbol))
- self.websocket_tasks[symbol] = task
-
- # Start processing thread for each symbol
- thread = Thread(target=self._processing_loop, args=(symbol,), daemon=True)
- thread.start()
- self.processing_threads[symbol] = thread
-
- logger.info("Real-time tick processing started")
-
- async def stop_processing(self):
- """Stop real-time tick processing"""
- logger.info("Stopping real-time tick processing...")
- self.streaming = False
-
- # Cancel WebSocket tasks
- for symbol, task in self.websocket_tasks.items():
- if not task.done():
- task.cancel()
- try:
- await task
- except asyncio.CancelledError:
- pass
-
- self.websocket_tasks.clear()
- logger.info("Real-time tick processing stopped")
-
- async def _websocket_stream(self, symbol: str):
- """WebSocket stream for real-time tick data"""
- binance_symbol = symbol.replace('/', '').lower()
- url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade"
-
- while self.streaming:
- try:
- async with websockets.connect(url) as websocket:
- logger.info(f"Tick WebSocket connected for {symbol}")
-
- async for message in websocket:
- if not self.streaming:
- break
-
- try:
- data = json.loads(message)
- await self._process_raw_tick(symbol, data)
- except Exception as e:
- logger.warning(f"Error processing tick for {symbol}: {e}")
-
- except Exception as e:
- logger.error(f"WebSocket error for {symbol}: {e}")
- if self.streaming:
- logger.info(f"Reconnecting tick WebSocket for {symbol} in 2 seconds...")
- await asyncio.sleep(2)
-
- async def _process_raw_tick(self, symbol: str, raw_data: Dict):
- """Process raw tick data from WebSocket"""
- try:
- # Extract tick information
- tick = TickData(
- timestamp=datetime.fromtimestamp(int(raw_data['T']) / 1000),
- price=float(raw_data['p']),
- volume=float(raw_data['q']),
- side='buy' if raw_data['m'] == False else 'sell', # m=true means buyer is market maker (sell)
- trade_id=raw_data.get('t')
- )
-
- # Add to buffer
- with self.data_lock:
- self.tick_buffers[symbol].append(tick)
- self.tick_counts[symbol] += 1
-
- except Exception as e:
- logger.error(f"Error processing raw tick for {symbol}: {e}")
-
- def _processing_loop(self, symbol: str):
- """Main processing loop for a symbol"""
- logger.info(f"Starting processing loop for {symbol}")
-
- while self.streaming:
- try:
- # Check if we have enough ticks to process
- with self.data_lock:
- tick_count = len(self.tick_buffers[symbol])
-
- if tick_count >= self.min_ticks_for_processing:
- start_time = time.time()
-
- # Process ticks
- features = self._extract_neural_features(symbol)
-
- if features is not None:
- # Store processed features
- with self.data_lock:
- self.processed_features[symbol].append(features)
-
- # Notify subscribers
- self._notify_feature_subscribers(symbol, features)
-
- # Track processing time
- processing_time = (time.time() - start_time) * 1000 # Convert to ms
- self.processing_times.append(processing_time)
-
- if len(self.processing_times) % 100 == 0:
- avg_time = np.mean(list(self.processing_times))
- logger.debug(f"RTP: Average processing time: {avg_time:.2f}ms")
-
- # Small sleep to prevent CPU overload
- time.sleep(0.001) # 1ms sleep for ultra-low latency
-
- except Exception as e:
- logger.error(f"Error in processing loop for {symbol}: {e}")
- time.sleep(0.01) # Longer sleep on error
-
- def _extract_neural_features(self, symbol: str) -> Optional[ProcessedTickFeatures]:
- """Extract neural network features from recent ticks"""
- try:
- with self.data_lock:
- # Get recent ticks
- recent_ticks = list(self.tick_buffers[symbol])[-self.processing_window:]
-
- if len(recent_ticks) < self.min_ticks_for_processing:
- return None
-
- # Convert ticks to neural network input
- tick_features = self._ticks_to_features(recent_ticks)
-
- # Process with neural network
- with torch.no_grad():
- tick_tensor = torch.FloatTensor(tick_features).unsqueeze(0).to(self.device)
- neural_features, confidence = self.tick_nn(tick_tensor)
-
- neural_features = neural_features.cpu().numpy().flatten()
- confidence = confidence.cpu().numpy().item()
-
- # Extract traditional features
- price_features = self._extract_price_features(recent_ticks)
- volume_features = self._extract_volume_features(recent_ticks)
- microstructure_features = self._extract_microstructure_features(recent_ticks)
-
- # Create processed features object
- processed = ProcessedTickFeatures(
- timestamp=recent_ticks[-1].timestamp,
- price_features=price_features,
- volume_features=volume_features,
- microstructure_features=microstructure_features,
- neural_features=neural_features,
- confidence=confidence
- )
-
- return processed
-
- except Exception as e:
- logger.error(f"Error extracting neural features for {symbol}: {e}")
- return None
-
- def _ticks_to_features(self, ticks: List[TickData]) -> np.ndarray:
- """Convert tick data to neural network input features"""
- features = []
-
- for i, tick in enumerate(ticks):
- tick_features = [
- tick.price,
- tick.volume,
- 1.0 if tick.side == 'buy' else 0.0, # Buy/sell indicator
- tick.timestamp.timestamp(), # Timestamp
- ]
-
- # Add relative features if we have previous ticks
- if i > 0:
- prev_tick = ticks[i-1]
- price_change = (tick.price - prev_tick.price) / prev_tick.price
- volume_ratio = tick.volume / (prev_tick.volume + 1e-8)
- time_delta = (tick.timestamp - prev_tick.timestamp).total_seconds()
-
- tick_features.extend([
- price_change,
- volume_ratio,
- time_delta
- ])
- else:
- tick_features.extend([0.0, 1.0, 0.0]) # Default values for first tick
-
- # Add moving averages if we have enough data
- if i >= 5:
- recent_prices = [t.price for t in ticks[max(0, i-4):i+1]]
- recent_volumes = [t.volume for t in ticks[max(0, i-4):i+1]]
-
- price_ma = np.mean(recent_prices)
- volume_ma = np.mean(recent_volumes)
-
- tick_features.extend([
- (tick.price - price_ma) / price_ma, # Price deviation from MA
- (tick.volume - volume_ma) / (volume_ma + 1e-8) # Volume deviation from MA
- ])
- else:
- tick_features.extend([0.0, 0.0])
-
- features.append(tick_features)
-
- # Pad or truncate to fixed size
- target_length = self.processing_window
- if len(features) < target_length:
- # Pad with zeros
- padding = [[0.0] * len(features[0])] * (target_length - len(features))
- features = padding + features
- elif len(features) > target_length:
- # Take the most recent ticks
- features = features[-target_length:]
-
- return np.array(features, dtype=np.float32)
-
- def _extract_price_features(self, ticks: List[TickData]) -> np.ndarray:
- """Extract price-based features"""
- prices = np.array([tick.price for tick in ticks])
-
- features = [
- prices[-1], # Current price
- np.mean(prices), # Average price
- np.std(prices), # Price volatility
- np.max(prices), # High
- np.min(prices), # Low
- (prices[-1] - prices[0]) / prices[0] if prices[0] != 0 else 0, # Total return
- ]
-
- # Price momentum features
- if len(prices) >= 10:
- short_ma = np.mean(prices[-5:])
- long_ma = np.mean(prices[-10:])
- momentum = (short_ma - long_ma) / long_ma if long_ma != 0 else 0
- features.append(momentum)
- else:
- features.append(0.0)
-
- return np.array(features, dtype=np.float32)
-
- def _extract_volume_features(self, ticks: List[TickData]) -> np.ndarray:
- """Extract volume-based features"""
- volumes = np.array([tick.volume for tick in ticks])
- buy_volumes = np.array([tick.volume for tick in ticks if tick.side == 'buy'])
- sell_volumes = np.array([tick.volume for tick in ticks if tick.side == 'sell'])
-
- features = [
- np.sum(volumes), # Total volume
- np.mean(volumes), # Average volume
- np.std(volumes), # Volume volatility
- np.sum(buy_volumes) if len(buy_volumes) > 0 else 0, # Buy volume
- np.sum(sell_volumes) if len(sell_volumes) > 0 else 0, # Sell volume
- ]
-
- # Volume imbalance
- total_buy = np.sum(buy_volumes) if len(buy_volumes) > 0 else 0
- total_sell = np.sum(sell_volumes) if len(sell_volumes) > 0 else 0
- total_volume = total_buy + total_sell
-
- if total_volume > 0:
- buy_ratio = total_buy / total_volume
- volume_imbalance = buy_ratio - 0.5 # -0.5 to 0.5 range
- else:
- volume_imbalance = 0.0
-
- features.append(volume_imbalance)
-
- # VWAP (Volume Weighted Average Price)
- if np.sum(volumes) > 0:
- prices = np.array([tick.price for tick in ticks])
- vwap = np.sum(prices * volumes) / np.sum(volumes)
- current_price = ticks[-1].price
- vwap_deviation = (current_price - vwap) / vwap if vwap != 0 else 0
- else:
- vwap_deviation = 0.0
-
- features.append(vwap_deviation)
-
- return np.array(features, dtype=np.float32)
-
- def _extract_microstructure_features(self, ticks: List[TickData]) -> np.ndarray:
- """Extract market microstructure features"""
- features = []
-
- # Trade frequency
- if len(ticks) >= 2:
- time_deltas = [(ticks[i].timestamp - ticks[i-1].timestamp).total_seconds()
- for i in range(1, len(ticks))]
- avg_time_delta = np.mean(time_deltas)
- trade_frequency = 1.0 / avg_time_delta if avg_time_delta > 0 else 0
- else:
- trade_frequency = 0.0
-
- features.append(trade_frequency)
-
- # Price impact features
- prices = [tick.price for tick in ticks]
- volumes = [tick.volume for tick in ticks]
-
- if len(prices) >= 3:
- # Calculate price changes and corresponding volumes
- price_changes = [(prices[i] - prices[i-1]) / prices[i-1]
- for i in range(1, len(prices)) if prices[i-1] != 0]
- corresponding_volumes = volumes[1:len(price_changes)+1]
-
- if len(price_changes) > 0 and len(corresponding_volumes) > 0:
- # Simple price impact measure
- price_impact = np.corrcoef(np.abs(price_changes), corresponding_volumes)[0, 1]
- if np.isnan(price_impact):
- price_impact = 0.0
- else:
- price_impact = 0.0
- else:
- price_impact = 0.0
-
- features.append(price_impact)
-
- # Bid-ask spread proxy (using price volatility)
- if len(prices) >= 5:
- recent_prices = prices[-5:]
- spread_proxy = (np.max(recent_prices) - np.min(recent_prices)) / np.mean(recent_prices)
- else:
- spread_proxy = 0.0
-
- features.append(spread_proxy)
-
- # Order flow imbalance (already calculated in volume features, but different perspective)
- buy_count = sum(1 for tick in ticks if tick.side == 'buy')
- sell_count = len(ticks) - buy_count
- total_trades = len(ticks)
-
- if total_trades > 0:
- order_flow_imbalance = (buy_count - sell_count) / total_trades
- else:
- order_flow_imbalance = 0.0
-
- features.append(order_flow_imbalance)
-
- return np.array(features, dtype=np.float32)
-
- def _notify_feature_subscribers(self, symbol: str, features: ProcessedTickFeatures):
- """Notify all feature subscribers of new processed features"""
- for callback in self.feature_subscribers:
- try:
- callback(symbol, features)
- except Exception as e:
- logger.error(f"Error notifying feature subscriber {callback.__name__}: {e}")
-
- def get_latest_features(self, symbol: str) -> Optional[ProcessedTickFeatures]:
- """Get the latest processed features for a symbol"""
- with self.data_lock:
- if symbol in self.processed_features and self.processed_features[symbol]:
- return self.processed_features[symbol][-1]
- return None
-
- def get_processing_stats(self) -> Dict[str, Any]:
- """Get processing performance statistics"""
- stats = {
- 'symbols': self.symbols,
- 'streaming': self.streaming,
- 'tick_counts': dict(self.tick_counts),
- 'buffer_sizes': {symbol: len(self.tick_buffers[symbol]) for symbol in self.symbols},
- 'feature_counts': {symbol: len(self.processed_features[symbol]) for symbol in self.symbols},
- 'subscribers': len(self.feature_subscribers)
- }
-
- if self.processing_times:
- stats['processing_performance'] = {
- 'avg_time_ms': np.mean(list(self.processing_times)),
- 'min_time_ms': np.min(list(self.processing_times)),
- 'max_time_ms': np.max(list(self.processing_times)),
- 'std_time_ms': np.std(list(self.processing_times))
- }
-
- return stats
-
- def train_neural_network(self, training_data: List[Tuple[np.ndarray, np.ndarray]], epochs: int = 100):
- """Train the tick processing neural network"""
- logger.info("Training tick processing neural network...")
-
- self.tick_nn.train()
- optimizer = torch.optim.Adam(self.tick_nn.parameters(), lr=0.001)
- criterion = nn.MSELoss()
-
- for epoch in range(epochs):
- total_loss = 0.0
-
- for batch_features, batch_targets in training_data:
- optimizer.zero_grad()
-
- # Convert to tensors
- features_tensor = torch.FloatTensor(batch_features).to(self.device)
- targets_tensor = torch.FloatTensor(batch_targets).to(self.device)
-
- # Forward pass
- outputs, confidence = self.tick_nn(features_tensor)
-
- # Calculate loss
- loss = criterion(outputs, targets_tensor)
-
- # Backward pass
- loss.backward()
- optimizer.step()
-
- total_loss += loss.item()
-
- if epoch % 10 == 0:
- avg_loss = total_loss / len(training_data)
- logger.info(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.6f}")
-
- self.tick_nn.eval()
- logger.info("Neural network training completed")
-
-# Integration with existing orchestrator
-def integrate_with_orchestrator(orchestrator, tick_processor: RealTimeTickProcessor):
- """Integrate tick processor with enhanced orchestrator"""
-
- def feature_callback(symbol: str, features: ProcessedTickFeatures):
- """Callback to feed processed features to orchestrator"""
- try:
- # Convert processed features to format expected by orchestrator
- feature_dict = {
- 'symbol': symbol,
- 'timestamp': features.timestamp,
- 'neural_features': features.neural_features,
- 'price_features': features.price_features,
- 'volume_features': features.volume_features,
- 'microstructure_features': features.microstructure_features,
- 'confidence': features.confidence
- }
-
- # Feed to orchestrator's real-time feature processing
- if hasattr(orchestrator, 'process_realtime_features'):
- orchestrator.process_realtime_features(feature_dict)
-
- except Exception as e:
- logger.error(f"Error integrating features with orchestrator: {e}")
-
- # Add the callback to tick processor
- tick_processor.add_feature_subscriber(feature_callback)
- logger.info("Tick processor integrated with orchestrator")
-
-# Factory function for easy creation
-def create_realtime_tick_processor(symbols: List[str] = None) -> RealTimeTickProcessor:
- """Create and configure a real-time tick processor"""
- if symbols is None:
- symbols = ['ETH/USDT', 'BTC/USDT']
-
- processor = RealTimeTickProcessor(symbols=symbols)
- logger.info(f"Created RealTimeTickProcessor for symbols: {symbols}")
-
- return processor
\ No newline at end of file
diff --git a/core/retrospective_trainer.py b/core/retrospective_trainer.py
deleted file mode 100644
index f7a0017..0000000
--- a/core/retrospective_trainer.py
+++ /dev/null
@@ -1,453 +0,0 @@
-"""
-Retrospective Training System
-
-This module implements a retrospective training system that:
-1. Triggers training when trades close with known P&L outcomes
-2. Uses captured model inputs from trade entry to train models
-3. Optimizes for profit by learning from profitable vs unprofitable patterns
-4. Supports simultaneous inference and training without weight reloading
-5. Implements reinforcement learning with immediate reward feedback
-"""
-
-import logging
-import threading
-import time
-import queue
-from datetime import datetime
-from typing import Dict, List, Optional, Any, Callable
-from dataclasses import dataclass
-import numpy as np
-from collections import deque
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class TrainingCase:
- """Represents a completed trade case for retrospective training"""
- case_id: str
- symbol: str
- action: str # 'BUY' or 'SELL'
- entry_price: float
- exit_price: float
- entry_time: datetime
- exit_time: datetime
- pnl: float
- fees: float
- confidence: float
- model_inputs: Dict[str, Any]
- market_state: Dict[str, Any]
- outcome_label: int # 1 for profit, 0 for loss, 2 for breakeven
- reward_signal: float # Scaled reward for RL training
- leverage: float = 1.0
-
-class RetrospectiveTrainer:
- """Retrospective training system for real-time model optimization"""
-
- def __init__(self, orchestrator=None, config: Optional[Dict[str, Any]] = None):
- """Initialize the retrospective trainer"""
- self.orchestrator = orchestrator
- self.config = config or {}
-
- # Training configuration
- self.batch_size = self.config.get('batch_size', 32)
- self.min_cases_for_training = self.config.get('min_cases_for_training', 5)
- self.profit_threshold = self.config.get('profit_threshold', 0.0)
- self.training_frequency = self.config.get('training_frequency_seconds', 120) # 2 minutes
- self.max_training_cases = self.config.get('max_training_cases', 1000)
-
- # Training state
- self.training_queue = queue.Queue()
- self.completed_cases = deque(maxlen=self.max_training_cases)
- self.training_stats = {
- 'total_cases': 0,
- 'profitable_cases': 0,
- 'loss_cases': 0,
- 'breakeven_cases': 0,
- 'avg_profit': 0.0,
- 'last_training_time': datetime.now(),
- 'training_sessions': 0,
- 'model_updates': 0
- }
-
- # Threading
- self.training_thread = None
- self.is_training_active = False
- self.training_lock = threading.Lock()
-
- logger.info("RetrospectiveTrainer initialized")
- logger.info(f"Configuration: batch_size={self.batch_size}, "
- f"min_cases={self.min_cases_for_training}, "
- f"training_freq={self.training_frequency}s")
-
- def add_completed_trade(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> bool:
- """Add a completed trade for retrospective training"""
- try:
- # Create training case from trade record
- case = self._create_training_case(trade_record, model_inputs)
- if case is None:
- return False
-
- # Add to completed cases
- self.completed_cases.append(case)
- self.training_queue.put(case)
-
- # Update statistics
- self.training_stats['total_cases'] += 1
- if case.outcome_label == 1: # Profit
- self.training_stats['profitable_cases'] += 1
- elif case.outcome_label == 0: # Loss
- self.training_stats['loss_cases'] += 1
- else: # Breakeven
- self.training_stats['breakeven_cases'] += 1
-
- # Calculate running average profit
- total_pnl = sum(c.pnl for c in self.completed_cases)
- self.training_stats['avg_profit'] = total_pnl / len(self.completed_cases)
-
- logger.info(f"RETROSPECTIVE: Added training case {case.case_id} "
- f"(P&L: ${case.pnl:.3f}, Label: {case.outcome_label})")
-
- # Trigger training if we have enough cases
- self._maybe_trigger_training()
-
- return True
-
- except Exception as e:
- logger.error(f"Error adding completed trade for retrospective training: {e}")
- return False
-
- def _create_training_case(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> Optional[TrainingCase]:
- """Create a training case from trade record and model inputs"""
- try:
- # Extract trade information
- symbol = trade_record.get('symbol', 'UNKNOWN')
- side = trade_record.get('side', 'UNKNOWN')
- pnl = trade_record.get('pnl', 0.0)
- fees = trade_record.get('fees', 0.0)
- confidence = trade_record.get('confidence', 0.0)
-
- # Calculate net P&L after fees
- net_pnl = pnl - fees
-
- # Determine outcome label and reward signal
- if net_pnl > self.profit_threshold:
- outcome_label = 1 # Profitable
- # Scale reward by profit magnitude and confidence
- reward_signal = min(10.0, net_pnl * confidence * 10) # Amplify for training
- elif net_pnl < -self.profit_threshold:
- outcome_label = 0 # Loss
- # Negative reward scaled by loss magnitude
- reward_signal = max(-10.0, net_pnl * confidence * 10) # Negative reward
- else:
- outcome_label = 2 # Breakeven
- reward_signal = 0.0
-
- # Create case ID
- timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
- case_id = f"retro_{timestamp_str}_{symbol.replace('/', '')}_{side}_pnl_{abs(net_pnl):.3f}".replace('.', 'p')
-
- # Create training case
- case = TrainingCase(
- case_id=case_id,
- symbol=symbol,
- action=side,
- entry_price=trade_record.get('entry_price', 0.0),
- exit_price=trade_record.get('exit_price', 0.0),
- entry_time=trade_record.get('entry_time', datetime.now()),
- exit_time=trade_record.get('exit_time', datetime.now()),
- pnl=net_pnl,
- fees=fees,
- confidence=confidence,
- model_inputs=model_inputs,
- market_state=model_inputs.get('market_state', {}),
- outcome_label=outcome_label,
- reward_signal=reward_signal,
- leverage=trade_record.get('leverage', 1.0)
- )
-
- return case
-
- except Exception as e:
- logger.error(f"Error creating training case: {e}")
- return None
-
- def _maybe_trigger_training(self):
- """Check if we should trigger a training session"""
- try:
- # Check if we have enough cases
- if len(self.completed_cases) < self.min_cases_for_training:
- return
-
- # Check if enough time has passed since last training
- time_since_last = (datetime.now() - self.training_stats['last_training_time']).total_seconds()
- if time_since_last < self.training_frequency:
- return
-
- # Check if training thread is not already running
- if self.is_training_active:
- logger.debug("Training already in progress, skipping trigger")
- return
-
- # Start training in background thread
- self._start_training_session()
-
- except Exception as e:
- logger.error(f"Error checking training trigger: {e}")
-
- def _start_training_session(self):
- """Start a training session in background thread"""
- try:
- if self.training_thread and self.training_thread.is_alive():
- logger.debug("Training thread already running")
- return
-
- self.training_thread = threading.Thread(
- target=self._run_training_session,
- daemon=True,
- name="RetrospectiveTrainer"
- )
- self.training_thread.start()
- logger.info("RETROSPECTIVE: Started training session")
-
- except Exception as e:
- logger.error(f"Error starting training session: {e}")
-
- def _run_training_session(self):
- """Run a complete training session"""
- try:
- with self.training_lock:
- self.is_training_active = True
- start_time = time.time()
-
- logger.info(f"RETROSPECTIVE: Training with {len(self.completed_cases)} cases")
-
- # Train models if orchestrator available
- training_results = {}
- if self.orchestrator:
- training_results = self._train_models()
-
- # Update statistics
- self.training_stats['last_training_time'] = datetime.now()
- self.training_stats['training_sessions'] += 1
- self.training_stats['model_updates'] += len(training_results)
-
- elapsed_time = time.time() - start_time
- logger.info(f"RETROSPECTIVE: Training completed in {elapsed_time:.2f}s - {training_results}")
-
- except Exception as e:
- logger.error(f"Error in retrospective training session: {e}")
- import traceback
- logger.error(traceback.format_exc())
- finally:
- self.is_training_active = False
-
- def _train_models(self) -> Dict[str, Any]:
- """Train available models using retrospective data"""
- results = {}
-
- try:
- # Prepare training data
- profitable_cases = [c for c in self.completed_cases if c.outcome_label == 1]
- loss_cases = [c for c in self.completed_cases if c.outcome_label == 0]
-
- if len(profitable_cases) == 0 and len(loss_cases) == 0:
- return {'error': 'No labeled cases for training'}
-
- logger.info(f"RETROSPECTIVE: Training data - Profitable: {len(profitable_cases)}, Loss: {len(loss_cases)}")
-
- # Train DQN agent if available
- if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
- try:
- dqn_result = self._train_dqn_retrospective()
- results['dqn'] = dqn_result
- logger.info(f"RETROSPECTIVE: DQN training result: {dqn_result}")
- except Exception as e:
- logger.warning(f"DQN retrospective training failed: {e}")
- results['dqn'] = {'error': str(e)}
-
- # Train other models
- if self.orchestrator and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
- try:
- # Update extrema trainer with retrospective feedback
- extrema_feedback = self._create_extrema_feedback()
- if extrema_feedback:
- results['extrema'] = {'feedback_cases': len(extrema_feedback)}
- logger.info(f"RETROSPECTIVE: Extrema feedback provided for {len(extrema_feedback)} cases")
- except Exception as e:
- logger.warning(f"Extrema retrospective training failed: {e}")
-
- return results
-
- except Exception as e:
- logger.error(f"Error training models retrospectively: {e}")
- return {'error': str(e)}
-
- def _train_dqn_retrospective(self) -> Dict[str, Any]:
- """Train DQN agent using retrospective experience replay"""
- try:
- if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
- return {'error': 'DQN agent not available'}
-
- dqn_agent = self.orchestrator.rl_agent
- experiences_added = 0
-
- # Add retrospective experiences to DQN replay buffer
- for case in self.completed_cases:
- try:
- # Extract state from model inputs
- state = self._extract_state_vector(case.model_inputs)
- if state is None:
- continue
-
- # Action mapping: BUY=0, SELL=1
- action = 0 if case.action == 'BUY' else 1
-
- # Use reward signal as immediate reward
- reward = case.reward_signal
-
- # For retrospective training, next_state is None (terminal)
- next_state = np.zeros_like(state) # Terminal state
- done = True
-
- # Add experience to DQN replay buffer
- if hasattr(dqn_agent, 'add_experience'):
- dqn_agent.add_experience(state, action, reward, next_state, done)
- experiences_added += 1
-
- except Exception as e:
- logger.debug(f"Error adding DQN experience: {e}")
- continue
-
- # Train DQN if we have enough experiences
- if experiences_added > 0 and hasattr(dqn_agent, 'train'):
- try:
- # Perform multiple training steps on retrospective data
- training_steps = min(10, experiences_added // 4) # Conservative training
- for _ in range(training_steps):
- loss = dqn_agent.train()
- if loss is None:
- break
-
- return {
- 'experiences_added': experiences_added,
- 'training_steps': training_steps,
- 'method': 'retrospective_experience_replay'
- }
- except Exception as e:
- logger.warning(f"DQN training step failed: {e}")
- return {'experiences_added': experiences_added, 'training_error': str(e)}
-
- return {'experiences_added': experiences_added, 'training_steps': 0}
-
- except Exception as e:
- logger.error(f"Error in DQN retrospective training: {e}")
- return {'error': str(e)}
-
- def _extract_state_vector(self, model_inputs: Dict[str, Any]) -> Optional[np.ndarray]:
- """Extract state vector for DQN training from model inputs"""
- try:
- # Try to get pre-built RL state
- if 'dqn_state' in model_inputs:
- state = model_inputs['dqn_state']
- if isinstance(state, dict) and 'state_vector' in state:
- return np.array(state['state_vector'])
-
- # Build state from market features
- market_state = model_inputs.get('market_state', {})
- features = []
-
- # Price features
- for key in ['current_price', 'price_sma_5', 'price_sma_20', 'price_std_20', 'price_rsi']:
- features.append(market_state.get(key, 0.0))
-
- # Volume features
- for key in ['volume_current', 'volume_sma_20', 'volume_ratio']:
- features.append(market_state.get(key, 0.0))
-
- # Technical indicators
- indicators = model_inputs.get('technical_indicators', {})
- for key in ['sma_10', 'sma_20', 'bb_upper', 'bb_lower', 'bb_position', 'macd', 'volatility']:
- features.append(indicators.get(key, 0.0))
-
- if len(features) < 5: # Minimum required features
- return None
-
- return np.array(features, dtype=np.float32)
-
- except Exception as e:
- logger.debug(f"Error extracting state vector: {e}")
- return None
-
- def _create_extrema_feedback(self) -> List[Dict[str, Any]]:
- """Create feedback data for extrema trainer"""
- feedback = []
-
- try:
- for case in self.completed_cases:
- if case.outcome_label in [0, 1]: # Only profit/loss cases
- feedback_item = {
- 'symbol': case.symbol,
- 'action': case.action,
- 'entry_price': case.entry_price,
- 'exit_price': case.exit_price,
- 'was_profitable': case.outcome_label == 1,
- 'reward_signal': case.reward_signal,
- 'market_state': case.market_state
- }
- feedback.append(feedback_item)
-
- return feedback
-
- except Exception as e:
- logger.error(f"Error creating extrema feedback: {e}")
- return []
-
- def get_training_stats(self) -> Dict[str, Any]:
- """Get current training statistics"""
- stats = self.training_stats.copy()
- stats['total_cases_in_memory'] = len(self.completed_cases)
- stats['training_queue_size'] = self.training_queue.qsize()
- stats['is_training_active'] = self.is_training_active
-
- # Calculate profit metrics
- if len(self.completed_cases) > 0:
- profitable_count = sum(1 for c in self.completed_cases if c.pnl > 0)
- stats['profit_rate'] = profitable_count / len(self.completed_cases)
- stats['total_pnl'] = sum(c.pnl for c in self.completed_cases)
- stats['avg_reward'] = sum(c.reward_signal for c in self.completed_cases) / len(self.completed_cases)
-
- return stats
-
- def force_training_session(self) -> bool:
- """Force a training session regardless of timing constraints"""
- try:
- if self.is_training_active:
- logger.warning("Training already in progress")
- return False
-
- if len(self.completed_cases) < 1:
- logger.warning("No completed cases available for training")
- return False
-
- logger.info("RETROSPECTIVE: Forcing training session")
- self._start_training_session()
- return True
-
- except Exception as e:
- logger.error(f"Error forcing training session: {e}")
- return False
-
- def stop(self):
- """Stop the retrospective trainer"""
- try:
- self.is_training_active = False
- if self.training_thread and self.training_thread.is_alive():
- self.training_thread.join(timeout=10)
- logger.info("RetrospectiveTrainer stopped")
- except Exception as e:
- logger.error(f"Error stopping RetrospectiveTrainer: {e}")
-
-
-def create_retrospective_trainer(orchestrator=None, config: Optional[Dict[str, Any]] = None) -> RetrospectiveTrainer:
- """Factory function to create a RetrospectiveTrainer instance"""
- return RetrospectiveTrainer(orchestrator=orchestrator, config=config)
diff --git a/core/rl_training_pipeline.py b/core/rl_training_pipeline.py
deleted file mode 100644
index 5f2fa7a..0000000
--- a/core/rl_training_pipeline.py
+++ /dev/null
@@ -1,529 +0,0 @@
-"""
-RL Training Pipeline with Comprehensive Experience Storage and Replay
-
-This module implements a robust RL training pipeline that:
-1. Stores all training experiences with profitability metrics
-2. Implements profit-weighted experience replay
-3. Tracks gradient information for each training step
-4. Enables retraining on most profitable trading sequences
-5. Maintains comprehensive trading episode analysis
-"""
-
-import logging
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from datetime import datetime, timedelta
-from pathlib import Path
-from typing import Dict, List, Optional, Tuple, Any
-from dataclasses import dataclass, field
-import json
-import pickle
-from collections import deque
-import threading
-import random
-
-from .training_data_collector import get_training_data_collector
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class RLExperience:
- """Single RL experience with complete state-action-reward information"""
- experience_id: str
- timestamp: datetime
- episode_id: str
-
- # Core RL components
- state: np.ndarray
- action: int # 0=SELL, 1=HOLD, 2=BUY
- reward: float
- next_state: np.ndarray
- done: bool
-
- # Extended state information
- market_context: Dict[str, Any]
- cnn_predictions: Optional[Dict[str, Any]] = None
- confidence_score: float = 0.0
-
- # Actual trading outcome
- actual_profit: Optional[float] = None
- actual_holding_time: Optional[timedelta] = None
- optimal_action: Optional[int] = None
-
- # Experience value for replay
- experience_value: float = 0.0
- profitability_score: float = 0.0
- learning_priority: float = 0.0
-
- # Training metadata
- times_trained: int = 0
- last_trained: Optional[datetime] = None
-
-class ProfitWeightedExperienceBuffer:
- """Experience buffer with profit-weighted sampling for replay"""
-
- def __init__(self, max_size: int = 100000):
- self.max_size = max_size
- self.experiences: Dict[str, RLExperience] = {}
- self.experience_order: deque = deque(maxlen=max_size)
- self.profitable_experiences: List[str] = []
- self.total_experiences = 0
- self.total_profitable = 0
-
- def add_experience(self, experience: RLExperience):
- """Add experience to buffer"""
- try:
- self.experiences[experience.experience_id] = experience
- self.experience_order.append(experience.experience_id)
-
- if experience.actual_profit is not None and experience.actual_profit > 0:
- self.profitable_experiences.append(experience.experience_id)
- self.total_profitable += 1
-
- # Remove oldest if buffer is full
- if len(self.experiences) > self.max_size:
- oldest_id = self.experience_order[0]
- if oldest_id in self.experiences:
- del self.experiences[oldest_id]
- if oldest_id in self.profitable_experiences:
- self.profitable_experiences.remove(oldest_id)
-
- self.total_experiences += 1
-
- except Exception as e:
- logger.error(f"Error adding experience to buffer: {e}")
-
- def sample_batch(self, batch_size: int, prioritize_profitable: bool = True) -> List[RLExperience]:
- """Sample batch with profit-weighted prioritization"""
- try:
- if len(self.experiences) < batch_size:
- return list(self.experiences.values())
-
- if prioritize_profitable and len(self.profitable_experiences) > batch_size // 2:
- # Sample mix of profitable and all experiences
- profitable_sample_size = min(batch_size // 2, len(self.profitable_experiences))
- remaining_sample_size = batch_size - profitable_sample_size
-
- profitable_ids = random.sample(self.profitable_experiences, profitable_sample_size)
- all_ids = list(self.experiences.keys())
- remaining_ids = random.sample(all_ids, remaining_sample_size)
-
- sampled_ids = profitable_ids + remaining_ids
- else:
- # Random sampling from all experiences
- all_ids = list(self.experiences.keys())
- sampled_ids = random.sample(all_ids, batch_size)
-
- sampled_experiences = [self.experiences[exp_id] for exp_id in sampled_ids]
-
- # Update training counts
- for experience in sampled_experiences:
- experience.times_trained += 1
- experience.last_trained = datetime.now()
-
- return sampled_experiences
-
- except Exception as e:
- logger.error(f"Error sampling batch: {e}")
- return list(self.experiences.values())[:batch_size]
-
- def get_most_profitable_experiences(self, limit: int = 100) -> List[RLExperience]:
- """Get most profitable experiences for targeted training"""
- try:
- profitable_experiences = [
- self.experiences[exp_id] for exp_id in self.profitable_experiences
- if exp_id in self.experiences
- ]
-
- profitable_experiences.sort(
- key=lambda x: x.actual_profit if x.actual_profit else 0,
- reverse=True
- )
-
- return profitable_experiences[:limit]
-
- except Exception as e:
- logger.error(f"Error getting profitable experiences: {e}")
- return []
-
-class RLTradingAgent(nn.Module):
- """RL Trading Agent with comprehensive state processing"""
-
- def __init__(self, state_dim: int = 2000, action_dim: int = 3, hidden_dim: int = 512):
- super(RLTradingAgent, self).__init__()
-
- self.state_dim = state_dim
- self.action_dim = action_dim
- self.hidden_dim = hidden_dim
-
- # State processing network
- self.state_processor = nn.Sequential(
- nn.Linear(state_dim, hidden_dim),
- nn.LayerNorm(hidden_dim),
- nn.ReLU(),
- nn.Dropout(0.2),
- nn.Linear(hidden_dim, hidden_dim // 2),
- nn.LayerNorm(hidden_dim // 2),
- nn.ReLU()
- )
-
- # Q-value network
- self.q_network = nn.Sequential(
- nn.Linear(hidden_dim // 2, hidden_dim // 4),
- nn.ReLU(),
- nn.Dropout(0.1),
- nn.Linear(hidden_dim // 4, action_dim)
- )
-
- # Policy network
- self.policy_network = nn.Sequential(
- nn.Linear(hidden_dim // 2, hidden_dim // 4),
- nn.ReLU(),
- nn.Dropout(0.1),
- nn.Linear(hidden_dim // 4, action_dim),
- nn.Softmax(dim=-1)
- )
-
- # Value network
- self.value_network = nn.Sequential(
- nn.Linear(hidden_dim // 2, hidden_dim // 4),
- nn.ReLU(),
- nn.Dropout(0.1),
- nn.Linear(hidden_dim // 4, 1)
- )
-
- def forward(self, state):
- """Forward pass through the agent"""
- processed_state = self.state_processor(state)
-
- q_values = self.q_network(processed_state)
- policy_probs = self.policy_network(processed_state)
- state_value = self.value_network(processed_state)
-
- return {
- 'q_values': q_values,
- 'policy_probs': policy_probs,
- 'state_value': state_value,
- 'processed_state': processed_state
- }
-
- def select_action(self, state, epsilon: float = 0.1) -> Tuple[int, float]:
- """Select action using epsilon-greedy policy"""
- self.eval()
- with torch.no_grad():
- if isinstance(state, np.ndarray):
- state = torch.from_numpy(state).float().unsqueeze(0)
-
- outputs = self.forward(state)
-
- if random.random() < epsilon:
- action = random.randint(0, self.action_dim - 1)
- confidence = 0.33
- else:
- q_values = outputs['q_values']
- action = torch.argmax(q_values, dim=1).item()
- q_softmax = F.softmax(q_values, dim=1)
- confidence = torch.max(q_softmax).item()
-
- return action, confidence
-
-@dataclass
-class RLTrainingStep:
- """Single RL training step with backpropagation data"""
- step_id: str
- timestamp: datetime
- batch_experiences: List[str]
-
- # Training data
- total_loss: float
- q_loss: float
- policy_loss: float
-
- # Gradients
- gradients: Dict[str, torch.Tensor]
- gradient_norms: Dict[str, float]
-
- # Metadata
- learning_rate: float = 0.001
- batch_size: int = 32
-
- # Performance
- batch_profitability: float = 0.0
- correct_actions: int = 0
- total_actions: int = 0
- step_value: float = 0.0
-
-@dataclass
-class RLTrainingSession:
- """Complete RL training session"""
- session_id: str
- start_timestamp: datetime
- end_timestamp: Optional[datetime] = None
-
- training_mode: str = 'experience_replay'
- symbol: str = ''
-
- training_steps: List[RLTrainingStep] = field(default_factory=list)
-
- total_steps: int = 0
- average_loss: float = 0.0
- best_loss: float = float('inf')
-
- profitable_actions: int = 0
- total_actions: int = 0
- profitability_rate: float = 0.0
- session_value: float = 0.0
-
-class RLTrainer:
- """RL trainer with comprehensive experience storage and replay"""
-
- def __init__(self, agent: RLTradingAgent, device: str = 'cuda', storage_dir: str = "rl_training_storage"):
- self.agent = agent.to(device)
- self.device = device
- self.storage_dir = Path(storage_dir)
- self.storage_dir.mkdir(parents=True, exist_ok=True)
-
- self.optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001)
- self.experience_buffer = ProfitWeightedExperienceBuffer()
- self.data_collector = get_training_data_collector()
-
- self.training_sessions: List[RLTrainingSession] = []
- self.current_session: Optional[RLTrainingSession] = None
-
- self.gamma = 0.99
-
- self.training_stats = {
- 'total_sessions': 0,
- 'total_steps': 0,
- 'total_experiences': 0,
- 'profitable_actions': 0,
- 'total_actions': 0,
- 'average_reward': 0.0
- }
-
- logger.info(f"RL Trainer initialized with {sum(p.numel() for p in agent.parameters()):,} parameters")
-
- def add_experience(self, state: np.ndarray, action: int, reward: float,
- next_state: np.ndarray, done: bool, market_context: Dict[str, Any],
- cnn_predictions: Dict[str, Any] = None, confidence_score: float = 0.0) -> str:
- """Add experience to the buffer"""
- try:
- experience_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
-
- experience = RLExperience(
- experience_id=experience_id,
- timestamp=datetime.now(),
- episode_id=market_context.get('episode_id', 'unknown'),
- state=state,
- action=action,
- reward=reward,
- next_state=next_state,
- done=done,
- market_context=market_context,
- cnn_predictions=cnn_predictions,
- confidence_score=confidence_score
- )
-
- self.experience_buffer.add_experience(experience)
- self.training_stats['total_experiences'] += 1
-
- return experience_id
-
- except Exception as e:
- logger.error(f"Error adding experience: {e}")
- return None
-
- def train_on_experiences(self, batch_size: int = 32, num_batches: int = 10) -> Dict[str, Any]:
- """Train on experiences with comprehensive data storage"""
- try:
- session = RLTrainingSession(
- session_id=f"rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
- start_timestamp=datetime.now(),
- training_mode='experience_replay'
- )
- self.current_session = session
-
- self.agent.train()
- total_loss = 0.0
-
- for batch_idx in range(num_batches):
- experiences = self.experience_buffer.sample_batch(batch_size, True)
-
- if len(experiences) < batch_size:
- continue
-
- # Prepare batch tensors
- states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
- actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
- rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
- next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
- dones = torch.BoolTensor([exp.done for exp in experiences]).to(self.device)
-
- # Forward pass
- self.optimizer.zero_grad()
-
- current_outputs = self.agent(states)
- current_q_values = current_outputs['q_values']
-
- # Calculate target Q-values
- with torch.no_grad():
- next_outputs = self.agent(next_states)
- next_q_values = next_outputs['q_values']
- max_next_q_values = torch.max(next_q_values, dim=1)[0]
- target_q_values = rewards + (self.gamma * max_next_q_values * ~dones)
-
- # Calculate loss
- current_q_values_for_actions = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
- q_loss = F.mse_loss(current_q_values_for_actions, target_q_values)
-
- # Backward pass
- q_loss.backward()
-
- # Store gradients
- gradients = {}
- gradient_norms = {}
- for name, param in self.agent.named_parameters():
- if param.grad is not None:
- gradients[name] = param.grad.clone().detach()
- gradient_norms[name] = param.grad.norm().item()
-
- torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=1.0)
- self.optimizer.step()
-
- # Create training step record
- step = RLTrainingStep(
- step_id=f"{session.session_id}_step_{batch_idx}",
- timestamp=datetime.now(),
- batch_experiences=[exp.experience_id for exp in experiences],
- total_loss=q_loss.item(),
- q_loss=q_loss.item(),
- policy_loss=0.0,
- gradients=gradients,
- gradient_norms=gradient_norms,
- batch_size=len(experiences)
- )
-
- session.training_steps.append(step)
- total_loss += q_loss.item()
-
- # Finalize session
- session.end_timestamp = datetime.now()
- session.total_steps = num_batches
- session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
-
- self._save_training_session(session)
-
- self.training_stats['total_sessions'] += 1
- self.training_stats['total_steps'] += session.total_steps
-
- logger.info(f"RL training session completed: {session.session_id}")
- logger.info(f"Average loss: {session.average_loss:.4f}")
-
- return {
- 'status': 'success',
- 'session_id': session.session_id,
- 'average_loss': session.average_loss,
- 'total_steps': session.total_steps
- }
-
- except Exception as e:
- logger.error(f"Error in RL training session: {e}")
- return {'status': 'error', 'error': str(e)}
- finally:
- self.current_session = None
-
- def train_on_profitable_experiences(self, min_profitability: float = 0.1,
- max_experiences: int = 1000, batch_size: int = 32) -> Dict[str, Any]:
- """Train specifically on most profitable experiences"""
- try:
- profitable_experiences = self.experience_buffer.get_most_profitable_experiences(max_experiences)
-
- filtered_experiences = [
- exp for exp in profitable_experiences
- if exp.actual_profit is not None and exp.actual_profit >= min_profitability
- ]
-
- if len(filtered_experiences) < batch_size:
- return {'status': 'insufficient_data', 'experiences_found': len(filtered_experiences)}
-
- logger.info(f"Training on {len(filtered_experiences)} profitable experiences")
-
- num_batches = len(filtered_experiences) // batch_size
-
- # Temporarily replace buffer sampling
- original_sample_method = self.experience_buffer.sample_batch
-
- def profitable_sample_batch(batch_size, prioritize_profitable=True):
- return random.sample(filtered_experiences, min(batch_size, len(filtered_experiences)))
-
- self.experience_buffer.sample_batch = profitable_sample_batch
-
- try:
- results = self.train_on_experiences(batch_size=batch_size, num_batches=num_batches)
- results['training_mode'] = 'profitable_replay'
- results['experiences_used'] = len(filtered_experiences)
- return results
- finally:
- self.experience_buffer.sample_batch = original_sample_method
-
- except Exception as e:
- logger.error(f"Error training on profitable experiences: {e}")
- return {'status': 'error', 'error': str(e)}
-
- def _save_training_session(self, session: RLTrainingSession):
- """Save training session to disk"""
- try:
- session_dir = self.storage_dir / 'sessions'
- session_dir.mkdir(parents=True, exist_ok=True)
-
- session_file = session_dir / f"{session.session_id}.pkl"
- with open(session_file, 'wb') as f:
- pickle.dump(session, f)
-
- metadata = {
- 'session_id': session.session_id,
- 'start_timestamp': session.start_timestamp.isoformat(),
- 'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
- 'training_mode': session.training_mode,
- 'total_steps': session.total_steps,
- 'average_loss': session.average_loss
- }
-
- metadata_file = session_dir / f"{session.session_id}_metadata.json"
- with open(metadata_file, 'w') as f:
- json.dump(metadata, f, indent=2)
-
- except Exception as e:
- logger.error(f"Error saving training session: {e}")
-
- def get_training_statistics(self) -> Dict[str, Any]:
- """Get comprehensive training statistics"""
- stats = self.training_stats.copy()
-
- if self.training_sessions:
- recent_sessions = sorted(self.training_sessions, key=lambda x: x.start_timestamp, reverse=True)[:10]
- stats['recent_sessions'] = [
- {
- 'session_id': s.session_id,
- 'timestamp': s.start_timestamp.isoformat(),
- 'mode': s.training_mode,
- 'average_loss': s.average_loss
- }
- for s in recent_sessions
- ]
-
- return stats
-
-# Global instance
-rl_trainer = None
-
-def get_rl_trainer(agent: RLTradingAgent = None) -> RLTrainer:
- """Get global RL trainer instance"""
- global rl_trainer
- if rl_trainer is None:
- if agent is None:
- agent = RLTradingAgent()
- rl_trainer = RLTrainer(agent)
- return rl_trainer
\ No newline at end of file
diff --git a/core/robust_cob_provider.py b/core/robust_cob_provider.py
deleted file mode 100644
index 443aabf..0000000
--- a/core/robust_cob_provider.py
+++ /dev/null
@@ -1,460 +0,0 @@
-"""
-Robust COB (Consolidated Order Book) Provider
-
-This module provides a robust COB data provider that handles:
-- HTTP 418 errors from Binance (rate limiting)
-- Thread safety issues
-- API rate limiting and backoff
-- Fallback data sources
-- Error recovery strategies
-
-Features:
-- Automatic rate limiting and backoff
-- Multiple exchange support with fallbacks
-- Thread-safe operations
-- Comprehensive error handling
-- Data validation and integrity checking
-"""
-
-import asyncio
-import logging
-import time
-import threading
-from datetime import datetime, timedelta
-from typing import Dict, List, Optional, Tuple, Any, Callable
-from dataclasses import dataclass, field
-from collections import deque
-import json
-import numpy as np
-from concurrent.futures import ThreadPoolExecutor, as_completed
-import requests
-
-from .api_rate_limiter import get_rate_limiter, RateLimitConfig
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class COBData:
- """Consolidated Order Book data structure"""
- symbol: str
- timestamp: datetime
- bids: List[Tuple[float, float]] # [(price, quantity), ...]
- asks: List[Tuple[float, float]] # [(price, quantity), ...]
-
- # Derived metrics
- spread: float = 0.0
- mid_price: float = 0.0
- total_bid_volume: float = 0.0
- total_ask_volume: float = 0.0
-
- # Data quality
- data_source: str = 'unknown'
- quality_score: float = 1.0
-
- def __post_init__(self):
- """Calculate derived metrics"""
- if self.bids and self.asks:
- self.spread = self.asks[0][0] - self.bids[0][0]
- self.mid_price = (self.asks[0][0] + self.bids[0][0]) / 2
- self.total_bid_volume = sum(qty for _, qty in self.bids)
- self.total_ask_volume = sum(qty for _, qty in self.asks)
-
- # Calculate quality score based on data completeness
- self.quality_score = min(
- len(self.bids) / 20, # Expect at least 20 bid levels
- len(self.asks) / 20, # Expect at least 20 ask levels
- 1.0
- )
-
-class RobustCOBProvider:
- """Robust COB provider with error handling and rate limiting"""
-
- def __init__(self, symbols: List[str] = None):
- self.symbols = symbols or ['ETHUSDT', 'BTCUSDT']
-
- # Rate limiter
- self.rate_limiter = get_rate_limiter()
-
- # Thread safety
- self.lock = threading.RLock()
-
- # Data cache
- self.cob_cache: Dict[str, COBData] = {}
- self.cache_timestamps: Dict[str, datetime] = {}
- self.cache_ttl = timedelta(seconds=5) # 5 second cache TTL
-
- # Error tracking
- self.error_counts: Dict[str, int] = {}
- self.last_successful_fetch: Dict[str, datetime] = {}
-
- # Background fetching
- self.is_running = False
- self.fetch_threads: Dict[str, threading.Thread] = {}
- self.executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="COB-Fetcher")
-
- # Fallback data
- self.fallback_data: Dict[str, COBData] = {}
-
- # Performance tracking
- self.fetch_stats = {
- 'total_requests': 0,
- 'successful_requests': 0,
- 'failed_requests': 0,
- 'rate_limited_requests': 0,
- 'cache_hits': 0,
- 'fallback_uses': 0
- }
-
- logger.info(f"Robust COB Provider initialized for symbols: {self.symbols}")
-
- def start_background_fetching(self):
- """Start background COB data fetching"""
- if self.is_running:
- logger.warning("Background fetching already running")
- return
-
- self.is_running = True
-
- # Start fetching thread for each symbol
- for symbol in self.symbols:
- thread = threading.Thread(
- target=self._background_fetch_worker,
- args=(symbol,),
- name=f"COB-{symbol}",
- daemon=True
- )
- self.fetch_threads[symbol] = thread
- thread.start()
-
- logger.info(f"Started background COB fetching for {len(self.symbols)} symbols")
-
- def stop_background_fetching(self):
- """Stop background COB data fetching"""
- self.is_running = False
-
- # Wait for threads to finish
- for symbol, thread in self.fetch_threads.items():
- thread.join(timeout=5)
- logger.debug(f"Stopped COB fetching for {symbol}")
-
- # Shutdown executor
- self.executor.shutdown(wait=True, timeout=10)
-
- logger.info("Stopped background COB fetching")
-
- def _background_fetch_worker(self, symbol: str):
- """Background worker for fetching COB data"""
- logger.info(f"Started COB fetching worker for {symbol}")
-
- while self.is_running:
- try:
- # Fetch COB data
- cob_data = self._fetch_cob_data_safe(symbol)
-
- if cob_data:
- with self.lock:
- self.cob_cache[symbol] = cob_data
- self.cache_timestamps[symbol] = datetime.now()
- self.last_successful_fetch[symbol] = datetime.now()
- self.error_counts[symbol] = 0 # Reset error count on success
-
- logger.debug(f"Updated COB cache for {symbol}")
- else:
- with self.lock:
- self.error_counts[symbol] = self.error_counts.get(symbol, 0) + 1
-
- logger.debug(f"Failed to fetch COB for {symbol}, error count: {self.error_counts.get(symbol, 0)}")
-
- # Wait before next fetch (adaptive based on errors)
- error_count = self.error_counts.get(symbol, 0)
- base_interval = 2.0 # Base 2 second interval
- backoff_interval = min(base_interval * (2 ** min(error_count, 5)), 60.0) # Max 60s
-
- time.sleep(backoff_interval)
-
- except Exception as e:
- logger.error(f"Error in COB fetching worker for {symbol}: {e}")
- time.sleep(10) # Wait 10s on unexpected errors
-
- logger.info(f"Stopped COB fetching worker for {symbol}")
-
- def _fetch_cob_data_safe(self, symbol: str) -> Optional[COBData]:
- """Safely fetch COB data with error handling"""
- try:
- self.fetch_stats['total_requests'] += 1
-
- # Try Binance first
- cob_data = self._fetch_binance_cob(symbol)
- if cob_data:
- self.fetch_stats['successful_requests'] += 1
- return cob_data
-
- # Try MEXC as fallback
- cob_data = self._fetch_mexc_cob(symbol)
- if cob_data:
- self.fetch_stats['successful_requests'] += 1
- cob_data.data_source = 'mexc_fallback'
- return cob_data
-
- # Use cached fallback data if available
- if symbol in self.fallback_data:
- self.fetch_stats['fallback_uses'] += 1
- fallback = self.fallback_data[symbol]
- fallback.timestamp = datetime.now()
- fallback.data_source = 'fallback_cache'
- fallback.quality_score *= 0.5 # Reduce quality score for old data
- return fallback
-
- self.fetch_stats['failed_requests'] += 1
- return None
-
- except Exception as e:
- logger.error(f"Error fetching COB data for {symbol}: {e}")
- self.fetch_stats['failed_requests'] += 1
- return None
-
- def _fetch_binance_cob(self, symbol: str) -> Optional[COBData]:
- """Fetch COB data from Binance with rate limiting"""
- try:
- url = f"https://api.binance.com/api/v3/depth"
- params = {
- 'symbol': symbol,
- 'limit': 100 # Get 100 levels
- }
-
- # Use rate limiter
- response = self.rate_limiter.make_request(
- 'binance_api',
- url,
- method='GET',
- params=params
- )
-
- if not response:
- self.fetch_stats['rate_limited_requests'] += 1
- return None
-
- if response.status_code != 200:
- logger.warning(f"Binance COB API returned {response.status_code} for {symbol}")
- return None
-
- data = response.json()
-
- # Parse order book data
- bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])]
- asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])]
-
- if not bids or not asks:
- logger.warning(f"Empty order book data from Binance for {symbol}")
- return None
-
- cob_data = COBData(
- symbol=symbol,
- timestamp=datetime.now(),
- bids=bids,
- asks=asks,
- data_source='binance'
- )
-
- # Store as fallback for future use
- self.fallback_data[symbol] = cob_data
-
- return cob_data
-
- except Exception as e:
- logger.error(f"Error fetching Binance COB for {symbol}: {e}")
- return None
-
- def _fetch_mexc_cob(self, symbol: str) -> Optional[COBData]:
- """Fetch COB data from MEXC as fallback"""
- try:
- url = f"https://api.mexc.com/api/v3/depth"
- params = {
- 'symbol': symbol,
- 'limit': 100
- }
-
- response = self.rate_limiter.make_request(
- 'mexc_api',
- url,
- method='GET',
- params=params
- )
-
- if not response or response.status_code != 200:
- return None
-
- data = response.json()
-
- # Parse order book data
- bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])]
- asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])]
-
- if not bids or not asks:
- return None
-
- return COBData(
- symbol=symbol,
- timestamp=datetime.now(),
- bids=bids,
- asks=asks,
- data_source='mexc'
- )
-
- except Exception as e:
- logger.debug(f"Error fetching MEXC COB for {symbol}: {e}")
- return None
-
- def get_cob_data(self, symbol: str) -> Optional[COBData]:
- """Get COB data for a symbol (from cache or fresh fetch)"""
- with self.lock:
- # Check cache first
- if symbol in self.cob_cache:
- cached_data = self.cob_cache[symbol]
- cache_time = self.cache_timestamps.get(symbol, datetime.min)
-
- # Return cached data if still fresh
- if datetime.now() - cache_time < self.cache_ttl:
- self.fetch_stats['cache_hits'] += 1
- return cached_data
-
- # If background fetching is running, return cached data even if stale
- if self.is_running and symbol in self.cob_cache:
- return self.cob_cache[symbol]
-
- # Fetch fresh data if not running background fetching
- if not self.is_running:
- return self._fetch_cob_data_safe(symbol)
-
- return None
-
- def get_cob_features(self, symbol: str, feature_count: int = 120) -> Optional[np.ndarray]:
- """
- Get COB features for ML models
-
- Args:
- symbol: Trading symbol
- feature_count: Number of features to return
-
- Returns:
- Numpy array of COB features or None if no data
- """
- cob_data = self.get_cob_data(symbol)
- if not cob_data:
- return None
-
- try:
- features = []
-
- # Basic market metrics
- features.extend([
- cob_data.mid_price,
- cob_data.spread,
- cob_data.total_bid_volume,
- cob_data.total_ask_volume,
- cob_data.quality_score
- ])
-
- # Bid levels (price and volume)
- max_levels = min(len(cob_data.bids), 20)
- for i in range(max_levels):
- price, volume = cob_data.bids[i]
- features.extend([price, volume])
-
- # Pad bid levels if needed
- for i in range(max_levels, 20):
- features.extend([0.0, 0.0])
-
- # Ask levels (price and volume)
- max_levels = min(len(cob_data.asks), 20)
- for i in range(max_levels):
- price, volume = cob_data.asks[i]
- features.extend([price, volume])
-
- # Pad ask levels if needed
- for i in range(max_levels, 20):
- features.extend([0.0, 0.0])
-
- # Calculate additional features
- if len(cob_data.bids) > 0 and len(cob_data.asks) > 0:
- # Volume imbalance
- bid_volume_5 = sum(vol for _, vol in cob_data.bids[:5])
- ask_volume_5 = sum(vol for _, vol in cob_data.asks[:5])
- volume_imbalance = (bid_volume_5 - ask_volume_5) / (bid_volume_5 + ask_volume_5) if (bid_volume_5 + ask_volume_5) > 0 else 0
- features.append(volume_imbalance)
-
- # Price levels
- bid_price_levels = [price for price, _ in cob_data.bids[:10]]
- ask_price_levels = [price for price, _ in cob_data.asks[:10]]
- features.extend(bid_price_levels + ask_price_levels)
-
- # Pad or truncate to desired feature count
- if len(features) < feature_count:
- features.extend([0.0] * (feature_count - len(features)))
- else:
- features = features[:feature_count]
-
- return np.array(features, dtype=np.float32)
-
- except Exception as e:
- logger.error(f"Error creating COB features for {symbol}: {e}")
- return None
-
- def get_provider_status(self) -> Dict[str, Any]:
- """Get provider status and statistics"""
- with self.lock:
- status = {
- 'is_running': self.is_running,
- 'symbols': self.symbols,
- 'cache_status': {},
- 'error_counts': self.error_counts.copy(),
- 'last_successful_fetch': {
- symbol: timestamp.isoformat()
- for symbol, timestamp in self.last_successful_fetch.items()
- },
- 'fetch_stats': self.fetch_stats.copy(),
- 'rate_limiter_status': self.rate_limiter.get_all_endpoint_status()
- }
-
- # Cache status for each symbol
- for symbol in self.symbols:
- cache_time = self.cache_timestamps.get(symbol)
- status['cache_status'][symbol] = {
- 'has_data': symbol in self.cob_cache,
- 'cache_time': cache_time.isoformat() if cache_time else None,
- 'cache_age_seconds': (datetime.now() - cache_time).total_seconds() if cache_time else None,
- 'data_quality': self.cob_cache[symbol].quality_score if symbol in self.cob_cache else 0.0
- }
-
- return status
-
- def reset_errors(self):
- """Reset error counts and rate limiter"""
- with self.lock:
- self.error_counts.clear()
- self.rate_limiter.reset_all_endpoints()
- logger.info("Reset all error counts and rate limiter")
-
- def force_refresh(self, symbol: str = None):
- """Force refresh COB data for symbol(s)"""
- symbols_to_refresh = [symbol] if symbol else self.symbols
-
- for sym in symbols_to_refresh:
- # Clear cache to force refresh
- with self.lock:
- if sym in self.cob_cache:
- del self.cob_cache[sym]
- if sym in self.cache_timestamps:
- del self.cache_timestamps[sym]
-
- logger.info(f"Forced refresh for {sym}")
-
-# Global COB provider instance
-_global_cob_provider = None
-
-def get_cob_provider(symbols: List[str] = None) -> RobustCOBProvider:
- """Get global COB provider instance"""
- global _global_cob_provider
- if _global_cob_provider is None:
- _global_cob_provider = RobustCOBProvider(symbols)
- return _global_cob_provider
\ No newline at end of file
diff --git a/core/shared_cob_service.py b/core/shared_cob_service.py
deleted file mode 100644
index c11aae5..0000000
--- a/core/shared_cob_service.py
+++ /dev/null
@@ -1,350 +0,0 @@
-#!/usr/bin/env python3
-"""
-Shared COB Service - Eliminates Redundant COB Implementations
-
-This service provides a singleton COB integration that can be shared across:
-- Dashboard components
-- RL trading systems
-- Enhanced orchestrators
-- Training pipelines
-
-Instead of each component creating its own COBIntegration instance,
-they all share this single service, eliminating redundant connections.
-"""
-
-import asyncio
-import logging
-import weakref
-from typing import Dict, List, Optional, Any, Callable, Set
-from datetime import datetime
-from threading import Lock
-from dataclasses import dataclass
-
-from .cob_integration import COBIntegration
-from .multi_exchange_cob_provider import COBSnapshot
-from .data_provider import DataProvider
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class COBSubscription:
- """Represents a subscription to COB updates"""
- subscriber_id: str
- callback: Callable
- symbol_filter: Optional[List[str]] = None
- callback_type: str = "general" # general, cnn, dqn, dashboard
-
-class SharedCOBService:
- """
- Shared COB Service - Singleton pattern for unified COB data access
-
- This service eliminates redundant COB integrations by providing a single
- shared instance that all components can subscribe to.
- """
-
- _instance: Optional['SharedCOBService'] = None
- _lock = Lock()
-
- def __new__(cls, *args, **kwargs):
- """Singleton pattern implementation"""
- if cls._instance is None:
- with cls._lock:
- if cls._instance is None:
- cls._instance = super(SharedCOBService, cls).__new__(cls)
- return cls._instance
-
- def __init__(self, symbols: Optional[List[str]] = None, data_provider: Optional[DataProvider] = None):
- """Initialize shared COB service (only called once due to singleton)"""
- if hasattr(self, '_initialized'):
- return
-
- self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
- self.data_provider = data_provider
-
- # Single COB integration instance
- self.cob_integration: Optional[COBIntegration] = None
- self.is_running = False
-
- # Subscriber management
- self.subscribers: Dict[str, COBSubscription] = {}
- self.subscriber_counter = 0
- self.subscription_lock = Lock()
-
- # Cached data for immediate access
- self.latest_snapshots: Dict[str, COBSnapshot] = {}
- self.latest_cnn_features: Dict[str, Any] = {}
- self.latest_dqn_states: Dict[str, Any] = {}
-
- # Performance tracking
- self.total_subscribers = 0
- self.update_count = 0
- self.start_time = None
-
- self._initialized = True
- logger.info(f"SharedCOBService initialized for symbols: {self.symbols}")
-
- async def start(self) -> None:
- """Start the shared COB service"""
- if self.is_running:
- logger.warning("SharedCOBService already running")
- return
-
- logger.info("Starting SharedCOBService...")
-
- try:
- # Initialize COB integration if not already done
- if self.cob_integration is None:
- self.cob_integration = COBIntegration(
- data_provider=self.data_provider,
- symbols=self.symbols
- )
-
- # Register internal callbacks
- self.cob_integration.add_cnn_callback(self._on_cob_cnn_update)
- self.cob_integration.add_dqn_callback(self._on_cob_dqn_update)
- self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_update)
-
- # Start COB integration
- await self.cob_integration.start()
-
- self.is_running = True
- self.start_time = datetime.now()
-
- logger.info("SharedCOBService started successfully")
- logger.info(f"Active subscribers: {len(self.subscribers)}")
-
- except Exception as e:
- logger.error(f"Error starting SharedCOBService: {e}")
- raise
-
- async def stop(self) -> None:
- """Stop the shared COB service"""
- if not self.is_running:
- return
-
- logger.info("Stopping SharedCOBService...")
-
- try:
- if self.cob_integration:
- await self.cob_integration.stop()
-
- self.is_running = False
-
- # Notify all subscribers of shutdown
- for subscription in self.subscribers.values():
- try:
- if hasattr(subscription.callback, '__call__'):
- subscription.callback("SHUTDOWN", None)
- except Exception as e:
- logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}")
-
- logger.info("SharedCOBService stopped")
-
- except Exception as e:
- logger.error(f"Error stopping SharedCOBService: {e}")
-
- def subscribe(self,
- callback: Callable,
- callback_type: str = "general",
- symbol_filter: Optional[List[str]] = None,
- subscriber_name: str = None) -> str:
- """
- Subscribe to COB updates
-
- Args:
- callback: Function to call on updates
- callback_type: Type of callback ('general', 'cnn', 'dqn', 'dashboard')
- symbol_filter: Only receive updates for these symbols (None = all)
- subscriber_name: Optional name for the subscriber
-
- Returns:
- Subscription ID for unsubscribing
- """
- with self.subscription_lock:
- self.subscriber_counter += 1
- subscriber_id = f"{callback_type}_{self.subscriber_counter}"
- if subscriber_name:
- subscriber_id = f"{subscriber_name}_{subscriber_id}"
-
- subscription = COBSubscription(
- subscriber_id=subscriber_id,
- callback=callback,
- symbol_filter=symbol_filter,
- callback_type=callback_type
- )
-
- self.subscribers[subscriber_id] = subscription
- self.total_subscribers += 1
-
- logger.info(f"New subscriber: {subscriber_id} ({callback_type})")
- logger.info(f"Total active subscribers: {len(self.subscribers)}")
-
- return subscriber_id
-
- def unsubscribe(self, subscriber_id: str) -> bool:
- """
- Unsubscribe from COB updates
-
- Args:
- subscriber_id: ID returned from subscribe()
-
- Returns:
- True if successfully unsubscribed
- """
- with self.subscription_lock:
- if subscriber_id in self.subscribers:
- del self.subscribers[subscriber_id]
- logger.info(f"Unsubscribed: {subscriber_id}")
- logger.info(f"Remaining subscribers: {len(self.subscribers)}")
- return True
- else:
- logger.warning(f"Subscriber not found: {subscriber_id}")
- return False
-
- # Internal callback handlers
-
- async def _on_cob_cnn_update(self, symbol: str, data: Dict):
- """Handle CNN feature updates from COB integration"""
- try:
- self.latest_cnn_features[symbol] = data
- await self._notify_subscribers("cnn", symbol, data)
- except Exception as e:
- logger.error(f"Error in CNN update handler: {e}")
-
- async def _on_cob_dqn_update(self, symbol: str, data: Dict):
- """Handle DQN state updates from COB integration"""
- try:
- self.latest_dqn_states[symbol] = data
- await self._notify_subscribers("dqn", symbol, data)
- except Exception as e:
- logger.error(f"Error in DQN update handler: {e}")
-
- async def _on_cob_dashboard_update(self, symbol: str, data: Dict):
- """Handle dashboard updates from COB integration"""
- try:
- # Store snapshot if it's a COBSnapshot
- if hasattr(data, 'volume_weighted_mid'): # Duck typing for COBSnapshot
- self.latest_snapshots[symbol] = data
-
- await self._notify_subscribers("dashboard", symbol, data)
- await self._notify_subscribers("general", symbol, data)
-
- self.update_count += 1
-
- except Exception as e:
- logger.error(f"Error in dashboard update handler: {e}")
-
- async def _notify_subscribers(self, callback_type: str, symbol: str, data: Any):
- """Notify all relevant subscribers of an update"""
- try:
- relevant_subscribers = [
- sub for sub in self.subscribers.values()
- if (sub.callback_type == callback_type or sub.callback_type == "general") and
- (sub.symbol_filter is None or symbol in sub.symbol_filter)
- ]
-
- for subscription in relevant_subscribers:
- try:
- if asyncio.iscoroutinefunction(subscription.callback):
- asyncio.create_task(subscription.callback(symbol, data))
- else:
- subscription.callback(symbol, data)
- except Exception as e:
- logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}")
-
- except Exception as e:
- logger.error(f"Error notifying subscribers: {e}")
-
- # Public data access methods
-
- def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]:
- """Get latest COB snapshot for a symbol"""
- if self.cob_integration:
- return self.cob_integration.get_cob_snapshot(symbol)
- return self.latest_snapshots.get(symbol)
-
- def get_cnn_features(self, symbol: str) -> Optional[Any]:
- """Get latest CNN features for a symbol"""
- return self.latest_cnn_features.get(symbol)
-
- def get_dqn_state(self, symbol: str) -> Optional[Any]:
- """Get latest DQN state for a symbol"""
- return self.latest_dqn_states.get(symbol)
-
- def get_market_depth_analysis(self, symbol: str) -> Optional[Dict]:
- """Get detailed market depth analysis"""
- if self.cob_integration:
- return self.cob_integration.get_market_depth_analysis(symbol)
- return None
-
- def get_exchange_breakdown(self, symbol: str) -> Optional[Dict]:
- """Get liquidity breakdown by exchange"""
- if self.cob_integration:
- return self.cob_integration.get_exchange_breakdown(symbol)
- return None
-
- def get_price_buckets(self, symbol: str) -> Optional[Dict]:
- """Get fine-grain price buckets"""
- if self.cob_integration:
- return self.cob_integration.get_price_buckets(symbol)
- return None
-
- def get_session_volume_profile(self, symbol: str) -> Optional[Dict]:
- """Get session volume profile"""
- if self.cob_integration and hasattr(self.cob_integration.cob_provider, 'get_session_volume_profile'):
- return self.cob_integration.cob_provider.get_session_volume_profile(symbol)
- return None
-
- def get_realtime_stats_for_nn(self, symbol: str) -> Dict:
- """Get real-time statistics formatted for NN models"""
- if self.cob_integration:
- return self.cob_integration.get_realtime_stats_for_nn(symbol)
- return {}
-
- def get_service_statistics(self) -> Dict[str, Any]:
- """Get service statistics"""
- uptime = None
- if self.start_time:
- uptime = (datetime.now() - self.start_time).total_seconds()
-
- base_stats = {
- 'service_name': 'SharedCOBService',
- 'is_running': self.is_running,
- 'symbols': self.symbols,
- 'total_subscribers': len(self.subscribers),
- 'lifetime_subscribers': self.total_subscribers,
- 'update_count': self.update_count,
- 'uptime_seconds': uptime,
- 'subscribers_by_type': {}
- }
-
- # Count subscribers by type
- for subscription in self.subscribers.values():
- callback_type = subscription.callback_type
- if callback_type not in base_stats['subscribers_by_type']:
- base_stats['subscribers_by_type'][callback_type] = 0
- base_stats['subscribers_by_type'][callback_type] += 1
-
- # Get COB integration stats if available
- if self.cob_integration:
- cob_stats = self.cob_integration.get_statistics()
- base_stats.update(cob_stats)
-
- return base_stats
-
-# Global service instance access functions
-
-def get_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService:
- """Get the shared COB service instance"""
- return SharedCOBService(symbols=symbols, data_provider=data_provider)
-
-async def start_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService:
- """Start the shared COB service"""
- service = get_shared_cob_service(symbols=symbols, data_provider=data_provider)
- await service.start()
- return service
-
-async def stop_shared_cob_service():
- """Stop the shared COB service"""
- service = get_shared_cob_service()
- await service.stop()
\ No newline at end of file
diff --git a/core/shared_data_manager.py b/core/shared_data_manager.py
deleted file mode 100644
index 87a21e7..0000000
--- a/core/shared_data_manager.py
+++ /dev/null
@@ -1,425 +0,0 @@
-"""
-Shared Data Manager for UI Stability Fix
-
-Manages data sharing between processes through files with proper locking
-and atomic operations to prevent corruption and conflicts.
-"""
-
-import json
-import os
-import time
-import tempfile
-import platform
-from datetime import datetime
-from dataclasses import dataclass, asdict
-from typing import Dict, Any, Optional, Union
-from pathlib import Path
-import logging
-
-# Windows-compatible file locking
-if platform.system() == "Windows":
- import msvcrt
-else:
- import fcntl
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class ProcessStatus:
- """Model for process status information"""
- name: str
- pid: int
- status: str # 'running', 'stopped', 'error'
- start_time: datetime
- last_heartbeat: datetime
- memory_usage: float
- cpu_usage: float
- error_message: Optional[str] = None
-
- def to_dict(self) -> Dict[str, Any]:
- """Convert to dictionary with datetime serialization"""
- data = asdict(self)
- data['start_time'] = self.start_time.isoformat()
- data['last_heartbeat'] = self.last_heartbeat.isoformat()
- return data
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> 'ProcessStatus':
- """Create from dictionary with datetime deserialization"""
- data['start_time'] = datetime.fromisoformat(data['start_time'])
- data['last_heartbeat'] = datetime.fromisoformat(data['last_heartbeat'])
- return cls(**data)
-
-@dataclass
-class TrainingStatus:
- """Model for training status information"""
- is_running: bool
- current_epoch: int
- total_epochs: int
- loss: float
- accuracy: float
- last_update: datetime
- model_path: str
- error_message: Optional[str] = None
-
- def to_dict(self) -> Dict[str, Any]:
- """Convert to dictionary with datetime serialization"""
- data = asdict(self)
- data['last_update'] = self.last_update.isoformat()
- return data
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> 'TrainingStatus':
- """Create from dictionary with datetime deserialization"""
- data['last_update'] = datetime.fromisoformat(data['last_update'])
- return cls(**data)
-
-@dataclass
-class DashboardState:
- """Model for dashboard state information"""
- is_connected: bool
- last_data_update: datetime
- active_connections: int
- error_count: int
- performance_metrics: Dict[str, float]
-
- def to_dict(self) -> Dict[str, Any]:
- """Convert to dictionary with datetime serialization"""
- data = asdict(self)
- data['last_data_update'] = self.last_data_update.isoformat()
- return data
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> 'DashboardState':
- """Create from dictionary with datetime deserialization"""
- data['last_data_update'] = datetime.fromisoformat(data['last_data_update'])
- return cls(**data)
-
-
-class SharedDataManager:
- """
- Manages data sharing between processes through files with proper locking
- and atomic operations to prevent corruption and conflicts.
- """
-
- def __init__(self, data_dir: str = "shared_data"):
- """
- Initialize the shared data manager
-
- Args:
- data_dir: Directory to store shared data files
- """
- self.data_dir = Path(data_dir)
- self.data_dir.mkdir(exist_ok=True)
-
- # Define file paths for different data types
- self.training_status_file = self.data_dir / "training_status.json"
- self.dashboard_state_file = self.data_dir / "dashboard_state.json"
- self.process_status_file = self.data_dir / "process_status.json"
- self.market_data_file = self.data_dir / "market_data.json"
- self.model_metrics_file = self.data_dir / "model_metrics.json"
-
- logger.info(f"SharedDataManager initialized with data directory: {self.data_dir}")
-
- def _lock_file(self, file_handle, exclusive=True):
- """Cross-platform file locking"""
- if platform.system() == "Windows":
- # Windows file locking
- try:
- if exclusive:
- msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
- else:
- msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
- except IOError:
- pass # File locking may not be available in all scenarios
- else:
- # Unix file locking
- lock_type = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH
- fcntl.flock(file_handle.fileno(), lock_type)
-
- def _unlock_file(self, file_handle):
- """Cross-platform file unlocking"""
- if platform.system() == "Windows":
- try:
- msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
- except IOError:
- pass
- else:
- fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
-
- def _write_json_atomic(self, file_path: Path, data: Dict[str, Any]) -> None:
- """
- Write JSON data atomically with file locking
-
- Args:
- file_path: Path to the file to write
- data: Data to write as JSON
- """
- temp_path = None
- try:
- # Create temporary file in the same directory
- temp_fd, temp_path = tempfile.mkstemp(
- dir=file_path.parent,
- prefix=f".{file_path.name}.",
- suffix=".tmp"
- )
-
- with os.fdopen(temp_fd, 'w') as temp_file:
- # Lock the temporary file
- self._lock_file(temp_file, exclusive=True)
-
- # Write data with proper formatting
- json.dump(data, temp_file, indent=2, default=str)
- temp_file.flush()
- os.fsync(temp_file.fileno())
-
- # Unlock before closing
- self._unlock_file(temp_file)
-
- # Atomically replace the original file
- os.replace(temp_path, file_path)
- logger.debug(f"Successfully wrote data to {file_path}")
-
- except Exception as e:
- # Clean up temporary file if it exists
- if temp_path:
- try:
- os.unlink(temp_path)
- except:
- pass
- logger.error(f"Failed to write data to {file_path}: {e}")
- raise
-
- def _read_json_safe(self, file_path: Path) -> Dict[str, Any]:
- """
- Read JSON data safely with file locking
-
- Args:
- file_path: Path to the file to read
-
- Returns:
- Dictionary containing the JSON data
- """
- if not file_path.exists():
- logger.debug(f"File {file_path} does not exist, returning empty dict")
- return {}
-
- try:
- with open(file_path, 'r') as file:
- # Lock the file for reading
- self._lock_file(file, exclusive=False)
- data = json.load(file)
- self._unlock_file(file)
- logger.debug(f"Successfully read data from {file_path}")
- return data
-
- except json.JSONDecodeError as e:
- logger.error(f"Invalid JSON in {file_path}: {e}")
- return {}
- except Exception as e:
- logger.error(f"Failed to read data from {file_path}: {e}")
- return {}
-
- def write_training_status(self, status: TrainingStatus) -> None:
- """
- Write training status to shared file
-
- Args:
- status: TrainingStatus object to write
- """
- try:
- data = status.to_dict()
- self._write_json_atomic(self.training_status_file, data)
- logger.debug("Training status written successfully")
- except Exception as e:
- logger.error(f"Failed to write training status: {e}")
- raise
-
- def read_training_status(self) -> Optional[TrainingStatus]:
- """
- Read training status from shared file
-
- Returns:
- TrainingStatus object or None if not available
- """
- try:
- data = self._read_json_safe(self.training_status_file)
- if not data:
- return None
- return TrainingStatus.from_dict(data)
- except Exception as e:
- logger.error(f"Failed to read training status: {e}")
- return None
-
- def write_dashboard_state(self, state: DashboardState) -> None:
- """
- Write dashboard state to shared file
-
- Args:
- state: DashboardState object to write
- """
- try:
- data = state.to_dict()
- self._write_json_atomic(self.dashboard_state_file, data)
- logger.debug("Dashboard state written successfully")
- except Exception as e:
- logger.error(f"Failed to write dashboard state: {e}")
- raise
-
- def read_dashboard_state(self) -> Optional[DashboardState]:
- """
- Read dashboard state from shared file
-
- Returns:
- DashboardState object or None if not available
- """
- try:
- data = self._read_json_safe(self.dashboard_state_file)
- if not data:
- return None
- return DashboardState.from_dict(data)
- except Exception as e:
- logger.error(f"Failed to read dashboard state: {e}")
- return None
-
- def write_process_status(self, status: ProcessStatus) -> None:
- """
- Write process status to shared file
-
- Args:
- status: ProcessStatus object to write
- """
- try:
- data = status.to_dict()
- self._write_json_atomic(self.process_status_file, data)
- logger.debug("Process status written successfully")
- except Exception as e:
- logger.error(f"Failed to write process status: {e}")
- raise
-
- def read_process_status(self) -> Optional[ProcessStatus]:
- """
- Read process status from shared file
-
- Returns:
- ProcessStatus object or None if not available
- """
- try:
- data = self._read_json_safe(self.process_status_file)
- if not data:
- return None
- return ProcessStatus.from_dict(data)
- except Exception as e:
- logger.error(f"Failed to read process status: {e}")
- return None
-
- def write_market_data(self, data: Dict[str, Any]) -> None:
- """
- Write market data to shared file
-
- Args:
- data: Market data dictionary to write
- """
- try:
- # Add timestamp to market data
- data['timestamp'] = datetime.now().isoformat()
- self._write_json_atomic(self.market_data_file, data)
- logger.debug("Market data written successfully")
- except Exception as e:
- logger.error(f"Failed to write market data: {e}")
- raise
-
- def read_market_data(self) -> Dict[str, Any]:
- """
- Read market data from shared file
-
- Returns:
- Dictionary containing market data
- """
- try:
- return self._read_json_safe(self.market_data_file)
- except Exception as e:
- logger.error(f"Failed to read market data: {e}")
- return {}
-
- def write_model_metrics(self, metrics: Dict[str, Any]) -> None:
- """
- Write model metrics to shared file
-
- Args:
- metrics: Model metrics dictionary to write
- """
- try:
- # Add timestamp to metrics
- metrics['timestamp'] = datetime.now().isoformat()
- self._write_json_atomic(self.model_metrics_file, metrics)
- logger.debug("Model metrics written successfully")
- except Exception as e:
- logger.error(f"Failed to write model metrics: {e}")
- raise
-
- def read_model_metrics(self) -> Dict[str, Any]:
- """
- Read model metrics from shared file
-
- Returns:
- Dictionary containing model metrics
- """
- try:
- return self._read_json_safe(self.model_metrics_file)
- except Exception as e:
- logger.error(f"Failed to read model metrics: {e}")
- return {}
-
- def cleanup(self) -> None:
- """
- Clean up shared data files
- """
- try:
- for file_path in [
- self.training_status_file,
- self.dashboard_state_file,
- self.process_status_file,
- self.market_data_file,
- self.model_metrics_file
- ]:
- if file_path.exists():
- file_path.unlink()
- logger.debug(f"Removed {file_path}")
-
- # Remove directory if empty
- if self.data_dir.exists() and not any(self.data_dir.iterdir()):
- self.data_dir.rmdir()
- logger.debug(f"Removed empty directory {self.data_dir}")
-
- except Exception as e:
- logger.error(f"Failed to cleanup shared data: {e}")
-
- def get_data_age(self, data_type: str) -> Optional[float]:
- """
- Get the age of data in seconds
-
- Args:
- data_type: Type of data ('training', 'dashboard', 'process', 'market', 'metrics')
-
- Returns:
- Age in seconds or None if file doesn't exist
- """
- file_map = {
- 'training': self.training_status_file,
- 'dashboard': self.dashboard_state_file,
- 'process': self.process_status_file,
- 'market': self.market_data_file,
- 'metrics': self.model_metrics_file
- }
-
- file_path = file_map.get(data_type)
- if not file_path or not file_path.exists():
- return None
-
- try:
- mtime = file_path.stat().st_mtime
- return time.time() - mtime
- except Exception as e:
- logger.error(f"Failed to get data age for {data_type}: {e}")
- return None
\ No newline at end of file
diff --git a/core/trading_action.py b/core/trading_action.py
deleted file mode 100644
index 1d90d37..0000000
--- a/core/trading_action.py
+++ /dev/null
@@ -1,59 +0,0 @@
-"""
-Trading Action Module
-
-Defines the TradingAction class used throughout the trading system.
-"""
-
-from dataclasses import dataclass
-from datetime import datetime
-from typing import Dict, Any, List
-
-@dataclass
-class TradingAction:
- """Represents a trading action with full context"""
- symbol: str
- action: str # 'BUY', 'SELL', 'HOLD'
- quantity: float
- confidence: float
- price: float
- timestamp: datetime
- reasoning: Dict[str, Any]
-
- def __post_init__(self):
- """Validate the trading action after initialization"""
- if self.action not in ['BUY', 'SELL', 'HOLD']:
- raise ValueError(f"Invalid action: {self.action}. Must be 'BUY', 'SELL', or 'HOLD'")
-
- if self.confidence < 0.0 or self.confidence > 1.0:
- raise ValueError(f"Invalid confidence: {self.confidence}. Must be between 0.0 and 1.0")
-
- if self.quantity < 0:
- raise ValueError(f"Invalid quantity: {self.quantity}. Must be non-negative")
-
- if self.price <= 0:
- raise ValueError(f"Invalid price: {self.price}. Must be positive")
-
- def to_dict(self) -> Dict[str, Any]:
- """Convert trading action to dictionary"""
- return {
- 'symbol': self.symbol,
- 'action': self.action,
- 'quantity': self.quantity,
- 'confidence': self.confidence,
- 'price': self.price,
- 'timestamp': self.timestamp.isoformat(),
- 'reasoning': self.reasoning
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> 'TradingAction':
- """Create trading action from dictionary"""
- return cls(
- symbol=data['symbol'],
- action=data['action'],
- quantity=data['quantity'],
- confidence=data['confidence'],
- price=data['price'],
- timestamp=datetime.fromisoformat(data['timestamp']),
- reasoning=data['reasoning']
- )
\ No newline at end of file
diff --git a/core/trading_executor_fix.py b/core/trading_executor_fix.py
deleted file mode 100644
index 7366f1d..0000000
--- a/core/trading_executor_fix.py
+++ /dev/null
@@ -1,401 +0,0 @@
-"""
-Trading Executor Fix - Addresses issues with entry/exit prices and P&L calculations
-
-This module provides fixes for:
-1. Identical entry prices issue
-2. Price caching problems
-3. Position tracking reset logic
-4. Trade cooldown implementation
-5. P&L calculation verification
-
-Apply these fixes to the TradingExecutor class to improve trade execution reliability.
-"""
-
-import logging
-import time
-from datetime import datetime, timedelta
-from typing import Dict, List, Optional, Any, Union
-
-logger = logging.getLogger(__name__)
-
-class TradingExecutorFix:
- """
- Fixes for the TradingExecutor class to address entry/exit price issues
- and improve P&L calculation accuracy.
- """
-
- def __init__(self, trading_executor):
- """
- Initialize the fix with a reference to the trading executor
-
- Args:
- trading_executor: The TradingExecutor instance to fix
- """
- self.trading_executor = trading_executor
-
- # Add cooldown tracking
- self.last_trade_time = {} # {symbol: timestamp}
- self.min_trade_cooldown = 30 # 30 seconds minimum between trades
-
- # Add price history for validation
- self.recent_entry_prices = {} # {symbol: [recent_prices]}
- self.max_price_history = 10 # Keep last 10 entry prices
-
- # Add position reset tracking
- self.position_reset_flags = {} # {symbol: bool}
-
- # Add price update tracking
- self.last_price_update = {} # {symbol: timestamp}
- self.price_update_threshold = 5 # 5 seconds max since last price update
-
- # Add P&L verification
- self.trade_history = {} # {symbol: [trade_records]}
-
- logger.info("TradingExecutorFix initialized - addressing entry/exit price issues")
-
- def apply_fixes(self):
- """Apply all fixes to the trading executor"""
- self._patch_execute_action()
- self._patch_close_position()
- self._patch_calculate_pnl()
- self._patch_update_prices()
-
- logger.info("All trading executor fixes applied successfully")
-
- def _patch_execute_action(self):
- """Patch the execute_action method to add price validation and cooldown"""
- original_execute_action = self.trading_executor.execute_action
-
- def execute_action_with_fixes(decision):
- """Enhanced execute_action with price validation and cooldown"""
- try:
- symbol = decision.symbol
- action = decision.action
- current_time = datetime.now()
-
- # 1. Check cooldown period
- if symbol in self.last_trade_time:
- time_since_last_trade = (current_time - self.last_trade_time[symbol]).total_seconds()
- if time_since_last_trade < self.min_trade_cooldown:
- logger.warning(f"Trade rejected: Cooldown period ({time_since_last_trade:.1f}s < {self.min_trade_cooldown}s) for {symbol}")
- return False
-
- # 2. Validate price freshness
- if symbol in self.last_price_update:
- time_since_update = (current_time - self.last_price_update[symbol]).total_seconds()
- if time_since_update > self.price_update_threshold:
- logger.warning(f"Trade rejected: Price data stale ({time_since_update:.1f}s > {self.price_update_threshold}s) for {symbol}")
- # Force price refresh
- self._refresh_price(symbol)
- return False
-
- # 3. Validate entry price against recent history
- current_price = self._get_current_price(symbol)
- if symbol in self.recent_entry_prices and len(self.recent_entry_prices[symbol]) > 0:
- # Check if price is identical to any recent entry
- if current_price in self.recent_entry_prices[symbol]:
- logger.warning(f"Trade rejected: Duplicate entry price ${current_price} for {symbol}")
- return False
-
- # 4. Ensure position is properly reset before new entry
- if not self._ensure_position_reset(symbol):
- logger.warning(f"Trade rejected: Position not properly reset for {symbol}")
- return False
-
- # Execute the original action
- result = original_execute_action(decision)
-
- # If successful, update tracking
- if result:
- # Update cooldown timestamp
- self.last_trade_time[symbol] = current_time
-
- # Update price history
- if symbol not in self.recent_entry_prices:
- self.recent_entry_prices[symbol] = []
-
- self.recent_entry_prices[symbol].append(current_price)
- # Keep only the most recent prices
- if len(self.recent_entry_prices[symbol]) > self.max_price_history:
- self.recent_entry_prices[symbol] = self.recent_entry_prices[symbol][-self.max_price_history:]
-
- # Mark position as active
- self.position_reset_flags[symbol] = False
-
- logger.info(f"Trade executed: {action} {symbol} at ${current_price} with validation")
-
- return result
-
- except Exception as e:
- logger.error(f"Error in execute_action_with_fixes: {e}")
- return original_execute_action(decision)
-
- # Replace the original method
- self.trading_executor.execute_action = execute_action_with_fixes
- logger.info("Patched execute_action with price validation and cooldown")
-
- def _patch_close_position(self):
- """Patch the close_position method to ensure proper position reset"""
- original_close_position = self.trading_executor.close_position
-
- def close_position_with_fixes(symbol, **kwargs):
- """Enhanced close_position with proper reset logic"""
- try:
- # Get current price for P&L verification
- exit_price = self._get_current_price(symbol)
-
- # Call original close position
- result = original_close_position(symbol, **kwargs)
-
- if result:
- # Mark position as reset
- self.position_reset_flags[symbol] = True
-
- # Record trade for verification
- if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
- position = self.trading_executor.positions[symbol]
-
- # Create trade record
- trade_record = {
- 'symbol': symbol,
- 'entry_time': getattr(position, 'entry_time', datetime.now()),
- 'exit_time': datetime.now(),
- 'entry_price': getattr(position, 'entry_price', 0),
- 'exit_price': exit_price,
- 'size': getattr(position, 'size', 0),
- 'side': getattr(position, 'side', 'UNKNOWN'),
- 'pnl': self._calculate_verified_pnl(position, exit_price),
- 'fees': getattr(position, 'fees', 0),
- 'hold_time_seconds': (datetime.now() - getattr(position, 'entry_time', datetime.now())).total_seconds()
- }
-
- # Store trade record
- if symbol not in self.trade_history:
- self.trade_history[symbol] = []
- self.trade_history[symbol].append(trade_record)
-
- logger.info(f"Position closed: {symbol} at ${exit_price} with verified P&L: ${trade_record['pnl']:.2f}")
-
- return result
-
- except Exception as e:
- logger.error(f"Error in close_position_with_fixes: {e}")
- return original_close_position(symbol, **kwargs)
-
- # Replace the original method
- self.trading_executor.close_position = close_position_with_fixes
- logger.info("Patched close_position with proper reset logic")
-
- def _patch_calculate_pnl(self):
- """Patch the calculate_pnl method to ensure accurate P&L calculation"""
- original_calculate_pnl = getattr(self.trading_executor, 'calculate_pnl', None)
-
- def calculate_pnl_with_fixes(position, current_price=None):
- """Enhanced calculate_pnl with verification"""
- try:
- # If no original method, implement our own
- if original_calculate_pnl is None:
- return self._calculate_verified_pnl(position, current_price)
-
- # Call original method
- original_pnl = original_calculate_pnl(position, current_price)
-
- # Calculate our verified P&L
- verified_pnl = self._calculate_verified_pnl(position, current_price)
-
- # If there's a significant difference, log it
- if abs(original_pnl - verified_pnl) > 0.01:
- logger.warning(f"P&L calculation discrepancy: original=${original_pnl:.2f}, verified=${verified_pnl:.2f}")
- # Use the verified P&L
- return verified_pnl
-
- return original_pnl
-
- except Exception as e:
- logger.error(f"Error in calculate_pnl_with_fixes: {e}")
- if original_calculate_pnl:
- return original_calculate_pnl(position, current_price)
- return 0.0
-
- # Replace the original method if it exists
- if original_calculate_pnl:
- self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
- logger.info("Patched calculate_pnl with verification")
- else:
- # Add the method if it doesn't exist
- self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
- logger.info("Added calculate_pnl method with verification")
-
- def _patch_update_prices(self):
- """Patch the update_prices method to track price updates"""
- original_update_prices = getattr(self.trading_executor, 'update_prices', None)
-
- def update_prices_with_tracking(prices):
- """Enhanced update_prices with timestamp tracking"""
- try:
- # Call original method if it exists
- if original_update_prices:
- result = original_update_prices(prices)
- else:
- # If no original method, update prices directly
- if hasattr(self.trading_executor, 'current_prices'):
- self.trading_executor.current_prices.update(prices)
- result = True
-
- # Track update timestamps
- current_time = datetime.now()
- for symbol in prices:
- self.last_price_update[symbol] = current_time
-
- return result
-
- except Exception as e:
- logger.error(f"Error in update_prices_with_tracking: {e}")
- if original_update_prices:
- return original_update_prices(prices)
- return False
-
- # Replace the original method if it exists
- if original_update_prices:
- self.trading_executor.update_prices = update_prices_with_tracking
- logger.info("Patched update_prices with timestamp tracking")
- else:
- # Add the method if it doesn't exist
- self.trading_executor.update_prices = update_prices_with_tracking
- logger.info("Added update_prices method with timestamp tracking")
-
- def _calculate_verified_pnl(self, position, current_price=None):
- """Calculate verified P&L for a position"""
- try:
- # Get position details
- entry_price = getattr(position, 'entry_price', 0)
- size = getattr(position, 'size', 0)
- side = getattr(position, 'side', 'UNKNOWN')
- leverage = getattr(position, 'leverage', 1.0)
- fees = getattr(position, 'fees', 0.0)
-
- # If current_price is not provided, try to get it
- if current_price is None:
- symbol = getattr(position, 'symbol', None)
- if symbol:
- current_price = self._get_current_price(symbol)
- else:
- return 0.0
-
- # Calculate P&L based on position side
- if side == 'LONG':
- pnl = (current_price - entry_price) * size * leverage
- elif side == 'SHORT':
- pnl = (entry_price - current_price) * size * leverage
- else:
- pnl = 0.0
-
- # Subtract fees for net P&L
- net_pnl = pnl - fees
-
- return net_pnl
-
- except Exception as e:
- logger.error(f"Error calculating verified P&L: {e}")
- return 0.0
-
- def _get_current_price(self, symbol):
- """Get current price for a symbol with fallbacks"""
- try:
- # Try to get from trading executor
- if hasattr(self.trading_executor, 'current_prices') and symbol in self.trading_executor.current_prices:
- return self.trading_executor.current_prices[symbol]
-
- # Try to get from data provider
- if hasattr(self.trading_executor, 'data_provider'):
- data_provider = self.trading_executor.data_provider
- if hasattr(data_provider, 'get_current_price'):
- price = data_provider.get_current_price(symbol)
- if price and price > 0:
- return price
-
- # Try to get from COB data
- if hasattr(self.trading_executor, 'latest_cob_data') and symbol in self.trading_executor.latest_cob_data:
- cob_data = self.trading_executor.latest_cob_data[symbol]
- if hasattr(cob_data, 'stats') and 'mid_price' in cob_data.stats:
- return cob_data.stats['mid_price']
-
- # Default fallback
- return 0.0
-
- except Exception as e:
- logger.error(f"Error getting current price for {symbol}: {e}")
- return 0.0
-
- def _refresh_price(self, symbol):
- """Force a price refresh for a symbol"""
- try:
- # Try to refresh from data provider
- if hasattr(self.trading_executor, 'data_provider'):
- data_provider = self.trading_executor.data_provider
- if hasattr(data_provider, 'fetch_current_price'):
- price = data_provider.fetch_current_price(symbol)
- if price and price > 0:
- # Update trading executor price
- if hasattr(self.trading_executor, 'current_prices'):
- self.trading_executor.current_prices[symbol] = price
-
- # Update timestamp
- self.last_price_update[symbol] = datetime.now()
-
- logger.info(f"Refreshed price for {symbol}: ${price:.2f}")
- return True
-
- logger.warning(f"Failed to refresh price for {symbol}")
- return False
-
- except Exception as e:
- logger.error(f"Error refreshing price for {symbol}: {e}")
- return False
-
- def _ensure_position_reset(self, symbol):
- """Ensure position is properly reset before new entry"""
- try:
- # Check if we have an active position
- if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
- # Position exists, check if it's valid
- position = self.trading_executor.positions[symbol]
- if position and getattr(position, 'active', False):
- logger.warning(f"Position already active for {symbol}, cannot enter new position")
- return False
-
- # Check reset flag
- if symbol in self.position_reset_flags and not self.position_reset_flags[symbol]:
- # Force position cleanup
- if hasattr(self.trading_executor, 'positions'):
- self.trading_executor.positions.pop(symbol, None)
-
- logger.info(f"Forced position reset for {symbol}")
- self.position_reset_flags[symbol] = True
-
- return True
-
- except Exception as e:
- logger.error(f"Error ensuring position reset for {symbol}: {e}")
- return False
-
- def get_trade_history(self, symbol=None):
- """Get verified trade history"""
- if symbol:
- return self.trade_history.get(symbol, [])
- return self.trade_history
-
- def get_price_update_status(self):
- """Get price update status for all symbols"""
- status = {}
- current_time = datetime.now()
-
- for symbol, timestamp in self.last_price_update.items():
- time_since_update = (current_time - timestamp).total_seconds()
- status[symbol] = {
- 'last_update': timestamp,
- 'seconds_ago': time_since_update,
- 'is_fresh': time_since_update <= self.price_update_threshold
- }
-
- return status
\ No newline at end of file
diff --git a/core/training_data_collector.py b/core/training_data_collector.py
deleted file mode 100644
index 3e43c6d..0000000
--- a/core/training_data_collector.py
+++ /dev/null
@@ -1,795 +0,0 @@
-"""
-Comprehensive Training Data Collection System
-
-This module implements a robust training data collection system that:
-1. Captures all model inputs with validation and completeness checks
-2. Stores training data packages with future outcome validation
-3. Detects rapid price changes for high-value training examples
-4. Enables replay and retraining on most profitable setups
-5. Maintains data integrity and traceability
-
-Key Features:
-- Real-time data package creation with all model inputs
-- Future outcome validation (profitable vs unprofitable predictions)
-- Rapid price change detection for premium training examples
-- Comprehensive data validation and completeness verification
-- Backpropagation data storage for gradient replay
-- Training episode profitability tracking and ranking
-"""
-
-import asyncio
-import json
-import logging
-import numpy as np
-import pandas as pd
-import pickle
-import torch
-from datetime import datetime, timedelta
-from pathlib import Path
-from typing import Dict, List, Optional, Tuple, Any, Callable
-from dataclasses import dataclass, field, asdict
-from collections import deque
-import hashlib
-import threading
-from concurrent.futures import ThreadPoolExecutor
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class ModelInputPackage:
- """Complete package of all model inputs at a specific timestamp"""
- timestamp: datetime
- symbol: str
-
- # Market data inputs
- ohlcv_data: Dict[str, pd.DataFrame] # {timeframe: DataFrame}
- tick_data: List[Dict[str, Any]] # Raw tick data
- cob_data: Dict[str, Any] # Consolidated Order Book data
- technical_indicators: Dict[str, float] # All technical indicators
- pivot_points: List[Dict[str, Any]] # Detected pivot points
-
- # Model-specific inputs
- cnn_features: np.ndarray # CNN input features
- rl_state: np.ndarray # RL state representation
- orchestrator_context: Dict[str, Any] # Orchestrator context
-
- # Cross-model inputs (outputs from other models)
- cnn_predictions: Optional[Dict[str, Any]] = None
- rl_predictions: Optional[Dict[str, Any]] = None
- orchestrator_decision: Optional[Dict[str, Any]] = None
-
- # Data validation
- data_hash: str = ""
- completeness_score: float = 0.0
- validation_flags: Dict[str, bool] = field(default_factory=dict)
-
- def __post_init__(self):
- """Calculate data hash and completeness after initialization"""
- self.data_hash = self._calculate_hash()
- self.completeness_score = self._calculate_completeness()
- self.validation_flags = self._validate_data()
-
- def _calculate_hash(self) -> str:
- """Calculate hash for data integrity verification"""
- try:
- # Create a string representation of all data
- data_str = f"{self.timestamp}_{self.symbol}"
- data_str += f"_{len(self.ohlcv_data)}_{len(self.tick_data)}"
- data_str += f"_{self.cnn_features.shape if self.cnn_features is not None else 'None'}"
- data_str += f"_{self.rl_state.shape if self.rl_state is not None else 'None'}"
-
- return hashlib.md5(data_str.encode()).hexdigest()
- except Exception as e:
- logger.warning(f"Error calculating data hash: {e}")
- return "invalid_hash"
-
- def _calculate_completeness(self) -> float:
- """Calculate completeness score (0.0 to 1.0)"""
- try:
- total_fields = 10 # Total expected data fields
- complete_fields = 0
-
- # Check each required field
- if self.ohlcv_data and len(self.ohlcv_data) > 0:
- complete_fields += 1
- if self.tick_data and len(self.tick_data) > 0:
- complete_fields += 1
- if self.cob_data and len(self.cob_data) > 0:
- complete_fields += 1
- if self.technical_indicators and len(self.technical_indicators) > 0:
- complete_fields += 1
- if self.pivot_points and len(self.pivot_points) > 0:
- complete_fields += 1
- if self.cnn_features is not None and self.cnn_features.size > 0:
- complete_fields += 1
- if self.rl_state is not None and self.rl_state.size > 0:
- complete_fields += 1
- if self.orchestrator_context and len(self.orchestrator_context) > 0:
- complete_fields += 1
- if self.cnn_predictions is not None:
- complete_fields += 1
- if self.rl_predictions is not None:
- complete_fields += 1
-
- return complete_fields / total_fields
- except Exception as e:
- logger.warning(f"Error calculating completeness: {e}")
- return 0.0
-
- def _validate_data(self) -> Dict[str, bool]:
- """Validate data integrity and consistency"""
- flags = {}
-
- try:
- # Validate timestamp
- flags['valid_timestamp'] = isinstance(self.timestamp, datetime)
-
- # Validate OHLCV data
- flags['valid_ohlcv'] = (
- self.ohlcv_data is not None and
- len(self.ohlcv_data) > 0 and
- all(isinstance(df, pd.DataFrame) for df in self.ohlcv_data.values())
- )
-
- # Validate feature arrays
- flags['valid_cnn_features'] = (
- self.cnn_features is not None and
- isinstance(self.cnn_features, np.ndarray) and
- self.cnn_features.size > 0
- )
-
- flags['valid_rl_state'] = (
- self.rl_state is not None and
- isinstance(self.rl_state, np.ndarray) and
- self.rl_state.size > 0
- )
-
- # Validate data consistency
- flags['data_consistent'] = self.completeness_score > 0.7
-
- except Exception as e:
- logger.warning(f"Error validating data: {e}")
- flags['validation_error'] = True
-
- return flags
-
-@dataclass
-class TrainingOutcome:
- """Future outcome validation for training data"""
- input_package_hash: str
- timestamp: datetime
- symbol: str
-
- # Price movement outcomes
- price_change_1m: float
- price_change_5m: float
- price_change_15m: float
- price_change_1h: float
-
- # Profitability metrics
- max_profit_potential: float
- max_loss_potential: float
- optimal_entry_price: float
- optimal_exit_price: float
- optimal_holding_time: timedelta
-
- # Classification labels
- is_profitable: bool
- profitability_score: float # 0.0 to 1.0
- risk_reward_ratio: float
-
- # Rapid price change detection
- is_rapid_change: bool
- change_velocity: float # Price change per minute
- volatility_spike: bool
-
- # Validation
- outcome_validated: bool = False
- validation_timestamp: datetime = field(default_factory=datetime.now)
-
-@dataclass
-class TrainingEpisode:
- """Complete training episode with inputs, predictions, and outcomes"""
- episode_id: str
- input_package: ModelInputPackage
- model_predictions: Dict[str, Any] # Predictions from all models
- actual_outcome: TrainingOutcome
-
- # Training metadata
- episode_type: str # 'normal', 'rapid_change', 'high_profit'
- profitability_rank: float # Ranking among all episodes
- training_priority: float # Priority for replay training
-
- # Backpropagation data storage
- gradient_data: Optional[Dict[str, torch.Tensor]] = None
- loss_components: Optional[Dict[str, float]] = None
- model_states: Optional[Dict[str, Any]] = None
-
- # Episode statistics
- created_timestamp: datetime = field(default_factory=datetime.now)
- last_trained_timestamp: Optional[datetime] = None
- training_count: int = 0
-
- def calculate_training_priority(self) -> float:
- """Calculate training priority based on profitability and characteristics"""
- try:
- priority = 0.0
-
- # Base priority from profitability
- if self.actual_outcome.is_profitable:
- priority += self.actual_outcome.profitability_score * 0.4
-
- # Bonus for rapid changes (high learning value)
- if self.actual_outcome.is_rapid_change:
- priority += 0.3
-
- # Bonus for high risk-reward ratio
- if self.actual_outcome.risk_reward_ratio > 2.0:
- priority += 0.2
-
- # Bonus for data completeness
- priority += self.input_package.completeness_score * 0.1
-
- # Penalty for frequent training (avoid overfitting)
- if self.training_count > 5:
- priority *= 0.8
-
- return min(priority, 1.0)
-
- except Exception as e:
- logger.warning(f"Error calculating training priority: {e}")
- return 0.0
-
-class RapidChangeDetector:
- """Detects rapid price changes for high-value training examples"""
-
- def __init__(self,
- velocity_threshold: float = 0.5, # % per minute
- volatility_multiplier: float = 3.0,
- lookback_minutes: int = 5):
- self.velocity_threshold = velocity_threshold
- self.volatility_multiplier = volatility_multiplier
- self.lookback_minutes = lookback_minutes
-
- # Price history for change detection
- self.price_history: Dict[str, deque] = {}
- self.volatility_baseline: Dict[str, float] = {}
-
- def add_price_point(self, symbol: str, timestamp: datetime, price: float):
- """Add new price point for change detection"""
- if symbol not in self.price_history:
- self.price_history[symbol] = deque(maxlen=self.lookback_minutes * 60) # 1 second resolution
- self.volatility_baseline[symbol] = 0.0
-
- self.price_history[symbol].append((timestamp, price))
- self._update_volatility_baseline(symbol)
-
- def detect_rapid_change(self, symbol: str) -> Tuple[bool, float, bool]:
- """
- Detect rapid price changes
-
- Returns:
- (is_rapid_change, change_velocity, volatility_spike)
- """
- if symbol not in self.price_history or len(self.price_history[symbol]) < 60:
- return False, 0.0, False
-
- try:
- prices = list(self.price_history[symbol])
-
- # Calculate recent velocity (last minute)
- recent_prices = prices[-60:] # Last 60 seconds
- if len(recent_prices) < 2:
- return False, 0.0, False
-
- start_price = recent_prices[0][1]
- end_price = recent_prices[-1][1]
- time_diff = (recent_prices[-1][0] - recent_prices[0][0]).total_seconds() / 60.0 # minutes
-
- if time_diff <= 0:
- return False, 0.0, False
-
- # Calculate velocity (% change per minute)
- velocity = abs((end_price - start_price) / start_price * 100) / time_diff
-
- # Check for rapid change
- is_rapid = velocity > self.velocity_threshold
-
- # Check for volatility spike
- current_volatility = self._calculate_current_volatility(symbol)
- baseline_volatility = self.volatility_baseline.get(symbol, 0.0)
- volatility_spike = (
- baseline_volatility > 0 and
- current_volatility > baseline_volatility * self.volatility_multiplier
- )
-
- return is_rapid, velocity, volatility_spike
-
- except Exception as e:
- logger.warning(f"Error detecting rapid change for {symbol}: {e}")
- return False, 0.0, False
-
- def _update_volatility_baseline(self, symbol: str):
- """Update volatility baseline for the symbol"""
- try:
- if len(self.price_history[symbol]) < 120: # Need at least 2 minutes of data
- return
-
- # Calculate rolling volatility over longer period
- prices = [p[1] for p in list(self.price_history[symbol])[-300:]] # Last 5 minutes
- if len(prices) < 2:
- return
-
- # Calculate standard deviation of price changes
- price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
- volatility = np.std(price_changes) * 100 # Convert to percentage
-
- # Update baseline with exponential moving average
- alpha = 0.1
- if self.volatility_baseline[symbol] == 0:
- self.volatility_baseline[symbol] = volatility
- else:
- self.volatility_baseline[symbol] = (
- alpha * volatility + (1 - alpha) * self.volatility_baseline[symbol]
- )
-
- except Exception as e:
- logger.warning(f"Error updating volatility baseline for {symbol}: {e}")
-
- def _calculate_current_volatility(self, symbol: str) -> float:
- """Calculate current volatility for the symbol"""
- try:
- if len(self.price_history[symbol]) < 60:
- return 0.0
-
- # Use last minute of data
- recent_prices = [p[1] for p in list(self.price_history[symbol])[-60:]]
- if len(recent_prices) < 2:
- return 0.0
-
- price_changes = [abs(recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
- for i in range(1, len(recent_prices))]
- return np.std(price_changes) * 100
-
- except Exception as e:
- logger.warning(f"Error calculating current volatility for {symbol}: {e}")
- return 0.0
-
-class TrainingDataCollector:
- """Main training data collection system"""
-
- def __init__(self,
- storage_dir: str = "training_data",
- max_episodes_per_symbol: int = 10000,
- outcome_validation_delay: timedelta = timedelta(hours=1)):
-
- self.storage_dir = Path(storage_dir)
- self.storage_dir.mkdir(parents=True, exist_ok=True)
-
- self.max_episodes_per_symbol = max_episodes_per_symbol
- self.outcome_validation_delay = outcome_validation_delay
-
- # Data storage
- self.training_episodes: Dict[str, List[TrainingEpisode]] = {} # {symbol: episodes}
- self.pending_outcomes: Dict[str, List[ModelInputPackage]] = {} # Awaiting outcome validation
-
- # Rapid change detection
- self.rapid_change_detector = RapidChangeDetector()
-
- # Data validation and statistics
- self.collection_stats = {
- 'total_episodes': 0,
- 'profitable_episodes': 0,
- 'rapid_change_episodes': 0,
- 'validation_errors': 0,
- 'data_completeness_avg': 0.0
- }
-
- # Background processing
- self.is_collecting = False
- self.collection_thread = None
- self.outcome_validation_thread = None
-
- # Thread safety
- self.data_lock = threading.Lock()
-
- logger.info(f"Training Data Collector initialized")
- logger.info(f"Storage directory: {self.storage_dir}")
- logger.info(f"Max episodes per symbol: {self.max_episodes_per_symbol}")
-
- def start_collection(self):
- """Start the training data collection system"""
- if self.is_collecting:
- logger.warning("Training data collection already running")
- return
-
- self.is_collecting = True
-
- # Start outcome validation thread
- self.outcome_validation_thread = threading.Thread(
- target=self._outcome_validation_worker,
- daemon=True
- )
- self.outcome_validation_thread.start()
-
- logger.info("Training data collection started")
-
- def stop_collection(self):
- """Stop the training data collection system"""
- self.is_collecting = False
-
- if self.outcome_validation_thread:
- self.outcome_validation_thread.join(timeout=5)
-
- logger.info("Training data collection stopped")
-
- def collect_training_data(self,
- symbol: str,
- ohlcv_data: Dict[str, pd.DataFrame],
- tick_data: List[Dict[str, Any]],
- cob_data: Dict[str, Any],
- technical_indicators: Dict[str, float],
- pivot_points: List[Dict[str, Any]],
- cnn_features: np.ndarray,
- rl_state: np.ndarray,
- orchestrator_context: Dict[str, Any],
- model_predictions: Dict[str, Any] = None) -> str:
- """
- Collect comprehensive training data package
-
- Returns:
- episode_id for tracking
- """
- try:
- # Create input package
- input_package = ModelInputPackage(
- timestamp=datetime.now(),
- symbol=symbol,
- ohlcv_data=ohlcv_data,
- tick_data=tick_data,
- cob_data=cob_data,
- technical_indicators=technical_indicators,
- pivot_points=pivot_points,
- cnn_features=cnn_features,
- rl_state=rl_state,
- orchestrator_context=orchestrator_context
- )
-
- # Validate data completeness
- if input_package.completeness_score < 0.5:
- logger.warning(f"Low data completeness for {symbol}: {input_package.completeness_score:.2f}")
- self.collection_stats['validation_errors'] += 1
- return None
-
- # Check for rapid price changes
- current_price = self._extract_current_price(ohlcv_data)
- if current_price:
- self.rapid_change_detector.add_price_point(symbol, input_package.timestamp, current_price)
-
- # Add to pending outcomes for future validation
- with self.data_lock:
- if symbol not in self.pending_outcomes:
- self.pending_outcomes[symbol] = []
-
- self.pending_outcomes[symbol].append(input_package)
-
- # Limit pending outcomes to prevent memory issues
- if len(self.pending_outcomes[symbol]) > 1000:
- self.pending_outcomes[symbol] = self.pending_outcomes[symbol][-500:]
-
- # Generate episode ID
- episode_id = f"{symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
-
- # Update statistics
- self.collection_stats['total_episodes'] += 1
- self.collection_stats['data_completeness_avg'] = (
- (self.collection_stats['data_completeness_avg'] * (self.collection_stats['total_episodes'] - 1) +
- input_package.completeness_score) / self.collection_stats['total_episodes']
- )
-
- logger.debug(f"Collected training data for {symbol}: {episode_id}")
- logger.debug(f"Data completeness: {input_package.completeness_score:.2f}")
-
- return episode_id
-
- except Exception as e:
- logger.error(f"Error collecting training data for {symbol}: {e}")
- self.collection_stats['validation_errors'] += 1
- return None
-
- def _extract_current_price(self, ohlcv_data: Dict[str, pd.DataFrame]) -> Optional[float]:
- """Extract current price from OHLCV data"""
- try:
- # Try to get price from shortest timeframe first
- for timeframe in ['1s', '1m', '5m', '15m', '1h']:
- if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty:
- return float(ohlcv_data[timeframe]['close'].iloc[-1])
- return None
- except Exception as e:
- logger.warning(f"Error extracting current price: {e}")
- return None
-
- def _outcome_validation_worker(self):
- """Background worker for validating training outcomes"""
- logger.info("Outcome validation worker started")
-
- while self.is_collecting:
- try:
- self._validate_pending_outcomes()
- threading.Event().wait(60) # Check every minute
-
- except Exception as e:
- logger.error(f"Error in outcome validation worker: {e}")
- threading.Event().wait(30) # Wait before retrying
-
- logger.info("Outcome validation worker stopped")
-
- def _validate_pending_outcomes(self):
- """Validate outcomes for pending training data"""
- current_time = datetime.now()
-
- with self.data_lock:
- for symbol in list(self.pending_outcomes.keys()):
- if symbol not in self.pending_outcomes:
- continue
-
- validated_packages = []
- remaining_packages = []
-
- for package in self.pending_outcomes[symbol]:
- # Check if enough time has passed for outcome validation
- if current_time - package.timestamp >= self.outcome_validation_delay:
- outcome = self._calculate_training_outcome(package)
- if outcome:
- self._create_training_episode(package, outcome)
- validated_packages.append(package)
- else:
- remaining_packages.append(package)
- else:
- remaining_packages.append(package)
-
- # Update pending outcomes
- self.pending_outcomes[symbol] = remaining_packages
-
- if validated_packages:
- logger.info(f"Validated {len(validated_packages)} outcomes for {symbol}")
-
- def _calculate_training_outcome(self, input_package: ModelInputPackage) -> Optional[TrainingOutcome]:
- """Calculate training outcome based on future price movements"""
- try:
- # This would typically fetch recent price data to calculate outcomes
- # For now, we'll create a placeholder implementation
-
- # Extract base price from input package
- base_price = self._extract_current_price(input_package.ohlcv_data)
- if not base_price:
- return None
-
- # Simulate outcome calculation (in real implementation, fetch actual future prices)
- # This is where you would integrate with your data provider to get actual outcomes
-
- # Check for rapid change
- is_rapid, velocity, volatility_spike = self.rapid_change_detector.detect_rapid_change(
- input_package.symbol
- )
-
- # Create outcome (placeholder values - replace with actual calculation)
- outcome = TrainingOutcome(
- input_package_hash=input_package.data_hash,
- timestamp=input_package.timestamp,
- symbol=input_package.symbol,
- price_change_1m=0.0, # Calculate from actual future data
- price_change_5m=0.0,
- price_change_15m=0.0,
- price_change_1h=0.0,
- max_profit_potential=0.0,
- max_loss_potential=0.0,
- optimal_entry_price=base_price,
- optimal_exit_price=base_price,
- optimal_holding_time=timedelta(minutes=5),
- is_profitable=False, # Determine from actual outcomes
- profitability_score=0.0,
- risk_reward_ratio=1.0,
- is_rapid_change=is_rapid,
- change_velocity=velocity,
- volatility_spike=volatility_spike,
- outcome_validated=True
- )
-
- return outcome
-
- except Exception as e:
- logger.error(f"Error calculating training outcome: {e}")
- return None
-
- def _create_training_episode(self, input_package: ModelInputPackage, outcome: TrainingOutcome):
- """Create complete training episode"""
- try:
- episode_id = f"{input_package.symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
-
- # Determine episode type
- episode_type = 'normal'
- if outcome.is_rapid_change:
- episode_type = 'rapid_change'
- self.collection_stats['rapid_change_episodes'] += 1
- elif outcome.profitability_score > 0.8:
- episode_type = 'high_profit'
-
- if outcome.is_profitable:
- self.collection_stats['profitable_episodes'] += 1
-
- # Create training episode
- episode = TrainingEpisode(
- episode_id=episode_id,
- input_package=input_package,
- model_predictions={}, # Will be filled when models make predictions
- actual_outcome=outcome,
- episode_type=episode_type,
- profitability_rank=0.0, # Will be calculated later
- training_priority=0.0
- )
-
- # Calculate training priority
- episode.training_priority = episode.calculate_training_priority()
-
- # Store episode
- symbol = input_package.symbol
- if symbol not in self.training_episodes:
- self.training_episodes[symbol] = []
-
- self.training_episodes[symbol].append(episode)
-
- # Limit episodes per symbol
- if len(self.training_episodes[symbol]) > self.max_episodes_per_symbol:
- # Keep highest priority episodes
- self.training_episodes[symbol].sort(key=lambda x: x.training_priority, reverse=True)
- self.training_episodes[symbol] = self.training_episodes[symbol][:self.max_episodes_per_symbol]
-
- # Save episode to disk
- self._save_episode_to_disk(episode)
-
- logger.debug(f"Created training episode: {episode_id}")
- logger.debug(f"Episode type: {episode_type}, Priority: {episode.training_priority:.3f}")
-
- except Exception as e:
- logger.error(f"Error creating training episode: {e}")
-
- def _save_episode_to_disk(self, episode: TrainingEpisode):
- """Save training episode to disk for persistence"""
- try:
- symbol_dir = self.storage_dir / episode.input_package.symbol
- symbol_dir.mkdir(parents=True, exist_ok=True)
-
- # Save episode data
- episode_file = symbol_dir / f"{episode.episode_id}.pkl"
- with open(episode_file, 'wb') as f:
- pickle.dump(episode, f)
-
- # Save episode metadata for quick access
- metadata = {
- 'episode_id': episode.episode_id,
- 'timestamp': episode.input_package.timestamp.isoformat(),
- 'episode_type': episode.episode_type,
- 'training_priority': episode.training_priority,
- 'profitability_score': episode.actual_outcome.profitability_score,
- 'is_profitable': episode.actual_outcome.is_profitable,
- 'is_rapid_change': episode.actual_outcome.is_rapid_change,
- 'data_completeness': episode.input_package.completeness_score
- }
-
- metadata_file = symbol_dir / f"{episode.episode_id}_metadata.json"
- with open(metadata_file, 'w') as f:
- json.dump(metadata, f, indent=2)
-
- except Exception as e:
- logger.error(f"Error saving episode to disk: {e}")
-
- def get_high_priority_episodes(self,
- symbol: str,
- limit: int = 100,
- min_priority: float = 0.5) -> List[TrainingEpisode]:
- """Get high-priority training episodes for replay training"""
- try:
- if symbol not in self.training_episodes:
- return []
-
- # Filter and sort by priority
- high_priority = [
- ep for ep in self.training_episodes[symbol]
- if ep.training_priority >= min_priority
- ]
-
- high_priority.sort(key=lambda x: x.training_priority, reverse=True)
-
- return high_priority[:limit]
-
- except Exception as e:
- logger.error(f"Error getting high priority episodes for {symbol}: {e}")
- return []
-
- def get_collection_statistics(self) -> Dict[str, Any]:
- """Get comprehensive collection statistics"""
- stats = self.collection_stats.copy()
-
- # Add per-symbol statistics
- stats['episodes_per_symbol'] = {
- symbol: len(episodes)
- for symbol, episodes in self.training_episodes.items()
- }
-
- # Add pending outcomes count
- stats['pending_outcomes'] = {
- symbol: len(packages)
- for symbol, packages in self.pending_outcomes.items()
- }
-
- # Calculate profitability rate
- if stats['total_episodes'] > 0:
- stats['profitability_rate'] = stats['profitable_episodes'] / stats['total_episodes']
- stats['rapid_change_rate'] = stats['rapid_change_episodes'] / stats['total_episodes']
- else:
- stats['profitability_rate'] = 0.0
- stats['rapid_change_rate'] = 0.0
-
- return stats
-
- def validate_data_integrity(self) -> Dict[str, Any]:
- """Comprehensive data integrity validation"""
- validation_results = {
- 'total_episodes_checked': 0,
- 'hash_mismatches': 0,
- 'completeness_issues': 0,
- 'validation_flag_failures': 0,
- 'corrupted_episodes': [],
- 'integrity_score': 1.0
- }
-
- try:
- for symbol, episodes in self.training_episodes.items():
- for episode in episodes:
- validation_results['total_episodes_checked'] += 1
-
- # Check data hash
- expected_hash = episode.input_package._calculate_hash()
- if expected_hash != episode.input_package.data_hash:
- validation_results['hash_mismatches'] += 1
- validation_results['corrupted_episodes'].append(episode.episode_id)
-
- # Check completeness
- if episode.input_package.completeness_score < 0.7:
- validation_results['completeness_issues'] += 1
-
- # Check validation flags
- if not episode.input_package.validation_flags.get('data_consistent', False):
- validation_results['validation_flag_failures'] += 1
-
- # Calculate integrity score
- total_issues = (
- validation_results['hash_mismatches'] +
- validation_results['completeness_issues'] +
- validation_results['validation_flag_failures']
- )
-
- if validation_results['total_episodes_checked'] > 0:
- validation_results['integrity_score'] = 1.0 - (
- total_issues / validation_results['total_episodes_checked']
- )
-
- logger.info(f"Data integrity validation completed")
- logger.info(f"Integrity score: {validation_results['integrity_score']:.3f}")
-
- except Exception as e:
- logger.error(f"Error during data integrity validation: {e}")
- validation_results['validation_error'] = str(e)
-
- return validation_results
-
-# Global instance for easy access
-training_data_collector = None
-
-def get_training_data_collector() -> TrainingDataCollector:
- """Get global training data collector instance"""
- global training_data_collector
- if training_data_collector is None:
- training_data_collector = TrainingDataCollector()
- return training_data_collector
\ No newline at end of file
diff --git a/dataprovider_realtime.py b/dataprovider_realtime.py
deleted file mode 100644
index 2072041..0000000
--- a/dataprovider_realtime.py
+++ /dev/null
@@ -1,2490 +0,0 @@
-# import asyncio
-# import json
-# import logging
-
-# # Fix PIL import issue that causes plotly JSON serialization errors
-# import os
-# os.environ['MPLBACKEND'] = 'Agg' # Use non-interactive backend
-# try:
-# # Try to fix PIL import issue
-# import PIL.Image
-# # Disable PIL in plotly to prevent circular import issues
-# import plotly.io as pio
-# pio.kaleido.scope.default_format = "png"
-# except ImportError:
-# pass
-# except Exception:
-# # Suppress any PIL-related errors during import
-# pass
-
-# from typing import Dict, List, Optional, Tuple, Union
-# import websockets
-# import plotly.graph_objects as go
-# from plotly.subplots import make_subplots
-# import dash
-# from dash import html, dcc
-# from dash.dependencies import Input, Output
-# import pandas as pd
-# import numpy as np
-# from collections import deque
-# import time
-# from threading import Thread
-# import requests
-# import os
-# from datetime import datetime, timedelta
-# import pytz
-# import tzlocal
-# import threading
-# import random
-# import dash_bootstrap_components as dbc
-# import uuid
-# import ta
-# from sklearn.preprocessing import MinMaxScaler
-# import re
-# import psutil
-# import gc
-# import websocket
-
-# # Import psycopg2 with error handling
-# try:
-# import psycopg2
-# PSYCOPG2_AVAILABLE = True
-# except ImportError:
-# PSYCOPG2_AVAILABLE = False
-# psycopg2 = None
-
-# # TimescaleDB configuration from environment variables
-# TIMESCALEDB_ENABLED = os.environ.get('TIMESCALEDB_ENABLED', '1') == '1' and PSYCOPG2_AVAILABLE
-# TIMESCALEDB_HOST = os.environ.get('TIMESCALEDB_HOST', '192.168.0.10')
-# TIMESCALEDB_PORT = int(os.environ.get('TIMESCALEDB_PORT', '5432'))
-# TIMESCALEDB_USER = os.environ.get('TIMESCALEDB_USER', 'postgres')
-# TIMESCALEDB_PASSWORD = os.environ.get('TIMESCALEDB_PASSWORD', 'timescaledbpass')
-# TIMESCALEDB_DB = os.environ.get('TIMESCALEDB_DB', 'candles')
-
-# class TimescaleDBHandler:
-# """Handler for TimescaleDB operations for candle storage and retrieval"""
-
-# def __init__(self):
-# """Initialize TimescaleDB connection if enabled"""
-# self.enabled = TIMESCALEDB_ENABLED
-# self.conn = None
-
-# if not self.enabled:
-# if not PSYCOPG2_AVAILABLE:
-# print("psycopg2 module not available. TimescaleDB integration disabled.")
-# return
-
-# try:
-# # Connect to TimescaleDB
-# self.conn = psycopg2.connect(
-# host=TIMESCALEDB_HOST,
-# port=TIMESCALEDB_PORT,
-# user=TIMESCALEDB_USER,
-# password=TIMESCALEDB_PASSWORD,
-# dbname=TIMESCALEDB_DB
-# )
-# print(f"Connected to TimescaleDB at {TIMESCALEDB_HOST}:{TIMESCALEDB_PORT}")
-
-# # Ensure the candles table exists
-# self._ensure_table()
-
-# print("TimescaleDB integration initialized successfully")
-# except Exception as e:
-# print(f"Error connecting to TimescaleDB: {str(e)}")
-# self.enabled = False
-# self.conn = None
-
-# def _ensure_table(self):
-# """Ensure the candles table exists with TimescaleDB hypertable"""
-# if not self.conn:
-# return
-
-# try:
-# with self.conn.cursor() as cur:
-# # Create the candles table if it doesn't exist
-# cur.execute('''
-# CREATE TABLE IF NOT EXISTS candles (
-# symbol TEXT,
-# interval TEXT,
-# timestamp TIMESTAMPTZ,
-# open DOUBLE PRECISION,
-# high DOUBLE PRECISION,
-# low DOUBLE PRECISION,
-# close DOUBLE PRECISION,
-# volume DOUBLE PRECISION,
-# PRIMARY KEY (symbol, interval, timestamp)
-# );
-# ''')
-
-# # Check if the table is already a hypertable
-# cur.execute('''
-# SELECT EXISTS (
-# SELECT 1 FROM timescaledb_information.hypertables
-# WHERE hypertable_name = 'candles'
-# );
-# ''')
-# is_hypertable = cur.fetchone()[0]
-
-# # Convert to hypertable if not already done
-# if not is_hypertable:
-# cur.execute('''
-# SELECT create_hypertable('candles', 'timestamp',
-# if_not_exists => TRUE,
-# migrate_data => TRUE
-# );
-# ''')
-
-# self.conn.commit()
-# print("TimescaleDB table structure verified")
-# except Exception as e:
-# print(f"Error setting up TimescaleDB tables: {str(e)}")
-# self.enabled = False
-
-# def upsert_candle(self, symbol, interval, candle):
-# """Insert or update a candle in TimescaleDB"""
-# if not self.enabled or not self.conn:
-# return False
-
-# try:
-# with self.conn.cursor() as cur:
-# cur.execute('''
-# INSERT INTO candles (
-# symbol, interval, timestamp,
-# open, high, low, close, volume
-# )
-# VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
-# ON CONFLICT (symbol, interval, timestamp)
-# DO UPDATE SET
-# open = EXCLUDED.open,
-# high = EXCLUDED.high,
-# low = EXCLUDED.low,
-# close = EXCLUDED.close,
-# volume = EXCLUDED.volume
-# ''', (
-# symbol, interval, candle['timestamp'],
-# candle['open'], candle['high'], candle['low'],
-# candle['close'], candle['volume']
-# ))
-# self.conn.commit()
-# return True
-# except Exception as e:
-# print(f"Error upserting candle to TimescaleDB: {str(e)}")
-# # Try to reconnect on error
-# try:
-# self.conn = psycopg2.connect(
-# host=TIMESCALEDB_HOST,
-# port=TIMESCALEDB_PORT,
-# user=TIMESCALEDB_USER,
-# password=TIMESCALEDB_PASSWORD,
-# dbname=TIMESCALEDB_DB
-# )
-# except:
-# pass
-# return False
-
-# def fetch_candles(self, symbol, interval, limit=1000):
-# """Fetch candles from TimescaleDB"""
-# if not self.enabled or not self.conn:
-# return []
-
-# try:
-# with self.conn.cursor() as cur:
-# cur.execute('''
-# SELECT timestamp, open, high, low, close, volume
-# FROM candles
-# WHERE symbol = %s AND interval = %s
-# ORDER BY timestamp DESC
-# LIMIT %s
-# ''', (symbol, interval, limit))
-
-# rows = cur.fetchall()
-
-# # Convert to list of dictionaries (ordered from oldest to newest)
-# candles = []
-# for row in reversed(rows): # Reverse to get oldest first
-# candle = {
-# 'timestamp': row[0],
-# 'open': row[1],
-# 'high': row[2],
-# 'low': row[3],
-# 'close': row[4],
-# 'volume': row[5]
-# }
-# candles.append(candle)
-
-# return candles
-# except Exception as e:
-# print(f"Error fetching candles from TimescaleDB: {str(e)}")
-# # Try to reconnect on error
-# try:
-# self.conn = psycopg2.connect(
-# host=TIMESCALEDB_HOST,
-# port=TIMESCALEDB_PORT,
-# user=TIMESCALEDB_USER,
-# password=TIMESCALEDB_PASSWORD,
-# dbname=TIMESCALEDB_DB
-# )
-# except:
-# pass
-# return []
-
-# class BinanceHistoricalData:
-# """
-# Class for fetching historical price data from Binance.
-# """
-# def __init__(self):
-# self.base_url = "https://api.binance.com/api/v3"
-# self.cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache')
-# if not os.path.exists(self.cache_dir):
-# os.makedirs(self.cache_dir)
-# # Timestamp of last data update
-# self.last_update = None
-
-# def get_historical_candles(self, symbol, interval_seconds=3600, limit=1000):
-# """
-# Fetch historical candles from Binance API.
-
-# Args:
-# symbol (str): Trading pair symbol (e.g., "BTC/USDT")
-# interval_seconds (int): Timeframe in seconds (e.g., 3600 for 1h)
-# limit (int): Number of candles to fetch
-
-# Returns:
-# pd.DataFrame: DataFrame with OHLCV data
-# """
-# # Convert interval_seconds to Binance interval format
-# interval_map = {
-# 1: "1s",
-# 60: "1m",
-# 300: "5m",
-# 900: "15m",
-# 1800: "30m",
-# 3600: "1h",
-# 14400: "4h",
-# 86400: "1d"
-# }
-
-# interval = interval_map.get(interval_seconds, "1h")
-
-# # Format symbol for Binance API (remove slash and make uppercase)
-# formatted_symbol = symbol.replace("/", "").upper()
-
-# # Check if we have cached data first
-# cache_file = self._get_cache_filename(formatted_symbol, interval)
-# cached_data = self._load_from_cache(formatted_symbol, interval)
-
-# # If we have cached data that's recent enough, use it
-# if cached_data is not None and len(cached_data) >= limit:
-# cache_age_minutes = (datetime.now() - self.last_update).total_seconds() / 60 if self.last_update else 60
-# if cache_age_minutes < 15: # Only use cache if it's less than 15 minutes old
-# logger.info(f"Using cached historical data for {symbol} ({interval})")
-# return cached_data
-
-# try:
-# # Build URL for klines endpoint
-# url = f"{self.base_url}/klines"
-# params = {
-# "symbol": formatted_symbol,
-# "interval": interval,
-# "limit": limit
-# }
-
-# # Make the request
-# response = requests.get(url, params=params)
-# response.raise_for_status()
-
-# # Parse the response
-# data = response.json()
-
-# # Create dataframe
-# df = pd.DataFrame(data, columns=[
-# "timestamp", "open", "high", "low", "close", "volume",
-# "close_time", "quote_asset_volume", "number_of_trades",
-# "taker_buy_base_asset_volume", "taker_buy_quote_asset_volume", "ignore"
-# ])
-
-# # Convert timestamp to datetime
-# df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
-
-# # Convert price columns to float
-# for col in ["open", "high", "low", "close", "volume"]:
-# df[col] = df[col].astype(float)
-
-# # Sort by timestamp
-# df = df.sort_values("timestamp")
-
-# # Save to cache for future use
-# self._save_to_cache(df, formatted_symbol, interval)
-# self.last_update = datetime.now()
-
-# logger.info(f"Fetched {len(df)} candles for {symbol} ({interval})")
-# return df
-
-# except Exception as e:
-# logger.error(f"Error fetching historical data from Binance: {str(e)}")
-# # Return cached data if we have it, even if it's not enough
-# if cached_data is not None:
-# logger.warning(f"Using cached data instead (may be incomplete)")
-# return cached_data
-# # Return empty dataframe on error
-# return pd.DataFrame()
-
-# def _get_cache_filename(self, symbol, interval):
-# """Get filename for cache file"""
-# return os.path.join(self.cache_dir, f"{symbol}_{interval}_candles.csv")
-
-# def _load_from_cache(self, symbol, interval):
-# """Load candles from cache file"""
-# try:
-# cache_file = self._get_cache_filename(symbol, interval)
-# if os.path.exists(cache_file):
-# # For 1s interval, check if the cache is recent (less than 10 minutes old)
-# if interval == "1s" or interval == 1:
-# file_mod_time = datetime.fromtimestamp(os.path.getmtime(cache_file))
-# time_diff = (datetime.now() - file_mod_time).total_seconds() / 60
-# if time_diff > 10:
-# logger.info("1s cache is older than 10 minutes, skipping load")
-# return None
-# logger.info(f"Using recent 1s cache (age: {time_diff:.1f} minutes)")
-
-# df = pd.read_csv(cache_file)
-# df["timestamp"] = pd.to_datetime(df["timestamp"])
-# logger.info(f"Loaded {len(df)} candles from cache: {cache_file}")
-# return df
-# except Exception as e:
-# logger.error(f"Error loading cached data: {str(e)}")
-# return None
-
-# def _save_to_cache(self, df, symbol, interval):
-# """Save candles to cache file"""
-# try:
-# cache_file = self._get_cache_filename(symbol, interval)
-# df.to_csv(cache_file, index=False)
-# logger.info(f"Saved {len(df)} candles to cache: {cache_file}")
-# return True
-# except Exception as e:
-# logger.error(f"Error saving to cache: {str(e)}")
-# return False
-
-# def get_recent_trades(self, symbol, limit=1000):
-# """Get recent trades for a symbol"""
-# formatted_symbol = symbol.replace("/", "")
-
-# try:
-# url = f"{self.base_url}/trades"
-# params = {
-# "symbol": formatted_symbol,
-# "limit": limit
-# }
-
-# response = requests.get(url, params=params)
-# response.raise_for_status()
-
-# data = response.json()
-
-# # Create dataframe
-# df = pd.DataFrame(data)
-# df["time"] = pd.to_datetime(df["time"], unit="ms")
-# df["price"] = df["price"].astype(float)
-# df["qty"] = df["qty"].astype(float)
-
-# return df
-
-# except Exception as e:
-# logger.error(f"Error fetching recent trades: {str(e)}")
-# return pd.DataFrame()
-
-# class MultiTimeframeDataInterface:
-# """
-# Enhanced Data Interface supporting:
-# - Multiple trading pairs
-# - Multiple timeframes per pair (1s, 1m, 1h, 1d + custom)
-# - Technical indicators
-# - Cross-timeframe normalization
-# - Real-time data updates
-# """
-
-# def __init__(self, symbol=None, timeframes=None, data_dir="data"):
-# """
-# Initialize the data interface.
-
-# Args:
-# symbol (str): Trading pair symbol (e.g., "BTC/USDT")
-# timeframes (list): List of timeframes to use (e.g., ['1m', '5m', '1h', '4h', '1d'])
-# data_dir (str): Directory to store/load datasets
-# """
-# self.symbol = symbol
-# self.timeframes = timeframes or ['1h', '4h', '1d']
-# self.data_dir = data_dir
-# self.scalers = {} # Store scalers for each timeframe
-
-# # Initialize the historical data fetcher
-# self.historical_data = BinanceHistoricalData()
-
-# # Create data directory if it doesn't exist
-# os.makedirs(self.data_dir, exist_ok=True)
-
-# # Initialize empty dataframes for each timeframe
-# self.dataframes = {tf: None for tf in self.timeframes}
-
-# # Store timestamps of last updates per timeframe
-# self.last_updates = {tf: None for tf in self.timeframes}
-
-# # Timeframe mapping (string to seconds)
-# self.timeframe_to_seconds = {
-# '1s': 1,
-# '1m': 60,
-# '5m': 300,
-# '15m': 900,
-# '30m': 1800,
-# '1h': 3600,
-# '4h': 14400,
-# '1d': 86400
-# }
-
-# logger.info(f"MultiTimeframeDataInterface initialized for {symbol} with timeframes {timeframes}")
-
-# def get_data(self, timeframe='1h', n_candles=1000, refresh=False, add_indicators=True):
-# """
-# Fetch historical price data for a given timeframe with optional indicators.
-
-# Args:
-# timeframe (str): Timeframe to fetch data for
-# n_candles (int): Number of candles to fetch
-# refresh (bool): Force refresh of the data
-# add_indicators (bool): Whether to add technical indicators
-
-# Returns:
-# pd.DataFrame: DataFrame with OHLCV data and indicators
-# """
-# # Check if we need to refresh
-# current_time = datetime.now()
-
-# if (not refresh and
-# self.dataframes[timeframe] is not None and
-# self.last_updates[timeframe] is not None and
-# (current_time - self.last_updates[timeframe]).total_seconds() < 60):
-# #logger.info(f"Using cached data for {self.symbol} {timeframe}")
-# return self.dataframes[timeframe]
-
-# interval_seconds = self.timeframe_to_seconds.get(timeframe, 3600)
-
-# # Fetch data
-# df = self.historical_data.get_historical_candles(
-# symbol=self.symbol,
-# interval_seconds=interval_seconds,
-# limit=n_candles
-# )
-
-# if df is None or df.empty:
-# logger.error(f"No data available for {self.symbol} {timeframe}")
-# return None
-
-# # Add indicators if requested
-# if add_indicators:
-# df = self.add_indicators(df)
-
-# # Store in cache
-# self.dataframes[timeframe] = df
-# self.last_updates[timeframe] = current_time
-
-# logger.info(f"Fetched and processed {len(df)} candles for {self.symbol} {timeframe}")
-# return df
-
-# def add_indicators(self, df):
-# """
-# Add comprehensive technical indicators to the dataframe.
-
-# Args:
-# df (pd.DataFrame): DataFrame with OHLCV data
-
-# Returns:
-# pd.DataFrame: DataFrame with added technical indicators
-# """
-# # Make a copy to avoid modifying the original
-# df_copy = df.copy()
-
-# # Basic price indicators
-# df_copy['returns'] = df_copy['close'].pct_change()
-# df_copy['log_returns'] = np.log(df_copy['close'] / df_copy['close'].shift(1))
-
-# # Moving Averages
-# df_copy['sma_7'] = ta.trend.sma_indicator(df_copy['close'], window=7)
-# df_copy['sma_25'] = ta.trend.sma_indicator(df_copy['close'], window=25)
-# df_copy['sma_99'] = ta.trend.sma_indicator(df_copy['close'], window=99)
-# df_copy['ema_9'] = ta.trend.ema_indicator(df_copy['close'], window=9)
-# df_copy['ema_21'] = ta.trend.ema_indicator(df_copy['close'], window=21)
-
-# # MACD
-# macd = ta.trend.MACD(df_copy['close'])
-# df_copy['macd'] = macd.macd()
-# df_copy['macd_signal'] = macd.macd_signal()
-# df_copy['macd_diff'] = macd.macd_diff()
-
-# # RSI
-# df_copy['rsi'] = ta.momentum.rsi(df_copy['close'], window=14)
-
-# # Bollinger Bands
-# bollinger = ta.volatility.BollingerBands(df_copy['close'])
-# df_copy['bb_high'] = bollinger.bollinger_hband()
-# df_copy['bb_low'] = bollinger.bollinger_lband()
-# df_copy['bb_pct'] = bollinger.bollinger_pband()
-
-# # Stochastic Oscillator
-# stoch = ta.momentum.StochasticOscillator(df_copy['high'], df_copy['low'], df_copy['close'])
-# df_copy['stoch_k'] = stoch.stoch()
-# df_copy['stoch_d'] = stoch.stoch_signal()
-
-# # ATR - Average True Range
-# df_copy['atr'] = ta.volatility.average_true_range(df_copy['high'], df_copy['low'], df_copy['close'], window=14)
-
-# # Money Flow Index
-# df_copy['mfi'] = ta.volume.money_flow_index(df_copy['high'], df_copy['low'], df_copy['close'], df_copy['volume'], window=14)
-
-# # OBV - On-Balance Volume
-# df_copy['obv'] = ta.volume.on_balance_volume(df_copy['close'], df_copy['volume'])
-
-# # Ichimoku Cloud
-# ichimoku = ta.trend.IchimokuIndicator(df_copy['high'], df_copy['low'])
-# df_copy['ichimoku_a'] = ichimoku.ichimoku_a()
-# df_copy['ichimoku_b'] = ichimoku.ichimoku_b()
-# df_copy['ichimoku_base'] = ichimoku.ichimoku_base_line()
-# df_copy['ichimoku_conv'] = ichimoku.ichimoku_conversion_line()
-
-# # ADX - Average Directional Index
-# adx = ta.trend.ADXIndicator(df_copy['high'], df_copy['low'], df_copy['close'])
-# df_copy['adx'] = adx.adx()
-# df_copy['adx_pos'] = adx.adx_pos()
-# df_copy['adx_neg'] = adx.adx_neg()
-
-# # VWAP - Volume Weighted Average Price (intraday)
-# # Custom calculation since TA library doesn't include VWAP
-# df_copy['vwap'] = (df_copy['volume'] * (df_copy['high'] + df_copy['low'] + df_copy['close']) / 3).cumsum() / df_copy['volume'].cumsum()
-
-# # Fill NaN values
-# df_copy = df_copy.fillna(method='bfill').fillna(0)
-
-# return df_copy
-
-# def get_multi_timeframe_data(self, timeframes=None, n_candles=1000, refresh=False, add_indicators=True):
-# """
-# Fetch data for multiple timeframes.
-
-# Args:
-# timeframes (list): List of timeframes to fetch
-# n_candles (int): Number of candles to fetch for each timeframe
-# refresh (bool): Force refresh of the data
-# add_indicators (bool): Whether to add technical indicators
-
-# Returns:
-# dict: Dictionary of dataframes indexed by timeframe
-# """
-# if timeframes is None:
-# timeframes = self.timeframes
-
-# result = {}
-
-# for tf in timeframes:
-# # For higher timeframes, we need fewer candles
-# tf_candles = n_candles
-# if tf == '4h':
-# tf_candles = max(250, n_candles // 4)
-# elif tf == '1d':
-# tf_candles = max(100, n_candles // 24)
-
-# df = self.get_data(timeframe=tf, n_candles=tf_candles, refresh=refresh, add_indicators=add_indicators)
-# if df is not None and not df.empty:
-# result[tf] = df
-
-# return result
-
-# def prepare_training_data(self, window_size=20, train_ratio=0.8, refresh=False):
-# """
-# Prepare training data from multiple timeframes.
-
-# Args:
-# window_size (int): Size of the sliding window
-# train_ratio (float): Ratio of data to use for training
-# refresh (bool): Whether to refresh the data
-
-# Returns:
-# tuple: (X_train, y_train, X_val, y_val, train_prices, val_prices)
-# """
-# # Get data for all timeframes
-# data_dict = self.get_multi_timeframe_data(refresh=refresh)
-
-# if not data_dict:
-# logger.error("Failed to fetch data for any timeframe")
-# return None, None, None, None, None, None
-
-# # Align all dataframes by timestamp
-# all_dfs = list(data_dict.values())
-# min_date = max([df['timestamp'].min() for df in all_dfs])
-# max_date = min([df['timestamp'].max() for df in all_dfs])
-
-# aligned_dfs = {}
-# for tf, df in data_dict.items():
-# aligned_df = df[(df['timestamp'] >= min_date) & (df['timestamp'] <= max_date)]
-# aligned_dfs[tf] = aligned_df
-
-# # Choose the lowest timeframe as the reference for time alignment
-# reference_tf = min(self.timeframes, key=lambda x: self.timeframe_to_seconds.get(x, 3600))
-# reference_df = aligned_dfs[reference_tf]
-
-# # Create sliding windows for each timeframe
-# X_dict = {}
-# for tf, df in aligned_dfs.items():
-# # Drop timestamp and create numeric features
-# features = df.drop('timestamp', axis=1).values
-
-# # Ensure the feature array is 3D: [samples, window, features]
-# X = np.array([features[i:i+window_size] for i in range(len(features)-window_size)])
-# X_dict[tf] = X
-
-# # Create target labels based on future price movements
-# reference_prices = reference_df['close'].values
-# future_prices = reference_prices[window_size:]
-# current_prices = reference_prices[window_size-1:-1]
-
-# # Calculate returns
-# returns = (future_prices - current_prices) / current_prices
-
-# # Create labels: 0=SELL, 1=HOLD, 2=BUY
-# threshold = 0.0005 # 0.05% threshold
-# y = np.zeros(len(returns), dtype=int)
-# y[returns > threshold] = 2 # BUY
-# y[returns < -threshold] = 0 # SELL
-# y[(returns >= -threshold) & (returns <= threshold)] = 1 # HOLD
-
-# # Split into training and validation sets
-# split_idx = int(len(y) * train_ratio)
-
-# X_train_dict = {tf: X[:split_idx] for tf, X in X_dict.items()}
-# X_val_dict = {tf: X[split_idx:] for tf, X in X_dict.items()}
-
-# y_train = y[:split_idx]
-# y_val = y[split_idx:]
-
-# train_prices = reference_prices[window_size-1:window_size-1+split_idx]
-# val_prices = reference_prices[window_size-1+split_idx:window_size-1+len(y)]
-
-# logger.info(f"Prepared training data - Train: {len(y_train)}, Val: {len(y_val)}")
-
-# return X_train_dict, y_train, X_val_dict, y_val, train_prices, val_prices
-
-# def normalize_data(self, data_dict, fit=True):
-# """
-# Normalize data across all timeframes.
-
-# Args:
-# data_dict (dict): Dictionary of data arrays by timeframe
-# fit (bool): Whether to fit new scalers or use existing ones
-
-# Returns:
-# dict: Dictionary of normalized data arrays
-# """
-# result = {}
-
-# for tf, data in data_dict.items():
-# # For 3D data [samples, window, features]
-# if len(data.shape) == 3:
-# samples, window, features = data.shape
-# reshaped = data.reshape(-1, features)
-
-# if fit or tf not in self.scalers:
-# self.scalers[tf] = MinMaxScaler()
-# normalized = self.scalers[tf].fit_transform(reshaped)
-# else:
-# normalized = self.scalers[tf].transform(reshaped)
-
-# result[tf] = normalized.reshape(samples, window, features)
-
-# # For 2D data [samples, features]
-# elif len(data.shape) == 2:
-# if fit or tf not in self.scalers:
-# self.scalers[tf] = MinMaxScaler()
-# result[tf] = self.scalers[tf].fit_transform(data)
-# else:
-# result[tf] = self.scalers[tf].transform(data)
-
-# return result
-
-# def get_realtime_features(self, timeframes=None, window_size=20):
-# """
-# Get the most recent data for real-time prediction.
-
-# Args:
-# timeframes (list): List of timeframes to use
-# window_size (int): Size of the sliding window
-
-# Returns:
-# dict: Dictionary of feature arrays for the latest window
-# """
-# if timeframes is None:
-# timeframes = self.timeframes
-
-# # Get fresh data
-# data_dict = self.get_multi_timeframe_data(timeframes=timeframes, refresh=True)
-
-# result = {}
-# for tf, df in data_dict.items():
-# if len(df) < window_size:
-# logger.warning(f"Not enough data for {tf} (need {window_size}, got {len(df)})")
-# continue
-
-# # Get the latest window
-# latest_data = df.tail(window_size).drop('timestamp', axis=1).values
-
-# # Add extra dimension to match model input shape [1, window_size, features]
-# result[tf] = latest_data.reshape(1, window_size, -1)
-
-# # Apply normalization using existing scalers
-# if self.scalers:
-# result = self.normalize_data(result, fit=False)
-
-# return result
-
-# def calculate_pnl(self, predictions, prices, position_size=1.0, fee_rate=0.0002):
-# """
-# Calculate PnL and win rate from predictions.
-
-# Args:
-# predictions (np.ndarray): Array of predicted actions (0=SELL, 1=HOLD, 2=BUY)
-# prices (np.ndarray): Array of prices
-# position_size (float): Size of each position
-# fee_rate (float): Trading fee rate (default: 0.0002 for 0.02% per trade)
-
-# Returns:
-# tuple: (total_pnl, win_rate, trades)
-# """
-# if len(predictions) < 2 or len(prices) < 2:
-# return 0.0, 0.0, []
-
-# # Ensure arrays are the same length
-# min_len = min(len(predictions), len(prices)-1)
-# actions = predictions[:min_len]
-
-# pnl = 0.0
-# wins = 0
-# trades = []
-
-# for i in range(min_len):
-# current_price = prices[i]
-# next_price = prices[i+1]
-# action = actions[i]
-
-# # Skip HOLD actions
-# if action == 1:
-# continue
-
-# price_change = (next_price - current_price) / current_price
-
-# if action == 2: # BUY
-# # Calculate raw PnL
-# raw_pnl = price_change * position_size
-
-# # Calculate fees (entry and exit)
-# entry_fee = position_size * fee_rate
-# exit_fee = position_size * (1 + price_change) * fee_rate
-# total_fees = entry_fee + exit_fee
-
-# # Net PnL after fees
-# trade_pnl = raw_pnl - total_fees
-
-# trade_type = 'BUY'
-# is_win = trade_pnl > 0
-# elif action == 0: # SELL
-# # Calculate raw PnL
-# raw_pnl = -price_change * position_size
-
-# # Calculate fees (entry and exit)
-# entry_fee = position_size * fee_rate
-# exit_fee = position_size * (1 - price_change) * fee_rate
-# total_fees = entry_fee + exit_fee
-
-# # Net PnL after fees
-# trade_pnl = raw_pnl - total_fees
-
-# trade_type = 'SELL'
-# is_win = trade_pnl > 0
-# else:
-# continue
-
-# pnl += trade_pnl
-# wins += int(is_win)
-
-# trades.append({
-# 'type': trade_type,
-# 'entry': float(current_price), # Ensure serializable
-# 'exit': float(next_price),
-# 'raw_pnl': float(raw_pnl),
-# 'fees': float(total_fees),
-# 'pnl': float(trade_pnl),
-# 'win': bool(is_win),
-# 'timestamp': datetime.now().isoformat() # Add timestamp
-# })
-
-# win_rate = wins / len(trades) if trades else 0.0
-
-# return float(pnl), float(win_rate), trades
-
-# # Configure logging with more detailed format
-# logging.basicConfig(
-# level=logging.INFO, # Changed to DEBUG for more detailed logs
-# format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
-# handlers=[
-# logging.StreamHandler(),
-# logging.FileHandler('realtime_chart.log')
-# ]
-# )
-# logger = logging.getLogger(__name__)
-
-# # Neural Network integration (conditional import)
-# NN_ENABLED = os.environ.get('ENABLE_NN_MODELS', '0') == '1'
-# nn_orchestrator = None
-# nn_inference_thread = None
-
-# if NN_ENABLED:
-# try:
-# import sys
-# # Add project root to sys.path if needed
-# project_root = os.path.dirname(os.path.abspath(__file__))
-# if project_root not in sys.path:
-# sys.path.append(project_root)
-
-# from NN.main import NeuralNetworkOrchestrator
-# logger.info("Neural Network module enabled")
-# except ImportError as e:
-# logger.warning(f"Failed to import Neural Network module, disabling NN features: {str(e)}")
-# NN_ENABLED = False
-
-# # NN utility functions
-# def setup_neural_network():
-# """Initialize the neural network components if enabled"""
-# global nn_orchestrator, NN_ENABLED
-
-# if not NN_ENABLED:
-# return False
-
-# try:
-# # Get configuration from environment variables or use defaults
-# symbol = os.environ.get('NN_SYMBOL', 'ETH/USDT')
-# timeframes = os.environ.get('NN_TIMEFRAMES', '1m,5m,1h,4h,1d').split(',')
-# output_size = int(os.environ.get('NN_OUTPUT_SIZE', '3')) # 3 for BUY/HOLD/SELL
-
-# # Configure the orchestrator
-# config = {
-# 'symbol': symbol,
-# 'timeframes': timeframes,
-# 'window_size': int(os.environ.get('NN_WINDOW_SIZE', '20')),
-# 'n_features': 5, # OHLCV
-# 'output_size': output_size,
-# 'model_dir': 'NN/models/saved',
-# 'data_dir': 'NN/data'
-# }
-
-# # Initialize the orchestrator
-# logger.info(f"Initializing Neural Network Orchestrator with config: {config}")
-# nn_orchestrator = NeuralNetworkOrchestrator(config)
-
-# # Load the model
-# model_loaded = nn_orchestrator.load_model()
-# if not model_loaded:
-# logger.warning("Failed to load neural network model. Using untrained model.")
-
-# return model_loaded
-# except Exception as e:
-# logger.error(f"Error setting up neural network: {str(e)}")
-# NN_ENABLED = False
-# return False
-
-# def start_nn_inference_thread(interval_seconds):
-# """Start a background thread to periodically run inference with the neural network"""
-# global nn_inference_thread
-
-# if not NN_ENABLED or nn_orchestrator is None:
-# logger.warning("Cannot start inference thread - Neural Network not enabled or initialized")
-# return False
-
-# def inference_worker():
-# """Worker function for the inference thread"""
-# model_type = os.environ.get('NN_MODEL_TYPE', 'cnn')
-# timeframe = os.environ.get('NN_TIMEFRAME', '1h')
-
-# logger.info(f"Starting neural network inference thread with {interval_seconds}s interval")
-# logger.info(f"Using model type: {model_type}, timeframe: {timeframe}")
-
-# # Wait a bit for charts to initialize
-# time.sleep(5)
-
-# # Track active charts
-# active_charts = []
-
-# while True:
-# try:
-# # Find active charts if we don't have them yet
-# if not active_charts and 'charts' in globals():
-# active_charts = globals()['charts']
-# logger.info(f"Found {len(active_charts)} active charts for NN signals")
-
-# # Run inference
-# result = nn_orchestrator.run_inference_pipeline(
-# model_type=model_type,
-# timeframe=timeframe
-# )
-
-# if result:
-# # Log the result
-# logger.info(f"Neural network inference result: {result}")
-
-# # Add signal to charts
-# if active_charts:
-# try:
-# if 'action' in result:
-# action = result['action']
-# timestamp = datetime.fromisoformat(result['timestamp'].replace('Z', '+00:00'))
-
-# # Get probability if available
-# probability = None
-# if 'probability' in result:
-# probability = result['probability']
-# elif 'probabilities' in result:
-# probability = result['probabilities'].get(action, None)
-
-# # Add signal to each chart
-# for chart in active_charts:
-# if hasattr(chart, 'add_nn_signal'):
-# chart.add_nn_signal(action, timestamp, probability)
-# except Exception as e:
-# logger.error(f"Error adding NN signal to chart: {str(e)}")
-# import traceback
-# logger.error(traceback.format_exc())
-
-# # Sleep for the interval
-# time.sleep(interval_seconds)
-
-# except Exception as e:
-# logger.error(f"Error in inference thread: {str(e)}")
-# import traceback
-# logger.error(traceback.format_exc())
-# time.sleep(5) # Wait a bit before retrying
-
-# # Create and start the thread
-# nn_inference_thread = threading.Thread(target=inference_worker, daemon=True)
-# nn_inference_thread.start()
-
-# return True
-
-# # Try to get local timezone, default to Sofia/EET if not available
-# try:
-# local_timezone = tzlocal.get_localzone()
-# # Get timezone name safely
-# try:
-# tz_name = str(local_timezone)
-# # Handle case where it might be zoneinfo.ZoneInfo object instead of pytz timezone
-# if hasattr(local_timezone, 'zone'):
-# tz_name = local_timezone.zone
-# elif hasattr(local_timezone, 'key'):
-# tz_name = local_timezone.key
-# else:
-# tz_name = str(local_timezone)
-# except:
-# tz_name = "Local"
-# logger.info(f"Detected local timezone: {local_timezone} ({tz_name})")
-# except Exception as e:
-# logger.warning(f"Could not detect local timezone: {str(e)}. Defaulting to Sofia/EET")
-# local_timezone = pytz.timezone('Europe/Sofia')
-# tz_name = "Europe/Sofia"
-
-# def convert_to_local_time(timestamp):
-# """Convert timestamp to local timezone"""
-# try:
-# if isinstance(timestamp, pd.Timestamp):
-# dt = timestamp.to_pydatetime()
-# elif isinstance(timestamp, np.datetime64):
-# dt = pd.Timestamp(timestamp).to_pydatetime()
-# elif isinstance(timestamp, str):
-# dt = pd.to_datetime(timestamp).to_pydatetime()
-# else:
-# dt = timestamp
-
-# # If datetime is naive (no timezone), assume it's UTC
-# if dt.tzinfo is None:
-# dt = dt.replace(tzinfo=pytz.UTC)
-
-# # Convert to local timezone
-# local_dt = dt.astimezone(local_timezone)
-# return local_dt
-# except Exception as e:
-# logger.error(f"Error converting timestamp to local time: {str(e)}")
-# return timestamp
-
-# # Initialize TimescaleDB handler - only once, at module level
-# timescaledb_handler = TimescaleDBHandler() if TIMESCALEDB_ENABLED else None
-
-# class TickStorage:
-# def __init__(self, symbol, timeframes=None, use_timescaledb=False):
-# """Initialize the tick storage for a specific symbol"""
-# self.symbol = symbol
-# self.timeframes = timeframes or ["1s", "5m", "15m", "1h", "4h", "1d"]
-# self.ticks = []
-# self.candles = {tf: [] for tf in self.timeframes}
-# self.current_candle = {tf: None for tf in self.timeframes}
-# self.last_candle_timestamp = {tf: None for tf in self.timeframes}
-# self.cache_dir = os.path.join(os.getcwd(), "cache", symbol.replace("/", ""))
-# self.cache_path = os.path.join(self.cache_dir, f"{symbol.replace('/', '')}_ticks.json") # Add missing cache_path
-# self.use_timescaledb = use_timescaledb
-# self.max_ticks = 10000 # Maximum number of ticks to store in memory
-
-# # Create cache directory if it doesn't exist
-# os.makedirs(self.cache_dir, exist_ok=True)
-
-# logger.info(f"Creating new tick storage for {symbol} with timeframes {self.timeframes}")
-# logger.info(f"Cache directory: {self.cache_dir}")
-# logger.info(f"Cache file: {self.cache_path}")
-
-# if use_timescaledb:
-# print(f"TickStorage: TimescaleDB integration is ENABLED for {symbol}")
-# else:
-# logger.info(f"TickStorage: TimescaleDB integration is DISABLED for {symbol}")
-
-# def _save_to_cache(self):
-# """Save ticks to a cache file"""
-# try:
-# # Only save the latest 5000 ticks to avoid giant files
-# ticks_to_save = self.ticks[-5000:] if len(self.ticks) > 5000 else self.ticks
-
-# # Convert pandas Timestamps to ISO strings for JSON serialization
-# serializable_ticks = []
-# for tick in ticks_to_save:
-# serializable_tick = tick.copy()
-# if isinstance(tick['timestamp'], pd.Timestamp):
-# serializable_tick['timestamp'] = tick['timestamp'].isoformat()
-# elif hasattr(tick['timestamp'], 'isoformat'):
-# serializable_tick['timestamp'] = tick['timestamp'].isoformat()
-# else:
-# # Keep as is if it's already a string or number
-# serializable_tick['timestamp'] = tick['timestamp']
-# serializable_ticks.append(serializable_tick)
-
-# with open(self.cache_path, 'w') as f:
-# json.dump(serializable_ticks, f)
-# logger.debug(f"Saved {len(serializable_ticks)} ticks to cache")
-# except Exception as e:
-# logger.error(f"Error saving ticks to cache: {e}")
-
-# def _load_from_cache(self):
-# """Load ticks from cache if available"""
-# if os.path.exists(self.cache_path):
-# try:
-# # Check if the cache file is recent (< 10 minutes old)
-# cache_age = time.time() - os.path.getmtime(self.cache_path)
-# if cache_age > 600: # 10 minutes in seconds
-# logger.warning(f"Cache file is {cache_age:.1f} seconds old (>10 min). Not using it.")
-# return False
-
-# with open(self.cache_path, 'r') as f:
-# cached_ticks = json.load(f)
-
-# if cached_ticks:
-# # Convert ISO strings back to pandas Timestamps
-# processed_ticks = []
-# for tick in cached_ticks:
-# processed_tick = tick.copy()
-# if isinstance(tick['timestamp'], str):
-# try:
-# processed_tick['timestamp'] = pd.Timestamp(tick['timestamp'])
-# except:
-# # If parsing fails, use current time
-# processed_tick['timestamp'] = pd.Timestamp.now()
-# else:
-# # Convert to pandas Timestamp if it's a number (milliseconds)
-# processed_tick['timestamp'] = pd.Timestamp(tick['timestamp'], unit='ms')
-# processed_ticks.append(processed_tick)
-
-# self.ticks = processed_ticks
-# logger.info(f"Loaded {len(cached_ticks)} ticks from cache")
-# return True
-# except Exception as e:
-# logger.error(f"Error loading ticks from cache: {e}")
-# return False
-
-# def add_tick(self, tick=None, price=None, volume=None, timestamp=None):
-# """
-# Add a tick to the storage and update candles for all timeframes
-
-# Args:
-# tick (dict, optional): A tick object containing price, quantity and timestamp
-# price (float, optional): Price of the tick (used in older interface)
-# volume (float, optional): Volume of the tick (used in older interface)
-# timestamp (datetime, optional): Timestamp of the tick (used in older interface)
-# """
-# # Handle tick as a dict or separate parameters for backward compatibility
-# if tick is not None and isinstance(tick, dict):
-# # Using the new interface with a tick object
-# price = tick['price']
-# volume = tick.get('quantity', 0)
-# timestamp = tick['timestamp']
-# elif price is not None:
-# # Using the old interface with separate parameters
-# # Convert datetime to pd.Timestamp if needed
-# if timestamp is not None and not isinstance(timestamp, pd.Timestamp):
-# timestamp = pd.Timestamp(timestamp)
-# else:
-# logger.error("Invalid tick: must provide either a tick dict or price")
-# return
-
-# # Ensure timestamp is a pandas Timestamp
-# if not isinstance(timestamp, pd.Timestamp):
-# if isinstance(timestamp, (int, float)):
-# # Assume it's milliseconds
-# timestamp = pd.Timestamp(timestamp, unit='ms')
-# else:
-# # Try to parse as string or datetime
-# timestamp = pd.Timestamp(timestamp)
-
-# # Create tick object with consistent pandas Timestamp
-# tick_obj = {
-# 'price': float(price),
-# 'quantity': float(volume) if volume is not None else 0.0,
-# 'timestamp': timestamp
-# }
-
-# # Add to the list of ticks
-# self.ticks.append(tick_obj)
-
-# # Limit the number of ticks to avoid memory issues
-# if len(self.ticks) > self.max_ticks:
-# self.ticks = self.ticks[-self.max_ticks:]
-
-# # Update candles for all timeframes
-# for timeframe in self.timeframes:
-# if timeframe == "1s":
-# self._update_1s_candle(tick_obj)
-# else:
-# self._update_candles_for_timeframe(timeframe, tick_obj)
-
-# # Cache to disk periodically
-# self._try_cache_ticks()
-
-# def _update_1s_candle(self, tick):
-# """Update the 1-second candle with the new tick"""
-# # Get timestamp for the start of the current second
-# tick_timestamp = tick['timestamp']
-# candle_timestamp = pd.Timestamp(int(tick_timestamp.timestamp() // 1 * 1_000_000_000))
-
-# # Check if we need to create a new candle
-# if self.current_candle["1s"] is None or self.current_candle["1s"]["timestamp"] != candle_timestamp:
-# # If we have a current candle, finalize it and add to candles list
-# if self.current_candle["1s"] is not None:
-# # Add the completed candle to the list
-# self.candles["1s"].append(self.current_candle["1s"])
-
-# # Limit the number of stored candles to prevent memory issues
-# if len(self.candles["1s"]) > 3600: # Keep last hour of 1s candles
-# self.candles["1s"] = self.candles["1s"][-3600:]
-
-# # Store in TimescaleDB if enabled
-# if self.use_timescaledb:
-# timescaledb_handler.upsert_candle(
-# self.symbol, "1s", self.current_candle["1s"]
-# )
-
-# # Log completed candle for debugging
-# logger.debug(f"Completed 1s candle: {self.current_candle['1s']['timestamp']} - Close: {self.current_candle['1s']['close']}")
-
-# # Create a new candle
-# self.current_candle["1s"] = {
-# "timestamp": candle_timestamp,
-# "open": float(tick["price"]),
-# "high": float(tick["price"]),
-# "low": float(tick["price"]),
-# "close": float(tick["price"]),
-# "volume": float(tick["quantity"]) if "quantity" in tick else 0.0
-# }
-
-# # Update last candle timestamp
-# self.last_candle_timestamp["1s"] = candle_timestamp
-# logger.debug(f"Created new 1s candle at {candle_timestamp}")
-# else:
-# # Update the current candle
-# current = self.current_candle["1s"]
-# price = float(tick["price"])
-
-# # Update high and low
-# if price > current["high"]:
-# current["high"] = price
-# if price < current["low"]:
-# current["low"] = price
-
-# # Update close price and add volume
-# current["close"] = price
-# current["volume"] += float(tick["quantity"]) if "quantity" in tick else 0.0
-
-# def _update_candles_for_timeframe(self, timeframe, tick):
-# """Update candles for a specific timeframe"""
-# # Skip 1s as it's handled separately
-# if timeframe == "1s":
-# return
-
-# # Convert timeframe to seconds
-# timeframe_seconds = self._timeframe_to_seconds(timeframe)
-
-# # Get the timestamp truncated to the timeframe interval
-# # e.g., for a 5m candle, the timestamp should be truncated to the nearest 5-minute mark
-# # Convert timestamp to datetime if it's not already
-# tick_timestamp = tick['timestamp']
-# if isinstance(tick_timestamp, pd.Timestamp):
-# ts = tick_timestamp
-# else:
-# ts = pd.Timestamp(tick_timestamp)
-
-# # Truncate timestamp to nearest timeframe interval
-# timestamp = pd.Timestamp(
-# int(ts.timestamp() // timeframe_seconds * timeframe_seconds * 1_000_000_000)
-# )
-
-# # Get the current candle for this timeframe
-# current_candle = self.current_candle[timeframe]
-
-# # If we have no current candle or the timestamp is different (new candle)
-# if current_candle is None or current_candle['timestamp'] != timestamp:
-# # If we have a current candle, add it to the candles list
-# if current_candle:
-# self.candles[timeframe].append(current_candle)
-
-# # Save to TimescaleDB if enabled
-# if self.use_timescaledb:
-# timescaledb_handler.upsert_candle(self.symbol, timeframe, current_candle)
-
-# # Create a new candle
-# current_candle = {
-# 'timestamp': timestamp,
-# 'open': tick['price'],
-# 'high': tick['price'],
-# 'low': tick['price'],
-# 'close': tick['price'],
-# 'volume': tick.get('quantity', 0)
-# }
-
-# # Update current candle
-# self.current_candle[timeframe] = current_candle
-# self.last_candle_timestamp[timeframe] = timestamp
-
-# else:
-# # Update existing candle
-# current_candle['high'] = max(current_candle['high'], tick['price'])
-# current_candle['low'] = min(current_candle['low'], tick['price'])
-# current_candle['close'] = tick['price']
-# current_candle['volume'] += tick.get('quantity', 0)
-
-# # Limit the number of candles to avoid memory issues
-# max_candles = 1000
-# if len(self.candles[timeframe]) > max_candles:
-# self.candles[timeframe] = self.candles[timeframe][-max_candles:]
-
-# def _timeframe_to_seconds(self, timeframe):
-# """Convert a timeframe string (e.g., '1m', '1h') to seconds"""
-# if timeframe == "1s":
-# return 1
-
-# try:
-# # Extract the number and unit
-# match = re.match(r'(\d+)([smhdw])', timeframe)
-# if not match:
-# return None
-
-# num, unit = match.groups()
-# num = int(num)
-
-# # Convert to seconds
-# if unit == 's':
-# return num
-# elif unit == 'm':
-# return num * 60
-# elif unit == 'h':
-# return num * 3600
-# elif unit == 'd':
-# return num * 86400
-# elif unit == 'w':
-# return num * 604800
-
-# return None
-# except:
-# return None
-
-# def get_candles(self, timeframe, limit=None):
-# """Get candles for a given timeframe"""
-# if timeframe in self.candles:
-# candles = self.candles[timeframe]
-
-# # Add the current candle if it exists and isn't None
-# if timeframe in self.current_candle and self.current_candle[timeframe] is not None:
-# # Make a copy of the current candle
-# current_candle_copy = self.current_candle[timeframe].copy()
-
-# # Check if the current candle is newer than the last candle in the list
-# if not candles or current_candle_copy["timestamp"] > candles[-1]["timestamp"]:
-# candles = candles + [current_candle_copy]
-
-# # Apply limit if provided
-# if limit and len(candles) > limit:
-# return candles[-limit:]
-# return candles
-# return []
-
-# def get_last_price(self):
-# """Get the last known price"""
-# if self.ticks:
-# return float(self.ticks[-1]["price"])
-# return None
-
-# def load_historical_data(self, symbol, limit=1000):
-# """Load historical data for all timeframes"""
-# logger.info(f"Starting historical data load for {symbol} with limit {limit}")
-
-# # Clear existing data
-# self.ticks = []
-# self.candles = {tf: [] for tf in self.timeframes}
-# self.current_candle = {tf: None for tf in self.timeframes}
-
-# # Try to load ticks from cache first
-# logger.info("Attempting to load from cache...")
-# cache_loaded = self._load_from_cache()
-# if cache_loaded:
-# logger.info("Successfully loaded data from cache")
-# else:
-# logger.info("No valid cache data found")
-
-# # Check if we have TimescaleDB enabled
-# if self.use_timescaledb and timescaledb_handler and timescaledb_handler.enabled:
-# logger.info("Attempting to fetch historical data from TimescaleDB")
-# loaded_from_db = False
-
-# # Load candles for each timeframe from TimescaleDB
-# for tf in self.timeframes:
-# try:
-# candles = timescaledb_handler.fetch_candles(symbol, tf, limit)
-# if candles:
-# self.candles[tf] = candles
-# loaded_from_db = True
-# logger.info(f"Loaded {len(candles)} {tf} candles from TimescaleDB")
-# else:
-# logger.info(f"No {tf} candles found in TimescaleDB")
-# except Exception as e:
-# logger.error(f"Error loading {tf} candles from TimescaleDB: {str(e)}")
-
-# if loaded_from_db:
-# logger.info("Successfully loaded historical data from TimescaleDB")
-# return True
-# else:
-# logger.info("TimescaleDB not available or disabled")
-
-# # If no TimescaleDB data and no cache, we need to get from Binance API
-# if not cache_loaded:
-# logger.info("Loading data from Binance API...")
-# # Create a BinanceHistoricalData instance
-# historical_data = BinanceHistoricalData()
-
-# # Load data for each timeframe
-# success_count = 0
-# for tf in self.timeframes:
-# if tf != "1s": # Skip 1s since we'll generate it from ticks
-# try:
-# logger.info(f"Fetching {tf} candles for {symbol}...")
-# df = historical_data.get_historical_candles(symbol, self._timeframe_to_seconds(tf), limit)
-# if df is not None and not df.empty:
-# logger.info(f"Loaded {len(df)} {tf} candles from Binance API")
-
-# # Convert to our candle format and store
-# candles = []
-# for _, row in df.iterrows():
-# candle = {
-# 'timestamp': row['timestamp'],
-# 'open': row['open'],
-# 'high': row['high'],
-# 'low': row['low'],
-# 'close': row['close'],
-# 'volume': row['volume']
-# }
-# candles.append(candle)
-
-# # Also save to TimescaleDB if enabled
-# if self.use_timescaledb and timescaledb_handler and timescaledb_handler.enabled:
-# timescaledb_handler.upsert_candle(symbol, tf, candle)
-
-# self.candles[tf] = candles
-# success_count += 1
-# else:
-# logger.warning(f"No data returned for {tf} candles")
-# except Exception as e:
-# logger.error(f"Error loading {tf} candles: {str(e)}")
-# import traceback
-# logger.error(traceback.format_exc())
-
-# logger.info(f"Successfully loaded {success_count} timeframes from Binance API")
-
-# # For 1s, load from API if possible or compute from first available timeframe
-# if "1s" in self.timeframes:
-# logger.info("Loading 1s candles...")
-# # Try to get 1s data from Binance
-# try:
-# df_1s = historical_data.get_historical_candles(symbol, 1, 300) # Only need recent 1s data
-# if df_1s is not None and not df_1s.empty:
-# logger.info(f"Loaded {len(df_1s)} recent 1s candles from Binance API")
-
-# # Convert to our candle format and store
-# candles_1s = []
-# for _, row in df_1s.iterrows():
-# candle = {
-# 'timestamp': row['timestamp'],
-# 'open': row['open'],
-# 'high': row['high'],
-# 'low': row['low'],
-# 'close': row['close'],
-# 'volume': row['volume']
-# }
-# candles_1s.append(candle)
-
-# # Also save to TimescaleDB if enabled
-# if self.use_timescaledb and timescaledb_handler and timescaledb_handler.enabled:
-# timescaledb_handler.upsert_candle(symbol, "1s", candle)
-
-# self.candles["1s"] = candles_1s
-# except Exception as e:
-# logger.error(f"Error loading 1s candles: {str(e)}")
-
-# # If 1s data not available or failed to load, approximate from 1m data
-# if not self.candles.get("1s"):
-# logger.info("1s data not available, trying to approximate from 1m data...")
-# # If 1s data not available, we can approximate from 1m data
-# if "1m" in self.timeframes and self.candles["1m"]:
-# # For demonstration, just use the 1m candles as placeholders for 1s
-# # In a real implementation, you might want more sophisticated interpolation
-# logger.info("Using 1m candles as placeholders for 1s timeframe")
-# self.candles["1s"] = []
-
-# # Take the most recent 5 minutes of 1m candles
-# recent_1m = self.candles["1m"][-5:] if self.candles["1m"] else []
-# logger.info(f"Creating 1s approximations from {len(recent_1m)} 1m candles")
-# for candle_1m in recent_1m:
-# # Create 60 1s candles for each 1m candle
-# ts_base = candle_1m["timestamp"].timestamp()
-# for i in range(60):
-# # Create a 1s candle with interpolated values
-# candle_1s = {
-# 'timestamp': pd.Timestamp(int((ts_base + i) * 1_000_000_000)),
-# 'open': candle_1m['open'],
-# 'high': candle_1m['high'],
-# 'low': candle_1m['low'],
-# 'close': candle_1m['close'],
-# 'volume': candle_1m['volume'] / 60.0 # Distribute volume evenly
-# }
-# self.candles["1s"].append(candle_1s)
-
-# # Also save to TimescaleDB if enabled
-# if self.use_timescaledb and timescaledb_handler and timescaledb_handler.enabled:
-# timescaledb_handler.upsert_candle(symbol, "1s", candle_1s)
-
-# logger.info(f"Created {len(self.candles['1s'])} approximated 1s candles")
-# else:
-# logger.warning("No 1m data available to approximate 1s candles from")
-
-# # Set the last candle of each timeframe as the current candle
-# for tf in self.timeframes:
-# if self.candles[tf]:
-# self.current_candle[tf] = self.candles[tf][-1].copy()
-# self.last_candle_timestamp[tf] = self.current_candle[tf]["timestamp"]
-# logger.debug(f"Set current candle for {tf}: {self.current_candle[tf]['timestamp']}")
-
-# # If we loaded ticks from cache, rebuild candles
-# if cache_loaded:
-# logger.info("Rebuilding candles from cached ticks...")
-# # Clear candles
-# self.candles = {tf: [] for tf in self.timeframes}
-# self.current_candle = {tf: None for tf in self.timeframes}
-
-# # Process each tick to rebuild the candles
-# for tick in self.ticks:
-# for tf in self.timeframes:
-# if tf == "1s":
-# self._update_1s_candle(tick)
-# else:
-# self._update_candles_for_timeframe(tf, tick)
-
-# logger.info("Finished rebuilding candles from ticks")
-
-# # Log final results
-# for tf in self.timeframes:
-# count = len(self.candles[tf])
-# logger.info(f"Final {tf} candle count: {count}")
-
-# has_data = cache_loaded or any(self.candles[tf] for tf in self.timeframes)
-# logger.info(f"Historical data loading completed. Has data: {has_data}")
-# return has_data
-
-# def _try_cache_ticks(self):
-# """Try to save ticks to cache periodically"""
-# # Only save to cache every 100 ticks to avoid excessive disk I/O
-# if len(self.ticks) % 100 == 0:
-# try:
-# self._save_to_cache()
-# except Exception as e:
-# # Don't spam logs with cache errors, just log once every 1000 ticks
-# if len(self.ticks) % 1000 == 0:
-# logger.warning(f"Cache save failed at {len(self.ticks)} ticks: {str(e)}")
-# pass # Continue even if cache fails
-
-# class Position:
-# """Represents a trading position"""
-
-# def __init__(self, action, entry_price, amount, timestamp=None, trade_id=None, fee_rate=0.0002):
-# self.action = action
-# self.entry_price = entry_price
-# self.amount = amount
-# self.entry_timestamp = timestamp or datetime.now()
-# self.exit_timestamp = None
-# self.exit_price = None
-# self.pnl = None
-# self.is_open = True
-# self.trade_id = trade_id or str(uuid.uuid4())[:8]
-# self.fee_rate = fee_rate
-# self.paid_fee = entry_price * amount * fee_rate # Calculate entry fee
-
-# def close(self, exit_price, exit_timestamp=None):
-# """Close an open position"""
-# self.exit_price = exit_price
-# self.exit_timestamp = exit_timestamp or datetime.now()
-# self.is_open = False
-
-# # Calculate P&L
-# if self.action == "BUY":
-# price_diff = self.exit_price - self.entry_price
-# # Calculate fee for exit trade
-# exit_fee = exit_price * self.amount * self.fee_rate
-# self.paid_fee += exit_fee # Add exit fee to total paid fee
-# self.pnl = (price_diff * self.amount) - self.paid_fee
-# else: # SELL
-# price_diff = self.entry_price - self.exit_price
-# # Calculate fee for exit trade
-# exit_fee = exit_price * self.amount * self.fee_rate
-# self.paid_fee += exit_fee # Add exit fee to total paid fee
-# self.pnl = (price_diff * self.amount) - self.paid_fee
-
-# return self.pnl
-
-# class RealTimeChart:
-# def __init__(self, app=None, symbol='BTCUSDT', timeframe='1m', standalone=True, chart_title=None,
-# run_signal_interpreter=False, debug_mode=False, historical_candles=None,
-# extended_hours=False, enable_logging=True, agent=None, trading_env=None,
-# max_memory_usage=90, memory_check_interval=10, tick_update_interval=0.5,
-# chart_update_interval=1, performance_monitoring=False, show_volume=True,
-# show_indicators=True, custom_trades=None, port=8050, height=900, width=1200,
-# positions_callback=None, allow_synthetic_data=True, tick_storage=None):
-# """Initialize a real-time chart with support for multiple indicators and backtesting."""
-
-# # Store parameters
-# self.symbol = symbol
-# self.timeframe = timeframe
-# self.debug_mode = debug_mode
-# self.standalone = standalone
-# self.chart_title = chart_title or f"{symbol} Real-Time Chart"
-# self.extended_hours = extended_hours
-# self.enable_logging = enable_logging
-# self.run_signal_interpreter = run_signal_interpreter
-# self.historical_candles = historical_candles
-# self.performance_monitoring = performance_monitoring
-# self.max_memory_usage = max_memory_usage
-# self.memory_check_interval = memory_check_interval
-# self.tick_update_interval = tick_update_interval
-# self.chart_update_interval = chart_update_interval
-# self.show_volume = show_volume
-# self.show_indicators = show_indicators
-# self.custom_trades = custom_trades
-# self.port = port
-# self.height = height
-# self.width = width
-# self.positions_callback = positions_callback
-# self.allow_synthetic_data = allow_synthetic_data
-
-# # Initialize interval store
-# self.interval_store = {'interval': 1} # Default to 1s timeframe
-
-# # Initialize trading components
-# self.agent = agent
-# self.trading_env = trading_env
-
-# # Initialize button styles for timeframe selection
-# self.button_style = {
-# 'background': '#343a40',
-# 'color': 'white',
-# 'border': 'none',
-# 'padding': '10px 20px',
-# 'margin': '0 5px',
-# 'borderRadius': '4px',
-# 'cursor': 'pointer'
-# }
-
-# self.active_button_style = {
-# 'background': '#007bff',
-# 'color': 'white',
-# 'border': 'none',
-# 'padding': '10px 20px',
-# 'margin': '0 5px',
-# 'borderRadius': '4px',
-# 'cursor': 'pointer',
-# 'fontWeight': 'bold'
-# }
-
-# # Initialize color schemes
-# self.colors = {
-# 'background': '#1e1e1e',
-# 'text': '#ffffff',
-# 'grid': '#333333',
-# 'candle_up': '#26a69a',
-# 'candle_down': '#ef5350',
-# 'volume_up': 'rgba(38, 166, 154, 0.3)',
-# 'volume_down': 'rgba(239, 83, 80, 0.3)',
-# 'ma': '#ffeb3b',
-# 'ema': '#29b6f6',
-# 'bollinger_bands': '#ff9800',
-# 'trades_buy': '#00e676',
-# 'trades_sell': '#ff1744'
-# }
-
-# # Initialize data storage
-# self.all_trades = [] # Store trades
-# self.positions = [] # Store open positions
-# self.latest_price = 0.0
-# self.latest_volume = 0.0
-# self.latest_timestamp = datetime.now()
-# self.current_balance = 100.0 # Starting balance
-# self.accumulative_pnl = 0.0 # Accumulated profit/loss
-
-# # Initialize trade rate counter
-# self.trade_count = 0
-# self.start_time = time.time()
-# self.trades_per_second = 0
-# self.trades_per_minute = 0
-# self.trades_per_hour = 0
-
-# # Initialize trade rate tracking variables
-# self.trade_times = [] # Store timestamps of recent trades for rate calculation
-# self.last_trade_rate_calculation = datetime.now()
-# self.trade_rate = {"per_second": 0, "per_minute": 0, "per_hour": 0}
-
-# # Initialize interactive components
-# self.app = app
-
-# # Create a new app if not provided
-# if self.app is None and standalone:
-# self.app = dash.Dash(
-# __name__,
-# external_stylesheets=[dbc.themes.DARKLY],
-# suppress_callback_exceptions=True
-# )
-
-# # Initialize tick storage if not provided
-# if tick_storage is None:
-# # Check if TimescaleDB integration is enabled
-# use_timescaledb = TIMESCALEDB_ENABLED and timescaledb_handler is not None
-
-# # Create a new tick storage
-# self.tick_storage = TickStorage(
-# symbol=symbol,
-# timeframes=["1s", "1m", "5m", "15m", "1h", "4h", "1d"],
-# use_timescaledb=use_timescaledb
-# )
-
-# # Load historical data immediately for cold start
-# logger.info(f"Loading historical data for {symbol} during chart initialization")
-# try:
-# data_loaded = self.tick_storage.load_historical_data(symbol)
-# if data_loaded:
-# logger.info(f"Successfully loaded historical data for {symbol}")
-# # Log what we have
-# for tf in ["1s", "1m", "5m", "15m", "1h"]:
-# candle_count = len(self.tick_storage.candles.get(tf, []))
-# logger.info(f" {tf}: {candle_count} candles")
-# else:
-# logger.warning(f"Failed to load historical data for {symbol}")
-# except Exception as e:
-# logger.error(f"Error loading historical data during initialization: {str(e)}")
-# import traceback
-# logger.error(traceback.format_exc())
-# else:
-# self.tick_storage = tick_storage
-
-# # Create layout and callbacks if app is provided
-# if self.app is not None:
-# # Create the layout
-# self.app.layout = self._create_layout()
-
-# # Register callbacks
-# self._setup_callbacks()
-
-# # Log initialization
-# if self.enable_logging:
-# logger.info(f"RealTimeChart initialized: {self.symbol} ({self.timeframe}) ")
-
-# def _create_layout(self):
-# return html.Div([
-# # Header section with title and current price
-# html.Div([
-# html.H1(f"{self.symbol} Real-Time Chart", className="display-4"),
-
-# # Current price ticker
-# html.Div([
-# html.H4("Current Price:", style={"display": "inline-block", "marginRight": "10px"}),
-# html.H3(id="current-price", style={"display": "inline-block", "color": "#17a2b8"}),
-# html.Div([
-# html.H5("Balance:", style={"display": "inline-block", "marginRight": "10px", "marginLeft": "30px"}),
-# html.H5(id="current-balance", style={"display": "inline-block", "color": "#28a745"}),
-# ], style={"display": "inline-block", "marginLeft": "40px"}),
-# html.Div([
-# html.H5("Accumulated PnL:", style={"display": "inline-block", "marginRight": "10px", "marginLeft": "30px"}),
-# html.H5(id="accumulated-pnl", style={"display": "inline-block", "color": "#ffc107"}),
-# ], style={"display": "inline-block", "marginLeft": "40px"}),
-
-# # Add trade rate display
-# html.Div([
-# html.H5("Trade Rate:", style={"display": "inline-block", "marginRight": "10px", "marginLeft": "30px"}),
-# html.Span([
-# html.Span(id="trade-rate-second", style={"color": "#ff7f0e"}),
-# html.Span("/s, "),
-# html.Span(id="trade-rate-minute", style={"color": "#ff7f0e"}),
-# html.Span("/m, "),
-# html.Span(id="trade-rate-hour", style={"color": "#ff7f0e"}),
-# html.Span("/h")
-# ], style={"display": "inline-block"}),
-# ], style={"display": "inline-block", "marginLeft": "40px"}),
-# ], style={"textAlign": "center", "margin": "20px 0"}),
-# ], style={"textAlign": "center", "marginBottom": "20px"}),
-
-# # Add interval component for periodic updates
-# dcc.Interval(
-# id='interval-component',
-# interval=500, # in milliseconds
-# n_intervals=0
-# ),
-
-# # Add timeframe selection buttons
-# html.Div([
-# html.Button('1s', id='btn-1s', n_clicks=0, style=self.active_button_style),
-# html.Button('5s', id='btn-5s', n_clicks=0, style=self.button_style),
-# html.Button('15s', id='btn-15s', n_clicks=0, style=self.button_style),
-# html.Button('1m', id='btn-1m', n_clicks=0, style=self.button_style),
-# html.Button('5m', id='btn-5m', n_clicks=0, style=self.button_style),
-# html.Button('15m', id='btn-15m', n_clicks=0, style=self.button_style),
-# html.Button('1h', id='btn-1h', n_clicks=0, style=self.button_style),
-# ], style={"textAlign": "center", "marginBottom": "20px"}),
-
-# # Store for the selected timeframe
-# dcc.Store(id='interval-store', data={'interval': 1}),
-
-# # Chart content (without wrapper div to avoid callback issues)
-# dcc.Graph(id='live-chart', style={"height": "600px"}),
-# dcc.Graph(id='secondary-charts', style={"height": "500px"}),
-# html.Div(id='positions-list')
-# ])
-
-# def _create_chart_and_controls(self):
-# """Create the chart and controls for the dashboard."""
-# try:
-# # Get selected interval from the dashboard (default to 1s if not available)
-# interval_seconds = 1
-# if hasattr(self, 'interval_store') and self.interval_store:
-# interval_seconds = self.interval_store.get('interval', 1)
-
-# # Create chart components
-# chart_div = html.Div([
-# # Update chart with data for the selected interval
-# dcc.Graph(
-# id='live-chart',
-# figure=self._update_main_chart(interval_seconds),
-# style={"height": "600px"}
-# ),
-
-# # Update secondary charts
-# dcc.Graph(
-# id='secondary-charts',
-# figure=self._update_secondary_charts(),
-# style={"height": "500px"}
-# ),
-
-# # Update positions list
-# html.Div(
-# id='positions-list',
-# children=self._get_position_list_rows()
-# )
-# ])
-
-# return chart_div
-
-# except Exception as e:
-# logger.error(f"Error creating chart and controls: {str(e)}")
-# import traceback
-# logger.error(traceback.format_exc())
-# # Return a simple error message as fallback
-# return html.Div(f"Error loading chart: {str(e)}", style={"color": "red", "padding": "20px"})
-
-# def _setup_callbacks(self):
-# """Setup Dash callbacks for the real-time chart"""
-# if self.app is None:
-# return
-
-# try:
-# # Update chart with all components based on interval
-# @self.app.callback(
-# [
-# Output('live-chart', 'figure'),
-# Output('secondary-charts', 'figure'),
-# Output('positions-list', 'children'),
-# Output('current-price', 'children'),
-# Output('current-balance', 'children'),
-# Output('accumulated-pnl', 'children'),
-# Output('trade-rate-second', 'children'),
-# Output('trade-rate-minute', 'children'),
-# Output('trade-rate-hour', 'children')
-# ],
-# [
-# Input('interval-component', 'n_intervals'),
-# Input('interval-store', 'data')
-# ]
-# )
-# def update_all(n_intervals, interval_data):
-# """Update all chart components"""
-# try:
-# # Get selected interval
-# interval_seconds = interval_data.get('interval', 1) if interval_data else 1
-
-# # Update main chart - limit data for performance
-# main_chart = self._update_main_chart(interval_seconds)
-
-# # Update secondary charts - limit data for performance
-# secondary_charts = self._update_secondary_charts()
-
-# # Update positions list
-# positions_list = self._get_position_list_rows()
-
-# # Update current price and balance
-# current_price = f"${self.latest_price:.2f}" if self.latest_price else "Error"
-# current_balance = f"${self.current_balance:.2f}"
-# accumulated_pnl = f"${self.accumulative_pnl:.2f}"
-
-# # Calculate trade rates
-# trade_rate = self._calculate_trade_rate()
-# trade_rate_second = f"{trade_rate['per_second']:.1f}"
-# trade_rate_minute = f"{trade_rate['per_minute']:.1f}"
-# trade_rate_hour = f"{trade_rate['per_hour']:.1f}"
-
-# return (main_chart, secondary_charts, positions_list,
-# current_price, current_balance, accumulated_pnl,
-# trade_rate_second, trade_rate_minute, trade_rate_hour)
-
-# except Exception as e:
-# logger.error(f"Error in update_all callback: {str(e)}")
-# # Return empty/error states
-# import plotly.graph_objects as go
-# empty_fig = go.Figure()
-# empty_fig.add_annotation(text="Chart Loading...", xref="paper", yref="paper", x=0.5, y=0.5)
-
-# return (empty_fig, empty_fig, [], "Loading...", "$0.00", "$0.00", "0.0", "0.0", "0.0")
-
-# # Timeframe selection callbacks
-# @self.app.callback(
-# [Output('interval-store', 'data'),
-# Output('btn-1s', 'style'), Output('btn-5s', 'style'), Output('btn-15s', 'style'),
-# Output('btn-1m', 'style'), Output('btn-5m', 'style'), Output('btn-15m', 'style'),
-# Output('btn-1h', 'style')],
-# [Input('btn-1s', 'n_clicks'), Input('btn-5s', 'n_clicks'), Input('btn-15s', 'n_clicks'),
-# Input('btn-1m', 'n_clicks'), Input('btn-5m', 'n_clicks'), Input('btn-15m', 'n_clicks'),
-# Input('btn-1h', 'n_clicks')]
-# )
-# def update_timeframe(n1s, n5s, n15s, n1m, n5m, n15m, n1h):
-# """Update selected timeframe based on button clicks"""
-# ctx = dash.callback_context
-# if not ctx.triggered:
-# # Default to 1s
-# styles = [self.active_button_style] + [self.button_style] * 6
-# return {'interval': 1}, *styles
-
-# button_id = ctx.triggered[0]['prop_id'].split('.')[0]
-
-# # Map button to interval seconds
-# interval_map = {
-# 'btn-1s': 1, 'btn-5s': 5, 'btn-15s': 15,
-# 'btn-1m': 60, 'btn-5m': 300, 'btn-15m': 900, 'btn-1h': 3600
-# }
-
-# selected_interval = interval_map.get(button_id, 1)
-
-# # Create styles - active for selected, normal for others
-# button_names = ['btn-1s', 'btn-5s', 'btn-15s', 'btn-1m', 'btn-5m', 'btn-15m', 'btn-1h']
-# styles = []
-# for name in button_names:
-# if name == button_id:
-# styles.append(self.active_button_style)
-# else:
-# styles.append(self.button_style)
-
-# return {'interval': selected_interval}, *styles
-
-# logger.info("Dash callbacks registered successfully")
-
-# except Exception as e:
-# logger.error(f"Error setting up callbacks: {str(e)}")
-# import traceback
-# logger.error(traceback.format_exc())
-
-# def _calculate_trade_rate(self):
-# """Calculate trading rate per second, minute, and hour"""
-# try:
-# now = datetime.now()
-# current_time = time.time()
-
-# # Filter trades within different time windows
-# trades_last_second = sum(1 for trade_time in self.trade_times if current_time - trade_time <= 1)
-# trades_last_minute = sum(1 for trade_time in self.trade_times if current_time - trade_time <= 60)
-# trades_last_hour = sum(1 for trade_time in self.trade_times if current_time - trade_time <= 3600)
-
-# return {
-# "per_second": trades_last_second,
-# "per_minute": trades_last_minute,
-# "per_hour": trades_last_hour
-# }
-# except Exception as e:
-# logger.warning(f"Error calculating trade rate: {str(e)}")
-# return {"per_second": 0.0, "per_minute": 0.0, "per_hour": 0.0}
-
-# def _update_secondary_charts(self):
-# """Create secondary charts for volume and indicators"""
-# try:
-# # Create subplots for secondary charts
-# fig = make_subplots(
-# rows=2, cols=1,
-# subplot_titles=['Volume', 'Technical Indicators'],
-# shared_xaxes=True,
-# vertical_spacing=0.1,
-# row_heights=[0.3, 0.7]
-# )
-
-# # Get latest candles (limit for performance)
-# candles = self.tick_storage.candles.get("1m", [])[-100:] # Last 100 candles for performance
-
-# if not candles:
-# fig.add_annotation(text="No data available", xref="paper", yref="paper", x=0.5, y=0.5)
-# fig.update_layout(
-# title="Secondary Charts",
-# template="plotly_dark",
-# height=400
-# )
-# return fig
-
-# # Extract data
-# timestamps = [candle['timestamp'] for candle in candles]
-# volumes = [candle['volume'] for candle in candles]
-# closes = [candle['close'] for candle in candles]
-
-# # Volume chart
-# colors = ['#26a69a' if i == 0 or closes[i] >= closes[i-1] else '#ef5350' for i in range(len(closes))]
-# fig.add_trace(
-# go.Bar(
-# x=timestamps,
-# y=volumes,
-# name='Volume',
-# marker_color=colors,
-# showlegend=False
-# ),
-# row=1, col=1
-# )
-
-# # Technical indicators
-# if len(closes) >= 20:
-# # Simple moving average
-# sma_20 = pd.Series(closes).rolling(window=20).mean()
-# fig.add_trace(
-# go.Scatter(
-# x=timestamps,
-# y=sma_20,
-# name='SMA 20',
-# line=dict(color='#ffeb3b', width=2)
-# ),
-# row=2, col=1
-# )
-
-# # RSI calculation
-# if len(closes) >= 14:
-# rsi = self._calculate_rsi(closes, 14)
-# fig.add_trace(
-# go.Scatter(
-# x=timestamps,
-# y=rsi,
-# name='RSI',
-# line=dict(color='#29b6f6', width=2),
-# yaxis='y3'
-# ),
-# row=2, col=1
-# )
-
-# # Update layout
-# fig.update_layout(
-# title="Volume & Technical Indicators",
-# template="plotly_dark",
-# height=400,
-# showlegend=True,
-# legend=dict(x=0, y=1, bgcolor='rgba(0,0,0,0)')
-# )
-
-# # Update y-axes
-# fig.update_yaxes(title="Volume", row=1, col=1)
-# fig.update_yaxes(title="Price", row=2, col=1)
-
-# return fig
-
-# except Exception as e:
-# logger.error(f"Error creating secondary charts: {str(e)}")
-# # Return empty figure on error
-# fig = go.Figure()
-# fig.add_annotation(text=f"Error: {str(e)}", xref="paper", yref="paper", x=0.5, y=0.5)
-# fig.update_layout(template="plotly_dark", height=400)
-# return fig
-
-# def _calculate_rsi(self, prices, period=14):
-# """Calculate RSI indicator"""
-# try:
-# prices = pd.Series(prices)
-# delta = prices.diff()
-# gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
-# loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
-# rs = gain / loss
-# rsi = 100 - (100 / (1 + rs))
-# return rsi.fillna(50).tolist() # Fill NaN with neutral RSI value
-# except Exception:
-# return [50] * len(prices) # Return neutral RSI on error
-
-# def _get_position_list_rows(self):
-# """Get list of current positions for display"""
-# try:
-# if not self.positions:
-# return [html.Div("No open positions", style={"color": "#888", "padding": "10px"})]
-
-# rows = []
-# for i, position in enumerate(self.positions):
-# try:
-# # Calculate current PnL
-# current_pnl = (self.latest_price - position.entry_price) * position.amount
-# if position.action.upper() == 'SELL':
-# current_pnl = -current_pnl
-
-# # Create position row
-# row = html.Div([
-# html.Span(f"#{i+1}: ", style={"fontWeight": "bold"}),
-# html.Span(f"{position.action.upper()} ",
-# style={"color": "#00e676" if position.action.upper() == "BUY" else "#ff1744"}),
-# html.Span(f"{position.amount:.4f} @ ${position.entry_price:.2f} "),
-# html.Span(f"PnL: ${current_pnl:.2f}",
-# style={"color": "#00e676" if current_pnl >= 0 else "#ff1744"})
-# ], style={"padding": "5px", "borderBottom": "1px solid #333"})
-
-# rows.append(row)
-# except Exception as e:
-# logger.warning(f"Error formatting position {i}: {str(e)}")
-
-# return rows
-
-# except Exception as e:
-# logger.error(f"Error getting position list: {str(e)}")
-# return [html.Div("Error loading positions", style={"color": "red", "padding": "10px"})]
-
-# def add_trade(self, action, price, amount, timestamp=None, trade_id=None):
-# """Add a trade to the chart and update tracking"""
-# try:
-# if timestamp is None:
-# timestamp = datetime.now()
-
-# # Create trade record
-# trade = {
-# 'id': trade_id or str(uuid.uuid4()),
-# 'action': action.upper(),
-# 'price': float(price),
-# 'amount': float(amount),
-# 'timestamp': timestamp,
-# 'value': float(price) * float(amount)
-# }
-
-# # Add to trades list
-# self.all_trades.append(trade)
-
-# # Update trade rate tracking
-# self.trade_times.append(time.time())
-# # Keep only last hour of trade times
-# cutoff_time = time.time() - 3600
-# self.trade_times = [t for t in self.trade_times if t > cutoff_time]
-
-# # Update positions
-# if action.upper() in ['BUY', 'SELL']:
-# position = Position(
-# action=action.upper(),
-# entry_price=float(price),
-# amount=float(amount),
-# timestamp=timestamp,
-# trade_id=trade['id']
-# )
-# self.positions.append(position)
-
-# # Update balance and PnL
-# if action.upper() == 'BUY':
-# self.current_balance -= trade['value']
-# else: # SELL
-# self.current_balance += trade['value']
-
-# # Calculate PnL for this trade
-# if len(self.all_trades) > 1:
-# # Simple PnL calculation - more sophisticated logic could be added
-# last_opposite_trades = [t for t in reversed(self.all_trades[:-1])
-# if t['action'] != action.upper()]
-# if last_opposite_trades:
-# last_trade = last_opposite_trades[0]
-# if action.upper() == 'SELL':
-# pnl = (float(price) - last_trade['price']) * float(amount)
-# else: # BUY
-# pnl = (last_trade['price'] - float(price)) * float(amount)
-# self.accumulative_pnl += pnl
-
-# logger.info(f"Added trade: {action.upper()} {amount} @ ${price:.2f}")
-
-# except Exception as e:
-# logger.error(f"Error adding trade: {str(e)}")
-
-# def _get_interval_key(self, interval_seconds):
-# """Convert interval seconds to timeframe key"""
-# if interval_seconds <= 1:
-# return "1s"
-# elif interval_seconds <= 5:
-# return "5s" if "5s" in self.tick_storage.timeframes else "1s"
-# elif interval_seconds <= 15:
-# return "15s" if "15s" in self.tick_storage.timeframes else "1m"
-# elif interval_seconds <= 60:
-# return "1m"
-# elif interval_seconds <= 300:
-# return "5m"
-# elif interval_seconds <= 900:
-# return "15m"
-# elif interval_seconds <= 3600:
-# return "1h"
-# elif interval_seconds <= 14400:
-# return "4h"
-# else:
-# return "1d"
-
-# def _update_main_chart(self, interval_seconds):
-# """Update the main chart for the specified interval"""
-# try:
-# # Convert interval seconds to timeframe key
-# interval_key = self._get_interval_key(interval_seconds)
-
-# # Get candles for this timeframe (limit to last 100 for performance)
-# candles = self.tick_storage.candles.get(interval_key, [])[-100:]
-
-# if not candles:
-# logger.warning(f"No candle data available for {interval_key}")
-# # Return empty figure with a message
-# fig = go.Figure()
-# fig.add_annotation(
-# text=f"No data available for {interval_key}",
-# xref="paper", yref="paper",
-# x=0.5, y=0.5,
-# showarrow=False,
-# font=dict(size=16, color="white")
-# )
-# fig.update_layout(
-# title=f"{self.symbol} - {interval_key} Chart",
-# template="plotly_dark",
-# height=600
-# )
-# return fig
-
-# # Extract data from candles
-# timestamps = [candle['timestamp'] for candle in candles]
-# opens = [candle['open'] for candle in candles]
-# highs = [candle['high'] for candle in candles]
-# lows = [candle['low'] for candle in candles]
-# closes = [candle['close'] for candle in candles]
-# volumes = [candle['volume'] for candle in candles]
-
-# # Create candlestick chart
-# fig = go.Figure()
-
-# # Add candlestick trace
-# fig.add_trace(go.Candlestick(
-# x=timestamps,
-# open=opens,
-# high=highs,
-# low=lows,
-# close=closes,
-# name="Price",
-# increasing_line_color='#26a69a',
-# decreasing_line_color='#ef5350',
-# increasing_fillcolor='#26a69a',
-# decreasing_fillcolor='#ef5350'
-# ))
-
-# # Add trade markers if we have trades
-# if self.all_trades:
-# # Filter trades to match the current timeframe window
-# start_time = timestamps[0] if timestamps else datetime.now() - timedelta(hours=1)
-# end_time = timestamps[-1] if timestamps else datetime.now()
-
-# filtered_trades = [
-# trade for trade in self.all_trades
-# if start_time <= trade['timestamp'] <= end_time
-# ]
-
-# if filtered_trades:
-# buy_trades = [t for t in filtered_trades if t['action'] == 'BUY']
-# sell_trades = [t for t in filtered_trades if t['action'] == 'SELL']
-
-# # Add BUY markers
-# if buy_trades:
-# fig.add_trace(go.Scatter(
-# x=[t['timestamp'] for t in buy_trades],
-# y=[t['price'] for t in buy_trades],
-# mode='markers',
-# marker=dict(
-# symbol='triangle-up',
-# size=12,
-# color='#00e676',
-# line=dict(color='white', width=1)
-# ),
-# name='BUY',
-# text=[f"BUY {t['amount']:.4f} @ ${t['price']:.2f}" for t in buy_trades],
-# hovertemplate='%{text}
Time: %{x}'
-# ))
-
-# # Add SELL markers
-# if sell_trades:
-# fig.add_trace(go.Scatter(
-# x=[t['timestamp'] for t in sell_trades],
-# y=[t['price'] for t in sell_trades],
-# mode='markers',
-# marker=dict(
-# symbol='triangle-down',
-# size=12,
-# color='#ff1744',
-# line=dict(color='white', width=1)
-# ),
-# name='SELL',
-# text=[f"SELL {t['amount']:.4f} @ ${t['price']:.2f}" for t in sell_trades],
-# hovertemplate='%{text}
Time: %{x}'
-# ))
-
-# # Add moving averages if we have enough data
-# if len(closes) >= 20:
-# # 20-period SMA
-# sma_20 = pd.Series(closes).rolling(window=20).mean()
-# fig.add_trace(go.Scatter(
-# x=timestamps,
-# y=sma_20,
-# name='SMA 20',
-# line=dict(color='#ffeb3b', width=1),
-# opacity=0.7
-# ))
-
-# if len(closes) >= 50:
-# # 50-period SMA
-# sma_50 = pd.Series(closes).rolling(window=50).mean()
-# fig.add_trace(go.Scatter(
-# x=timestamps,
-# y=sma_50,
-# name='SMA 50',
-# line=dict(color='#ff9800', width=1),
-# opacity=0.7
-# ))
-
-# # Update layout
-# fig.update_layout(
-# title=f"{self.symbol} - {interval_key} Chart ({len(candles)} candles)",
-# template="plotly_dark",
-# height=600,
-# xaxis_title="Time",
-# yaxis_title="Price ($)",
-# legend=dict(
-# yanchor="top",
-# y=0.99,
-# xanchor="left",
-# x=0.01,
-# bgcolor="rgba(0,0,0,0.5)"
-# ),
-# hovermode='x unified',
-# dragmode='pan'
-# )
-
-# # Remove range slider for better performance
-# fig.update_layout(xaxis_rangeslider_visible=False)
-
-# # Update the latest price
-# if closes:
-# self.latest_price = closes[-1]
-# self.latest_timestamp = timestamps[-1]
-
-# return fig
-
-# except Exception as e:
-# logger.error(f"Error updating main chart: {str(e)}")
-# import traceback
-# logger.error(traceback.format_exc())
-
-# # Return error figure
-# fig = go.Figure()
-# fig.add_annotation(
-# text=f"Chart Error: {str(e)}",
-# xref="paper", yref="paper",
-# x=0.5, y=0.5,
-# showarrow=False,
-# font=dict(size=16, color="red")
-# )
-# fig.update_layout(
-# title="Chart Error",
-# template="plotly_dark",
-# height=600
-# )
-# return fig
-
-# def set_trading_env(self, trading_env):
-# """Set the trading environment to monitor for new trades"""
-# self.trading_env = trading_env
-# if hasattr(trading_env, 'add_trade_callback'):
-# trading_env.add_trade_callback(self.add_trade)
-# logger.info("Trading environment integrated with chart")
-
-# def set_agent(self, agent):
-# """Set the agent to monitor for trading decisions"""
-# self.agent = agent
-# logger.info("Agent integrated with chart")
-
-# def update_from_env(self, env_data):
-# """Update chart data from trading environment"""
-# try:
-# if 'latest_price' in env_data:
-# self.latest_price = env_data['latest_price']
-
-# if 'balance' in env_data:
-# self.current_balance = env_data['balance']
-
-# if 'pnl' in env_data:
-# self.accumulative_pnl = env_data['pnl']
-
-# if 'trades' in env_data:
-# # Add any new trades
-# for trade in env_data['trades']:
-# if trade not in self.all_trades:
-# self.add_trade(
-# action=trade.get('action', 'HOLD'),
-# price=trade.get('price', self.latest_price),
-# amount=trade.get('amount', 0.1),
-# timestamp=trade.get('timestamp', datetime.now()),
-# trade_id=trade.get('id')
-# )
-# except Exception as e:
-# logger.error(f"Error updating from environment: {str(e)}")
-
-# def get_latest_data(self):
-# """Get the latest data for external systems"""
-# return {
-# 'latest_price': self.latest_price,
-# 'latest_volume': self.latest_volume,
-# 'latest_timestamp': self.latest_timestamp,
-# 'current_balance': self.current_balance,
-# 'accumulative_pnl': self.accumulative_pnl,
-# 'positions': len(self.positions),
-# 'trade_count': len(self.all_trades),
-# 'trade_rate': self._calculate_trade_rate()
-# }
-
-# async def start_websocket(self):
-# """Start the websocket connection for real-time data"""
-# try:
-# logger.info("Starting websocket connection for real-time data")
-
-# # Start the websocket data fetching
-# websocket_url = "wss://stream.binance.com:9443/ws/ethusdt@ticker"
-
-# async def websocket_handler():
-# """Handle websocket connection and data updates"""
-# try:
-# async with websockets.connect(websocket_url) as websocket:
-# logger.info(f"WebSocket connected for {self.symbol}")
-# message_count = 0
-
-# async for message in websocket:
-# try:
-# data = json.loads(message)
-
-# # Update tick storage with new price data
-# tick = {
-# 'price': float(data['c']), # Current price
-# 'volume': float(data['v']), # Volume
-# 'timestamp': pd.Timestamp.now()
-# }
-
-# self.tick_storage.add_tick(tick)
-
-# # Update chart's latest price and volume
-# self.latest_price = float(data['c'])
-# self.latest_volume = float(data['v'])
-# self.latest_timestamp = pd.Timestamp.now()
-
-# message_count += 1
-
-# # Log periodic updates
-# if message_count % 100 == 0:
-# logger.info(f"Received message #{message_count}")
-# logger.info(f"Processed {message_count} ticks, current price: ${self.latest_price:.2f}")
-
-# # Log candle counts
-# candle_count = len(self.tick_storage.candles.get("1s", []))
-# logger.info(f"Current 1s candles count: {candle_count}")
-
-# except json.JSONDecodeError as e:
-# logger.warning(f"Failed to parse websocket message: {str(e)}")
-# except Exception as e:
-# logger.error(f"Error processing websocket message: {str(e)}")
-
-# except websockets.exceptions.ConnectionClosed:
-# logger.warning("WebSocket connection closed")
-# except Exception as e:
-# logger.error(f"WebSocket error: {str(e)}")
-
-# # Start the websocket handler in the background
-# await websocket_handler()
-
-# except Exception as e:
-# logger.error(f"Error starting websocket: {str(e)}")
-# import traceback
-# logger.error(traceback.format_exc())
-
-# def run(self, host='127.0.0.1', port=8050, debug=False):
-# """Run the Dash app"""
-# try:
-# if self.app is None:
-# logger.error("No Dash app instance available")
-# return
-
-# logger.info("="*60)
-# logger.info("🔗 ACCESS WEB UI AT: http://localhost:8050/")
-# logger.info("📊 View live trading data and charts in your browser")
-# logger.info("="*60)
-
-# # Run the app - FIXED: Updated for newer Dash versions
-# self.app.run(
-# host=host,
-# port=port,
-# debug=debug,
-# use_reloader=False, # Disable reloader to avoid conflicts
-# threaded=True # Enable threading for better performance
-# )
-# except Exception as e:
-# logger.error(f"Error running Dash app: {str(e)}")
-# import traceback
-# logger.error(traceback.format_exc())
diff --git a/debug/test_fixed_issues.py b/debug/test_fixed_issues.py
deleted file mode 100644
index e4bc8f6..0000000
--- a/debug/test_fixed_issues.py
+++ /dev/null
@@ -1,164 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script to verify that both model prediction and trading statistics issues are fixed
-"""
-
-import sys
-import os
-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
-
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-from core.trading_executor import TradingExecutor
-import asyncio
-import logging
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-async def test_model_predictions():
- """Test that model predictions are working correctly"""
-
- logger.info("=" * 60)
- logger.info("TESTING MODEL PREDICTIONS")
- logger.info("=" * 60)
-
- # Initialize components
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider)
-
- # Check model registration
- logger.info("1. Checking model registration...")
- models = orchestrator.model_registry.get_all_models()
- logger.info(f" Registered models: {list(models.keys()) if models else 'None'}")
-
- # Test making a decision
- logger.info("2. Testing trading decision generation...")
- decision = await orchestrator.make_trading_decision('ETH/USDT')
-
- if decision:
- logger.info(f" ✅ Decision generated: {decision.action} (confidence: {decision.confidence:.3f})")
- logger.info(f" ✅ Reasoning: {decision.reasoning}")
- return True
- else:
- logger.error(" ❌ No decision generated")
- return False
-
-def test_trading_statistics():
- """Test that trading statistics calculations are working correctly"""
-
- logger.info("=" * 60)
- logger.info("TESTING TRADING STATISTICS")
- logger.info("=" * 60)
-
- # Initialize trading executor
- trading_executor = TradingExecutor()
-
- # Check if we have any trades
- trade_history = trading_executor.get_trade_history()
- logger.info(f"1. Current trade history: {len(trade_history)} trades")
-
- # Get daily stats
- daily_stats = trading_executor.get_daily_stats()
- logger.info("2. Daily statistics from trading executor:")
- logger.info(f" Total trades: {daily_stats.get('total_trades', 0)}")
- logger.info(f" Winning trades: {daily_stats.get('winning_trades', 0)}")
- logger.info(f" Losing trades: {daily_stats.get('losing_trades', 0)}")
- logger.info(f" Win rate: {daily_stats.get('win_rate', 0.0) * 100:.1f}%")
- logger.info(f" Avg winning trade: ${daily_stats.get('avg_winning_trade', 0.0):.2f}")
- logger.info(f" Avg losing trade: ${daily_stats.get('avg_losing_trade', 0.0):.2f}")
- logger.info(f" Total P&L: ${daily_stats.get('total_pnl', 0.0):.2f}")
-
- # Simulate some trades if we don't have any
- if daily_stats.get('total_trades', 0) == 0:
- logger.info("3. No trades found - simulating some test trades...")
-
- # Add some mock trades to the trade history
- from core.trading_executor import TradeRecord
- from datetime import datetime
-
- # Add a winning trade
- winning_trade = TradeRecord(
- symbol='ETH/USDT',
- side='LONG',
- quantity=0.01,
- entry_price=2500.0,
- exit_price=2550.0,
- entry_time=datetime.now(),
- exit_time=datetime.now(),
- pnl=0.50, # $0.50 profit
- fees=0.01,
- confidence=0.8
- )
- trading_executor.trade_history.append(winning_trade)
-
- # Add a losing trade
- losing_trade = TradeRecord(
- symbol='ETH/USDT',
- side='LONG',
- quantity=0.01,
- entry_price=2500.0,
- exit_price=2480.0,
- entry_time=datetime.now(),
- exit_time=datetime.now(),
- pnl=-0.20, # $0.20 loss
- fees=0.01,
- confidence=0.7
- )
- trading_executor.trade_history.append(losing_trade)
-
- # Get updated stats
- daily_stats = trading_executor.get_daily_stats()
- logger.info(" Updated statistics after adding test trades:")
- logger.info(f" Total trades: {daily_stats.get('total_trades', 0)}")
- logger.info(f" Winning trades: {daily_stats.get('winning_trades', 0)}")
- logger.info(f" Losing trades: {daily_stats.get('losing_trades', 0)}")
- logger.info(f" Win rate: {daily_stats.get('win_rate', 0.0) * 100:.1f}%")
- logger.info(f" Avg winning trade: ${daily_stats.get('avg_winning_trade', 0.0):.2f}")
- logger.info(f" Avg losing trade: ${daily_stats.get('avg_losing_trade', 0.0):.2f}")
- logger.info(f" Total P&L: ${daily_stats.get('total_pnl', 0.0):.2f}")
-
- # Verify calculations
- expected_win_rate = 1/2 # 1 win out of 2 trades = 50%
- expected_avg_win = 0.50
- expected_avg_loss = -0.20
-
- actual_win_rate = daily_stats.get('win_rate', 0.0)
- actual_avg_win = daily_stats.get('avg_winning_trade', 0.0)
- actual_avg_loss = daily_stats.get('avg_losing_trade', 0.0)
-
- logger.info("4. Verifying calculations:")
- logger.info(f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {actual_win_rate*100:.1f}% ✅" if abs(actual_win_rate - expected_win_rate) < 0.01 else f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {actual_win_rate*100:.1f}% ❌")
- logger.info(f" Avg win: Expected ${expected_avg_win:.2f}, Got ${actual_avg_win:.2f} ✅" if abs(actual_avg_win - expected_avg_win) < 0.01 else f" Avg win: Expected ${expected_avg_win:.2f}, Got ${actual_avg_win:.2f} ❌")
- logger.info(f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${actual_avg_loss:.2f} ✅" if abs(actual_avg_loss - expected_avg_loss) < 0.01 else f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${actual_avg_loss:.2f} ❌")
-
- return True
-
- return True
-
-async def main():
- """Run all tests"""
-
- logger.info("🚀 STARTING COMPREHENSIVE FIXES TEST")
- logger.info("Testing both model prediction fixes and trading statistics fixes")
-
- # Test model predictions
- prediction_success = await test_model_predictions()
-
- # Test trading statistics
- stats_success = test_trading_statistics()
-
- logger.info("=" * 60)
- logger.info("TEST SUMMARY")
- logger.info("=" * 60)
- logger.info(f"Model Predictions: {'✅ FIXED' if prediction_success else '❌ STILL BROKEN'}")
- logger.info(f"Trading Statistics: {'✅ FIXED' if stats_success else '❌ STILL BROKEN'}")
-
- if prediction_success and stats_success:
- logger.info("🎉 ALL ISSUES FIXED! The system should now work correctly.")
- else:
- logger.error("❌ Some issues remain. Check the logs above for details.")
-
-if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
diff --git a/debug/test_trading_fixes.py b/debug/test_trading_fixes.py
deleted file mode 100644
index 79eca49..0000000
--- a/debug/test_trading_fixes.py
+++ /dev/null
@@ -1,250 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script to verify trading fixes:
-1. Position sizes with leverage
-2. ETH-only trading
-3. Correct win rate calculations
-4. Meaningful P&L values
-"""
-
-import sys
-import os
-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
-
-from core.trading_executor import TradingExecutor
-from core.trading_executor import TradeRecord
-from datetime import datetime
-import logging
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_position_sizing():
- """Test that position sizing now includes leverage and meaningful amounts"""
-
- logger.info("=" * 60)
- logger.info("TESTING POSITION SIZING WITH LEVERAGE")
- logger.info("=" * 60)
-
- # Initialize trading executor
- trading_executor = TradingExecutor()
-
- # Test position calculation
- confidence = 0.8
- current_price = 2500.0 # ETH price
-
- position_value = trading_executor._calculate_position_size(confidence, current_price)
- quantity = position_value / current_price
-
- logger.info(f"1. Position calculation test:")
- logger.info(f" Confidence: {confidence}")
- logger.info(f" ETH Price: ${current_price}")
- logger.info(f" Position Value: ${position_value:.2f}")
- logger.info(f" Quantity: {quantity:.6f} ETH")
-
- # Check if position is meaningful
- if position_value > 1000: # Should be >$1000 with 10x leverage
- logger.info(" ✅ Position size is meaningful (>$1000)")
- else:
- logger.error(f" ❌ Position size too small: ${position_value:.2f}")
-
- # Test different confidence levels
- logger.info("2. Testing different confidence levels:")
- for conf in [0.2, 0.5, 0.8, 1.0]:
- pos_val = trading_executor._calculate_position_size(conf, current_price)
- qty = pos_val / current_price
- logger.info(f" Confidence {conf}: ${pos_val:.2f} ({qty:.6f} ETH)")
-
-def test_eth_only_restriction():
- """Test that only ETH trades are allowed"""
-
- logger.info("=" * 60)
- logger.info("TESTING ETH-ONLY TRADING RESTRICTION")
- logger.info("=" * 60)
-
- trading_executor = TradingExecutor()
-
- # Test ETH trade (should be allowed)
- logger.info("1. Testing ETH/USDT trade (should be allowed):")
- eth_allowed = trading_executor._check_safety_conditions('ETH/USDT', 'BUY')
- logger.info(f" ETH/USDT allowed: {'✅ YES' if eth_allowed else '❌ NO'}")
-
- # Test BTC trade (should be blocked)
- logger.info("2. Testing BTC/USDT trade (should be blocked):")
- btc_allowed = trading_executor._check_safety_conditions('BTC/USDT', 'BUY')
- logger.info(f" BTC/USDT allowed: {'❌ YES (ERROR!)' if btc_allowed else '✅ NO (CORRECT)'}")
-
-def test_win_rate_calculation():
- """Test that win rate calculations are correct"""
-
- logger.info("=" * 60)
- logger.info("TESTING WIN RATE CALCULATIONS")
- logger.info("=" * 60)
-
- trading_executor = TradingExecutor()
-
- # Clear existing trades
- trading_executor.trade_history = []
-
- # Add test trades with meaningful P&L
- logger.info("1. Adding test trades with meaningful P&L:")
-
- # Add 3 winning trades
- for i in range(3):
- winning_trade = TradeRecord(
- symbol='ETH/USDT',
- side='LONG',
- quantity=1.0,
- entry_price=2500.0,
- exit_price=2550.0,
- entry_time=datetime.now(),
- exit_time=datetime.now(),
- pnl=50.0, # $50 profit with leverage
- fees=1.0,
- confidence=0.8,
- hold_time_seconds=30.0 # 30 second hold
- )
- trading_executor.trade_history.append(winning_trade)
- logger.info(f" Added winning trade #{i+1}: +$50.00 (30s hold)")
-
- # Add 2 losing trades
- for i in range(2):
- losing_trade = TradeRecord(
- symbol='ETH/USDT',
- side='LONG',
- quantity=1.0,
- entry_price=2500.0,
- exit_price=2475.0,
- entry_time=datetime.now(),
- exit_time=datetime.now(),
- pnl=-25.0, # $25 loss with leverage
- fees=1.0,
- confidence=0.7,
- hold_time_seconds=15.0 # 15 second hold
- )
- trading_executor.trade_history.append(losing_trade)
- logger.info(f" Added losing trade #{i+1}: -$25.00 (15s hold)")
-
- # Get statistics
- stats = trading_executor.get_daily_stats()
-
- logger.info("2. Calculated statistics:")
- logger.info(f" Total trades: {stats['total_trades']}")
- logger.info(f" Winning trades: {stats['winning_trades']}")
- logger.info(f" Losing trades: {stats['losing_trades']}")
- logger.info(f" Win rate: {stats['win_rate']*100:.1f}%")
- logger.info(f" Avg winning trade: ${stats['avg_winning_trade']:.2f}")
- logger.info(f" Avg losing trade: ${stats['avg_losing_trade']:.2f}")
- logger.info(f" Total P&L: ${stats['total_pnl']:.2f}")
-
- # Verify calculations
- expected_win_rate = 3/5 # 3 wins out of 5 trades = 60%
- expected_avg_win = 50.0
- expected_avg_loss = -25.0
-
- logger.info("3. Verification:")
- win_rate_ok = abs(stats['win_rate'] - expected_win_rate) < 0.01
- avg_win_ok = abs(stats['avg_winning_trade'] - expected_avg_win) < 0.01
- avg_loss_ok = abs(stats['avg_losing_trade'] - expected_avg_loss) < 0.01
-
- logger.info(f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {stats['win_rate']*100:.1f}% {'✅' if win_rate_ok else '❌'}")
- logger.info(f" Avg win: Expected ${expected_avg_win:.2f}, Got ${stats['avg_winning_trade']:.2f} {'✅' if avg_win_ok else '❌'}")
- logger.info(f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${stats['avg_losing_trade']:.2f} {'✅' if avg_loss_ok else '❌'}")
-
- return win_rate_ok and avg_win_ok and avg_loss_ok
-
-def test_new_features():
- """Test new features: hold time, leverage, percentage-based sizing"""
-
- logger.info("=" * 60)
- logger.info("TESTING NEW FEATURES")
- logger.info("=" * 60)
-
- trading_executor = TradingExecutor()
-
- # Test account info
- account_info = trading_executor.get_account_info()
- logger.info(f"1. Account Information:")
- logger.info(f" Account Balance: ${account_info['account_balance']:.2f}")
- logger.info(f" Leverage: {account_info['leverage']:.0f}x")
- logger.info(f" Trading Mode: {account_info['trading_mode']}")
- logger.info(f" Position Sizing: {account_info['position_sizing']['base_percent']:.1f}% base")
-
- # Test leverage setting
- logger.info("2. Testing leverage control:")
- old_leverage = trading_executor.get_leverage()
- logger.info(f" Current leverage: {old_leverage:.0f}x")
-
- success = trading_executor.set_leverage(100.0)
- new_leverage = trading_executor.get_leverage()
- logger.info(f" Set to 100x: {'✅ SUCCESS' if success and new_leverage == 100.0 else '❌ FAILED'}")
-
- # Reset leverage
- trading_executor.set_leverage(old_leverage)
-
- # Test percentage-based position sizing
- logger.info("3. Testing percentage-based position sizing:")
- confidence = 0.8
- eth_price = 2500.0
-
- position_value = trading_executor._calculate_position_size(confidence, eth_price)
- account_balance = trading_executor._get_account_balance_for_sizing()
- base_percent = trading_executor.mexc_config.get('base_position_percent', 5.0)
- leverage = trading_executor.get_leverage()
-
- expected_base = account_balance * (base_percent / 100.0) * confidence
- expected_leveraged = expected_base * leverage
-
- logger.info(f" Account: ${account_balance:.2f}")
- logger.info(f" Base %: {base_percent:.1f}%")
- logger.info(f" Confidence: {confidence:.1f}")
- logger.info(f" Leverage: {leverage:.0f}x")
- logger.info(f" Expected base: ${expected_base:.2f}")
- logger.info(f" Expected leveraged: ${expected_leveraged:.2f}")
- logger.info(f" Actual: ${position_value:.2f}")
-
- sizing_ok = abs(position_value - expected_leveraged) < 0.01
- logger.info(f" Percentage sizing: {'✅ CORRECT' if sizing_ok else '❌ INCORRECT'}")
-
- return sizing_ok
-
-def main():
- """Run all tests"""
-
- logger.info("🚀 TESTING TRADING FIXES AND NEW FEATURES")
- logger.info("Testing position sizing, ETH-only trading, win rate calculations, and new features")
-
- # Test position sizing
- test_position_sizing()
-
- # Test ETH-only restriction
- test_eth_only_restriction()
-
- # Test win rate calculation
- calculation_success = test_win_rate_calculation()
-
- # Test new features
- features_success = test_new_features()
-
- logger.info("=" * 60)
- logger.info("TEST SUMMARY")
- logger.info("=" * 60)
- logger.info(f"Position Sizing: ✅ Updated with percentage-based leverage")
- logger.info(f"ETH-Only Trading: ✅ Configured in config")
- logger.info(f"Win Rate Calculation: {'✅ FIXED' if calculation_success else '❌ STILL BROKEN'}")
- logger.info(f"New Features: {'✅ WORKING' if features_success else '❌ ISSUES FOUND'}")
-
- if calculation_success and features_success:
- logger.info("🎉 ALL FEATURES WORKING! Now you should see:")
- logger.info(" - Percentage-based position sizing (2-20% of account)")
- logger.info(" - 50x leverage (adjustable in UI)")
- logger.info(" - Hold time in seconds for each trade")
- logger.info(" - Total fees in trading statistics")
- logger.info(" - Only ETH/USDT trades")
- logger.info(" - Correct win rate calculations")
- else:
- logger.error("❌ Some issues remain. Check the logs above for details.")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/debug/trade_audit.py b/debug/trade_audit.py
deleted file mode 100644
index 04efd8b..0000000
--- a/debug/trade_audit.py
+++ /dev/null
@@ -1,344 +0,0 @@
-#!/usr/bin/env python3
-"""
-Trade Audit Tool
-
-This tool analyzes trade data to identify potential issues with:
-- Duplicate entry prices
-- Rapid consecutive trades
-- P&L calculation accuracy
-- Position tracking problems
-
-Usage:
- python debug/trade_audit.py [--trades-file path/to/trades.json]
-"""
-
-import argparse
-import json
-import pandas as pd
-import numpy as np
-from datetime import datetime, timedelta
-import matplotlib.pyplot as plt
-import os
-import sys
-from pathlib import Path
-
-# Add project root to path
-project_root = Path(__file__).parent.parent
-sys.path.insert(0, str(project_root))
-
-def parse_trade_time(time_str):
- """Parse trade time string to datetime object"""
- try:
- # Try HH:MM:SS format
- return datetime.strptime(time_str, "%H:%M:%S")
- except ValueError:
- try:
- # Try full datetime format
- return datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
- except ValueError:
- # Return as is if parsing fails
- return time_str
-
-def load_trades_from_file(file_path):
- """Load trades from JSON file"""
- try:
- with open(file_path, 'r') as f:
- return json.load(f)
- except FileNotFoundError:
- print(f"Error: File {file_path} not found")
- return []
- except json.JSONDecodeError:
- print(f"Error: File {file_path} is not valid JSON")
- return []
-
-def load_trades_from_dashboard_cache():
- """Load trades from dashboard cache file if available"""
- cache_paths = [
- "cache/dashboard_trades.json",
- "cache/closed_trades.json",
- "data/trades_history.json"
- ]
-
- for path in cache_paths:
- if os.path.exists(path):
- print(f"Loading trades from cache: {path}")
- return load_trades_from_file(path)
-
- print("No trade cache files found")
- return []
-
-def parse_trade_data(trades_data):
- """Parse trade data into a pandas DataFrame for analysis"""
- parsed_trades = []
-
- for trade in trades_data:
- # Handle different trade data formats
- parsed_trade = {}
-
- # Time field might be named entry_time or time
- if 'entry_time' in trade:
- parsed_trade['time'] = parse_trade_time(trade['entry_time'])
- elif 'time' in trade:
- parsed_trade['time'] = parse_trade_time(trade['time'])
- else:
- parsed_trade['time'] = None
-
- # Side might be named side or action
- parsed_trade['side'] = trade.get('side', trade.get('action', 'UNKNOWN'))
-
- # Size might be named size or quantity
- parsed_trade['size'] = float(trade.get('size', trade.get('quantity', 0)))
-
- # Entry and exit prices
- parsed_trade['entry_price'] = float(trade.get('entry_price', trade.get('entry', 0)))
- parsed_trade['exit_price'] = float(trade.get('exit_price', trade.get('exit', 0)))
-
- # Hold time in seconds
- parsed_trade['hold_time'] = float(trade.get('hold_time_seconds', trade.get('hold', 0)))
-
- # P&L and fees
- parsed_trade['pnl'] = float(trade.get('pnl', 0))
- parsed_trade['fees'] = float(trade.get('fees', 0))
-
- # Calculate expected P&L for verification
- if parsed_trade['side'] == 'LONG' or parsed_trade['side'] == 'BUY':
- expected_pnl = (parsed_trade['exit_price'] - parsed_trade['entry_price']) * parsed_trade['size']
- else: # SHORT or SELL
- expected_pnl = (parsed_trade['entry_price'] - parsed_trade['exit_price']) * parsed_trade['size']
-
- parsed_trade['expected_pnl'] = expected_pnl
- parsed_trade['pnl_difference'] = parsed_trade['pnl'] - expected_pnl
-
- parsed_trades.append(parsed_trade)
-
- # Convert to DataFrame
- if parsed_trades:
- df = pd.DataFrame(parsed_trades)
- return df
- else:
- return pd.DataFrame()
-
-def analyze_trades(df):
- """Analyze trades for potential issues"""
- if df.empty:
- print("No trades to analyze")
- return
-
- print(f"\n{'='*50}")
- print("TRADE AUDIT RESULTS")
- print(f"{'='*50}")
- print(f"Total trades analyzed: {len(df)}")
-
- # Check for duplicate entry prices
- entry_price_counts = df['entry_price'].value_counts()
- duplicate_entries = entry_price_counts[entry_price_counts > 1]
-
- print(f"\n{'='*20} DUPLICATE ENTRY PRICES {'='*20}")
- if not duplicate_entries.empty:
- print(f"Found {len(duplicate_entries)} prices with multiple entries:")
- for price, count in duplicate_entries.items():
- print(f" ${price:.2f}: {count} trades")
-
- # Analyze the duplicate entry trades in more detail
- for price in duplicate_entries.index:
- duplicate_df = df[df['entry_price'] == price].copy()
- duplicate_df['time_diff'] = duplicate_df['time'].diff().dt.total_seconds()
-
- print(f"\nDetailed analysis for entry price ${price:.2f}:")
- print(f" Time gaps between consecutive trades:")
- for i, (_, row) in enumerate(duplicate_df.iterrows()):
- if i > 0: # Skip first row as it has no previous trade
- time_diff = row['time_diff']
- if pd.notna(time_diff):
- print(f" {row['time'].strftime('%H:%M:%S')}: {time_diff:.0f} seconds after previous trade")
- else:
- print("No duplicate entry prices found")
-
- # Check for rapid consecutive trades
- df = df.sort_values('time')
- df['time_since_last'] = df['time'].diff().dt.total_seconds()
-
- rapid_trades = df[df['time_since_last'] < 30].copy()
-
- print(f"\n{'='*20} RAPID CONSECUTIVE TRADES {'='*20}")
- if not rapid_trades.empty:
- print(f"Found {len(rapid_trades)} trades executed within 30 seconds of previous trade:")
- for _, row in rapid_trades.iterrows():
- if pd.notna(row['time_since_last']):
- print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} ${row['size']:.2f} @ ${row['entry_price']:.2f} - {row['time_since_last']:.0f}s after previous")
- else:
- print("No rapid consecutive trades found")
-
- # Check for P&L calculation accuracy
- pnl_diff = df[abs(df['pnl_difference']) > 0.01].copy()
-
- print(f"\n{'='*20} P&L CALCULATION ISSUES {'='*20}")
- if not pnl_diff.empty:
- print(f"Found {len(pnl_diff)} trades with P&L calculation discrepancies:")
- for _, row in pnl_diff.iterrows():
- print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} - Reported: ${row['pnl']:.2f}, Expected: ${row['expected_pnl']:.2f}, Diff: ${row['pnl_difference']:.2f}")
- else:
- print("No P&L calculation issues found")
-
- # Check for side distribution
- side_counts = df['side'].value_counts()
-
- print(f"\n{'='*20} TRADE SIDE DISTRIBUTION {'='*20}")
- for side, count in side_counts.items():
- print(f" {side}: {count} trades ({count/len(df)*100:.1f}%)")
-
- # Check for hold time distribution
- print(f"\n{'='*20} HOLD TIME DISTRIBUTION {'='*20}")
- print(f" Min hold time: {df['hold_time'].min():.0f} seconds")
- print(f" Max hold time: {df['hold_time'].max():.0f} seconds")
- print(f" Avg hold time: {df['hold_time'].mean():.0f} seconds")
- print(f" Median hold time: {df['hold_time'].median():.0f} seconds")
-
- # Hold time buckets
- hold_buckets = [0, 30, 60, 120, 300, 600, 1800, 3600, float('inf')]
- hold_labels = ['0-30s', '30-60s', '1-2m', '2-5m', '5-10m', '10-30m', '30-60m', '60m+']
-
- df['hold_bucket'] = pd.cut(df['hold_time'], bins=hold_buckets, labels=hold_labels)
- hold_dist = df['hold_bucket'].value_counts().sort_index()
-
- for bucket, count in hold_dist.items():
- print(f" {bucket}: {count} trades ({count/len(df)*100:.1f}%)")
-
- # Generate summary statistics
- print(f"\n{'='*20} TRADE PERFORMANCE SUMMARY {'='*20}")
- winning_trades = df[df['pnl'] > 0]
- losing_trades = df[df['pnl'] < 0]
-
- print(f" Win rate: {len(winning_trades)/len(df)*100:.1f}% ({len(winning_trades)}W/{len(losing_trades)}L)")
- print(f" Avg win: ${winning_trades['pnl'].mean():.2f}")
- print(f" Avg loss: ${abs(losing_trades['pnl'].mean()):.2f}")
- print(f" Total P&L: ${df['pnl'].sum():.2f}")
- print(f" Total fees: ${df['fees'].sum():.2f}")
- print(f" Net P&L: ${(df['pnl'].sum() - df['fees'].sum()):.2f}")
-
- # Plot entry price distribution
- plt.figure(figsize=(10, 6))
- plt.hist(df['entry_price'], bins=20, alpha=0.7)
- plt.title('Entry Price Distribution')
- plt.xlabel('Entry Price ($)')
- plt.ylabel('Number of Trades')
- plt.grid(True, alpha=0.3)
- plt.savefig('debug/entry_price_distribution.png')
-
- # Plot P&L distribution
- plt.figure(figsize=(10, 6))
- plt.hist(df['pnl'], bins=20, alpha=0.7)
- plt.title('P&L Distribution')
- plt.xlabel('P&L ($)')
- plt.ylabel('Number of Trades')
- plt.grid(True, alpha=0.3)
- plt.savefig('debug/pnl_distribution.png')
-
- print(f"\n{'='*20} AUDIT COMPLETE {'='*20}")
- print("Plots saved to debug/entry_price_distribution.png and debug/pnl_distribution.png")
-
-def analyze_manual_trades(trades_data):
- """Analyze manually provided trade data"""
- # Parse the trade data into a structured format
- parsed_trades = []
-
- for line in trades_data.strip().split('\n'):
- if not line or line.startswith('from last session') or line.startswith('Recent Closed Trades') or line.startswith('Trading Performance'):
- continue
-
- if line.startswith('Win Rate:'):
- # This is the summary line, skip it
- continue
-
- try:
- # Parse trade line format: Time Side Size Entry Exit Hold P&L Fees
- parts = line.split('$')
-
- time_side = parts[0].strip().split()
- time = time_side[0]
- side = time_side[1]
-
- size = float(parts[1].split()[0])
- entry = float(parts[2].split()[0])
- exit = float(parts[3].split()[0])
-
- # The hold time and P&L are in the last parts
- remaining = parts[3].split()
- hold = int(remaining[1])
- pnl = float(parts[4].split()[0])
-
- # Fees might be in a different format
- if len(parts) > 5:
- fees = float(parts[5].strip())
- else:
- fees = 0.0
-
- parsed_trade = {
- 'time': parse_trade_time(time),
- 'side': side,
- 'size': size,
- 'entry_price': entry,
- 'exit_price': exit,
- 'hold_time': hold,
- 'pnl': pnl,
- 'fees': fees
- }
-
- # Calculate expected P&L
- if side == 'LONG' or side == 'BUY':
- expected_pnl = (exit - entry) * size
- else: # SHORT or SELL
- expected_pnl = (entry - exit) * size
-
- parsed_trade['expected_pnl'] = expected_pnl
- parsed_trade['pnl_difference'] = pnl - expected_pnl
-
- parsed_trades.append(parsed_trade)
-
- except Exception as e:
- print(f"Error parsing trade line: {line}")
- print(f"Error details: {e}")
-
- # Convert to DataFrame
- if parsed_trades:
- df = pd.DataFrame(parsed_trades)
- return df
- else:
- return pd.DataFrame()
-
-def main():
- parser = argparse.ArgumentParser(description='Trade Audit Tool')
- parser.add_argument('--trades-file', type=str, help='Path to trades JSON file')
- parser.add_argument('--manual-trades', type=str, help='Path to text file with manually entered trades')
- args = parser.parse_args()
-
- # Create debug directory if it doesn't exist
- os.makedirs('debug', exist_ok=True)
-
- if args.trades_file:
- trades_data = load_trades_from_file(args.trades_file)
- df = parse_trade_data(trades_data)
- elif args.manual_trades:
- try:
- with open(args.manual_trades, 'r') as f:
- manual_trades = f.read()
- df = analyze_manual_trades(manual_trades)
- except Exception as e:
- print(f"Error reading manual trades file: {e}")
- df = pd.DataFrame()
- else:
- # Try to load from dashboard cache
- trades_data = load_trades_from_dashboard_cache()
- if trades_data:
- df = parse_trade_data(trades_data)
- else:
- print("No trade data provided. Use --trades-file or --manual-trades")
- return
-
- if not df.empty:
- analyze_trades(df)
- else:
- print("No valid trade data to analyze")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/debug_training_methods.py b/debug_training_methods.py
deleted file mode 100644
index f7e5813..0000000
--- a/debug_training_methods.py
+++ /dev/null
@@ -1,84 +0,0 @@
-#!/usr/bin/env python3
-"""
-Debug Training Methods
-
-This script checks what training methods are available on each model.
-"""
-
-import asyncio
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-
-async def debug_training_methods():
- """Debug the available training methods on each model"""
- print("=== Debugging Training Methods ===")
-
- # Initialize orchestrator
- print("1. Initializing orchestrator...")
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider=data_provider)
-
- # Wait for initialization
- await asyncio.sleep(2)
-
- print("\n2. Checking available training methods on each model:")
-
- for model_name, model_interface in orchestrator.model_registry.models.items():
- print(f"\n--- {model_name} ---")
- print(f"Interface type: {type(model_interface).__name__}")
-
- # Get underlying model
- underlying_model = getattr(model_interface, 'model', None)
- if underlying_model:
- print(f"Underlying model type: {type(underlying_model).__name__}")
- else:
- print("No underlying model found")
- continue
-
- # Check for training methods
- training_methods = []
- for method in ['train_on_outcome', 'add_experience', 'remember', 'replay', 'add_training_sample', 'train', 'train_with_reward', 'update_loss']:
- if hasattr(underlying_model, method):
- training_methods.append(method)
-
- print(f"Available training methods: {training_methods}")
-
- # Check for specific attributes
- attributes = []
- for attr in ['memory', 'batch_size', 'training_data']:
- if hasattr(underlying_model, attr):
- attr_value = getattr(underlying_model, attr)
- if attr == 'memory' and hasattr(attr_value, '__len__'):
- attributes.append(f"{attr}(len={len(attr_value)})")
- elif attr == 'training_data' and hasattr(attr_value, '__len__'):
- attributes.append(f"{attr}(len={len(attr_value)})")
- else:
- attributes.append(f"{attr}={attr_value}")
-
- print(f"Relevant attributes: {attributes}")
-
- # Check if it's an RL agent
- if hasattr(underlying_model, 'act') and hasattr(underlying_model, 'remember'):
- print("✅ Detected as RL Agent")
- elif hasattr(underlying_model, 'predict') and hasattr(underlying_model, 'add_training_sample'):
- print("✅ Detected as CNN Model")
- else:
- print("❓ Unknown model type")
-
- print("\n3. Testing a simple training attempt:")
-
- # Get a prediction first
- predictions = await orchestrator._get_all_predictions('ETH/USDT')
- print(f"Got {len(predictions)} predictions")
-
- # Try to trigger training for each model
- for model_name in orchestrator.model_registry.models.keys():
- print(f"\nTesting training for {model_name}...")
- try:
- await orchestrator._trigger_immediate_training_for_model(model_name, 'ETH/USDT')
- print(f"✅ Training attempt completed for {model_name}")
- except Exception as e:
- print(f"❌ Training failed for {model_name}: {e}")
-
-if __name__ == "__main__":
- asyncio.run(debug_training_methods())
\ No newline at end of file
diff --git a/docs/exchanges/bybit/examples.py b/docs/exchanges/bybit/examples.py
deleted file mode 100644
index 4827533..0000000
--- a/docs/exchanges/bybit/examples.py
+++ /dev/null
@@ -1,233 +0,0 @@
-"""
-Bybit Integration Examples
-Based on official pybit library documentation and examples
-"""
-
-import os
-from pybit.unified_trading import HTTP
-import logging
-
-# Configure logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def create_bybit_session(testnet=True):
- """Create a Bybit HTTP session.
-
- Args:
- testnet (bool): Use testnet if True, live if False
-
- Returns:
- HTTP: Bybit session object
- """
- api_key = os.getenv('BYBIT_API_KEY')
- api_secret = os.getenv('BYBIT_API_SECRET')
-
- if not api_key or not api_secret:
- raise ValueError("BYBIT_API_KEY and BYBIT_API_SECRET must be set in environment")
-
- session = HTTP(
- testnet=testnet,
- api_key=api_key,
- api_secret=api_secret,
- )
-
- logger.info(f"Created Bybit session (testnet: {testnet})")
- return session
-
-def get_account_info(session):
- """Get account information and balances."""
- try:
- # Get account info
- account_info = session.get_wallet_balance(accountType="UNIFIED")
- logger.info(f"Account info: {account_info}")
-
- return account_info
- except Exception as e:
- logger.error(f"Error getting account info: {e}")
- return None
-
-def get_ticker_info(session, symbol="BTCUSDT"):
- """Get ticker information for a symbol.
-
- Args:
- session: Bybit HTTP session
- symbol: Trading symbol (default: BTCUSDT)
- """
- try:
- ticker = session.get_tickers(category="linear", symbol=symbol)
- logger.info(f"Ticker for {symbol}: {ticker}")
- return ticker
- except Exception as e:
- logger.error(f"Error getting ticker for {symbol}: {e}")
- return None
-
-def get_orderbook(session, symbol="BTCUSDT", limit=25):
- """Get orderbook for a symbol.
-
- Args:
- session: Bybit HTTP session
- symbol: Trading symbol
- limit: Number of price levels to return
- """
- try:
- orderbook = session.get_orderbook(
- category="linear",
- symbol=symbol,
- limit=limit
- )
- logger.info(f"Orderbook for {symbol}: {orderbook}")
- return orderbook
- except Exception as e:
- logger.error(f"Error getting orderbook for {symbol}: {e}")
- return None
-
-def place_limit_order(session, symbol="BTCUSDT", side="Buy", qty="0.001", price="50000"):
- """Place a limit order.
-
- Args:
- session: Bybit HTTP session
- symbol: Trading symbol
- side: "Buy" or "Sell"
- qty: Order quantity as string
- price: Order price as string
- """
- try:
- order = session.place_order(
- category="linear",
- symbol=symbol,
- side=side,
- orderType="Limit",
- qty=qty,
- price=price,
- timeInForce="GTC" # Good Till Cancelled
- )
- logger.info(f"Placed order: {order}")
- return order
- except Exception as e:
- logger.error(f"Error placing order: {e}")
- return None
-
-def place_market_order(session, symbol="BTCUSDT", side="Buy", qty="0.001"):
- """Place a market order.
-
- Args:
- session: Bybit HTTP session
- symbol: Trading symbol
- side: "Buy" or "Sell"
- qty: Order quantity as string
- """
- try:
- order = session.place_order(
- category="linear",
- symbol=symbol,
- side=side,
- orderType="Market",
- qty=qty
- )
- logger.info(f"Placed market order: {order}")
- return order
- except Exception as e:
- logger.error(f"Error placing market order: {e}")
- return None
-
-def get_open_orders(session, symbol=None):
- """Get open orders.
-
- Args:
- session: Bybit HTTP session
- symbol: Trading symbol (optional, gets all if None)
- """
- try:
- params = {"category": "linear", "openOnly": True}
- if symbol:
- params["symbol"] = symbol
-
- orders = session.get_open_orders(**params)
- logger.info(f"Open orders: {orders}")
- return orders
- except Exception as e:
- logger.error(f"Error getting open orders: {e}")
- return None
-
-def cancel_order(session, symbol, order_id):
- """Cancel an order.
-
- Args:
- session: Bybit HTTP session
- symbol: Trading symbol
- order_id: Order ID to cancel
- """
- try:
- result = session.cancel_order(
- category="linear",
- symbol=symbol,
- orderId=order_id
- )
- logger.info(f"Cancelled order {order_id}: {result}")
- return result
- except Exception as e:
- logger.error(f"Error cancelling order {order_id}: {e}")
- return None
-
-def get_position(session, symbol="BTCUSDT"):
- """Get position information.
-
- Args:
- session: Bybit HTTP session
- symbol: Trading symbol
- """
- try:
- positions = session.get_positions(
- category="linear",
- symbol=symbol
- )
- logger.info(f"Position for {symbol}: {positions}")
- return positions
- except Exception as e:
- logger.error(f"Error getting position for {symbol}: {e}")
- return None
-
-def get_trade_history(session, symbol="BTCUSDT", limit=50):
- """Get trade history.
-
- Args:
- session: Bybit HTTP session
- symbol: Trading symbol
- limit: Number of trades to return
- """
- try:
- trades = session.get_executions(
- category="linear",
- symbol=symbol,
- limit=limit
- )
- logger.info(f"Trade history for {symbol}: {trades}")
- return trades
- except Exception as e:
- logger.error(f"Error getting trade history for {symbol}: {e}")
- return None
-
-# Example usage
-if __name__ == "__main__":
- # Create session (testnet by default)
- session = create_bybit_session(testnet=True)
-
- # Get account info
- account_info = get_account_info(session)
-
- # Get ticker
- ticker = get_ticker_info(session, "BTCUSDT")
-
- # Get orderbook
- orderbook = get_orderbook(session, "BTCUSDT")
-
- # Get open orders
- open_orders = get_open_orders(session)
-
- # Get position
- position = get_position(session, "BTCUSDT")
-
- # Note: Uncomment below to actually place orders (use with caution)
- # order = place_limit_order(session, "BTCUSDT", "Buy", "0.001", "30000")
- # market_order = place_market_order(session, "BTCUSDT", "Buy", "0.001")
\ No newline at end of file
diff --git a/example_usage_simplified_data_provider.py b/example_usage_simplified_data_provider.py
deleted file mode 100644
index a3a3d14..0000000
--- a/example_usage_simplified_data_provider.py
+++ /dev/null
@@ -1,89 +0,0 @@
-#!/usr/bin/env python3
-"""
-Example usage of the simplified data provider
-"""
-
-import time
-import logging
-from core.data_provider import DataProvider
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def main():
- """Demonstrate the simplified data provider usage"""
-
- # Initialize data provider (starts automatic maintenance)
- logger.info("Initializing DataProvider...")
- dp = DataProvider()
-
- # Wait for initial data load (happens automatically in background)
- logger.info("Waiting for initial data load...")
- time.sleep(15) # Give it time to load data
-
- # Example 1: Get cached historical data (no API calls)
- logger.info("\n=== Example 1: Getting Historical Data ===")
- eth_1m_data = dp.get_historical_data('ETH/USDT', '1m', limit=50)
- if eth_1m_data is not None:
- logger.info(f"ETH/USDT 1m data: {len(eth_1m_data)} candles")
- logger.info(f"Latest candle: {eth_1m_data.iloc[-1]['close']}")
-
- # Example 2: Get current prices
- logger.info("\n=== Example 2: Current Prices ===")
- eth_price = dp.get_current_price('ETH/USDT')
- btc_price = dp.get_current_price('BTC/USDT')
- logger.info(f"ETH current price: ${eth_price}")
- logger.info(f"BTC current price: ${btc_price}")
-
- # Example 3: Check cache status
- logger.info("\n=== Example 3: Cache Status ===")
- cache_summary = dp.get_cached_data_summary()
- for symbol in cache_summary['cached_data']:
- logger.info(f"\n{symbol}:")
- for timeframe, info in cache_summary['cached_data'][symbol].items():
- if 'candle_count' in info and info['candle_count'] > 0:
- logger.info(f" {timeframe}: {info['candle_count']} candles, latest: ${info['latest_price']}")
- else:
- logger.info(f" {timeframe}: {info.get('status', 'no data')}")
-
- # Example 4: Multiple timeframe data
- logger.info("\n=== Example 4: Multiple Timeframes ===")
- for tf in ['1s', '1m', '1h', '1d']:
- data = dp.get_historical_data('ETH/USDT', tf, limit=5)
- if data is not None and not data.empty:
- logger.info(f"ETH {tf}: {len(data)} candles, range: ${data['close'].min():.2f} - ${data['close'].max():.2f}")
-
- # Example 5: Health check
- logger.info("\n=== Example 5: Health Check ===")
- health = dp.health_check()
- logger.info(f"Data maintenance active: {health['data_maintenance_active']}")
- logger.info(f"Symbols: {health['symbols']}")
- logger.info(f"Timeframes: {health['timeframes']}")
-
- # Example 6: Wait and show automatic updates
- logger.info("\n=== Example 6: Automatic Updates ===")
- logger.info("Waiting 30 seconds to show automatic data updates...")
-
- # Get initial timestamp
- initial_data = dp.get_historical_data('ETH/USDT', '1s', limit=1)
- initial_time = initial_data.index[-1] if initial_data is not None else None
-
- time.sleep(30)
-
- # Check if data was updated
- updated_data = dp.get_historical_data('ETH/USDT', '1s', limit=1)
- updated_time = updated_data.index[-1] if updated_data is not None else None
-
- if initial_time and updated_time and updated_time > initial_time:
- logger.info(f"✅ Data automatically updated! New timestamp: {updated_time}")
- else:
- logger.info("⏳ Data update in progress...")
-
- # Clean shutdown
- logger.info("\n=== Shutting Down ===")
- dp.stop_automatic_data_maintenance()
- logger.info("DataProvider stopped successfully")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/kill_stale_processes.py b/kill_stale_processes.py
deleted file mode 100644
index de74b8c..0000000
--- a/kill_stale_processes.py
+++ /dev/null
@@ -1,331 +0,0 @@
-"""
-Kill Stale Processes
-
-This script identifies and kills stale Python processes that might be causing
-the dashboard startup freeze. It looks for:
-1. Hanging dashboard processes
-2. Stale COB data collection threads
-3. Matplotlib GUI processes
-4. Blocked network connections
-
-Usage:
- python kill_stale_processes.py
-"""
-
-import os
-import sys
-import psutil
-import signal
-import time
-from datetime import datetime
-
-def find_python_processes():
- """Find all Python processes"""
- python_processes = []
-
- try:
- for proc in psutil.process_iter(['pid', 'name', 'cmdline', 'create_time', 'status']):
- try:
- if proc.info['name'] and 'python' in proc.info['name'].lower():
- # Get command line to identify dashboard processes
- cmdline = ' '.join(proc.info['cmdline']) if proc.info['cmdline'] else ''
-
- python_processes.append({
- 'pid': proc.info['pid'],
- 'name': proc.info['name'],
- 'cmdline': cmdline,
- 'create_time': proc.info['create_time'],
- 'status': proc.info['status'],
- 'process': proc
- })
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- continue
-
- except Exception as e:
- print(f"Error finding Python processes: {e}")
-
- return python_processes
-
-def identify_dashboard_processes(python_processes):
- """Identify processes related to the dashboard"""
- dashboard_processes = []
-
- dashboard_keywords = [
- 'clean_dashboard',
- 'run_clean_dashboard',
- 'dashboard',
- 'trading',
- 'cob_data',
- 'orchestrator',
- 'data_provider'
- ]
-
- for proc_info in python_processes:
- cmdline = proc_info['cmdline'].lower()
-
- # Check if this is a dashboard-related process
- is_dashboard = any(keyword in cmdline for keyword in dashboard_keywords)
-
- if is_dashboard:
- dashboard_processes.append(proc_info)
-
- return dashboard_processes
-
-def identify_stale_processes(python_processes):
- """Identify potentially stale processes"""
- stale_processes = []
- current_time = time.time()
-
- for proc_info in python_processes:
- try:
- proc = proc_info['process']
-
- # Check if process is in a problematic state
- if proc_info['status'] in ['zombie', 'stopped']:
- stale_processes.append({
- **proc_info,
- 'reason': f"Process status: {proc_info['status']}"
- })
- continue
-
- # Check if process has been running for a very long time without activity
- age_hours = (current_time - proc_info['create_time']) / 3600
- if age_hours > 24: # Running for more than 24 hours
- try:
- # Check CPU usage
- cpu_percent = proc.cpu_percent(interval=1)
- if cpu_percent < 0.1: # Very low CPU usage
- stale_processes.append({
- **proc_info,
- 'reason': f"Old process ({age_hours:.1f}h) with low CPU usage ({cpu_percent:.1f}%)"
- })
- except:
- pass
-
- # Check for processes with high memory usage but no activity
- try:
- memory_info = proc.memory_info()
- memory_mb = memory_info.rss / 1024 / 1024
-
- if memory_mb > 500: # More than 500MB
- cpu_percent = proc.cpu_percent(interval=1)
- if cpu_percent < 0.1:
- stale_processes.append({
- **proc_info,
- 'reason': f"High memory usage ({memory_mb:.1f}MB) with low CPU usage ({cpu_percent:.1f}%)"
- })
- except:
- pass
-
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- continue
-
- return stale_processes
-
-def kill_process_safely(proc_info, force=False):
- """Kill a process safely"""
- try:
- proc = proc_info['process']
- pid = proc_info['pid']
-
- print(f"Attempting to {'force kill' if force else 'terminate'} PID {pid}: {proc_info['name']}")
-
- if force:
- # Force kill
- if os.name == 'nt': # Windows
- os.system(f"taskkill /F /PID {pid}")
- else: # Unix/Linux
- os.kill(pid, signal.SIGKILL)
- else:
- # Graceful termination
- proc.terminate()
-
- # Wait for termination
- try:
- proc.wait(timeout=5)
- print(f"✅ Process {pid} terminated gracefully")
- return True
- except psutil.TimeoutExpired:
- print(f"⚠️ Process {pid} didn't terminate gracefully, will force kill")
- return False
-
- print(f"✅ Process {pid} killed")
- return True
-
- except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
- print(f"⚠️ Could not kill process {proc_info['pid']}: {e}")
- return False
- except Exception as e:
- print(f"❌ Error killing process {proc_info['pid']}: {e}")
- return False
-
-def check_port_usage():
- """Check if dashboard port is in use"""
- try:
- import socket
-
- # Check if port 8050 is in use
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- result = sock.connect_ex(('localhost', 8050))
- sock.close()
-
- if result == 0:
- print("⚠️ Port 8050 is in use")
-
- # Find process using the port
- for conn in psutil.net_connections():
- if conn.laddr.port == 8050:
- try:
- proc = psutil.Process(conn.pid)
- print(f" Port 8050 used by PID {conn.pid}: {proc.name()}")
- return conn.pid
- except:
- pass
- else:
- print("✅ Port 8050 is available")
- return None
-
- except Exception as e:
- print(f"Error checking port usage: {e}")
- return None
-
-def main():
- """Main function"""
- print("🔍 Stale Process Killer")
- print("=" * 50)
-
- try:
- # Step 1: Find all Python processes
- print("🔍 Finding Python processes...")
- python_processes = find_python_processes()
- print(f"Found {len(python_processes)} Python processes")
-
- # Step 2: Identify dashboard processes
- print("\n🎯 Identifying dashboard processes...")
- dashboard_processes = identify_dashboard_processes(python_processes)
-
- if dashboard_processes:
- print(f"Found {len(dashboard_processes)} dashboard-related processes:")
- for proc in dashboard_processes:
- age_hours = (time.time() - proc['create_time']) / 3600
- print(f" PID {proc['pid']}: {proc['name']} (age: {age_hours:.1f}h, status: {proc['status']})")
- print(f" Command: {proc['cmdline'][:100]}...")
- else:
- print("No dashboard processes found")
-
- # Step 3: Check port usage
- print("\n🌐 Checking port usage...")
- port_pid = check_port_usage()
-
- # Step 4: Identify stale processes
- print("\n🕵️ Identifying stale processes...")
- stale_processes = identify_stale_processes(python_processes)
-
- if stale_processes:
- print(f"Found {len(stale_processes)} potentially stale processes:")
- for proc in stale_processes:
- print(f" PID {proc['pid']}: {proc['name']} - {proc['reason']}")
- else:
- print("No stale processes identified")
-
- # Step 5: Ask user what to do
- if dashboard_processes or stale_processes or port_pid:
- print("\n🤔 What would you like to do?")
- print("1. Kill all dashboard processes")
- print("2. Kill only stale processes")
- print("3. Kill process using port 8050")
- print("4. Kill all identified processes")
- print("5. Show process details and exit")
- print("6. Exit without killing anything")
-
- try:
- choice = input("\nEnter your choice (1-6): ").strip()
-
- if choice == '1':
- # Kill dashboard processes
- print("\n🔫 Killing dashboard processes...")
- for proc in dashboard_processes:
- if not kill_process_safely(proc):
- kill_process_safely(proc, force=True)
-
- elif choice == '2':
- # Kill stale processes
- print("\n🔫 Killing stale processes...")
- for proc in stale_processes:
- if not kill_process_safely(proc):
- kill_process_safely(proc, force=True)
-
- elif choice == '3':
- # Kill process using port 8050
- if port_pid:
- print(f"\n🔫 Killing process using port 8050 (PID {port_pid})...")
- try:
- proc = psutil.Process(port_pid)
- proc_info = {
- 'pid': port_pid,
- 'name': proc.name(),
- 'process': proc
- }
- if not kill_process_safely(proc_info):
- kill_process_safely(proc_info, force=True)
- except:
- print(f"❌ Could not kill process {port_pid}")
- else:
- print("No process found using port 8050")
-
- elif choice == '4':
- # Kill all identified processes
- print("\n🔫 Killing all identified processes...")
- all_processes = dashboard_processes + stale_processes
- if port_pid:
- try:
- proc = psutil.Process(port_pid)
- all_processes.append({
- 'pid': port_pid,
- 'name': proc.name(),
- 'process': proc
- })
- except:
- pass
-
- for proc in all_processes:
- if not kill_process_safely(proc):
- kill_process_safely(proc, force=True)
-
- elif choice == '5':
- # Show details
- print("\n📋 Process Details:")
- all_processes = dashboard_processes + stale_processes
- for proc in all_processes:
- print(f"\nPID {proc['pid']}: {proc['name']}")
- print(f" Status: {proc['status']}")
- print(f" Command: {proc['cmdline']}")
- print(f" Created: {datetime.fromtimestamp(proc['create_time'])}")
-
- elif choice == '6':
- print("👋 Exiting without killing processes")
-
- else:
- print("❌ Invalid choice")
-
- except KeyboardInterrupt:
- print("\n👋 Cancelled by user")
- else:
- print("\n✅ No problematic processes found")
-
- print("\n" + "=" * 50)
- print("💡 After killing processes, you can try:")
- print(" python run_lightweight_dashboard.py")
- print(" or")
- print(" python fix_startup_freeze.py")
-
- return True
-
- except Exception as e:
- print(f"❌ Error in main function: {e}")
- return False
-
-if __name__ == "__main__":
- success = main()
- if not success:
- sys.exit(1)
\ No newline at end of file
diff --git a/launch_training.py b/launch_training.py
deleted file mode 100644
index 5eb92e6..0000000
--- a/launch_training.py
+++ /dev/null
@@ -1,41 +0,0 @@
-"""
-Launch training with optimized short-term models only
-"""
-
-import os
-import sys
-from pathlib import Path
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import load_config
-from core.training import TrainingManager
-from core.models import OptimizedShortTermModel
-
-def main():
- """Main training function using only optimized models"""
- config = load_config()
-
- # Initialize model
- model = OptimizedShortTermModel()
-
- # Load best model if exists
- best_model_path = config.model_paths.get('ticks_model')
- if os.path.exists(best_model_path):
- model.load_state_dict(torch.load(best_model_path))
-
- # Initialize training
- trainer = TrainingManager(
- model=model,
- config=config,
- use_ticks=True,
- use_realtime=True
- )
-
- # Start training
- trainer.train()
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/main.py b/main.py
deleted file mode 100644
index ff3a43d..0000000
--- a/main.py
+++ /dev/null
@@ -1,458 +0,0 @@
-#!/usr/bin/env python3
-"""
-Streamlined Trading System - Web Dashboard + Training
-
-Integrated system with both training loop and web dashboard:
-- Training Pipeline: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution
-- Web Dashboard: Real-time monitoring and control interface
-- 2-Action System: BUY/SELL with intelligent position management
-- Always invested approach with smart risk/reward setup detection
-
-Usage:
- python main.py [--symbol ETH/USDT] [--port 8050]
-"""
-
-import os
-# Fix OpenMP library conflicts before importing other modules
-os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
-os.environ['OMP_NUM_THREADS'] = '4'
-
-import asyncio
-import argparse
-import logging
-import sys
-from pathlib import Path
-from threading import Thread
-import time
-from safe_logging import setup_safe_logging
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import get_config, setup_logging, Config
-from core.data_provider import DataProvider
-
-# Import checkpoint management
-from utils.checkpoint_manager import get_checkpoint_manager
-from utils.training_integration import get_training_integration
-
-logger = logging.getLogger(__name__)
-
-async def run_web_dashboard():
- """Run the streamlined web dashboard with 2-action system and always-invested approach"""
- try:
- logger.info("Starting Streamlined Trading Dashboard...")
- logger.info("2-Action System: BUY/SELL with intelligent position management")
- logger.info("Always Invested Approach: Smart risk/reward setup detection")
- logger.info("Integrated Training Pipeline: Live data -> Models -> Trading")
-
- # Get configuration
- config = get_config()
-
- # Initialize core components for streamlined pipeline
- from core.standardized_data_provider import StandardizedDataProvider
- from core.orchestrator import TradingOrchestrator
- from core.trading_executor import TradingExecutor
-
- # Create standardized data provider (validated BaseDataInput, pivots, COB)
- data_provider = StandardizedDataProvider()
-
- # Start real-time streaming for BOM caching
- try:
- await data_provider.start_real_time_streaming()
- logger.info("[SUCCESS] Real-time data streaming started for BOM caching")
- except Exception as e:
- logger.warning(f"[WARNING] Real-time streaming failed: {e}")
-
- # Verify data connection with retry mechanism
- logger.info("[DATA] Verifying live data connection...")
- symbol = config.get('symbols', ['ETH/USDT'])[0]
-
- # Wait for data provider to initialize and fetch initial data
- max_retries = 10
- retry_delay = 2
-
- for attempt in range(max_retries):
- test_df = data_provider.get_historical_data(symbol, '1m', limit=10)
- if test_df is not None and len(test_df) > 0:
- logger.info("[SUCCESS] Data connection verified")
- logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation")
- break
- else:
- if attempt < max_retries - 1:
- logger.info(f"[DATA] Waiting for data provider to initialize... (attempt {attempt + 1}/{max_retries})")
- await asyncio.sleep(retry_delay)
- else:
- logger.warning("[WARNING] Data connection verification failed, but continuing with system startup")
- logger.warning("The system will attempt to fetch data as needed during operation")
-
- # Load model registry for integrated pipeline
- try:
- from models import get_model_registry
- model_registry = {} # Use simple dict for now
- logger.info("[MODELS] Model registry initialized for training")
- except ImportError:
- model_registry = {}
- logger.warning("Model registry not available, using empty registry")
-
- # Initialize checkpoint management
- checkpoint_manager = get_checkpoint_manager()
- training_integration = get_training_integration()
- logger.info("Checkpoint management initialized for training pipeline")
-
- # Create unified orchestrator with full ML pipeline
- orchestrator = TradingOrchestrator(
- data_provider=data_provider,
- enhanced_rl_training=True,
- model_registry={}
- )
- logger.info("Unified Trading Orchestrator initialized with full ML pipeline")
- logger.info("Data Bus -> Models (DQN + CNN + COB) -> Decision Model -> Trading Signals")
-
- # Checkpoint management will be handled in the training loop
- logger.info("Checkpoint management will be initialized in training loop")
-
- # Unified orchestrator includes COB integration as part of data bus
- logger.info("COB Integration available - feeds into unified data bus")
-
- # Create trading executor for live execution
- trading_executor = TradingExecutor()
-
- # Start the training and monitoring loop
- logger.info(f"Starting Enhanced Training Pipeline")
- logger.info("Live Data Processing: ENABLED")
- logger.info("COB Integration: ENABLED (Real-time market microstructure)")
- logger.info("Integrated CNN Training: ENABLED")
- logger.info("Integrated RL Training: ENABLED")
- logger.info("Real-time Indicators & Pivots: ENABLED")
- logger.info("Live Trading Execution: ENABLED")
- logger.info("2-Action System: BUY/SELL with position intelligence")
- logger.info("Always Invested: Different thresholds for entry/exit")
- logger.info("Pipeline: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
- logger.info("Starting training loop...")
-
- # Start the training loop
- logger.info("About to start training loop...")
- await start_training_loop(orchestrator, trading_executor)
-
- except Exception as e:
- logger.error(f"Error in streamlined dashboard: {e}")
- logger.error("Training stopped")
- import traceback
- logger.error(traceback.format_exc())
-
-def start_web_ui(port=8051):
- """Start the main TradingDashboard UI in a separate thread"""
- try:
- logger.info("=" * 50)
- logger.info("Starting Main Trading Dashboard UI...")
- logger.info(f"Trading Dashboard: http://127.0.0.1:{port}")
- logger.info("COB Integration: ENABLED (Real-time order book visualization)")
- logger.info("=" * 50)
-
- # Import and create the Clean Trading Dashboard
- from web.clean_dashboard import CleanTradingDashboard
- from core.standardized_data_provider import StandardizedDataProvider
- from core.orchestrator import TradingOrchestrator
- from core.trading_executor import TradingExecutor
-
- # Initialize components for the dashboard
- config = get_config()
- data_provider = StandardizedDataProvider()
-
- # Start real-time streaming for BOM caching (non-blocking)
- try:
- import threading
- def start_streaming():
- import asyncio
- asyncio.run(data_provider.start_real_time_streaming())
-
- streaming_thread = threading.Thread(target=start_streaming, daemon=True)
- streaming_thread.start()
- logger.info("[SUCCESS] Real-time streaming thread started for dashboard")
- except Exception as e:
- logger.warning(f"[WARNING] Dashboard streaming setup failed: {e}")
-
- # Load model registry for enhanced features
- try:
- from models import get_model_registry
- model_registry = {} # Use simple dict for now
- except ImportError:
- model_registry = {}
-
- # Initialize checkpoint management for dashboard
- dashboard_checkpoint_manager = get_checkpoint_manager()
- dashboard_training_integration = get_training_integration()
-
- # Create unified orchestrator for the dashboard
- dashboard_orchestrator = TradingOrchestrator(
- data_provider=data_provider,
- enhanced_rl_training=True,
- model_registry={}
- )
-
- trading_executor = TradingExecutor("config.yaml")
-
- # Create the clean trading dashboard with enhanced features
- dashboard = CleanTradingDashboard(
- data_provider=data_provider,
- orchestrator=dashboard_orchestrator,
- trading_executor=trading_executor
- )
-
- logger.info("Clean Trading Dashboard created successfully")
- logger.info("Features: Live trading, COB visualization, ML pipeline monitoring, Position management")
- logger.info("✅ Unified orchestrator with decision-making model and checkpoint management")
-
- # Run the dashboard server (COB integration will start automatically)
- dashboard.run_server(host='127.0.0.1', port=port, debug=False)
-
- except Exception as e:
- logger.error(f"Error starting main trading dashboard UI: {e}")
- import traceback
- logger.error(traceback.format_exc())
-
-async def start_training_loop(orchestrator, trading_executor):
- """Start the main training and monitoring loop with checkpoint management"""
- logger.info("=" * 70)
- logger.info("STARTING ENHANCED TRAINING LOOP WITH COB INTEGRATION")
- logger.info("=" * 70)
-
- logger.info("Training loop function entered successfully")
-
- # Initialize checkpoint management for training loop
- checkpoint_manager = get_checkpoint_manager()
- training_integration = get_training_integration()
-
- # Training statistics for checkpoint management
- training_stats = {
- 'iteration_count': 0,
- 'total_decisions': 0,
- 'successful_trades': 0,
- 'best_performance': 0.0,
- 'last_checkpoint_iteration': 0
- }
-
- try:
- # Start real-time processing (Basic orchestrator doesn't have this method)
- logger.info("Checking for real-time processing capabilities...")
- try:
- if hasattr(orchestrator, 'start_realtime_processing'):
- logger.info("Starting real-time processing...")
- await orchestrator.start_realtime_processing()
- logger.info("Real-time processing started")
- else:
- logger.info("Basic orchestrator - no real-time processing method available")
- except Exception as e:
- logger.warning(f"Real-time processing not available: {e}")
-
- logger.info("About to enter main training loop...")
-
- # Main training loop
- iteration = 0
- while True:
- iteration += 1
- training_stats['iteration_count'] = iteration
-
- logger.info(f"Training iteration {iteration}")
-
- # Make trading decisions using Basic orchestrator (single symbol method)
- decisions = {}
- symbols = ['ETH/USDT'] # Focus on ETH only for training
-
- for symbol in symbols:
- try:
- decision = await orchestrator.make_trading_decision(symbol)
- decisions[symbol] = decision
- except Exception as e:
- logger.warning(f"Error making decision for {symbol}: {e}")
- decisions[symbol] = None
-
- # Process decisions and collect training metrics
- iteration_decisions = 0
- iteration_performance = 0.0
-
- # Log decisions and performance
- for symbol, decision in decisions.items():
- if decision:
- iteration_decisions += 1
- logger.info(f"{symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
-
- # Track performance for checkpoint management
- iteration_performance += decision.confidence
-
- # Execute if confidence is high enough
- if decision.confidence > 0.7:
- logger.info(f"Executing {symbol}: {decision.action}")
- training_stats['successful_trades'] += 1
- # trading_executor.execute_action(decision)
-
- # Update training statistics
- training_stats['total_decisions'] += iteration_decisions
- if iteration_performance > training_stats['best_performance']:
- training_stats['best_performance'] = iteration_performance
-
- # Save checkpoint every 50 iterations or when performance improves significantly
- should_save_checkpoint = (
- iteration % 50 == 0 or # Regular interval
- iteration_performance > training_stats['best_performance'] * 1.1 or # 10% improvement
- iteration - training_stats['last_checkpoint_iteration'] >= 100 # Force save every 100 iterations
- )
-
- if should_save_checkpoint:
- try:
- # Create performance metrics for checkpoint
- performance_metrics = {
- 'avg_confidence': iteration_performance / max(iteration_decisions, 1),
- 'success_rate': training_stats['successful_trades'] / max(training_stats['total_decisions'], 1),
- 'total_decisions': training_stats['total_decisions'],
- 'iteration': iteration
- }
-
- # Save orchestrator state (if it has models)
- if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
- saved = orchestrator.rl_agent.save_checkpoint(iteration_performance)
- if saved:
- logger.info(f"✅ RL Agent checkpoint saved at iteration {iteration}")
-
- if hasattr(orchestrator, 'cnn_model') and orchestrator.cnn_model:
- # Simulate CNN checkpoint save
- logger.info(f"✅ CNN Model training state saved at iteration {iteration}")
-
- if hasattr(orchestrator, 'extrema_trainer') and orchestrator.extrema_trainer:
- saved = orchestrator.extrema_trainer.save_checkpoint()
- if saved:
- logger.info(f"✅ ExtremaTrainer checkpoint saved at iteration {iteration}")
-
- training_stats['last_checkpoint_iteration'] = iteration
- logger.info(f"📊 Checkpoint management completed for iteration {iteration}")
-
- except Exception as e:
- logger.warning(f"Checkpoint saving failed at iteration {iteration}: {e}")
-
- # Log performance metrics every 10 iterations
- if iteration % 10 == 0:
- metrics = orchestrator.get_performance_metrics()
- logger.info(f"Performance metrics: {metrics}")
-
- # Log training statistics
- logger.info(f"Training stats: {training_stats}")
-
- # Log checkpoint statistics
- checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
- logger.info(f"Checkpoints: {checkpoint_stats['total_checkpoints']} total, "
- f"{checkpoint_stats['total_size_mb']:.2f} MB")
-
- # Log COB integration status (Basic orchestrator doesn't have COB features)
- symbols = getattr(orchestrator, 'symbols', ['ETH/USDT'])
- if hasattr(orchestrator, 'latest_cob_features'):
- for symbol in symbols:
- cob_features = orchestrator.latest_cob_features.get(symbol)
- cob_state = orchestrator.latest_cob_state.get(symbol)
- if cob_features is not None:
- logger.info(f"{symbol} COB: CNN features {cob_features.shape}, DQN state {cob_state.shape if cob_state is not None else 'None'}")
- else:
- logger.debug("Basic orchestrator - no COB integration features available")
-
- # Sleep between iterations
- await asyncio.sleep(5) # 5 second intervals
-
- except KeyboardInterrupt:
- logger.info("Training interrupted by user")
- except Exception as e:
- logger.error(f"Error in training loop: {e}")
- import traceback
- logger.error(traceback.format_exc())
- finally:
- # Save final checkpoints before shutdown
- try:
- logger.info("Saving final checkpoints before shutdown...")
-
- if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
- orchestrator.rl_agent.save_checkpoint(0.0, force_save=True)
- logger.info("✅ Final RL Agent checkpoint saved")
-
- if hasattr(orchestrator, 'extrema_trainer') and orchestrator.extrema_trainer:
- orchestrator.extrema_trainer.save_checkpoint(force_save=True)
- logger.info("✅ Final ExtremaTrainer checkpoint saved")
-
- # Log final checkpoint statistics
- final_stats = checkpoint_manager.get_checkpoint_stats()
- logger.info(f"📊 Final checkpoint stats: {final_stats['total_checkpoints']} checkpoints, "
- f"{final_stats['total_size_mb']:.2f} MB total")
-
- except Exception as e:
- logger.warning(f"Error saving final checkpoints: {e}")
-
- # Stop real-time processing (Basic orchestrator doesn't have these methods)
- try:
- if hasattr(orchestrator, 'stop_realtime_processing'):
- await orchestrator.stop_realtime_processing()
- except Exception as e:
- logger.warning(f"Error stopping real-time processing: {e}")
-
- try:
- if hasattr(orchestrator, 'stop_cob_integration'):
- await orchestrator.stop_cob_integration()
- except Exception as e:
- logger.warning(f"Error stopping COB integration: {e}")
- logger.info("Training loop stopped with checkpoint management")
-
-async def main():
- """Main entry point with both training loop and web dashboard"""
- parser = argparse.ArgumentParser(description='Streamlined Trading System - Training + Web Dashboard')
- parser.add_argument('--symbol', type=str, default='ETH/USDT',
- help='Primary trading symbol (default: ETH/USDT)')
- parser.add_argument('--port', type=int, default=8050,
- help='Web dashboard port (default: 8050)')
- parser.add_argument('--debug', action='store_true',
- help='Enable debug mode')
-
- args = parser.parse_args()
-
- # Setup logging and ensure directories exist
- Path("logs").mkdir(exist_ok=True)
- Path("NN/models/saved").mkdir(parents=True, exist_ok=True)
- setup_safe_logging()
-
- try:
- logger.info("=" * 70)
- logger.info("STREAMLINED TRADING SYSTEM - TRAINING + MAIN DASHBOARD")
- logger.info(f"Primary Symbol: {args.symbol}")
- logger.info(f"Training Port: {args.port}")
- logger.info(f"Main Trading Dashboard: http://127.0.0.1:{args.port}")
- logger.info("2-Action System: BUY/SELL with intelligent position management")
- logger.info("Always Invested: Learning to spot high risk/reward setups")
- logger.info("Flow: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
- logger.info("Main Dashboard: Live trading, RL monitoring, Position management")
- logger.info("🔄 Checkpoint Management: Automatic training state persistence")
- # logger.info("📊 W&B Integration: Optional experiment tracking")
- logger.info("💾 Model Rotation: Keep best 5 checkpoints per model")
- logger.info("=" * 70)
-
- # Start main trading dashboard UI in a separate thread
- web_thread = Thread(target=lambda: start_web_ui(args.port), daemon=True)
- web_thread.start()
- logger.info("Main trading dashboard UI thread started")
-
- # Give web UI time to start
- await asyncio.sleep(2)
-
- # Run the training loop (this will run indefinitely)
- await run_web_dashboard()
-
- logger.info("[SUCCESS] Operation completed successfully!")
-
- except KeyboardInterrupt:
- logger.info("System shutdown requested by user")
- except Exception as e:
- logger.error(f"Fatal error: {e}")
- import traceback
- logger.error(traceback.format_exc())
- return 1
-
- return 0
-
-if __name__ == "__main__":
- sys.exit(asyncio.run(main()))
\ No newline at end of file
diff --git a/main_clean.py b/main_clean.py
deleted file mode 100644
index 3965ef2..0000000
--- a/main_clean.py
+++ /dev/null
@@ -1,133 +0,0 @@
-#!/usr/bin/env python3
-"""
-Clean Main Entry Point for Enhanced Trading Dashboard
-
-This is the main entry point that safely launches the clean dashboard
-with proper error handling and optimized settings.
-"""
-
-import os
-import sys
-import logging
-import argparse
-from typing import Optional
-
-# Add project root to path
-sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
-
-# Import core components
-try:
- from core.config import setup_logging
- from core.data_provider import DataProvider
- from core.orchestrator import TradingOrchestrator
- from core.trading_executor import TradingExecutor
- from web.clean_dashboard import create_clean_dashboard
-except ImportError as e:
- print(f"Error importing core modules: {e}")
- sys.exit(1)
-
-logger = logging.getLogger(__name__)
-
-def create_safe_orchestrator() -> Optional[TradingOrchestrator]:
- """Create orchestrator with safe CNN model handling"""
- try:
- # Create orchestrator with basic configuration (uses correct constructor parameters)
- orchestrator = TradingOrchestrator(
- enhanced_rl_training=False # Disable problematic training initially
- )
-
- logger.info("Trading orchestrator created successfully")
- return orchestrator
-
- except Exception as e:
- logger.error(f"Error creating orchestrator: {e}")
- logger.info("Continuing without orchestrator - dashboard will run in view-only mode")
- return None
-
-def create_safe_trading_executor() -> Optional[TradingExecutor]:
- """Create trading executor with safe configuration"""
- try:
- # TradingExecutor only accepts config_path parameter
- trading_executor = TradingExecutor(config_path="config.yaml")
-
- logger.info("Trading executor created successfully")
- return trading_executor
-
- except Exception as e:
- logger.error(f"Error creating trading executor: {e}")
- logger.info("Continuing without trading executor - dashboard will be view-only")
- return None
-
-def main():
- """Main entry point for clean dashboard"""
- parser = argparse.ArgumentParser(description='Enhanced Trading Dashboard')
- parser.add_argument('--port', type=int, default=8050, help='Dashboard port (default: 8050)')
- parser.add_argument('--host', type=str, default='127.0.0.1', help='Dashboard host (default: 127.0.0.1)')
- parser.add_argument('--debug', action='store_true', help='Enable debug mode')
- parser.add_argument('--no-training', action='store_true', help='Disable ML training for stability')
-
- args = parser.parse_args()
-
- # Setup logging
- try:
- setup_logging()
- logger.info("================================================================================")
- logger.info("CLEAN ENHANCED TRADING DASHBOARD")
- logger.info("================================================================================")
- logger.info(f"Starting on http://{args.host}:{args.port}")
- logger.info("Features: Real-time Charts, Trading Interface, Model Monitoring")
- logger.info("================================================================================")
- except Exception as e:
- print(f"Error setting up logging: {e}")
- # Continue without logging setup
-
- # Set environment variables for optimization
- os.environ['ENABLE_REALTIME_CHARTS'] = '1'
- if not args.no_training:
- os.environ['ENABLE_NN_MODELS'] = '1'
-
- try:
- # Create data provider
- logger.info("Initializing data provider...")
- data_provider = DataProvider(symbols=['ETH/USDT', 'BTC/USDT'])
-
- # Create orchestrator (with safe CNN handling)
- logger.info("Initializing trading orchestrator...")
- orchestrator = create_safe_orchestrator()
-
- # Create trading executor
- logger.info("Initializing trading executor...")
- trading_executor = create_safe_trading_executor()
-
- # Create and run dashboard
- logger.info("Creating clean dashboard...")
- dashboard = create_clean_dashboard(
- data_provider=data_provider,
- orchestrator=orchestrator,
- trading_executor=trading_executor
- )
-
- # Start the dashboard server
- logger.info(f"Starting dashboard server on http://{args.host}:{args.port}")
- dashboard.run_server(
- host=args.host,
- port=args.port,
- debug=args.debug
- )
-
- except KeyboardInterrupt:
- logger.info("Dashboard stopped by user")
- except Exception as e:
- logger.error(f"Error running dashboard: {e}")
-
- # Try to provide helpful error message
- if "model.fit" in str(e) or "CNN" in str(e):
- logger.error("CNN model training error detected. Try running with --no-training flag")
- logger.error("Command: python main_clean.py --no-training")
-
- sys.exit(1)
- finally:
- logger.info("Clean dashboard shutdown complete")
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/migrate_existing_models.py b/migrate_existing_models.py
deleted file mode 100644
index 0950b37..0000000
--- a/migrate_existing_models.py
+++ /dev/null
@@ -1,204 +0,0 @@
-#!/usr/bin/env python3
-"""
-Migrate Existing Models to Checkpoint System
-
-This script migrates existing model files to the new checkpoint system
-and creates proper database metadata entries.
-"""
-
-import os
-import shutil
-import logging
-from datetime import datetime
-from pathlib import Path
-from utils.database_manager import get_database_manager, CheckpointMetadata
-from utils.checkpoint_manager import save_checkpoint
-from utils.text_logger import get_text_logger
-
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-def migrate_existing_models():
- """Migrate existing models to checkpoint system"""
- print("=== Migrating Existing Models to Checkpoint System ===")
-
- db_manager = get_database_manager()
- text_logger = get_text_logger()
-
- # Define model migrations
- migrations = [
- {
- 'model_name': 'enhanced_cnn',
- 'model_type': 'cnn',
- 'source_file': 'models/enhanced_cnn/ETH_USDT_cnn.pth',
- 'performance_metrics': {'loss': 0.0187, 'accuracy': 0.92},
- 'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True}
- },
- {
- 'model_name': 'dqn_agent',
- 'model_type': 'rl',
- 'source_file': 'models/enhanced_rl/ETH_USDT_dqn_policy.pth',
- 'performance_metrics': {'loss': 0.0234, 'reward': 145.2},
- 'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True, 'type': 'policy'}
- },
- {
- 'model_name': 'dqn_agent_target',
- 'model_type': 'rl',
- 'source_file': 'models/enhanced_rl/ETH_USDT_dqn_target.pth',
- 'performance_metrics': {'loss': 0.0234, 'reward': 145.2},
- 'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True, 'type': 'target'}
- }
- ]
-
- migrated_count = 0
-
- for migration in migrations:
- source_path = Path(migration['source_file'])
-
- if not source_path.exists():
- logger.warning(f"Source file not found: {source_path}")
- continue
-
- try:
- # Create checkpoint directory
- checkpoint_dir = Path("models/checkpoints") / migration['model_name']
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
-
- # Create checkpoint filename
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- checkpoint_id = f"{migration['model_name']}_{timestamp}"
- checkpoint_file = checkpoint_dir / f"{checkpoint_id}.pt"
-
- # Copy model file to checkpoint location
- shutil.copy2(source_path, checkpoint_file)
- logger.info(f"Copied {source_path} -> {checkpoint_file}")
-
- # Calculate file size
- file_size_mb = checkpoint_file.stat().st_size / (1024 * 1024)
-
- # Create checkpoint metadata
- metadata = CheckpointMetadata(
- checkpoint_id=checkpoint_id,
- model_name=migration['model_name'],
- model_type=migration['model_type'],
- timestamp=datetime.now(),
- performance_metrics=migration['performance_metrics'],
- training_metadata=migration['training_metadata'],
- file_path=str(checkpoint_file),
- file_size_mb=file_size_mb,
- is_active=True
- )
-
- # Save to database
- if db_manager.save_checkpoint_metadata(metadata):
- logger.info(f"Saved checkpoint metadata: {checkpoint_id}")
-
- # Log to text file
- text_logger.log_checkpoint_event(
- model_name=migration['model_name'],
- event_type="MIGRATED",
- checkpoint_id=checkpoint_id,
- details=f"from {source_path}, size={file_size_mb:.1f}MB"
- )
-
- migrated_count += 1
- else:
- logger.error(f"Failed to save checkpoint metadata: {checkpoint_id}")
-
- except Exception as e:
- logger.error(f"Failed to migrate {migration['model_name']}: {e}")
-
- print(f"\nMigration completed: {migrated_count} models migrated")
-
- # Show current checkpoint status
- print("\n=== Current Checkpoint Status ===")
- for model_name in ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target']:
- checkpoints = db_manager.list_checkpoints(model_name)
- if checkpoints:
- print(f"{model_name}: {len(checkpoints)} checkpoints")
- for checkpoint in checkpoints[:2]: # Show first 2
- print(f" - {checkpoint.checkpoint_id} ({checkpoint.file_size_mb:.1f}MB)")
- else:
- print(f"{model_name}: No checkpoints")
-
-def verify_checkpoint_system():
- """Verify the checkpoint system is working"""
- print("\n=== Verifying Checkpoint System ===")
-
- db_manager = get_database_manager()
-
- # Test loading checkpoints
- for model_name in ['dqn_agent', 'enhanced_cnn']:
- metadata = db_manager.get_best_checkpoint_metadata(model_name)
- if metadata:
- file_exists = Path(metadata.file_path).exists()
- print(f"{model_name}: ✅ Metadata found, File exists: {file_exists}")
- if file_exists:
- print(f" -> {metadata.checkpoint_id} ({metadata.file_size_mb:.1f}MB)")
- else:
- print(f" -> ERROR: File missing: {metadata.file_path}")
- else:
- print(f"{model_name}: ❌ No checkpoint metadata found")
-
-def create_test_checkpoint():
- """Create a test checkpoint to verify saving works"""
- print("\n=== Testing Checkpoint Saving ===")
-
- try:
- import torch
- import torch.nn as nn
-
- # Create a simple test model
- class TestModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.linear = nn.Linear(10, 1)
-
- def forward(self, x):
- return self.linear(x)
-
- test_model = TestModel()
-
- # Save using the checkpoint system
- from utils.checkpoint_manager import save_checkpoint
-
- result = save_checkpoint(
- model=test_model,
- model_name="test_model",
- model_type="test",
- performance_metrics={"loss": 0.1, "accuracy": 0.95},
- training_metadata={"test": True, "created": datetime.now().isoformat()}
- )
-
- if result:
- print(f"✅ Test checkpoint saved successfully: {result.checkpoint_id}")
-
- # Verify it exists
- db_manager = get_database_manager()
- metadata = db_manager.get_best_checkpoint_metadata("test_model")
- if metadata and Path(metadata.file_path).exists():
- print(f"✅ Test checkpoint verified: {metadata.file_path}")
-
- # Clean up test checkpoint
- Path(metadata.file_path).unlink()
- print("🧹 Test checkpoint cleaned up")
- else:
- print("❌ Test checkpoint verification failed")
- else:
- print("❌ Test checkpoint saving failed")
-
- except Exception as e:
- print(f"❌ Test checkpoint creation failed: {e}")
-
-def main():
- """Main migration process"""
- migrate_existing_models()
- verify_checkpoint_system()
- create_test_checkpoint()
-
- print("\n=== Migration Complete ===")
- print("The checkpoint system should now work properly!")
- print("Existing models have been migrated and the system is ready for use.")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/model_manager.py b/model_manager.py
deleted file mode 100644
index b09ddfc..0000000
--- a/model_manager.py
+++ /dev/null
@@ -1,558 +0,0 @@
-"""
-Enhanced Model Management System for Trading Dashboard
-
-This system provides:
-- Automatic cleanup of old model checkpoints
-- Best model tracking with performance metrics
-- Configurable retention policies
-- Startup model loading
-- Performance-based model selection
-"""
-
-import os
-import json
-import shutil
-import logging
-import torch
-import glob
-from datetime import datetime, timedelta
-from typing import Dict, List, Optional, Tuple, Any
-from dataclasses import dataclass, asdict
-from pathlib import Path
-import numpy as np
-
-logger = logging.getLogger(__name__)
-
-@dataclass
-class ModelMetrics:
- """Performance metrics for model evaluation"""
- accuracy: float = 0.0
- profit_factor: float = 0.0
- win_rate: float = 0.0
- sharpe_ratio: float = 0.0
- max_drawdown: float = 0.0
- total_trades: int = 0
- avg_trade_duration: float = 0.0
- confidence_score: float = 0.0
-
- def get_composite_score(self) -> float:
- """Calculate composite performance score"""
- # Weighted composite score
- weights = {
- 'profit_factor': 0.3,
- 'sharpe_ratio': 0.25,
- 'win_rate': 0.2,
- 'accuracy': 0.15,
- 'confidence_score': 0.1
- }
-
- # Normalize values to 0-1 range
- normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0
- normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1
- normalized_win_rate = self.win_rate
- normalized_accuracy = self.accuracy
- normalized_confidence = self.confidence_score
-
- # Apply penalties for poor performance
- drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown
-
- score = (
- weights['profit_factor'] * normalized_pf +
- weights['sharpe_ratio'] * normalized_sharpe +
- weights['win_rate'] * normalized_win_rate +
- weights['accuracy'] * normalized_accuracy +
- weights['confidence_score'] * normalized_confidence
- ) * drawdown_penalty
-
- return min(max(score, 0), 1)
-
-@dataclass
-class ModelInfo:
- """Complete model information and metadata"""
- model_type: str # 'cnn', 'rl', 'transformer'
- model_name: str
- file_path: str
- creation_time: datetime
- last_updated: datetime
- file_size_mb: float
- metrics: ModelMetrics
- training_episodes: int = 0
- model_version: str = "1.0"
-
- def to_dict(self) -> Dict[str, Any]:
- """Convert to dictionary for JSON serialization"""
- data = asdict(self)
- data['creation_time'] = self.creation_time.isoformat()
- data['last_updated'] = self.last_updated.isoformat()
- return data
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
- """Create from dictionary"""
- data['creation_time'] = datetime.fromisoformat(data['creation_time'])
- data['last_updated'] = datetime.fromisoformat(data['last_updated'])
- data['metrics'] = ModelMetrics(**data['metrics'])
- return cls(**data)
-
-class ModelManager:
- """Enhanced model management system"""
-
- def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None):
- self.base_dir = Path(base_dir)
- self.config = config or self._get_default_config()
-
- # Model directories
- self.models_dir = self.base_dir / "models"
- self.nn_models_dir = self.base_dir / "NN" / "models"
- self.registry_file = self.models_dir / "model_registry.json"
- self.best_models_dir = self.models_dir / "best_models"
-
- # Create directories
- self.best_models_dir.mkdir(parents=True, exist_ok=True)
-
- # Model registry
- self.model_registry: Dict[str, ModelInfo] = {}
- self._load_registry()
-
- logger.info(f"Model Manager initialized - Base: {self.base_dir}")
- logger.info(f"Retention policy: Keep {self.config['max_models_per_type']} best models per type")
-
- def _get_default_config(self) -> Dict[str, Any]:
- """Get default configuration"""
- return {
- 'max_models_per_type': 3, # Keep top 3 models per type
- 'max_total_models': 10, # Maximum total models to keep
- 'cleanup_frequency_hours': 24, # Cleanup every 24 hours
- 'min_performance_threshold': 0.3, # Minimum composite score
- 'max_checkpoint_age_days': 7, # Delete checkpoints older than 7 days
- 'auto_cleanup_enabled': True,
- 'backup_before_cleanup': True,
- 'model_size_limit_mb': 100, # Individual model size limit
- 'total_storage_limit_gb': 5.0 # Total storage limit
- }
-
- def _load_registry(self):
- """Load model registry from file"""
- try:
- if self.registry_file.exists():
- with open(self.registry_file, 'r') as f:
- data = json.load(f)
- self.model_registry = {
- k: ModelInfo.from_dict(v) for k, v in data.items()
- }
- logger.info(f"Loaded {len(self.model_registry)} models from registry")
- else:
- logger.info("No existing model registry found")
- except Exception as e:
- logger.error(f"Error loading model registry: {e}")
- self.model_registry = {}
-
- def _save_registry(self):
- """Save model registry to file"""
- try:
- self.models_dir.mkdir(parents=True, exist_ok=True)
- with open(self.registry_file, 'w') as f:
- data = {k: v.to_dict() for k, v in self.model_registry.items()}
- json.dump(data, f, indent=2, default=str)
- logger.info(f"Saved registry with {len(self.model_registry)} models")
- except Exception as e:
- logger.error(f"Error saving model registry: {e}")
-
- def cleanup_all_existing_models(self, confirm: bool = False) -> Dict[str, Any]:
- """
- Clean up all existing model files and prepare for 2-action system training
-
- Args:
- confirm: If True, perform the cleanup. If False, return what would be cleaned
-
- Returns:
- Dict with cleanup statistics
- """
- cleanup_stats = {
- 'files_found': 0,
- 'files_deleted': 0,
- 'directories_cleaned': 0,
- 'space_freed_mb': 0.0,
- 'errors': []
- }
-
- # Model file patterns for both 2-action and legacy 3-action systems
- model_patterns = [
- "**/*.pt", "**/*.pth", "**/*.h5", "**/*.pkl", "**/*.joblib", "**/*.model",
- "**/checkpoint_*", "**/model_*", "**/cnn_*", "**/dqn_*", "**/rl_*"
- ]
-
- # Directories to clean
- model_directories = [
- "models/saved",
- "NN/models/saved",
- "NN/models/saved/checkpoints",
- "NN/models/saved/realtime_checkpoints",
- "NN/models/saved/realtime_ticks_checkpoints",
- "model_backups"
- ]
-
- try:
- # Scan for files to be cleaned
- for directory in model_directories:
- dir_path = Path(self.base_dir) / directory
- if dir_path.exists():
- for pattern in model_patterns:
- for file_path in dir_path.glob(pattern):
- if file_path.is_file():
- cleanup_stats['files_found'] += 1
- file_size = file_path.stat().st_size / (1024 * 1024) # MB
- cleanup_stats['space_freed_mb'] += file_size
-
- if confirm:
- try:
- file_path.unlink()
- cleanup_stats['files_deleted'] += 1
- logger.info(f"Deleted model file: {file_path}")
- except Exception as e:
- cleanup_stats['errors'].append(f"Failed to delete {file_path}: {e}")
-
- # Clean up empty checkpoint directories
- for directory in model_directories:
- dir_path = Path(self.base_dir) / directory
- if dir_path.exists():
- for subdir in dir_path.rglob("*"):
- if subdir.is_dir() and not any(subdir.iterdir()):
- if confirm:
- try:
- subdir.rmdir()
- cleanup_stats['directories_cleaned'] += 1
- logger.info(f"Removed empty directory: {subdir}")
- except Exception as e:
- cleanup_stats['errors'].append(f"Failed to remove directory {subdir}: {e}")
-
- if confirm:
- # Clear the registry for fresh start with 2-action system
- self.model_registry = {
- 'models': {},
- 'metadata': {
- 'last_updated': datetime.now().isoformat(),
- 'total_models': 0,
- 'system_type': '2_action', # Mark as 2-action system
- 'action_space': ['SELL', 'BUY'],
- 'version': '2.0'
- }
- }
- self._save_registry()
-
- logger.info("=" * 60)
- logger.info("MODEL CLEANUP COMPLETED - 2-ACTION SYSTEM READY")
- logger.info(f"Files deleted: {cleanup_stats['files_deleted']}")
- logger.info(f"Space freed: {cleanup_stats['space_freed_mb']:.2f} MB")
- logger.info(f"Directories cleaned: {cleanup_stats['directories_cleaned']}")
- logger.info("Registry reset for 2-action system (BUY/SELL)")
- logger.info("Ready for fresh training with intelligent position management")
- logger.info("=" * 60)
- else:
- logger.info("=" * 60)
- logger.info("MODEL CLEANUP PREVIEW - 2-ACTION SYSTEM MIGRATION")
- logger.info(f"Files to delete: {cleanup_stats['files_found']}")
- logger.info(f"Space to free: {cleanup_stats['space_freed_mb']:.2f} MB")
- logger.info("Run with confirm=True to perform cleanup")
- logger.info("=" * 60)
-
- except Exception as e:
- cleanup_stats['errors'].append(f"Cleanup error: {e}")
- logger.error(f"Error during model cleanup: {e}")
-
- return cleanup_stats
-
- def register_model(self, model_path: str, model_type: str, metrics: Optional[ModelMetrics] = None) -> str:
- """
- Register a new model in the 2-action system
-
- Args:
- model_path: Path to the model file
- model_type: Type of model ('cnn', 'rl', 'transformer')
- metrics: Performance metrics
-
- Returns:
- str: Unique model name/ID
- """
- if not Path(model_path).exists():
- raise FileNotFoundError(f"Model file not found: {model_path}")
-
- # Generate unique model name
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- model_name = f"{model_type}_2action_{timestamp}"
-
- # Get file info
- file_path = Path(model_path)
- file_size_mb = file_path.stat().st_size / (1024 * 1024)
-
- # Default metrics for 2-action system
- if metrics is None:
- metrics = ModelMetrics(
- accuracy=0.0,
- profit_factor=1.0,
- win_rate=0.5,
- sharpe_ratio=0.0,
- max_drawdown=0.0,
- confidence_score=0.5
- )
-
- # Create model info
- model_info = ModelInfo(
- model_type=model_type,
- model_name=model_name,
- file_path=str(file_path.absolute()),
- creation_time=datetime.now(),
- last_updated=datetime.now(),
- file_size_mb=file_size_mb,
- metrics=metrics,
- model_version="2.0" # 2-action system version
- )
-
- # Add to registry
- self.model_registry['models'][model_name] = model_info.to_dict()
- self.model_registry['metadata']['total_models'] = len(self.model_registry['models'])
- self.model_registry['metadata']['last_updated'] = datetime.now().isoformat()
- self.model_registry['metadata']['system_type'] = '2_action'
- self.model_registry['metadata']['action_space'] = ['SELL', 'BUY']
-
- self._save_registry()
-
- # Cleanup old models if necessary
- self._cleanup_models_by_type(model_type)
-
- logger.info(f"Registered 2-action model: {model_name}")
- logger.info(f"Model type: {model_type}, Size: {file_size_mb:.2f} MB")
- logger.info(f"Performance score: {metrics.get_composite_score():.4f}")
-
- return model_name
-
- def _should_keep_model(self, model_info: ModelInfo) -> bool:
- """Determine if model should be kept based on performance"""
- score = model_info.metrics.get_composite_score()
-
- # Check minimum threshold
- if score < self.config['min_performance_threshold']:
- return False
-
- # Check size limit
- if model_info.file_size_mb > self.config['model_size_limit_mb']:
- logger.warning(f"Model too large: {model_info.file_size_mb:.1f}MB > {self.config['model_size_limit_mb']}MB")
- return False
-
- # Check if better than existing models of same type
- existing_models = self.get_models_by_type(model_info.model_type)
- if len(existing_models) >= self.config['max_models_per_type']:
- # Find worst performing model
- worst_model = min(existing_models.values(), key=lambda m: m.metrics.get_composite_score())
- if score <= worst_model.metrics.get_composite_score():
- return False
-
- return True
-
- def _cleanup_models_by_type(self, model_type: str):
- """Cleanup old models of specific type, keeping only the best ones"""
- models_of_type = self.get_models_by_type(model_type)
- max_keep = self.config['max_models_per_type']
-
- if len(models_of_type) <= max_keep:
- return
-
- # Sort by performance score
- sorted_models = sorted(
- models_of_type.items(),
- key=lambda x: x[1].metrics.get_composite_score(),
- reverse=True
- )
-
- # Keep only the best models
- models_to_keep = sorted_models[:max_keep]
- models_to_remove = sorted_models[max_keep:]
-
- for model_name, model_info in models_to_remove:
- try:
- # Remove file
- model_path = Path(model_info.file_path)
- if model_path.exists():
- model_path.unlink()
-
- # Remove from registry
- del self.model_registry[model_name]
-
- logger.info(f"Removed old model: {model_name} (Score: {model_info.metrics.get_composite_score():.3f})")
-
- except Exception as e:
- logger.error(f"Error removing model {model_name}: {e}")
-
- def get_models_by_type(self, model_type: str) -> Dict[str, ModelInfo]:
- """Get all models of a specific type"""
- return {
- name: info for name, info in self.model_registry.items()
- if info.model_type == model_type
- }
-
- def get_best_model(self, model_type: str) -> Optional[ModelInfo]:
- """Get the best performing model of a specific type"""
- models_of_type = self.get_models_by_type(model_type)
-
- if not models_of_type:
- return None
-
- return max(models_of_type.values(), key=lambda m: m.metrics.get_composite_score())
-
- def load_best_models(self) -> Dict[str, Any]:
- """Load the best models for each type"""
- loaded_models = {}
-
- for model_type in ['cnn', 'rl', 'transformer']:
- best_model = self.get_best_model(model_type)
-
- if best_model:
- try:
- model_path = Path(best_model.file_path)
- if model_path.exists():
- # Load the model
- model_data = torch.load(model_path, map_location='cpu')
- loaded_models[model_type] = {
- 'model': model_data,
- 'info': best_model,
- 'path': str(model_path)
- }
- logger.info(f"Loaded best {model_type} model: {best_model.model_name} "
- f"(Score: {best_model.metrics.get_composite_score():.3f})")
- else:
- logger.warning(f"Best {model_type} model file not found: {model_path}")
- except Exception as e:
- logger.error(f"Error loading {model_type} model: {e}")
- else:
- logger.info(f"No {model_type} model available")
-
- return loaded_models
-
- def update_model_performance(self, model_name: str, metrics: ModelMetrics):
- """Update performance metrics for a model"""
- if model_name in self.model_registry:
- self.model_registry[model_name].metrics = metrics
- self.model_registry[model_name].last_updated = datetime.now()
- self._save_registry()
-
- logger.info(f"Updated metrics for {model_name}: Score {metrics.get_composite_score():.3f}")
- else:
- logger.warning(f"Model {model_name} not found in registry")
-
- def get_storage_stats(self) -> Dict[str, Any]:
- """Get storage usage statistics"""
- total_size_mb = 0
- model_count = 0
-
- for model_info in self.model_registry.values():
- total_size_mb += model_info.file_size_mb
- model_count += 1
-
- # Check actual storage usage
- actual_size_mb = 0
- if self.best_models_dir.exists():
- actual_size_mb = sum(
- f.stat().st_size for f in self.best_models_dir.rglob('*') if f.is_file()
- ) / 1024 / 1024
-
- return {
- 'total_models': model_count,
- 'registered_size_mb': total_size_mb,
- 'actual_size_mb': actual_size_mb,
- 'storage_limit_gb': self.config['total_storage_limit_gb'],
- 'utilization_percent': (actual_size_mb / 1024) / self.config['total_storage_limit_gb'] * 100,
- 'models_by_type': {
- model_type: len(self.get_models_by_type(model_type))
- for model_type in ['cnn', 'rl', 'transformer']
- }
- }
-
- def get_model_leaderboard(self) -> List[Dict[str, Any]]:
- """Get model performance leaderboard"""
- leaderboard = []
-
- for model_name, model_info in self.model_registry.items():
- leaderboard.append({
- 'name': model_name,
- 'type': model_info.model_type,
- 'score': model_info.metrics.get_composite_score(),
- 'profit_factor': model_info.metrics.profit_factor,
- 'win_rate': model_info.metrics.win_rate,
- 'sharpe_ratio': model_info.metrics.sharpe_ratio,
- 'size_mb': model_info.file_size_mb,
- 'age_days': (datetime.now() - model_info.creation_time).days,
- 'last_updated': model_info.last_updated.strftime('%Y-%m-%d %H:%M')
- })
-
- # Sort by score
- leaderboard.sort(key=lambda x: x['score'], reverse=True)
-
- return leaderboard
-
- def cleanup_checkpoints(self) -> Dict[str, Any]:
- """Clean up old checkpoint files"""
- cleanup_summary = {
- 'deleted_files': 0,
- 'freed_space_mb': 0,
- 'errors': []
- }
-
- cutoff_date = datetime.now() - timedelta(days=self.config['max_checkpoint_age_days'])
-
- # Search for checkpoint files
- checkpoint_patterns = [
- "**/checkpoint_*.pt",
- "**/model_*.pt",
- "**/*checkpoint*",
- "**/epoch_*.pt"
- ]
-
- for pattern in checkpoint_patterns:
- for file_path in self.base_dir.rglob(pattern):
- if "best_models" not in str(file_path) and file_path.is_file():
- try:
- file_time = datetime.fromtimestamp(file_path.stat().st_mtime)
- if file_time < cutoff_date:
- size_mb = file_path.stat().st_size / 1024 / 1024
- file_path.unlink()
- cleanup_summary['deleted_files'] += 1
- cleanup_summary['freed_space_mb'] += size_mb
- except Exception as e:
- error_msg = f"Error deleting checkpoint {file_path}: {e}"
- logger.error(error_msg)
- cleanup_summary['errors'].append(error_msg)
-
- if cleanup_summary['deleted_files'] > 0:
- logger.info(f"Checkpoint cleanup: Deleted {cleanup_summary['deleted_files']} files, "
- f"freed {cleanup_summary['freed_space_mb']:.1f}MB")
-
- return cleanup_summary
-
-def create_model_manager() -> ModelManager:
- """Create and initialize the global model manager"""
- return ModelManager()
-
-# Example usage
-if __name__ == "__main__":
- # Configure logging
- logging.basicConfig(level=logging.INFO)
-
- # Create model manager
- manager = ModelManager()
-
- # Clean up all existing models (with confirmation)
- print("WARNING: This will delete ALL existing models!")
- print("Type 'CONFIRM' to proceed:")
- user_input = input().strip()
-
- if user_input == "CONFIRM":
- cleanup_result = manager.cleanup_all_existing_models(confirm=True)
- print(f"\nCleanup complete:")
- print(f"- Deleted {cleanup_result['files_deleted']} files")
- print(f"- Freed {cleanup_result['space_freed_mb']:.1f}MB of space")
- print(f"- Cleaned {cleanup_result['directories_cleaned']} directories")
-
- if cleanup_result['errors']:
- print(f"- {len(cleanup_result['errors'])} errors occurred")
- else:
- print("Cleanup cancelled")
\ No newline at end of file
diff --git a/position_sync_enhancement.py b/position_sync_enhancement.py
deleted file mode 100644
index 2020d94..0000000
--- a/position_sync_enhancement.py
+++ /dev/null
@@ -1,193 +0,0 @@
-#!/usr/bin/env python3
-"""
-Position Sync Enhancement - Fix P&L and Win Rate Calculation
-
-This script enhances the position synchronization and P&L calculation
-to properly account for leverage in the trading system.
-"""
-
-import os
-import sys
-import logging
-from pathlib import Path
-from datetime import datetime
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import get_config, setup_logging
-from core.trading_executor import TradingExecutor, TradeRecord
-
-# Setup logging
-setup_logging()
-logger = logging.getLogger(__name__)
-
-def analyze_trade_records():
- """Analyze trade records for P&L calculation issues"""
- logger.info("Analyzing trade records for P&L calculation issues...")
-
- # Initialize trading executor
- trading_executor = TradingExecutor()
-
- # Get trade records
- trade_records = trading_executor.trade_records
-
- if not trade_records:
- logger.warning("No trade records found.")
- return
-
- logger.info(f"Found {len(trade_records)} trade records.")
-
- # Analyze P&L calculation
- total_pnl = 0.0
- total_gross_pnl = 0.0
- total_fees = 0.0
- winning_trades = 0
- losing_trades = 0
- breakeven_trades = 0
-
- for trade in trade_records:
- # Calculate correct P&L with leverage
- entry_value = trade.entry_price * trade.quantity
- exit_value = trade.exit_price * trade.quantity
-
- if trade.side == 'LONG':
- gross_pnl = (exit_value - entry_value) * trade.leverage
- else: # SHORT
- gross_pnl = (entry_value - exit_value) * trade.leverage
-
- # Calculate fees
- fees = (entry_value + exit_value) * 0.001 # 0.1% fee on both entry and exit
-
- # Calculate net P&L
- net_pnl = gross_pnl - fees
-
- # Compare with stored values
- pnl_diff = abs(net_pnl - trade.pnl)
- if pnl_diff > 0.01: # More than 1 cent difference
- logger.warning(f"P&L calculation issue detected for trade {trade.entry_time}:")
- logger.warning(f" Stored P&L: ${trade.pnl:.2f}")
- logger.warning(f" Calculated P&L: ${net_pnl:.2f}")
- logger.warning(f" Difference: ${pnl_diff:.2f}")
- logger.warning(f" Leverage used: {trade.leverage}x")
-
- # Update statistics
- total_pnl += net_pnl
- total_gross_pnl += gross_pnl
- total_fees += fees
-
- if net_pnl > 0.01: # More than 1 cent profit
- winning_trades += 1
- elif net_pnl < -0.01: # More than 1 cent loss
- losing_trades += 1
- else:
- breakeven_trades += 1
-
- # Calculate win rate
- total_trades = winning_trades + losing_trades + breakeven_trades
- win_rate = (winning_trades / total_trades * 100) if total_trades > 0 else 0.0
-
- logger.info("\nTrade Analysis Results:")
- logger.info(f" Total trades: {total_trades}")
- logger.info(f" Winning trades: {winning_trades}")
- logger.info(f" Losing trades: {losing_trades}")
- logger.info(f" Breakeven trades: {breakeven_trades}")
- logger.info(f" Win rate: {win_rate:.1f}%")
- logger.info(f" Total P&L: ${total_pnl:.2f}")
- logger.info(f" Total gross P&L: ${total_gross_pnl:.2f}")
- logger.info(f" Total fees: ${total_fees:.2f}")
-
- # Check for leverage issues
- leverage_issues = False
- for trade in trade_records:
- if trade.leverage <= 1.0:
- leverage_issues = True
- logger.warning(f"Low leverage detected: {trade.leverage}x for trade at {trade.entry_time}")
-
- if leverage_issues:
- logger.warning("\nLeverage issues detected. Consider fixing the leverage calculation.")
- logger.info("Recommended fix: Ensure leverage is properly set in the trading executor.")
- else:
- logger.info("\nNo leverage issues detected.")
-
-def fix_leverage_calculation():
- """Fix leverage calculation in the trading executor"""
- logger.info("Fixing leverage calculation in the trading executor...")
-
- # Initialize trading executor
- trading_executor = TradingExecutor()
-
- # Get current leverage
- current_leverage = trading_executor.current_leverage
- logger.info(f"Current leverage setting: {current_leverage}x")
-
- # Check if leverage is properly set
- if current_leverage <= 1:
- logger.warning("Leverage is set too low. Updating to 20x...")
- trading_executor.current_leverage = 20
- logger.info(f"Updated leverage to {trading_executor.current_leverage}x")
- else:
- logger.info("Leverage is already set correctly.")
-
- # Update trade records with correct leverage
- updated_count = 0
- for i, trade in enumerate(trading_executor.trade_records):
- if trade.leverage <= 1.0:
- # Create updated trade record
- updated_trade = TradeRecord(
- symbol=trade.symbol,
- side=trade.side,
- quantity=trade.quantity,
- entry_price=trade.entry_price,
- exit_price=trade.exit_price,
- entry_time=trade.entry_time,
- exit_time=trade.exit_time,
- pnl=trade.pnl,
- fees=trade.fees,
- confidence=trade.confidence,
- hold_time_seconds=trade.hold_time_seconds,
- leverage=trading_executor.current_leverage, # Use current leverage setting
- position_size_usd=trade.position_size_usd,
- gross_pnl=trade.gross_pnl,
- net_pnl=trade.net_pnl
- )
-
- # Recalculate P&L with correct leverage
- entry_value = updated_trade.entry_price * updated_trade.quantity
- exit_value = updated_trade.exit_price * updated_trade.quantity
-
- if updated_trade.side == 'LONG':
- updated_trade.gross_pnl = (exit_value - entry_value) * updated_trade.leverage
- else: # SHORT
- updated_trade.gross_pnl = (entry_value - exit_value) * updated_trade.leverage
-
- # Recalculate fees
- updated_trade.fees = (entry_value + exit_value) * 0.001 # 0.1% fee on both entry and exit
-
- # Recalculate net P&L
- updated_trade.net_pnl = updated_trade.gross_pnl - updated_trade.fees
- updated_trade.pnl = updated_trade.net_pnl
-
- # Update trade record
- trading_executor.trade_records[i] = updated_trade
- updated_count += 1
-
- logger.info(f"Updated {updated_count} trade records with correct leverage.")
-
- # Save updated trade records
- # Note: This is a placeholder. In a real implementation, you would need to
- # persist the updated trade records to storage.
- logger.info("Changes will take effect on next dashboard restart.")
-
- return updated_count > 0
-
-if __name__ == "__main__":
- logger.info("=" * 70)
- logger.info("POSITION SYNC ENHANCEMENT")
- logger.info("=" * 70)
-
- if len(sys.argv) > 1 and sys.argv[1] == 'fix':
- fix_leverage_calculation()
- else:
- analyze_trade_records()
\ No newline at end of file
diff --git a/read_logs.py b/read_logs.py
deleted file mode 100644
index e120f51..0000000
--- a/read_logs.py
+++ /dev/null
@@ -1,124 +0,0 @@
-#!/usr/bin/env python
-"""
-Log Reader Utility
-
-This script provides a convenient way to read and filter log files during
-development.
-"""
-
-import os
-import sys
-import time
-import argparse
-from datetime import datetime
-
-def parse_args():
- """Parse command line arguments"""
- parser = argparse.ArgumentParser(description='Read and filter log files')
- parser.add_argument('--file', type=str, help='Log file to read (defaults to most recent .log file)')
- parser.add_argument('--tail', type=int, default=50, help='Number of lines to show from the end')
- parser.add_argument('--follow', '-f', action='store_true', help='Follow the file as it grows')
- parser.add_argument('--filter', type=str, help='Only show lines containing this string')
- parser.add_argument('--list', action='store_true', help='List all log files sorted by modification time')
- return parser.parse_args()
-
-def get_most_recent_log():
- """Find the most recently modified log file"""
- log_files = [f for f in os.listdir('.') if f.endswith('.log')]
- if not log_files:
- print("No log files found in current directory.")
- sys.exit(1)
-
- # Sort by modification time (newest first)
- log_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
- return log_files[0]
-
-def list_log_files():
- """List all log files sorted by modification time"""
- log_files = [f for f in os.listdir('.') if f.endswith('.log')]
- if not log_files:
- print("No log files found in current directory.")
- sys.exit(1)
-
- # Sort by modification time (newest first)
- log_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
-
- print(f"{'LAST MODIFIED':<20} {'SIZE':<10} FILENAME")
- print("-" * 60)
- for log_file in log_files:
- mtime = datetime.fromtimestamp(os.path.getmtime(log_file))
- size = os.path.getsize(log_file)
- size_str = f"{size / 1024:.1f} KB" if size > 1024 else f"{size} B"
- print(f"{mtime.strftime('%Y-%m-%d %H:%M:%S'):<20} {size_str:<10} {log_file}")
-
-def read_log_tail(file_path, num_lines, filter_text=None):
- """Read the last N lines of a file"""
- try:
- with open(file_path, 'r', encoding='utf-8') as f:
- # Read all lines (inefficient but simple)
- lines = f.readlines()
-
- # Filter if needed
- if filter_text:
- lines = [line for line in lines if filter_text in line]
-
- # Get the last N lines
- last_lines = lines[-num_lines:] if len(lines) > num_lines else lines
- return last_lines
- except Exception as e:
- print(f"Error reading file: {str(e)}")
- sys.exit(1)
-
-def follow_log(file_path, filter_text=None):
- """Follow the log file as it grows (like tail -f)"""
- try:
- with open(file_path, 'r', encoding='utf-8') as f:
- # Go to the end of the file
- f.seek(0, 2)
-
- while True:
- line = f.readline()
- if line:
- if not filter_text or filter_text in line:
- # Remove newlines at the end to avoid double spacing
- print(line.rstrip())
- else:
- time.sleep(0.1) # Sleep briefly to avoid consuming CPU
- except KeyboardInterrupt:
- print("\nLog reading stopped.")
- except Exception as e:
- print(f"Error following file: {str(e)}")
- sys.exit(1)
-
-def main():
- """Main function"""
- args = parse_args()
-
- # List all log files if requested
- if args.list:
- list_log_files()
- return
-
- # Determine which file to read
- file_path = args.file
- if not file_path:
- file_path = get_most_recent_log()
- print(f"Reading most recent log file: {file_path}")
-
- # Follow mode (like tail -f)
- if args.follow:
- print(f"Following {file_path} (Press Ctrl+C to stop)...")
- # First print the tail
- for line in read_log_tail(file_path, args.tail, args.filter):
- print(line.rstrip())
- print("-" * 80)
- print("Waiting for new content...")
- # Then follow
- follow_log(file_path, args.filter)
- else:
- # Just print the tail
- for line in read_log_tail(file_path, args.tail, args.filter):
- print(line.rstrip())
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/reset_db_manager.py b/reset_db_manager.py
deleted file mode 100644
index cec1b0a..0000000
--- a/reset_db_manager.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#!/usr/bin/env python3
-"""
-Script to reset the database manager instance to trigger migration in running system
-"""
-
-import sys
-import os
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from utils.database_manager import reset_database_manager
-import logging
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def main():
- """Reset the database manager to trigger migration"""
- try:
- logger.info("Resetting database manager to trigger migration...")
- reset_database_manager()
- logger.info("✅ Database manager reset successfully!")
- logger.info("The migration will run automatically on the next database access.")
- return True
- except Exception as e:
- logger.error(f"❌ Failed to reset database manager: {e}")
- return False
-
-if __name__ == "__main__":
- success = main()
- sys.exit(0 if success else 1)
\ No newline at end of file
diff --git a/reset_models_and_fix_mapping.py b/reset_models_and_fix_mapping.py
deleted file mode 100644
index 59871b2..0000000
--- a/reset_models_and_fix_mapping.py
+++ /dev/null
@@ -1,204 +0,0 @@
-#!/usr/bin/env python3
-"""
-Reset Models and Fix Action Mapping
-
-This script:
-1. Deletes existing model files
-2. Creates new model files with consistent action mapping
-3. Updates action mapping in key files
-"""
-
-import os
-import shutil
-import logging
-import sys
-import torch
-import numpy as np
-from datetime import datetime
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-def ensure_directory(directory):
- """Ensure directory exists"""
- if not os.path.exists(directory):
- os.makedirs(directory)
- logger.info(f"Created directory: {directory}")
-
-def delete_directory_contents(directory):
- """Delete all files in a directory"""
- if os.path.exists(directory):
- for filename in os.listdir(directory):
- file_path = os.path.join(directory, filename)
- try:
- if os.path.isfile(file_path) or os.path.islink(file_path):
- os.unlink(file_path)
- elif os.path.isdir(file_path):
- shutil.rmtree(file_path)
- logger.info(f"Deleted: {file_path}")
- except Exception as e:
- logger.error(f"Failed to delete {file_path}. Reason: {e}")
-
-def create_backup_directory():
- """Create a backup directory with timestamp"""
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- backup_dir = f"models/backup_{timestamp}"
- ensure_directory(backup_dir)
- return backup_dir
-
-def backup_models():
- """Backup existing models"""
- backup_dir = create_backup_directory()
-
- # List of model directories to backup
- model_dirs = [
- "models/enhanced_rl",
- "models/enhanced_cnn",
- "models/realtime_rl_cob",
- "models/rl",
- "models/cnn"
- ]
-
- for model_dir in model_dirs:
- if os.path.exists(model_dir):
- dest_dir = os.path.join(backup_dir, os.path.basename(model_dir))
- ensure_directory(dest_dir)
-
- # Copy files
- for filename in os.listdir(model_dir):
- file_path = os.path.join(model_dir, filename)
- if os.path.isfile(file_path):
- shutil.copy2(file_path, dest_dir)
- logger.info(f"Backed up: {file_path} to {dest_dir}")
-
- return backup_dir
-
-def initialize_dqn_model():
- """Initialize a new DQN model with consistent action mapping"""
- try:
- # Import necessary modules
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
- from NN.models.dqn_agent import DQNAgent
-
- # Define state shape for BTC and ETH
- state_shape = (100,) # Default feature dimension
-
- # Create models directory
- ensure_directory("models/enhanced_rl")
-
- # Initialize DQN with 3 actions (BUY=0, SELL=1, HOLD=2)
- dqn_btc = DQNAgent(
- state_shape=state_shape,
- n_actions=3, # BUY=0, SELL=1, HOLD=2
- learning_rate=0.001,
- epsilon=0.5, # Start with moderate exploration
- epsilon_min=0.01,
- epsilon_decay=0.995,
- model_name="BTC_USDT_dqn"
- )
-
- dqn_eth = DQNAgent(
- state_shape=state_shape,
- n_actions=3, # BUY=0, SELL=1, HOLD=2
- learning_rate=0.001,
- epsilon=0.5, # Start with moderate exploration
- epsilon_min=0.01,
- epsilon_decay=0.995,
- model_name="ETH_USDT_dqn"
- )
-
- # Save initial models
- torch.save(dqn_btc.policy_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_policy.pth")
- torch.save(dqn_btc.target_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_target.pth")
- torch.save(dqn_eth.policy_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_policy.pth")
- torch.save(dqn_eth.target_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_target.pth")
-
- logger.info("Initialized new DQN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)")
- return True
- except Exception as e:
- logger.error(f"Failed to initialize DQN models: {e}")
- return False
-
-def initialize_cnn_model():
- """Initialize a new CNN model with consistent action mapping"""
- try:
- # Import necessary modules
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
- from NN.models.enhanced_cnn import EnhancedCNN
-
- # Define input dimension and number of actions
- input_dim = 100 # Default feature dimension
- n_actions = 3 # BUY=0, SELL=1, HOLD=2
-
- # Create models directory
- ensure_directory("models/enhanced_cnn")
-
- # Initialize CNN models for BTC and ETH
- cnn_btc = EnhancedCNN(input_dim, n_actions)
- cnn_eth = EnhancedCNN(input_dim, n_actions)
-
- # Save initial models
- torch.save(cnn_btc.state_dict(), "models/enhanced_cnn/BTC_USDT_cnn.pth")
- torch.save(cnn_eth.state_dict(), "models/enhanced_cnn/ETH_USDT_cnn.pth")
-
- logger.info("Initialized new CNN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)")
- return True
- except Exception as e:
- logger.error(f"Failed to initialize CNN models: {e}")
- return False
-
-def initialize_realtime_rl_model():
- """Initialize a new realtime RL model with consistent action mapping"""
- try:
- # Create models directory
- ensure_directory("models/realtime_rl_cob")
-
- # Create empty model files to ensure directory is not empty
- with open("models/realtime_rl_cob/README.txt", "w") as f:
- f.write("Realtime RL COB models will be saved here.\n")
- f.write("Action mapping: BUY=0, SELL=1, HOLD=2\n")
-
- logger.info("Initialized realtime RL model directory")
- return True
- except Exception as e:
- logger.error(f"Failed to initialize realtime RL models: {e}")
- return False
-
-def main():
- """Main function to reset models and fix action mapping"""
- logger.info("Starting model reset and action mapping fix")
-
- # Backup existing models
- backup_dir = backup_models()
- logger.info(f"Backed up existing models to {backup_dir}")
-
- # Delete existing model files
- model_dirs = [
- "models/enhanced_rl",
- "models/enhanced_cnn",
- "models/realtime_rl_cob"
- ]
-
- for model_dir in model_dirs:
- delete_directory_contents(model_dir)
- logger.info(f"Deleted contents of {model_dir}")
-
- # Initialize new models with consistent action mapping
- dqn_success = initialize_dqn_model()
- cnn_success = initialize_cnn_model()
- rl_success = initialize_realtime_rl_model()
-
- if dqn_success and cnn_success and rl_success:
- logger.info("Successfully reset models and fixed action mapping")
- logger.info("New action mapping: BUY=0, SELL=1, HOLD=2")
- else:
- logger.error("Failed to reset models and fix action mapping")
-
- logger.info("Model reset complete")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/run_clean_dashboard.py b/run_clean_dashboard.py
index 40202b5..d1b38b3 100644
--- a/run_clean_dashboard.py
+++ b/run_clean_dashboard.py
@@ -1,325 +1,26 @@
-#!/usr/bin/env python3
-"""
-Run Clean Trading Dashboard with Full Training Pipeline
-Integrated system with both training loop and clean web dashboard
-"""
-
-import os
-# Fix OpenMP library conflicts before importing other modules
-os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
-os.environ['OMP_NUM_THREADS'] = '4'
-
-# Fix matplotlib backend issue - set non-interactive backend before any imports
-import matplotlib
-matplotlib.use('Agg') # Use non-interactive Agg backend
-
-import asyncio
import logging
-import sys
-import platform
-from safe_logging import setup_safe_logging
-import threading
-import time
-from pathlib import Path
-
-# Windows-specific async event loop configuration
-if platform.system() == "Windows":
- # Use ProactorEventLoop on Windows for better I/O handling
- asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import get_config, setup_logging
-from core.data_provider import DataProvider
-
-# Import checkpoint management
-from utils.checkpoint_manager import get_checkpoint_manager
-from utils.training_integration import get_training_integration
-
-# Setup logging
-setup_safe_logging()
-logger = logging.getLogger(__name__)
-
-async def start_training_pipeline(orchestrator, trading_executor):
- """Start the training pipeline in the background with comprehensive error handling"""
- logger.info("=" * 70)
- logger.info("STARTING TRAINING PIPELINE WITH CLEAN DASHBOARD")
- logger.info("=" * 70)
-
- # Set up async exception handler
- def handle_async_exception(loop, context):
- """Handle uncaught async exceptions"""
- exception = context.get('exception')
- if exception:
- logger.error(f"Uncaught async exception: {exception}")
- logger.error(f"Context: {context}")
- else:
- logger.error(f"Async error: {context.get('message', 'Unknown error')}")
-
- # Get current event loop and set exception handler
- loop = asyncio.get_running_loop()
- loop.set_exception_handler(handle_async_exception)
-
- # Initialize checkpoint management
- checkpoint_manager = get_checkpoint_manager()
- training_integration = get_training_integration()
-
- # Training statistics
- training_stats = {
- 'iteration_count': 0,
- 'total_decisions': 0,
- 'successful_trades': 0,
- 'best_performance': 0.0,
- 'last_checkpoint_iteration': 0
- }
-
- try:
- # Start real-time processing with error handling
- try:
- if hasattr(orchestrator, 'start_realtime_processing'):
- await orchestrator.start_realtime_processing()
- logger.info("Real-time processing started")
- except Exception as e:
- logger.error(f"Error starting real-time processing: {e}")
-
- # Start COB integration with error handling
- try:
- if hasattr(orchestrator, 'start_cob_integration'):
- await orchestrator.start_cob_integration()
- logger.info("COB integration started - 5-minute data matrix active")
- else:
- logger.info("COB integration not available")
- except Exception as e:
- logger.error(f"Error starting COB integration: {e}")
-
- # Main training loop
- iteration = 0
- last_checkpoint_time = time.time()
-
- while True:
- try:
- iteration += 1
- training_stats['iteration_count'] = iteration
-
- # Get symbols to process
- symbols = orchestrator.symbols if hasattr(orchestrator, 'symbols') else ['ETH/USDT']
-
- # Process each symbol
- for symbol in symbols:
- try:
- # Make trading decision (this triggers model training)
- decision = await orchestrator.make_trading_decision(symbol)
- if decision:
- training_stats['total_decisions'] += 1
- logger.debug(f"[{symbol}] Decision: {decision.action} @ {decision.confidence:.1%}")
-
- except Exception as e:
- logger.warning(f"Error processing {symbol}: {e}")
-
- # Status logging every 100 iterations
- if iteration % 100 == 0:
- current_time = time.time()
- elapsed = current_time - last_checkpoint_time
-
- logger.info(f"[TRAINING] Iteration {iteration}, Decisions: {training_stats['total_decisions']}, Time: {elapsed:.1f}s")
-
- # Models will save their own checkpoints when performance improves
- training_stats['last_checkpoint_iteration'] = iteration
- last_checkpoint_time = current_time
-
- # Brief pause to prevent overwhelming the system
- await asyncio.sleep(0.1) # 100ms between iterations
-
- except Exception as e:
- logger.error(f"Training loop error: {e}")
- await asyncio.sleep(5) # Wait longer on error
-
- except Exception as e:
- logger.error(f"Training pipeline error: {e}")
- import traceback
- logger.error(traceback.format_exc())
-
-def start_clean_dashboard_with_training():
- """Start clean dashboard with full training pipeline"""
- try:
- logger.info("=" * 80)
- logger.info("CLEAN TRADING DASHBOARD + FULL TRAINING PIPELINE")
- logger.info("=" * 80)
- logger.info("Features: Real-time Training, COB Integration, Clean UI")
- logger.info("Universal Data Stream: ENABLED")
- logger.info("Neural Decision Fusion: ENABLED")
- logger.info("COB Integration: ENABLED")
- logger.info("GPU Training: ENABLED")
- logger.info("TensorBoard Integration: ENABLED")
- logger.info("Multi-symbol: ETH/USDT, BTC/USDT")
-
- # Get port from environment or use default
- dashboard_port = int(os.environ.get('DASHBOARD_PORT', '8051'))
- tensorboard_port = int(os.environ.get('TENSORBOARD_PORT', '6006'))
- logger.info(f"Dashboard: http://127.0.0.1:{dashboard_port}")
- logger.info(f"TensorBoard: http://127.0.0.1:{tensorboard_port}")
- logger.info("=" * 80)
-
- # Check environment variables
- enable_universal_stream = os.environ.get('ENABLE_UNIVERSAL_DATA_STREAM', '1') == '1'
- enable_nn_fusion = os.environ.get('ENABLE_NN_DECISION_FUSION', '1') == '1'
- enable_cob = os.environ.get('ENABLE_COB_INTEGRATION', '1') == '1'
-
- logger.info(f"Universal Data Stream: {'ENABLED' if enable_universal_stream else 'DISABLED'}")
- logger.info(f"Neural Decision Fusion: {'ENABLED' if enable_nn_fusion else 'DISABLED'}")
- logger.info(f"COB Integration: {'ENABLED' if enable_cob else 'DISABLED'}")
-
- # Get configuration
- config = get_config()
-
- # Initialize core components with standardized versions
- from core.standardized_data_provider import StandardizedDataProvider
- from core.orchestrator import TradingOrchestrator
- from core.trading_executor import TradingExecutor
-
- # Create standardized data provider
- data_provider = StandardizedDataProvider()
- logger.info("StandardizedDataProvider created with BaseDataInput support")
-
- # Create enhanced orchestrator with standardized data provider
- orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
- logger.info("Enhanced Trading Orchestrator created with COB integration")
-
- # Create trading executor
- trading_executor = TradingExecutor(config_path="config.yaml")
- logger.info(f"Creating trading executor with {trading_executor.primary_name} configuration...")
-
-
- # Connect trading executor to orchestrator
- orchestrator.trading_executor = trading_executor
- logger.info("Trading Executor connected to Orchestrator")
-
- # Initialize system resource monitoring
- from utils.system_monitor import start_system_monitoring
- system_monitor = start_system_monitoring()
-
- # Set up cleanup callback for memory management
- def cleanup_callback():
- """Custom cleanup for memory management"""
- try:
- # Clear orchestrator caches
- if hasattr(orchestrator, 'recent_decisions'):
- for symbol in orchestrator.recent_decisions:
- if len(orchestrator.recent_decisions[symbol]) > 50:
- orchestrator.recent_decisions[symbol] = orchestrator.recent_decisions[symbol][-25:]
-
- # Clear data provider caches
- if hasattr(data_provider, 'clear_old_data'):
- data_provider.clear_old_data()
-
- logger.info("Custom memory cleanup completed")
- except Exception as e:
- logger.error(f"Error in custom cleanup: {e}")
-
- system_monitor.set_callbacks(cleanup=cleanup_callback)
- logger.info("System resource monitoring started with memory cleanup")
-
- # Import clean dashboard
- from web.clean_dashboard import create_clean_dashboard
-
- # Create clean dashboard
- logger.info("Creating clean dashboard...")
- dashboard = create_clean_dashboard(data_provider, orchestrator, trading_executor)
- logger.info("Clean Trading Dashboard created")
-
- # Add memory cleanup method to dashboard
- def cleanup_dashboard_memory():
- """Clean up dashboard memory caches"""
- try:
- if hasattr(dashboard, 'recent_decisions'):
- dashboard.recent_decisions = dashboard.recent_decisions[-50:] # Keep last 50
- if hasattr(dashboard, 'closed_trades'):
- dashboard.closed_trades = dashboard.closed_trades[-100:] # Keep last 100
- if hasattr(dashboard, 'tick_cache'):
- dashboard.tick_cache = dashboard.tick_cache[-1000:] # Keep last 1000
- logger.debug("Dashboard memory cleanup completed")
- except Exception as e:
- logger.error(f"Error in dashboard memory cleanup: {e}")
-
- # Set cleanup method on dashboard
- dashboard.cleanup_memory = cleanup_dashboard_memory
-
- # Start training pipeline in background thread with enhanced error handling
- def training_worker():
- """Run training pipeline in background with comprehensive error handling"""
- try:
- asyncio.run(start_training_pipeline(orchestrator, trading_executor))
- except KeyboardInterrupt:
- logger.info("Training worker stopped by user")
- except Exception as e:
- logger.error(f"Training worker error: {e}")
- import traceback
- logger.error(f"Training worker traceback: {traceback.format_exc()}")
- # Don't exit - let main thread handle restart
-
- training_thread = threading.Thread(target=training_worker, daemon=True)
- training_thread.start()
- logger.info("Training pipeline started in background with error handling")
-
- # Wait a moment for training to initialize
- time.sleep(3)
-
- # Start TensorBoard in background
- from web.tensorboard_integration import get_tensorboard_integration
- tensorboard_port = int(os.environ.get('TENSORBOARD_PORT', '6006'))
- tensorboard_integration = get_tensorboard_integration(log_dir="runs", port=tensorboard_port)
-
- # Start TensorBoard server
- tensorboard_started = tensorboard_integration.start_tensorboard(open_browser=False)
- if tensorboard_started:
- logger.info(f"TensorBoard started at {tensorboard_integration.get_tensorboard_url()}")
- else:
- logger.warning("Failed to start TensorBoard - training metrics will not be visualized")
-
- # Start dashboard server with error handling (this blocks)
- logger.info("Starting Clean Dashboard Server with error handling...")
- try:
- dashboard.run_server(host='127.0.0.1', port=dashboard_port, debug=False)
- except Exception as e:
- logger.error(f"Dashboard server error: {e}")
- import traceback
- logger.error(f"Dashboard server traceback: {traceback.format_exc()}")
- raise # Re-raise to trigger main error handling
-
- except KeyboardInterrupt:
- logger.info("System stopped by user")
- # Stop TensorBoard
- try:
- tensorboard_integration = get_tensorboard_integration()
- tensorboard_integration.stop_tensorboard()
- except:
- pass
- except Exception as e:
- logger.error(f"Error running clean dashboard with training: {e}")
- import traceback
- traceback.print_exc()
- sys.exit(1)
+from core.config import setup_logging, get_config
+from core.trading_executor import TradingExecutor
+from core.orchestrator import TradingOrchestrator
+from core.standardized_data_provider import StandardizedDataProvider
+from web.clean_dashboard import CleanTradingDashboard
def main():
- """Main function with comprehensive error handling"""
- try:
- start_clean_dashboard_with_training()
- except KeyboardInterrupt:
- logger.info("Dashboard stopped by user (Ctrl+C)")
- sys.exit(0)
- except Exception as e:
- logger.error(f"Critical error in main: {e}")
- import traceback
- logger.error(traceback.format_exc())
- sys.exit(1)
+ setup_logging()
+ cfg = get_config()
+
+ data_provider = StandardizedDataProvider()
+ trading_executor = TradingExecutor()
+ orchestrator = TradingOrchestrator(data_provider=data_provider)
+
+ dashboard = CleanTradingDashboard(
+ data_provider=data_provider,
+ orchestrator=orchestrator,
+ trading_executor=trading_executor
+ )
+ logging.getLogger(__name__).info("Starting Clean Trading Dashboard at http://127.0.0.1:8050")
+ dashboard.run_server(host='127.0.0.1', port=8050, debug=False)
+
+if __name__ == '__main__':
+ main()
-if __name__ == "__main__":
- # Ensure logging is flushed on exit
- import atexit
- def flush_logs():
- logging.shutdown()
- atexit.register(flush_logs)
-
- main()
\ No newline at end of file
diff --git a/run_continuous_training.py b/run_continuous_training.py
deleted file mode 100644
index 86c5c69..0000000
--- a/run_continuous_training.py
+++ /dev/null
@@ -1,501 +0,0 @@
-#!/usr/bin/env python3
-"""
-Continuous Full Training System (RL + CNN)
-
-This system runs continuous training for both RL and CNN models using the enhanced
-DataProvider for consistent data streaming to both models and the dashboard.
-
-Features:
-- Single DataProvider instance for all data needs
-- Continuous RL training with real-time market data
-- CNN training with perfect move detection
-- Real-time performance monitoring
-- Automatic model checkpointing
-- Integration with live trading dashboard
-"""
-
-import asyncio
-import logging
-import time
-import signal
-import sys
-from datetime import datetime, timedelta
-from threading import Thread, Event
-from typing import Dict, Any
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler('logs/continuous_training.log'),
- logging.StreamHandler()
- ]
-)
-logger = logging.getLogger(__name__)
-
-# Import our components
-from core.config import get_config
-from core.data_provider import DataProvider, MarketTick
-from core.enhanced_orchestrator import EnhancedTradingOrchestrator
-from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard
-
-# Import checkpoint management
-from utils.checkpoint_manager import get_checkpoint_manager
-from utils.training_integration import get_training_integration
-
-class ContinuousTrainingSystem:
- """Comprehensive continuous training system for RL + CNN models"""
-
- def __init__(self):
- """Initialize the continuous training system"""
- self.config = get_config()
-
- # Single DataProvider instance for all data needs
- self.data_provider = DataProvider(
- symbols=['ETH/USDT', 'BTC/USDT'],
- timeframes=['1s', '1m', '1h', '1d']
- )
-
- # Enhanced orchestrator for AI trading
- self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
-
- # Dashboard for monitoring
- self.dashboard = None
-
- # Training control
- self.running = False
- self.shutdown_event = Event()
-
- # Checkpoint management
- self.checkpoint_manager = get_checkpoint_manager()
- self.training_integration = get_training_integration()
-
- # Performance tracking
- self.training_stats = {
- 'start_time': None,
- 'rl_training_cycles': 0,
- 'cnn_training_cycles': 0,
- 'perfect_moves_detected': 0,
- 'total_ticks_processed': 0,
- 'models_saved': 0,
- 'last_checkpoint': None,
- 'best_rl_reward': float('-inf'),
- 'best_cnn_accuracy': 0.0
- }
-
- # Training intervals
- self.rl_training_interval = 300 # 5 minutes
- self.cnn_training_interval = 600 # 10 minutes
- self.checkpoint_interval = 1800 # 30 minutes
-
- logger.info("Continuous Training System initialized with checkpoint management")
- logger.info(f"RL training interval: {self.rl_training_interval}s")
- logger.info(f"CNN training interval: {self.cnn_training_interval}s")
- logger.info(f"Checkpoint interval: {self.checkpoint_interval}s")
-
- async def start(self, run_dashboard: bool = True):
- """Start the continuous training system"""
- logger.info("Starting Continuous Training System...")
- self.running = True
- self.training_stats['start_time'] = datetime.now()
-
- try:
- # Start DataProvider streaming
- logger.info("Starting DataProvider real-time streaming...")
- await self.data_provider.start_real_time_streaming()
-
- # Subscribe to tick data for training
- subscriber_id = self.data_provider.subscribe_to_ticks(
- callback=self._handle_training_tick,
- symbols=['ETH/USDT', 'BTC/USDT'],
- subscriber_name="ContinuousTraining"
- )
- logger.info(f"Subscribed to training tick stream: {subscriber_id}")
-
- # Start training threads
- training_tasks = [
- asyncio.create_task(self._rl_training_loop()),
- asyncio.create_task(self._cnn_training_loop()),
- asyncio.create_task(self._checkpoint_loop()),
- asyncio.create_task(self._monitoring_loop())
- ]
-
- # Start dashboard if requested
- if run_dashboard:
- dashboard_task = asyncio.create_task(self._run_dashboard())
- training_tasks.append(dashboard_task)
-
- logger.info("All training components started successfully")
-
- # Wait for shutdown signal
- await self._wait_for_shutdown()
-
- except Exception as e:
- logger.error(f"Error in continuous training system: {e}")
- raise
- finally:
- await self.stop()
-
- def _handle_training_tick(self, tick: MarketTick):
- """Handle incoming tick data for training"""
- try:
- self.training_stats['total_ticks_processed'] += 1
-
- # Process tick through orchestrator for RL training
- if self.orchestrator and hasattr(self.orchestrator, 'process_tick'):
- self.orchestrator.process_tick(tick)
-
- # Log every 1000 ticks
- if self.training_stats['total_ticks_processed'] % 1000 == 0:
- logger.info(f"Processed {self.training_stats['total_ticks_processed']} training ticks")
-
- except Exception as e:
- logger.warning(f"Error processing training tick: {e}")
-
- async def _rl_training_loop(self):
- """Continuous RL training loop"""
- logger.info("Starting RL training loop...")
-
- while self.running:
- try:
- start_time = time.time()
-
- # Perform RL training cycle
- if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
- logger.info("Starting RL training cycle...")
-
- # Get recent market data for training
- training_data = self._prepare_rl_training_data()
-
- if training_data is not None:
- # Train RL agent
- training_results = await self._train_rl_agent(training_data)
-
- if training_results:
- self.training_stats['rl_training_cycles'] += 1
- logger.info(f"RL training cycle {self.training_stats['rl_training_cycles']} completed")
- logger.info(f"Training results: {training_results}")
- else:
- logger.warning("No training data available for RL agent")
-
- # Wait for next training cycle
- elapsed = time.time() - start_time
- sleep_time = max(0, self.rl_training_interval - elapsed)
- await asyncio.sleep(sleep_time)
-
- except Exception as e:
- logger.error(f"Error in RL training loop: {e}")
- await asyncio.sleep(60) # Wait before retrying
-
- async def _cnn_training_loop(self):
- """Continuous CNN training loop"""
- logger.info("Starting CNN training loop...")
-
- while self.running:
- try:
- start_time = time.time()
-
- # Perform CNN training cycle
- if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
- logger.info("Starting CNN training cycle...")
-
- # Detect perfect moves for CNN training
- perfect_moves = self._detect_perfect_moves()
-
- if perfect_moves:
- self.training_stats['perfect_moves_detected'] += len(perfect_moves)
-
- # Train CNN with perfect moves
- training_results = await self._train_cnn_model(perfect_moves)
-
- if training_results:
- self.training_stats['cnn_training_cycles'] += 1
- logger.info(f"CNN training cycle {self.training_stats['cnn_training_cycles']} completed")
- logger.info(f"Perfect moves processed: {len(perfect_moves)}")
- else:
- logger.info("No perfect moves detected for CNN training")
-
- # Wait for next training cycle
- elapsed = time.time() - start_time
- sleep_time = max(0, self.cnn_training_interval - elapsed)
- await asyncio.sleep(sleep_time)
-
- except Exception as e:
- logger.error(f"Error in CNN training loop: {e}")
- await asyncio.sleep(60) # Wait before retrying
-
- async def _checkpoint_loop(self):
- """Automatic model checkpointing loop"""
- logger.info("Starting checkpoint loop...")
-
- while self.running:
- try:
- await asyncio.sleep(self.checkpoint_interval)
-
- logger.info("Creating model checkpoints...")
-
- # Save RL model
- if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
- rl_checkpoint = await self._save_rl_checkpoint()
- if rl_checkpoint:
- logger.info(f"RL checkpoint saved: {rl_checkpoint}")
-
- # Save CNN model
- if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
- cnn_checkpoint = await self._save_cnn_checkpoint()
- if cnn_checkpoint:
- logger.info(f"CNN checkpoint saved: {cnn_checkpoint}")
-
- self.training_stats['models_saved'] += 1
- self.training_stats['last_checkpoint'] = datetime.now()
-
- except Exception as e:
- logger.error(f"Error in checkpoint loop: {e}")
-
- async def _monitoring_loop(self):
- """System monitoring and performance tracking loop"""
- logger.info("Starting monitoring loop...")
-
- while self.running:
- try:
- await asyncio.sleep(300) # Monitor every 5 minutes
-
- # Log system statistics
- uptime = datetime.now() - self.training_stats['start_time']
-
- logger.info("=== CONTINUOUS TRAINING SYSTEM STATUS ===")
- logger.info(f"Uptime: {uptime}")
- logger.info(f"RL training cycles: {self.training_stats['rl_training_cycles']}")
- logger.info(f"CNN training cycles: {self.training_stats['cnn_training_cycles']}")
- logger.info(f"Perfect moves detected: {self.training_stats['perfect_moves_detected']}")
- logger.info(f"Total ticks processed: {self.training_stats['total_ticks_processed']}")
- logger.info(f"Models saved: {self.training_stats['models_saved']}")
-
- # DataProvider statistics
- if hasattr(self.data_provider, 'get_subscriber_stats'):
- subscriber_stats = self.data_provider.get_subscriber_stats()
- logger.info(f"Active subscribers: {subscriber_stats.get('active_subscribers', 0)}")
- logger.info(f"Total ticks distributed: {subscriber_stats.get('distribution_stats', {}).get('total_ticks_distributed', 0)}")
-
- # Orchestrator performance
- if hasattr(self.orchestrator, 'get_performance_metrics'):
- perf_metrics = self.orchestrator.get_performance_metrics()
- logger.info(f"Orchestrator performance: {perf_metrics}")
-
- logger.info("==========================================")
-
- except Exception as e:
- logger.error(f"Error in monitoring loop: {e}")
-
- async def _run_dashboard(self):
- """Run the dashboard in a separate thread"""
- try:
- logger.info("Starting live trading dashboard...")
-
- def run_dashboard():
- self.dashboard = RealTimeScalpingDashboard(
- data_provider=self.data_provider,
- orchestrator=self.orchestrator
- )
- self.dashboard.run(host='127.0.0.1', port=8051, debug=False)
-
- dashboard_thread = Thread(target=run_dashboard, daemon=True)
- dashboard_thread.start()
-
- logger.info("Dashboard started at http://127.0.0.1:8051")
-
- # Keep dashboard thread alive
- while self.running:
- await asyncio.sleep(10)
-
- except Exception as e:
- logger.error(f"Error running dashboard: {e}")
-
- def _prepare_rl_training_data(self) -> Dict[str, Any]:
- """Prepare training data for RL agent"""
- try:
- # Get recent market data from DataProvider
- eth_data = self.data_provider.get_latest_candles('ETH/USDT', '1m', limit=1000)
- btc_data = self.data_provider.get_latest_candles('BTC/USDT', '1m', limit=1000)
-
- if eth_data is not None and not eth_data.empty:
- return {
- 'eth_data': eth_data,
- 'btc_data': btc_data,
- 'timestamp': datetime.now()
- }
-
- return None
-
- except Exception as e:
- logger.error(f"Error preparing RL training data: {e}")
- return None
-
- def _detect_perfect_moves(self) -> list:
- """Detect perfect trading moves for CNN training"""
- try:
- # Get recent tick data
- recent_ticks = self.data_provider.get_recent_ticks('ETHUSDT', count=500)
-
- if not recent_ticks:
- return []
-
- # Simple perfect move detection (can be enhanced)
- perfect_moves = []
-
- for i in range(1, len(recent_ticks) - 1):
- prev_tick = recent_ticks[i-1]
- curr_tick = recent_ticks[i]
- next_tick = recent_ticks[i+1]
-
- # Detect significant price movements
- price_change = (next_tick.price - curr_tick.price) / curr_tick.price
-
- if abs(price_change) > 0.001: # 0.1% movement
- perfect_moves.append({
- 'timestamp': curr_tick.timestamp,
- 'price': curr_tick.price,
- 'action': 'BUY' if price_change > 0 else 'SELL',
- 'confidence': min(abs(price_change) * 100, 1.0)
- })
-
- return perfect_moves[-10:] # Return last 10 perfect moves
-
- except Exception as e:
- logger.error(f"Error detecting perfect moves: {e}")
- return []
-
- async def _train_rl_agent(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
- """Train the RL agent with market data"""
- try:
- # Placeholder for RL training logic
- # This would integrate with the actual RL agent
-
- logger.info("Training RL agent with market data...")
-
- # Simulate training time
- await asyncio.sleep(1)
-
- return {
- 'loss': 0.05,
- 'reward': 0.75,
- 'episodes': 100
- }
-
- except Exception as e:
- logger.error(f"Error training RL agent: {e}")
- return None
-
- async def _train_cnn_model(self, perfect_moves: list) -> Dict[str, Any]:
- """Train the CNN model with perfect moves"""
- try:
- # Placeholder for CNN training logic
- # This would integrate with the actual CNN model
-
- logger.info(f"Training CNN model with {len(perfect_moves)} perfect moves...")
-
- # Simulate training time
- await asyncio.sleep(2)
-
- return {
- 'accuracy': 0.92,
- 'loss': 0.08,
- 'perfect_moves_processed': len(perfect_moves)
- }
-
- except Exception as e:
- logger.error(f"Error training CNN model: {e}")
- return None
-
- async def _save_rl_checkpoint(self) -> str:
- """Save RL model checkpoint"""
- try:
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
- checkpoint_path = f"models/rl/checkpoint_rl_{timestamp}.pt"
-
- # Placeholder for actual model saving
- logger.info(f"Saving RL checkpoint to {checkpoint_path}")
-
- return checkpoint_path
-
- except Exception as e:
- logger.error(f"Error saving RL checkpoint: {e}")
- return None
-
- async def _save_cnn_checkpoint(self) -> str:
- """Save CNN model checkpoint"""
- try:
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
- checkpoint_path = f"models/cnn/checkpoint_cnn_{timestamp}.pt"
-
- # Placeholder for actual model saving
- logger.info(f"Saving CNN checkpoint to {checkpoint_path}")
-
- return checkpoint_path
-
- except Exception as e:
- logger.error(f"Error saving CNN checkpoint: {e}")
- return None
-
- async def _wait_for_shutdown(self):
- """Wait for shutdown signal"""
- def signal_handler(signum, frame):
- logger.info(f"Received signal {signum}, shutting down...")
- self.shutdown_event.set()
-
- signal.signal(signal.SIGINT, signal_handler)
- signal.signal(signal.SIGTERM, signal_handler)
-
- # Wait for shutdown event
- while not self.shutdown_event.is_set():
- await asyncio.sleep(1)
-
- async def stop(self):
- """Stop the continuous training system"""
- logger.info("Stopping Continuous Training System...")
- self.running = False
-
- try:
- # Stop DataProvider streaming
- if self.data_provider:
- await self.data_provider.stop_real_time_streaming()
-
- # Final checkpoint
- logger.info("Creating final checkpoints...")
- await self._save_rl_checkpoint()
- await self._save_cnn_checkpoint()
-
- # Log final statistics
- uptime = datetime.now() - self.training_stats['start_time']
- logger.info("=== FINAL TRAINING STATISTICS ===")
- logger.info(f"Total uptime: {uptime}")
- logger.info(f"RL training cycles: {self.training_stats['rl_training_cycles']}")
- logger.info(f"CNN training cycles: {self.training_stats['cnn_training_cycles']}")
- logger.info(f"Perfect moves detected: {self.training_stats['perfect_moves_detected']}")
- logger.info(f"Total ticks processed: {self.training_stats['total_ticks_processed']}")
- logger.info(f"Models saved: {self.training_stats['models_saved']}")
- logger.info("=================================")
-
- except Exception as e:
- logger.error(f"Error during shutdown: {e}")
-
- logger.info("Continuous Training System stopped")
-
-async def main():
- """Main entry point"""
- logger.info("Starting Continuous Full Training System (RL + CNN)")
-
- # Create and start the training system
- training_system = ContinuousTrainingSystem()
-
- try:
- await training_system.start(run_dashboard=True)
- except KeyboardInterrupt:
- logger.info("Interrupted by user")
- except Exception as e:
- logger.error(f"Fatal error: {e}")
- sys.exit(1)
-
-if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
diff --git a/run_crash_safe_dashboard.py b/run_crash_safe_dashboard.py
deleted file mode 100644
index 2c228cb..0000000
--- a/run_crash_safe_dashboard.py
+++ /dev/null
@@ -1,269 +0,0 @@
-#!/usr/bin/env python3
-"""
-Crash-Safe Dashboard Runner
-
-This runner is designed to prevent crashes by:
-1. Isolating imports with try/except blocks
-2. Minimal initialization
-3. Graceful error handling
-4. No complex training loops
-5. Safe component loading
-"""
-
-import os
-import sys
-import logging
-import traceback
-from pathlib import Path
-
-# Fix environment issues before any imports
-os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
-os.environ['OMP_NUM_THREADS'] = '1' # Minimal threads
-os.environ['MPLBACKEND'] = 'Agg'
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-# Setup basic logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-# Reduce noise from other loggers
-logging.getLogger('werkzeug').setLevel(logging.ERROR)
-logging.getLogger('dash').setLevel(logging.ERROR)
-logging.getLogger('matplotlib').setLevel(logging.ERROR)
-
-class CrashSafeDashboard:
- """Crash-safe dashboard with minimal dependencies"""
-
- def __init__(self):
- """Initialize with safe error handling"""
- self.components = {}
- self.dashboard_app = None
- self.initialization_errors = []
-
- logger.info("Initializing crash-safe dashboard...")
-
- def safe_import(self, module_name, class_name=None):
- """Safely import modules with error handling"""
- try:
- if class_name:
- module = __import__(module_name, fromlist=[class_name])
- return getattr(module, class_name)
- else:
- return __import__(module_name)
- except Exception as e:
- error_msg = f"Failed to import {module_name}.{class_name if class_name else ''}: {e}"
- logger.error(error_msg)
- self.initialization_errors.append(error_msg)
- return None
-
- def initialize_core_components(self):
- """Initialize core components safely"""
- logger.info("Initializing core components...")
-
- # Try to import and initialize config
- try:
- from core.config import get_config, setup_logging
- setup_logging()
- self.components['config'] = get_config()
- logger.info("✓ Config loaded")
- except Exception as e:
- logger.error(f"✗ Config failed: {e}")
- self.initialization_errors.append(f"Config: {e}")
-
- # Try to initialize data provider
- try:
- from core.data_provider import DataProvider
- self.components['data_provider'] = DataProvider()
- logger.info("✓ Data provider initialized")
- except Exception as e:
- logger.error(f"✗ Data provider failed: {e}")
- self.initialization_errors.append(f"Data provider: {e}")
-
- # Try to initialize trading executor
- try:
- from core.trading_executor import TradingExecutor
- self.components['trading_executor'] = TradingExecutor()
- logger.info("✓ Trading executor initialized")
- except Exception as e:
- logger.error(f"✗ Trading executor failed: {e}")
- self.initialization_errors.append(f"Trading executor: {e}")
-
- # Try to initialize orchestrator (WITHOUT training to avoid crashes)
- try:
- from core.orchestrator import TradingOrchestrator
- self.components['orchestrator'] = TradingOrchestrator(
- data_provider=self.components.get('data_provider'),
- enhanced_rl_training=False # DISABLED to prevent crashes
- )
- logger.info("✓ Orchestrator initialized (training disabled)")
- except Exception as e:
- logger.error(f"✗ Orchestrator failed: {e}")
- self.initialization_errors.append(f"Orchestrator: {e}")
-
- def create_minimal_dashboard(self):
- """Create minimal dashboard without complex features"""
- try:
- import dash
- from dash import html, dcc
-
- # Create minimal Dash app
- self.dashboard_app = dash.Dash(__name__)
-
- # Create simple layout
- self.dashboard_app.layout = html.Div([
- html.H1("Trading Dashboard - Safe Mode", style={'textAlign': 'center'}),
- html.Hr(),
-
- # Status section
- html.Div([
- html.H3("System Status"),
- html.Div(id="system-status", children=self._get_system_status()),
- ], style={'margin': '20px'}),
-
- # Error section
- html.Div([
- html.H3("Initialization Status"),
- html.Div(id="init-status", children=self._get_init_status()),
- ], style={'margin': '20px'}),
-
- # Simple refresh interval
- dcc.Interval(
- id='interval-component',
- interval=5000, # Update every 5 seconds
- n_intervals=0
- )
- ])
-
- # Add simple callback
- @self.dashboard_app.callback(
- [dash.dependencies.Output('system-status', 'children'),
- dash.dependencies.Output('init-status', 'children')],
- [dash.dependencies.Input('interval-component', 'n_intervals')]
- )
- def update_status(n):
- try:
- return self._get_system_status(), self._get_init_status()
- except Exception as e:
- logger.error(f"Callback error: {e}")
- return f"Callback error: {e}", "Error in callback"
-
- logger.info("✓ Minimal dashboard created")
- return True
-
- except Exception as e:
- logger.error(f"✗ Dashboard creation failed: {e}")
- logger.error(traceback.format_exc())
- return False
-
- def _get_system_status(self):
- """Get system status for display"""
- try:
- status_items = []
-
- # Check components
- for name, component in self.components.items():
- if component is not None:
- status_items.append(html.P(f"✓ {name.replace('_', ' ').title()}: OK",
- style={'color': 'green'}))
- else:
- status_items.append(html.P(f"✗ {name.replace('_', ' ').title()}: Failed",
- style={'color': 'red'}))
-
- # Add timestamp
- status_items.append(html.P(f"Last update: {datetime.now().strftime('%H:%M:%S')}",
- style={'color': 'gray', 'fontSize': '12px'}))
-
- return status_items
-
- except Exception as e:
- return [html.P(f"Status error: {e}", style={'color': 'red'})]
-
- def _get_init_status(self):
- """Get initialization status for display"""
- try:
- if not self.initialization_errors:
- return [html.P("✓ All components initialized successfully", style={'color': 'green'})]
-
- error_items = [html.P("⚠️ Some components failed to initialize:", style={'color': 'orange'})]
-
- for error in self.initialization_errors:
- error_items.append(html.P(f"• {error}", style={'color': 'red', 'fontSize': '12px'}))
-
- return error_items
-
- except Exception as e:
- return [html.P(f"Init status error: {e}", style={'color': 'red'})]
-
- def run(self, port=8051):
- """Run the crash-safe dashboard"""
- try:
- logger.info("=" * 60)
- logger.info("CRASH-SAFE DASHBOARD")
- logger.info("=" * 60)
- logger.info("Mode: Safe mode with minimal features")
- logger.info("Training: Completely disabled")
- logger.info("Focus: System stability and basic monitoring")
- logger.info("=" * 60)
-
- # Initialize components
- self.initialize_core_components()
-
- # Create dashboard
- if not self.create_minimal_dashboard():
- logger.error("Failed to create dashboard")
- return False
-
- # Report initialization status
- if self.initialization_errors:
- logger.warning(f"Dashboard starting with {len(self.initialization_errors)} component failures")
- for error in self.initialization_errors:
- logger.warning(f" - {error}")
- else:
- logger.info("All components initialized successfully")
-
- # Start dashboard
- logger.info(f"Starting dashboard on http://127.0.0.1:{port}")
- logger.info("Press Ctrl+C to stop")
-
- self.dashboard_app.run_server(
- host='127.0.0.1',
- port=port,
- debug=False,
- use_reloader=False,
- threaded=True
- )
-
- return True
-
- except KeyboardInterrupt:
- logger.info("Dashboard stopped by user")
- return True
- except Exception as e:
- logger.error(f"Dashboard failed: {e}")
- logger.error(traceback.format_exc())
- return False
-
-def main():
- """Main function with comprehensive error handling"""
- try:
- dashboard = CrashSafeDashboard()
- success = dashboard.run()
-
- if success:
- logger.info("Dashboard completed successfully")
- else:
- logger.error("Dashboard failed")
-
- except Exception as e:
- logger.error(f"Fatal error: {e}")
- logger.error(traceback.format_exc())
- sys.exit(1)
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/run_enhanced_rl_training.py b/run_enhanced_rl_training.py
deleted file mode 100644
index 743d924..0000000
--- a/run_enhanced_rl_training.py
+++ /dev/null
@@ -1,525 +0,0 @@
-#!/usr/bin/env python3
-"""
-Enhanced RL Training Launcher with Real Data Integration
-
-This script launches the comprehensive RL training system that uses:
-- Real-time tick data (300s window for momentum detection)
-- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d)
-- BTC reference data for correlation
-- CNN hidden features and predictions
-- Williams Market Structure pivot points
-- Market microstructure analysis
-
-The RL model will receive ~13,400 features instead of the previous ~100 basic features.
-Training metrics are automatically logged to TensorBoard for visualization.
-"""
-
-import asyncio
-import logging
-import time
-import signal
-import sys
-from datetime import datetime, timedelta
-from pathlib import Path
-from typing import Dict, List, Optional
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler('enhanced_rl_training.log'),
- logging.StreamHandler(sys.stdout)
- ]
-)
-
-logger = logging.getLogger(__name__)
-
-# Import our enhanced components
-from core.config import get_config
-from core.data_provider import DataProvider
-from core.enhanced_orchestrator import EnhancedTradingOrchestrator
-from training.enhanced_rl_trainer import EnhancedRLTrainer
-from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
-from training.williams_market_structure import WilliamsMarketStructure
-from training.cnn_rl_bridge import CNNRLBridge
-from utils.tensorboard_logger import TensorBoardLogger
-
-class EnhancedRLTrainingSystem:
- """Comprehensive RL training system with real data integration"""
-
- def __init__(self):
- """Initialize the enhanced RL training system"""
- self.config = get_config()
- self.running = False
- self.data_provider = None
- self.orchestrator = None
- self.rl_trainer = None
-
- # Performance tracking
- self.training_stats = {
- 'training_sessions': 0,
- 'total_experiences': 0,
- 'avg_state_size': 0,
- 'data_quality_score': 0.0,
- 'last_training_time': None
- }
-
- # Initialize TensorBoard logger
- experiment_name = f"enhanced_rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
- self.tb_logger = TensorBoardLogger(
- log_dir="runs",
- experiment_name=experiment_name,
- enabled=True
- )
-
- logger.info("Enhanced RL Training System initialized")
- logger.info(f"TensorBoard logging enabled for experiment: {experiment_name}")
- logger.info("Features:")
- logger.info("- Real-time tick data processing (300s window)")
- logger.info("- Multi-timeframe OHLCV analysis (1s, 1m, 1h, 1d)")
- logger.info("- BTC correlation analysis")
- logger.info("- CNN feature integration")
- logger.info("- Williams Market Structure pivot points")
- logger.info("- ~13,400 feature state vector (vs previous ~100)")
-
-# async def initialize(self):
-# """Initialize all components"""
-# try:
-# logger.info("Initializing enhanced RL training components...")
-
-# # Initialize data provider with real-time streaming
-# logger.info("Setting up data provider with real-time streaming...")
-# self.data_provider = DataProvider(
-# symbols=self.config.symbols,
-# timeframes=self.config.timeframes
-# )
-
-# # Start real-time data streaming
-# await self.data_provider.start_real_time_streaming()
-# logger.info("Real-time data streaming started")
-
-# # Wait for initial data collection
-# logger.info("Collecting initial market data...")
-# await asyncio.sleep(30) # Allow 30 seconds for data collection
-
-# # Initialize enhanced orchestrator
-# logger.info("Initializing enhanced orchestrator...")
-# self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
-
-# # Initialize enhanced RL trainer with comprehensive state building
-# logger.info("Initializing enhanced RL trainer...")
-# self.rl_trainer = EnhancedRLTrainer(
-# config=self.config,
-# orchestrator=self.orchestrator
-# )
-
-# # Verify data availability
-# data_status = await self._verify_data_availability()
-# if not data_status['has_sufficient_data']:
-# logger.warning("Insufficient data detected. Continuing with limited training.")
-# logger.warning(f"Data status: {data_status}")
-# else:
-# logger.info("Sufficient data available for comprehensive RL training")
-# logger.info(f"Tick data: {data_status['tick_count']} ticks")
-# logger.info(f"OHLCV data: {data_status['ohlcv_bars']} bars")
-
-# self.running = True
-# logger.info("Enhanced RL training system initialized successfully")
-
-# except Exception as e:
-# logger.error(f"Error during initialization: {e}")
-# raise
-
-# async def _verify_data_availability(self) -> Dict[str, any]:
-# """Verify that we have sufficient data for training"""
-# try:
-# data_status = {
-# 'has_sufficient_data': False,
-# 'tick_count': 0,
-# 'ohlcv_bars': 0,
-# 'symbols_with_data': [],
-# 'missing_data': []
-# }
-
-# for symbol in self.config.symbols:
-# # Check tick data
-# recent_ticks = self.data_provider.get_recent_ticks(symbol, count=100)
-# tick_count = len(recent_ticks)
-
-# # Check OHLCV data
-# ohlcv_bars = 0
-# for timeframe in ['1s', '1m', '1h', '1d']:
-# try:
-# df = self.data_provider.get_historical_data(
-# symbol=symbol,
-# timeframe=timeframe,
-# limit=50,
-# refresh=True
-# )
-# if df is not None and not df.empty:
-# ohlcv_bars += len(df)
-# except Exception as e:
-# logger.warning(f"Error checking {timeframe} data for {symbol}: {e}")
-
-# data_status['tick_count'] += tick_count
-# data_status['ohlcv_bars'] += ohlcv_bars
-
-# if tick_count >= 50 and ohlcv_bars >= 100:
-# data_status['symbols_with_data'].append(symbol)
-# else:
-# data_status['missing_data'].append(f"{symbol}: {tick_count} ticks, {ohlcv_bars} bars")
-
-# # Consider data sufficient if we have at least one symbol with good data
-# data_status['has_sufficient_data'] = len(data_status['symbols_with_data']) > 0
-
-# return data_status
-
-# except Exception as e:
-# logger.error(f"Error verifying data availability: {e}")
-# return {'has_sufficient_data': False, 'error': str(e)}
-
-# async def run_training_loop(self):
-# """Run the main training loop with real data"""
-# logger.info("Starting enhanced RL training loop...")
-
-# training_cycle = 0
-# last_state_size_log = time.time()
-
-# try:
-# while self.running:
-# training_cycle += 1
-# cycle_start_time = time.time()
-
-# logger.info(f"Training cycle {training_cycle} started")
-
-# # Get comprehensive market states with real data
-# market_states = await self._get_comprehensive_market_states()
-
-# if not market_states:
-# logger.warning("No market states available. Waiting for data...")
-# await asyncio.sleep(60)
-# continue
-
-# # Train RL agents with comprehensive states
-# training_results = await self._train_rl_agents(market_states)
-
-# # Update performance tracking
-# self._update_training_stats(training_results, market_states)
-
-# # Log training progress
-# cycle_duration = time.time() - cycle_start_time
-# logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
-
-# # Log state size periodically
-# if time.time() - last_state_size_log > 300: # Every 5 minutes
-# self._log_state_size_info(market_states)
-# last_state_size_log = time.time()
-
-# # Save models periodically
-# if training_cycle % 10 == 0:
-# await self._save_training_progress()
-
-# # Wait before next training cycle
-# await asyncio.sleep(300) # Train every 5 minutes
-
-# except Exception as e:
-# logger.error(f"Error in training loop: {e}")
-# raise
-
-# async def _get_comprehensive_market_states(self) -> Dict[str, any]:
-# """Get comprehensive market states with all required data"""
-# try:
-# # Get market states from orchestrator
-# universal_stream = self.orchestrator.universal_adapter.get_universal_stream()
-# market_states = await self.orchestrator._get_all_market_states_universal(universal_stream)
-
-# # Verify data quality
-# quality_score = self._calculate_data_quality(market_states)
-# self.training_stats['data_quality_score'] = quality_score
-
-# if quality_score < 0.5:
-# logger.warning(f"Low data quality detected: {quality_score:.2f}")
-
-# return market_states
-
-# except Exception as e:
-# logger.error(f"Error getting comprehensive market states: {e}")
-# return {}
-
-# def _calculate_data_quality(self, market_states: Dict[str, any]) -> float:
-# """Calculate data quality score based on available data"""
-# try:
-# if not market_states:
-# return 0.0
-
-# total_score = 0.0
-# total_symbols = len(market_states)
-
-# for symbol, state in market_states.items():
-# symbol_score = 0.0
-
-# # Score based on tick data availability
-# if hasattr(state, 'raw_ticks') and state.raw_ticks:
-# tick_score = min(len(state.raw_ticks) / 100, 1.0) # Max score for 100+ ticks
-# symbol_score += tick_score * 0.3
-
-# # Score based on OHLCV data availability
-# if hasattr(state, 'ohlcv_data') and state.ohlcv_data:
-# ohlcv_score = len(state.ohlcv_data) / 4.0 # Max score for all 4 timeframes
-# symbol_score += min(ohlcv_score, 1.0) * 0.4
-
-# # Score based on CNN features
-# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
-# symbol_score += 0.15
-
-# # Score based on pivot points
-# if hasattr(state, 'pivot_points') and state.pivot_points:
-# symbol_score += 0.15
-
-# total_score += symbol_score
-
-# return total_score / total_symbols if total_symbols > 0 else 0.0
-
-# except Exception as e:
-# logger.warning(f"Error calculating data quality: {e}")
-# return 0.5 # Default to medium quality
-
- async def _train_rl_agents(self, market_states: Dict[str, any]) -> Dict[str, any]:
- """Train RL agents with comprehensive market states"""
- try:
- training_results = {
- 'symbols_trained': [],
- 'total_experiences': 0,
- 'avg_state_size': 0,
- 'training_errors': [],
- 'losses': {},
- 'rewards': {}
- }
-
- for symbol, market_state in market_states.items():
- try:
- # Convert market state to comprehensive RL state
- rl_state = self.rl_trainer._market_state_to_rl_state(market_state)
-
- if rl_state is not None and len(rl_state) > 0:
- # Record state size
- state_size = len(rl_state)
- training_results['avg_state_size'] += state_size
-
- # Log state size to TensorBoard
- self.tb_logger.log_scalar(
- f'State/{symbol}/Size',
- state_size,
- self.training_stats['training_sessions']
- )
-
- # Simulate trading action for experience generation
- # In real implementation, this would be actual trading decisions
- action = self._simulate_trading_action(symbol, rl_state)
-
- # Generate reward based on market outcome
- reward = self._calculate_training_reward(symbol, market_state, action)
-
- # Store reward for TensorBoard logging
- training_results['rewards'][symbol] = reward
-
- # Log action and reward to TensorBoard
- self.tb_logger.log_scalars(f'Actions/{symbol}', {
- 'action': action,
- 'reward': reward
- }, self.training_stats['training_sessions'])
-
- # Add experience to RL agent
- agent = self.rl_trainer.agents.get(symbol)
- if agent:
- # Create next state (would be actual next market state in real scenario)
- next_state = rl_state # Simplified for now
-
- agent.remember(
- state=rl_state,
- action=action,
- reward=reward,
- next_state=next_state,
- done=False
- )
-
- # Train agent if enough experiences
- if len(agent.replay_buffer) >= agent.batch_size:
- loss = agent.replay()
- if loss is not None:
- logger.debug(f"Agent {symbol} training loss: {loss:.4f}")
-
- # Store loss for TensorBoard logging
- training_results['losses'][symbol] = loss
-
- # Log loss to TensorBoard
- self.tb_logger.log_scalar(
- f'Training/{symbol}/Loss',
- loss,
- self.training_stats['training_sessions']
- )
-
- training_results['symbols_trained'].append(symbol)
- training_results['total_experiences'] += 1
-
- except Exception as e:
- error_msg = f"Error training {symbol}: {e}"
- logger.warning(error_msg)
- training_results['training_errors'].append(error_msg)
-
- # Calculate average state size
- if len(training_results['symbols_trained']) > 0:
- training_results['avg_state_size'] /= len(training_results['symbols_trained'])
-
- # Log overall training metrics to TensorBoard
- self.tb_logger.log_scalars('Training/Overall', {
- 'symbols_trained': len(training_results['symbols_trained']),
- 'experiences': training_results['total_experiences'],
- 'avg_state_size': training_results['avg_state_size'],
- 'errors': len(training_results['training_errors'])
- }, self.training_stats['training_sessions'])
-
- return training_results
-
- except Exception as e:
- logger.error(f"Error training RL agents: {e}")
- return {'error': str(e)}
-
-# def _simulate_trading_action(self, symbol: str, rl_state) -> int:
-# """Simulate trading action for training (would be real decision in production)"""
-# # Simple simulation based on state features
-# if len(rl_state) > 100:
-# # Use momentum features to decide action
-# momentum_features = rl_state[:100] # First 100 features assumed to be momentum
-# avg_momentum = sum(momentum_features) / len(momentum_features)
-
-# if avg_momentum > 0.6:
-# return 1 # BUY
-# elif avg_momentum < 0.4:
-# return 2 # SELL
-# else:
-# return 0 # HOLD
-# else:
-# return 0 # HOLD as default
-
-# def _calculate_training_reward(self, symbol: str, market_state, action: int) -> float:
-# """Calculate training reward based on market state and action"""
-# try:
-# # Simple reward calculation based on market conditions
-# base_reward = 0.0
-
-# # Reward based on volatility alignment
-# if hasattr(market_state, 'volatility'):
-# if action == 0 and market_state.volatility > 0.02: # HOLD in high volatility
-# base_reward += 0.1
-# elif action != 0 and market_state.volatility < 0.01: # Trade in low volatility
-# base_reward += 0.1
-
-# # Reward based on trend alignment
-# if hasattr(market_state, 'trend_strength'):
-# if action == 1 and market_state.trend_strength > 0.6: # BUY in uptrend
-# base_reward += 0.2
-# elif action == 2 and market_state.trend_strength < 0.4: # SELL in downtrend
-# base_reward += 0.2
-
-# return base_reward
-
-# except Exception as e:
-# logger.warning(f"Error calculating reward for {symbol}: {e}")
-# return 0.0
-
-# def _update_training_stats(self, training_results: Dict[str, any], market_states: Dict[str, any]):
-# """Update training statistics"""
-# self.training_stats['training_sessions'] += 1
-# self.training_stats['total_experiences'] += training_results.get('total_experiences', 0)
-# self.training_stats['avg_state_size'] = training_results.get('avg_state_size', 0)
-# self.training_stats['last_training_time'] = datetime.now()
-
-# # Log statistics periodically
-# if self.training_stats['training_sessions'] % 10 == 0:
-# logger.info("Training Statistics:")
-# logger.info(f" Sessions: {self.training_stats['training_sessions']}")
-# logger.info(f" Total Experiences: {self.training_stats['total_experiences']}")
-# logger.info(f" Avg State Size: {self.training_stats['avg_state_size']:.0f}")
-# logger.info(f" Data Quality: {self.training_stats['data_quality_score']:.2f}")
-
-# def _log_state_size_info(self, market_states: Dict[str, any]):
-# """Log information about state sizes for debugging"""
-# for symbol, state in market_states.items():
-# info = []
-
-# if hasattr(state, 'raw_ticks'):
-# info.append(f"ticks: {len(state.raw_ticks)}")
-
-# if hasattr(state, 'ohlcv_data'):
-# total_bars = sum(len(bars) for bars in state.ohlcv_data.values())
-# info.append(f"OHLCV bars: {total_bars}")
-
-# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
-# info.append("CNN features: available")
-
-# if hasattr(state, 'pivot_points') and state.pivot_points:
-# info.append("pivot points: available")
-
-# logger.info(f"{symbol} state data: {', '.join(info)}")
-
-# async def _save_training_progress(self):
-# """Save training progress and models"""
-# try:
-# if self.rl_trainer:
-# self.rl_trainer._save_all_models()
-# logger.info("Training progress saved")
-# except Exception as e:
-# logger.error(f"Error saving training progress: {e}")
-
-# async def shutdown(self):
-# """Graceful shutdown"""
-# logger.info("Shutting down enhanced RL training system...")
-# self.running = False
-
-# # Save final state
-# await self._save_training_progress()
-
-# # Stop data provider
-# if self.data_provider:
-# await self.data_provider.stop_real_time_streaming()
-
-# logger.info("Enhanced RL training system shutdown complete")
-
-# async def main():
-# """Main function to run enhanced RL training"""
-# system = None
-
-# def signal_handler(signum, frame):
-# logger.info("Received shutdown signal")
-# if system:
-# asyncio.create_task(system.shutdown())
-
-# # Set up signal handlers
-# signal.signal(signal.SIGINT, signal_handler)
-# signal.signal(signal.SIGTERM, signal_handler)
-
-# try:
-# # Create and initialize the training system
-# system = EnhancedRLTrainingSystem()
-# await system.initialize()
-
-# logger.info("Enhanced RL Training System is now running...")
-# logger.info("The RL model now receives ~13,400 features instead of ~100!")
-# logger.info("Press Ctrl+C to stop")
-
-# # Run the training loop
-# await system.run_training_loop()
-
-# except KeyboardInterrupt:
-# logger.info("Training interrupted by user")
-# except Exception as e:
-# logger.error(f"Error in main training loop: {e}")
-# raise
-# finally:
-# if system:
-# await system.shutdown()
-
-# if __name__ == "__main__":
-# asyncio.run(main())
\ No newline at end of file
diff --git a/run_enhanced_training_dashboard.py b/run_enhanced_training_dashboard.py
deleted file mode 100644
index 3422636..0000000
--- a/run_enhanced_training_dashboard.py
+++ /dev/null
@@ -1,95 +0,0 @@
-#!/usr/bin/env python3
-"""
-Run Dashboard with Enhanced Training System Enabled
-
-This script starts the trading dashboard with the enhanced real-time
-training system automatically enabled and running.
-"""
-
-import sys
-import os
-import asyncio
-import logging
-from datetime import datetime
-
-# Add project root to path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-from core.trading_executor import TradingExecutor
-from web.clean_dashboard import create_clean_dashboard
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-async def main():
- """Start dashboard with enhanced training enabled"""
- try:
- logger.info("=" * 70)
- logger.info("STARTING DASHBOARD WITH ENHANCED TRAINING SYSTEM")
- logger.info("=" * 70)
-
- # 1. Initialize components with enhanced training
- logger.info("1. Initializing components...")
- data_provider = DataProvider()
- trading_executor = TradingExecutor()
-
- # 2. Create orchestrator with enhanced training ENABLED
- logger.info("2. Creating orchestrator with enhanced training...")
- orchestrator = TradingOrchestrator(
- data_provider=data_provider,
- enhanced_rl_training=True # 🔥 THIS ENABLES ENHANCED TRAINING
- )
-
- # 3. Verify enhanced training is available
- logger.info("3. Verifying enhanced training system...")
- if orchestrator.enhanced_training_system:
- logger.info("✅ Enhanced training system available")
- logger.info(f" - Training enabled: {orchestrator.training_enabled}")
-
- # 4. Start enhanced training
- logger.info("4. Starting enhanced training system...")
- start_result = orchestrator.start_enhanced_training()
- if start_result:
- logger.info("✅ Enhanced training started successfully")
- else:
- logger.warning("⚠️ Enhanced training start failed")
- else:
- logger.warning("⚠️ Enhanced training system not available")
-
- # 5. Create dashboard
- logger.info("5. Creating dashboard...")
- dashboard = create_clean_dashboard(
- data_provider=data_provider,
- orchestrator=orchestrator,
- trading_executor=trading_executor
- )
-
- # 6. Connect training system to dashboard
- logger.info("6. Connecting training system to dashboard...")
- orchestrator.set_training_dashboard(dashboard)
-
- # 7. Start dashboard
- logger.info("7. Starting dashboard...")
- logger.info("🎉 Dashboard with enhanced training is now running!")
- logger.info(" - Enhanced training: ENABLED")
- logger.info(" - Real-time learning: ACTIVE")
- logger.info(" - Dashboard URL: http://127.0.0.1:8051")
-
- # Keep running
- await asyncio.sleep(3600) # Run for 1 hour
-
- except KeyboardInterrupt:
- logger.info("Dashboard stopped by user")
- except Exception as e:
- logger.error(f"Error starting dashboard: {e}")
- import traceback
- logger.error(traceback.format_exc())
-
-if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
diff --git a/run_integrated_rl_cob_dashboard.py b/run_integrated_rl_cob_dashboard.py
deleted file mode 100644
index 2620117..0000000
--- a/run_integrated_rl_cob_dashboard.py
+++ /dev/null
@@ -1,510 +0,0 @@
-#!/usr/bin/env python3
-"""
-Integrated Real-time RL COB Trading System with Dashboard
-
-This script starts both:
-1. RealtimeRLCOBTrader - 1B parameter RL model with real-time training
-2. COB Dashboard - Real-time visualization with RL predictions
-
-The RL predictions are integrated into the dashboard for live visualization.
-"""
-
-import asyncio
-import logging
-import signal
-import sys
-import json
-import os
-from datetime import datetime
-from typing import Dict, Any
-from aiohttp import web
-
-# Local imports
-from core.realtime_rl_cob_trader import RealtimeRLCOBTrader, PredictionResult
-from core.trading_executor import TradingExecutor
-from core.config import load_config
-from web.cob_realtime_dashboard import COBDashboardServer
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler('logs/integrated_rl_cob_system.log'),
- logging.StreamHandler(sys.stdout)
- ]
-)
-
-logger = logging.getLogger(__name__)
-
-class IntegratedRLCOBSystem:
- """
- Integrated Real-time RL COB Trading System with Dashboard
- """
-
- def __init__(self, config_path: str = "config.yaml"):
- """Initialize integrated system with configuration"""
- self.config = load_config(config_path)
- self.trader = None
- self.dashboard = None
- self.trading_executor = None
- self.running = False
-
- # RL prediction storage for dashboard
- self.rl_predictions: Dict[str, list] = {}
- self.prediction_history: Dict[str, list] = {}
-
- # Setup signal handlers for graceful shutdown
- signal.signal(signal.SIGINT, self._signal_handler)
- signal.signal(signal.SIGTERM, self._signal_handler)
-
- logger.info("IntegratedRLCOBSystem initialized")
-
- def _signal_handler(self, signum, frame):
- """Handle shutdown signals"""
- logger.info(f"Received signal {signum}, initiating graceful shutdown...")
- self.running = False
-
- async def start(self):
- """Start the integrated RL COB trading system with dashboard"""
- try:
- logger.info("=" * 60)
- logger.info("INTEGRATED RL COB SYSTEM STARTING")
- logger.info("Real-time RL Trading + Dashboard")
- logger.info("=" * 60)
-
- # Initialize trading executor
- await self._initialize_trading_executor()
-
- # Initialize RL trader with prediction callback
- await self._initialize_rl_trader()
-
- # Initialize dashboard with RL integration
- await self._initialize_dashboard()
-
- # Start the integrated system
- await self._start_integrated_system()
-
- # Run main loop
- await self._run_main_loop()
-
- except Exception as e:
- logger.error(f"Critical error in integrated system: {e}")
- raise
- finally:
- await self.stop()
-
- async def _initialize_trading_executor(self):
- """Initialize the trading executor"""
- logger.info("Initializing Trading Executor...")
-
- # Get trading configuration
- trading_config = self.config.get('trading', {})
- mexc_config = self.config.get('mexc', {})
-
- # Determine if we should run in simulation mode
- simulation_mode = mexc_config.get('simulation_mode', True)
-
- if simulation_mode:
- logger.info("Running in SIMULATION mode - no real trades will be executed")
- else:
- logger.warning("Running in LIVE TRADING mode - real money at risk!")
-
- # Add safety confirmation for live trading
- confirmation = input("Type 'CONFIRM_LIVE_TRADING' to proceed with live trading: ")
- if confirmation != 'CONFIRM_LIVE_TRADING':
- logger.info("Live trading not confirmed, switching to simulation mode")
- simulation_mode = True
-
- # Initialize trading executor with config path
- self.trading_executor = TradingExecutor("config.yaml")
-
- logger.info(f"Trading Executor initialized in {'SIMULATION' if simulation_mode else 'LIVE'} mode")
-
- async def _initialize_rl_trader(self):
- """Initialize the RL trader with prediction callbacks"""
- logger.info("Initializing Real-time RL COB Trader...")
-
- # Get RL configuration
- rl_config = self.config.get('realtime_rl', {})
-
- # Trading symbols
- symbols = rl_config.get('symbols', ['BTC/USDT', 'ETH/USDT'])
-
- # Initialize prediction storage
- for symbol in symbols:
- self.rl_predictions[symbol] = []
- self.prediction_history[symbol] = []
-
- # RL parameters
- inference_interval_ms = rl_config.get('inference_interval_ms', 200)
- min_confidence_threshold = rl_config.get('min_confidence_threshold', 0.7)
- required_confident_predictions = rl_config.get('required_confident_predictions', 3)
- model_checkpoint_dir = rl_config.get('model_checkpoint_dir', 'models/realtime_rl_cob')
-
- # Initialize RL trader
- self.trader = RealtimeRLCOBTrader(
- symbols=symbols,
- trading_executor=self.trading_executor,
- model_checkpoint_dir=model_checkpoint_dir,
- inference_interval_ms=inference_interval_ms,
- min_confidence_threshold=min_confidence_threshold,
- required_confident_predictions=required_confident_predictions
- )
-
- # Monkey-patch the trader to capture predictions
- original_add_signal = self.trader._add_signal
- def enhanced_add_signal(symbol: str, prediction: PredictionResult):
- # Call original method
- original_add_signal(symbol, prediction)
- # Capture prediction for dashboard
- self._on_rl_prediction(symbol, prediction)
-
- self.trader._add_signal = enhanced_add_signal
-
- logger.info(f"RL Trader initialized for symbols: {symbols}")
- logger.info(f"Inference interval: {inference_interval_ms}ms")
- logger.info(f"Confidence threshold: {min_confidence_threshold}")
- logger.info(f"Required predictions: {required_confident_predictions}")
-
- def _on_rl_prediction(self, symbol: str, prediction: PredictionResult):
- """Handle RL predictions for dashboard integration"""
- try:
- # Convert prediction to dashboard format
- prediction_data = {
- 'timestamp': prediction.timestamp.isoformat(),
- 'direction': prediction.predicted_direction, # 0=DOWN, 1=SIDEWAYS, 2=UP
- 'confidence': prediction.confidence,
- 'predicted_change': prediction.predicted_change,
- 'direction_text': ['DOWN', 'SIDEWAYS', 'UP'][prediction.predicted_direction],
- 'color': ['red', 'gray', 'green'][prediction.predicted_direction]
- }
-
- # Add to current predictions (for live display)
- self.rl_predictions[symbol].append(prediction_data)
- if len(self.rl_predictions[symbol]) > 100: # Keep last 100
- self.rl_predictions[symbol] = self.rl_predictions[symbol][-100:]
-
- # Add to history (for chart overlay)
- self.prediction_history[symbol].append(prediction_data)
- if len(self.prediction_history[symbol]) > 1000: # Keep last 1000
- self.prediction_history[symbol] = self.prediction_history[symbol][-1000:]
-
- logger.debug(f"Captured RL prediction for {symbol}: {prediction.predicted_direction} "
- f"(confidence: {prediction.confidence:.3f})")
-
- except Exception as e:
- logger.error(f"Error capturing RL prediction: {e}")
-
- async def _initialize_dashboard(self):
- """Initialize the COB dashboard with RL integration"""
- logger.info("Initializing COB Dashboard with RL Integration...")
-
- # Get dashboard configuration
- dashboard_config = self.config.get('dashboard', {})
- host = dashboard_config.get('host', 'localhost')
- port = dashboard_config.get('port', 8053)
-
- # Create enhanced dashboard server
- self.dashboard = EnhancedCOBDashboardServer(
- host=host,
- port=port,
- rl_system=self # Pass reference to get predictions
- )
-
- logger.info(f"COB Dashboard initialized at http://{host}:{port}")
-
- async def _start_integrated_system(self):
- """Start the complete integrated system"""
- logger.info("Starting Integrated RL COB System...")
-
- # Start RL trader first (this initializes COB integration)
- await self.trader.start()
- logger.info("RL Trader started")
-
- # Start dashboard (uses same COB integration)
- await self.dashboard.start()
- logger.info("COB Dashboard started")
-
- self.running = True
-
- logger.info("INTEGRATED SYSTEM FULLY OPERATIONAL!")
- logger.info("1B parameter RL model: ACTIVE")
- logger.info("Real-time COB data: STREAMING")
- logger.info("Signal accumulation: ACTIVE")
- logger.info("Live predictions: VISIBLE IN DASHBOARD")
- logger.info("Continuous training: ACTIVE")
- logger.info(f"Dashboard URL: http://{self.dashboard.host}:{self.dashboard.port}")
-
- async def _run_main_loop(self):
- """Main monitoring and statistics loop"""
- logger.info("Starting integrated system monitoring...")
-
- last_stats_time = datetime.now()
- stats_interval = 60 # Print stats every 60 seconds
-
- while self.running:
- try:
- # Sleep for a bit
- await asyncio.sleep(10)
-
- # Print periodic statistics
- current_time = datetime.now()
- if (current_time - last_stats_time).total_seconds() >= stats_interval:
- await self._print_integrated_stats()
- last_stats_time = current_time
-
- except asyncio.CancelledError:
- break
- except Exception as e:
- logger.error(f"Error in main loop: {e}")
- await asyncio.sleep(5)
-
- logger.info("Integrated system monitoring stopped")
-
- async def _print_integrated_stats(self):
- """Print comprehensive integrated system statistics"""
- try:
- logger.info("=" * 80)
- logger.info("INTEGRATED RL COB SYSTEM STATISTICS")
- logger.info("=" * 80)
-
- # RL Trader Statistics
- if self.trader:
- rl_stats = self.trader.get_performance_stats()
- logger.info("\n🤖 RL TRADER PERFORMANCE:")
-
- for symbol in self.trader.symbols:
- training_stats = rl_stats.get('training_stats', {}).get(symbol, {})
- inference_stats = rl_stats.get('inference_stats', {}).get(symbol, {})
- signal_stats = rl_stats.get('signal_stats', {}).get(symbol, {})
-
- logger.info(f"\n 📈 {symbol}:")
- logger.info(f" Predictions: {training_stats.get('total_predictions', 0)}")
- logger.info(f" Success Rate: {signal_stats.get('success_rate', 0):.1%}")
- logger.info(f" Avg Inference: {inference_stats.get('average_inference_time_ms', 0):.1f}ms")
- logger.info(f" Current Signals: {signal_stats.get('current_signals', 0)}")
-
- # RL prediction stats for dashboard
- recent_predictions = len(self.rl_predictions.get(symbol, []))
- total_predictions = len(self.prediction_history.get(symbol, []))
- logger.info(f" Dashboard Predictions: {recent_predictions} recent, {total_predictions} total")
-
- # Dashboard Statistics
- if self.dashboard:
- logger.info(f"\nDASHBOARD STATISTICS:")
- logger.info(f" Active Connections: {len(self.dashboard.websocket_connections)}")
- logger.info(f" Server Status: {'RUNNING' if self.dashboard.site else 'STOPPED'}")
- logger.info(f" URL: http://{self.dashboard.host}:{self.dashboard.port}")
-
- # Trading Executor Statistics
- if self.trading_executor:
- positions = self.trading_executor.get_positions()
- trade_history = self.trading_executor.get_trade_history()
-
- logger.info(f"\n💰 TRADING STATISTICS:")
- logger.info(f" Active Positions: {len(positions)}")
- logger.info(f" Total Trades: {len(trade_history)}")
-
- if trade_history:
- total_pnl = sum(trade.pnl for trade in trade_history)
- profitable_trades = sum(1 for trade in trade_history if trade.pnl > 0)
- win_rate = (profitable_trades / len(trade_history)) * 100
-
- logger.info(f" Total P&L: ${total_pnl:.2f}")
- logger.info(f" Win Rate: {win_rate:.1f}%")
-
- logger.info("=" * 80)
-
- except Exception as e:
- logger.error(f"Error printing integrated stats: {e}")
-
- async def stop(self):
- """Stop the integrated system gracefully"""
- if not self.running:
- return
-
- logger.info("Stopping Integrated RL COB System...")
-
- self.running = False
-
- # Stop dashboard
- if self.dashboard:
- await self.dashboard.stop()
- logger.info("Dashboard stopped")
-
- # Stop RL trader
- if self.trader:
- await self.trader.stop()
- logger.info("RL Trader stopped")
-
- logger.info("🏁 Integrated system stopped successfully")
-
- def get_rl_predictions(self, symbol: str) -> Dict[str, Any]:
- """Get RL predictions for dashboard display"""
- return {
- 'recent_predictions': self.rl_predictions.get(symbol, []),
- 'prediction_history': self.prediction_history.get(symbol, []),
- 'total_predictions': len(self.prediction_history.get(symbol, [])),
- 'recent_count': len(self.rl_predictions.get(symbol, []))
- }
-
-class EnhancedCOBDashboardServer(COBDashboardServer):
- """Enhanced COB Dashboard with RL prediction integration"""
-
- def __init__(self, host: str = 'localhost', port: int = 8053, rl_system: IntegratedRLCOBSystem = None):
- super().__init__(host, port)
- self.rl_system = rl_system
-
- # Add RL prediction routes
- self._setup_rl_routes()
-
- logger.info("Enhanced COB Dashboard with RL predictions initialized")
-
- async def serve_dashboard(self, request):
- """Serve the enhanced dashboard HTML with RL predictions"""
- try:
- # Read the enhanced dashboard HTML
- dashboard_path = os.path.join(os.path.dirname(__file__), 'enhanced_cob_dashboard.html')
-
- if os.path.exists(dashboard_path):
- with open(dashboard_path, 'r', encoding='utf-8') as f:
- html_content = f.read()
- return web.Response(text=html_content, content_type='text/html')
- else:
- # Fallback to basic dashboard
- logger.warning("Enhanced dashboard HTML not found, using basic dashboard")
- return await super().serve_dashboard(request)
-
- except Exception as e:
- logger.error(f"Error serving enhanced dashboard: {e}")
- return web.Response(text="Dashboard error", status=500)
-
- def _setup_rl_routes(self):
- """Setup additional routes for RL predictions"""
- self.app.router.add_get('/api/rl-predictions/{symbol}', self.get_rl_predictions)
- self.app.router.add_get('/api/rl-status', self.get_rl_status)
-
- async def get_rl_predictions(self, request):
- """Get RL predictions for a symbol"""
- try:
- symbol = request.match_info['symbol']
- symbol = symbol.replace('%2F', '/')
-
- if symbol not in self.symbols:
- return web.json_response({
- 'error': f'Symbol {symbol} not supported'
- }, status=400)
-
- if not self.rl_system:
- return web.json_response({
- 'error': 'RL system not available'
- }, status=503)
-
- predictions = self.rl_system.get_rl_predictions(symbol)
-
- return web.json_response({
- 'symbol': symbol,
- 'timestamp': datetime.now().isoformat(),
- 'predictions': predictions
- })
-
- except Exception as e:
- logger.error(f"Error getting RL predictions: {e}")
- return web.json_response({
- 'error': str(e)
- }, status=500)
-
- async def get_rl_status(self, request):
- """Get RL system status"""
- try:
- if not self.rl_system or not self.rl_system.trader:
- return web.json_response({
- 'status': 'inactive',
- 'error': 'RL system not available'
- })
-
- rl_stats = self.rl_system.trader.get_performance_stats()
-
- status = {
- 'status': 'active',
- 'symbols': self.rl_system.trader.symbols,
- 'model_info': rl_stats.get('model_info', {}),
- 'inference_interval_ms': self.rl_system.trader.inference_interval_ms,
- 'confidence_threshold': self.rl_system.trader.min_confidence_threshold,
- 'required_predictions': self.rl_system.trader.required_confident_predictions,
- 'device': str(self.rl_system.trader.device),
- 'running': self.rl_system.trader.running
- }
-
- return web.json_response(status)
-
- except Exception as e:
- logger.error(f"Error getting RL status: {e}")
- return web.json_response({
- 'status': 'error',
- 'error': str(e)
- }, status=500)
-
- async def _broadcast_cob_update(self, symbol: str, data: Dict):
- """Enhanced COB update broadcast with RL predictions"""
- try:
- # Get RL predictions if available
- rl_data = {}
- if self.rl_system:
- rl_predictions = self.rl_system.get_rl_predictions(symbol)
- rl_data = {
- 'rl_predictions': rl_predictions.get('recent_predictions', [])[-10:], # Last 10
- 'prediction_count': rl_predictions.get('total_predictions', 0)
- }
-
- # Enhanced data with RL predictions
- enhanced_data = {
- **data,
- 'rl_data': rl_data
- }
-
- # Broadcast to all WebSocket connections
- message = json.dumps({
- 'type': 'cob_update',
- 'symbol': symbol,
- 'timestamp': datetime.now().isoformat(),
- 'data': enhanced_data
- }, default=str)
-
- # Send to all connected clients
- disconnected = []
- for ws in self.websocket_connections:
- try:
- await ws.send_str(message)
- except Exception as e:
- logger.debug(f"WebSocket send failed: {e}")
- disconnected.append(ws)
-
- # Remove disconnected clients
- for ws in disconnected:
- self.websocket_connections.discard(ws)
-
- except Exception as e:
- logger.error(f"Error broadcasting enhanced COB update: {e}")
-
-async def main():
- """Main entry point for integrated RL COB system"""
- # Create logs directory
- os.makedirs('logs', exist_ok=True)
-
- # Initialize and start integrated system
- system = IntegratedRLCOBSystem()
-
- try:
- await system.start()
- except KeyboardInterrupt:
- logger.info("Received keyboard interrupt")
- except Exception as e:
- logger.error(f"System error: {e}")
- raise
- finally:
- await system.stop()
-
-if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
diff --git a/run_mexc_browser.py b/run_mexc_browser.py
deleted file mode 100644
index 76eb3dc..0000000
--- a/run_mexc_browser.py
+++ /dev/null
@@ -1,52 +0,0 @@
-#!/usr/bin/env python3
-"""
-One-Click MEXC Browser Launcher
-
-Simply run this script to start capturing MEXC futures trading requests.
-"""
-
-import sys
-import os
-
-# Add project root to path
-project_root = os.path.dirname(os.path.abspath(__file__))
-sys.path.insert(0, project_root)
-
-def main():
- """Launch MEXC browser automation"""
- print("🚀 MEXC Futures Request Interceptor")
- print("=" * 50)
- print("This will automatically:")
- print("✅ Install ChromeDriver")
- print("✅ Open MEXC futures page")
- print("✅ Capture all API requests")
- print("✅ Extract session cookies")
- print("✅ Save data to JSON files")
- print("\nRequirements will be installed automatically if missing.")
-
- try:
- # First try to run the auto browser directly
- from core.mexc_webclient.auto_browser import main as run_auto_browser
- run_auto_browser()
-
- except ImportError as e:
- print(f"\n⚠️ Import error: {e}")
- print("Installing requirements first...")
-
- # Try to install requirements and run setup
- try:
- from setup_mexc_browser import main as setup_main
- setup_main()
- except ImportError:
- print("❌ Could not find setup script")
- print("Please run: pip install selenium webdriver-manager")
-
- except Exception as e:
- print(f"❌ Error: {e}")
- print("\nTroubleshooting:")
- print("1. Make sure you have Chrome browser installed")
- print("2. Check your internet connection")
- print("3. Try running: pip install selenium webdriver-manager")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/run_optimized_cob_system.py b/run_optimized_cob_system.py
deleted file mode 100644
index 405ef11..0000000
--- a/run_optimized_cob_system.py
+++ /dev/null
@@ -1,451 +0,0 @@
-#!/usr/bin/env python3
-"""
-Optimized COB System - Eliminates Redundant Implementations
-
-This optimized script runs both the COB dashboard and 1B RL trading system
-in a single process with shared data sources to eliminate redundancies:
-
-BEFORE (Redundant):
-- Dashboard: Own COBIntegration instance
-- RL Trader: Own COBIntegration instance
-- Training: Own COBIntegration instance
-= 3x WebSocket connections, 3x order book processing
-
-AFTER (Optimized):
-- Shared COBIntegration instance
-- Single WebSocket connection per exchange
-- Shared order book processing and caching
-= 1x connections, 1x processing, shared memory
-
-Resource savings: ~60% memory, ~70% network bandwidth
-"""
-
-import asyncio
-import logging
-import signal
-import sys
-import json
-import os
-from datetime import datetime
-from typing import Dict, Any, List, Optional
-from aiohttp import web
-import threading
-
-# Local imports
-from core.cob_integration import COBIntegration
-from core.data_provider import DataProvider
-from core.trading_executor import TradingExecutor
-from core.config import load_config
-from web.cob_realtime_dashboard import COBDashboardServer
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler('logs/optimized_cob_system.log'),
- logging.StreamHandler(sys.stdout)
- ]
-)
-
-logger = logging.getLogger(__name__)
-
-class OptimizedCOBSystem:
- """
- Optimized COB System - Single COB instance shared across all components
- """
-
- def __init__(self, config_path: str = "config.yaml"):
- """Initialize optimized system with shared resources"""
- self.config = load_config(config_path)
- self.running = False
-
- # Shared components (eliminate redundancy)
- self.data_provider = DataProvider()
- self.shared_cob_integration: Optional[COBIntegration] = None
- self.trading_executor: Optional[TradingExecutor] = None
-
- # Dashboard using shared COB
- self.dashboard_server: Optional[COBDashboardServer] = None
-
- # Performance tracking
- self.performance_stats = {
- 'start_time': None,
- 'cob_updates_processed': 0,
- 'dashboard_connections': 0,
- 'memory_saved_mb': 0
- }
-
- # Setup signal handlers
- signal.signal(signal.SIGINT, self._signal_handler)
- signal.signal(signal.SIGTERM, self._signal_handler)
-
- logger.info("OptimizedCOBSystem initialized - Eliminating redundant implementations")
-
- def _signal_handler(self, signum, frame):
- """Handle shutdown signals"""
- logger.info(f"Received signal {signum}, initiating graceful shutdown...")
- self.running = False
-
- async def start(self):
- """Start the optimized COB system"""
- try:
- logger.info("=" * 70)
- logger.info("🚀 OPTIMIZED COB SYSTEM STARTING")
- logger.info("=" * 70)
- logger.info("Eliminating redundant COB implementations...")
- logger.info("Single shared COB integration for all components")
- logger.info("=" * 70)
-
- # Initialize shared components
- await self._initialize_shared_components()
-
- # Initialize dashboard with shared COB
- await self._initialize_optimized_dashboard()
-
- # Start the integrated system
- await self._start_optimized_system()
-
- # Run main monitoring loop
- await self._run_optimized_loop()
-
- except Exception as e:
- logger.error(f"Critical error in optimized system: {e}")
- import traceback
- logger.error(traceback.format_exc())
- raise
- finally:
- await self.stop()
-
- async def _initialize_shared_components(self):
- """Initialize shared components (eliminates redundancy)"""
- logger.info("1. Initializing shared COB integration...")
-
- # Single COB integration instance for entire system
- self.shared_cob_integration = COBIntegration(
- data_provider=self.data_provider,
- symbols=['BTC/USDT', 'ETH/USDT']
- )
-
- # Start the shared COB integration
- await self.shared_cob_integration.start()
-
- logger.info("2. Initializing trading executor...")
-
- # Trading executor configuration
- trading_config = self.config.get('trading', {})
- mexc_config = self.config.get('mexc', {})
- simulation_mode = mexc_config.get('simulation_mode', True)
-
- self.trading_executor = TradingExecutor()
-
- logger.info("✅ Shared components initialized")
- logger.info(f" Single COB integration: {len(self.shared_cob_integration.symbols)} symbols")
- logger.info(f" Trading mode: {'SIMULATION' if simulation_mode else 'LIVE'}")
-
- async def _initialize_optimized_dashboard(self):
- """Initialize dashboard that uses shared COB (no redundant instance)"""
- logger.info("3. Initializing optimized dashboard...")
-
- # Create dashboard and replace its COB with our shared one
- self.dashboard_server = COBDashboardServer(host='localhost', port=8053)
-
- # Replace the dashboard's COB integration with our shared one
- self.dashboard_server.cob_integration = self.shared_cob_integration
-
- logger.info("✅ Optimized dashboard initialized with shared COB")
-
- async def _start_optimized_system(self):
- """Start the optimized system with shared resources"""
- logger.info("4. Starting optimized system...")
-
- self.running = True
- self.performance_stats['start_time'] = datetime.now()
-
- # Start dashboard server with shared COB
- await self.dashboard_server.start()
-
- # Estimate memory savings
- # Start RL trader
- await self.rl_trader.start()
-
- # Estimate memory savings
- estimated_savings = self._calculate_memory_savings()
- self.performance_stats['memory_saved_mb'] = estimated_savings
-
- logger.info("🚀 Optimized COB System started successfully!")
- logger.info(f"💾 Estimated memory savings: {estimated_savings:.0f} MB")
- logger.info(f"🌐 Dashboard: http://localhost:8053")
- logger.info(f"🤖 RL Training: Active with 1B parameters")
- logger.info(f"📊 Shared COB: Single integration for all components")
- logger.info("🔄 System Status: OPTIMIZED - No redundant implementations")
-
- def _calculate_memory_savings(self) -> float:
- """Calculate estimated memory savings from eliminating redundancy"""
- # Estimates based on typical COB memory usage
- cob_integration_memory_mb = 512 # Order books, caches, connections
- websocket_connection_memory_mb = 64 # Per exchange connection
-
- # Before: 3 separate COB integrations (dashboard + RL trader + training)
- before_memory = 3 * cob_integration_memory_mb + 3 * websocket_connection_memory_mb
-
- # After: 1 shared COB integration
- after_memory = 1 * cob_integration_memory_mb + 1 * websocket_connection_memory_mb
-
- savings = before_memory - after_memory
- return savings
-
- async def _run_optimized_loop(self):
- """Main optimized monitoring loop"""
- logger.info("Starting optimized monitoring loop...")
-
- last_stats_time = datetime.now()
- stats_interval = 60 # Print stats every 60 seconds
-
- while self.running:
- try:
- # Sleep for a bit
- await asyncio.sleep(10)
-
- # Update performance stats
- self._update_performance_stats()
-
- # Print periodic statistics
- current_time = datetime.now()
- if (current_time - last_stats_time).total_seconds() >= stats_interval:
- await self._print_optimized_stats()
- last_stats_time = current_time
-
- except asyncio.CancelledError:
- break
- except Exception as e:
- logger.error(f"Error in optimized loop: {e}")
- await asyncio.sleep(5)
-
- logger.info("Optimized monitoring loop stopped")
-
- def _update_performance_stats(self):
- """Update performance statistics"""
- try:
- # Get stats from shared COB integration
- if self.shared_cob_integration:
- cob_stats = self.shared_cob_integration.get_statistics()
- self.performance_stats['cob_updates_processed'] = cob_stats.get('total_signals', {}).get('BTC/USDT', 0)
-
- # Get stats from dashboard
- if self.dashboard_server:
- dashboard_stats = self.dashboard_server.get_stats()
- self.performance_stats['dashboard_connections'] = dashboard_stats.get('active_connections', 0)
-
- # Get stats from RL trader
- if self.rl_trader:
- rl_stats = self.rl_trader.get_stats()
- self.performance_stats['rl_predictions'] = rl_stats.get('total_predictions', 0)
-
- # Get stats from trading executor
- if self.trading_executor:
- trade_history = self.trading_executor.get_trade_history()
- self.performance_stats['trades_executed'] = len(trade_history)
-
- except Exception as e:
- logger.warning(f"Error updating performance stats: {e}")
-
- async def _print_optimized_stats(self):
- """Print comprehensive optimized system statistics"""
- try:
- stats = self.performance_stats
- uptime = (datetime.now() - stats['start_time']).total_seconds() if stats['start_time'] else 0
-
- logger.info("=" * 80)
- logger.info("🚀 OPTIMIZED COB SYSTEM PERFORMANCE STATISTICS")
- logger.info("=" * 80)
-
- logger.info("📊 Resource Optimization:")
- logger.info(f" Memory Saved: {stats['memory_saved_mb']:.0f} MB")
- logger.info(f" Uptime: {uptime:.0f} seconds")
- logger.info(f" COB Updates: {stats['cob_updates_processed']}")
-
- logger.info("\n🌐 Dashboard Statistics:")
- logger.info(f" Active Connections: {stats['dashboard_connections']}")
- logger.info(f" Server Status: {'RUNNING' if self.dashboard_server else 'STOPPED'}")
-
- logger.info("\n🤖 RL Trading Statistics:")
- logger.info(f" Total Predictions: {stats['rl_predictions']}")
- logger.info(f" Trades Executed: {stats['trades_executed']}")
- logger.info(f" Trainer Status: {'ACTIVE' if self.rl_trader else 'STOPPED'}")
-
- # Shared COB statistics
- if self.shared_cob_integration:
- cob_stats = self.shared_cob_integration.get_statistics()
- logger.info("\n📈 Shared COB Integration:")
- logger.info(f" Active Exchanges: {', '.join(cob_stats.get('active_exchanges', []))}")
- logger.info(f" Streaming: {cob_stats.get('is_streaming', False)}")
- logger.info(f" CNN Callbacks: {cob_stats.get('cnn_callbacks', 0)}")
- logger.info(f" DQN Callbacks: {cob_stats.get('dqn_callbacks', 0)}")
- logger.info(f" Dashboard Callbacks: {cob_stats.get('dashboard_callbacks', 0)}")
-
- logger.info("=" * 80)
- logger.info("✅ OPTIMIZATION STATUS: Redundancy eliminated, shared resources active")
-
- except Exception as e:
- logger.error(f"Error printing optimized stats: {e}")
-
- async def stop(self):
- """Stop the optimized system gracefully"""
- if not self.running:
- return
-
- logger.info("Stopping Optimized COB System...")
-
- self.running = False
-
- # Stop RL trader
- if self.rl_trader:
- await self.rl_trader.stop()
- logger.info("✅ RL Trader stopped")
-
- # Stop dashboard
- if self.dashboard_server:
- await self.dashboard_server.stop()
- logger.info("✅ Dashboard stopped")
-
- # Stop shared COB integration (last, as others depend on it)
- if self.shared_cob_integration:
- await self.shared_cob_integration.stop()
- logger.info("✅ Shared COB integration stopped")
-
- # Print final optimization report
- await self._print_final_optimization_report()
-
- logger.info("Optimized COB System stopped successfully")
-
- async def _print_final_optimization_report(self):
- """Print final optimization report"""
- stats = self.performance_stats
- uptime = (datetime.now() - stats['start_time']).total_seconds() if stats['start_time'] else 0
-
- logger.info("\n📊 FINAL OPTIMIZATION REPORT:")
- logger.info(f" Total Runtime: {uptime:.0f} seconds")
- logger.info(f" Memory Saved: {stats['memory_saved_mb']:.0f} MB")
- logger.info(f" COB Updates Processed: {stats['cob_updates_processed']}")
- logger.info(f" RL Predictions Made: {stats['rl_predictions']}")
- logger.info(f" Trades Executed: {stats['trades_executed']}")
- logger.info(" ✅ Redundant implementations eliminated")
- logger.info(" ✅ Shared COB integration successful")
-
-
-# Simplified components that use shared COB (no redundant integrations)
-
-class EnhancedCOBDashboard(COBDashboardServer):
- """Enhanced dashboard that uses shared COB integration"""
-
- def __init__(self, host: str = 'localhost', port: int = 8053,
- shared_cob: COBIntegration = None, performance_tracker: Dict = None):
- # Initialize parent without creating new COB integration
- self.shared_cob = shared_cob
- self.performance_tracker = performance_tracker or {}
- super().__init__(host, port)
-
- # Use shared COB instead of creating new one
- self.cob_integration = shared_cob
- logger.info("Enhanced dashboard using shared COB integration (no redundancy)")
-
- def get_stats(self) -> Dict[str, Any]:
- """Get dashboard statistics"""
- return {
- 'active_connections': len(self.websocket_connections),
- 'using_shared_cob': self.shared_cob is not None,
- 'server_running': self.runner is not None
- }
-
-class OptimizedRLTrader:
- """Optimized RL trader that uses shared COB integration"""
-
- def __init__(self, symbols: List[str], shared_cob: COBIntegration,
- trading_executor: TradingExecutor, performance_tracker: Dict = None):
- self.symbols = symbols
- self.shared_cob = shared_cob
- self.trading_executor = trading_executor
- self.performance_tracker = performance_tracker or {}
- self.running = False
-
- # Subscribe to shared COB updates instead of creating new integration
- self.subscription_id = None
- self.prediction_count = 0
-
- logger.info("Optimized RL trader using shared COB integration (no redundancy)")
-
- async def start(self):
- """Start RL trader with shared COB"""
- self.running = True
-
- # Subscribe to shared COB updates
- self.subscription_id = self.shared_cob.add_dqn_callback(self._on_cob_update)
-
- # Start prediction loop
- asyncio.create_task(self._prediction_loop())
-
- logger.info("Optimized RL trader started with shared COB subscription")
-
- async def stop(self):
- """Stop RL trader"""
- self.running = False
- logger.info("Optimized RL trader stopped")
-
- async def _on_cob_update(self, symbol: str, data: Dict):
- """Handle COB updates from shared integration"""
- try:
- # Process RL prediction using shared data
- self.prediction_count += 1
-
- # Simple prediction logic (placeholder)
- confidence = 0.75 # Example confidence
-
- if self.prediction_count % 100 == 0:
- logger.info(f"RL Prediction #{self.prediction_count} for {symbol} (confidence: {confidence:.2f})")
-
- except Exception as e:
- logger.error(f"Error in RL update: {e}")
-
- async def _prediction_loop(self):
- """Main prediction loop"""
- while self.running:
- try:
- # RL model inference would go here
- await asyncio.sleep(0.2) # 200ms inference interval
- except Exception as e:
- logger.error(f"Error in prediction loop: {e}")
- await asyncio.sleep(1)
-
- def get_stats(self) -> Dict[str, Any]:
- """Get RL trader statistics"""
- return {
- 'total_predictions': self.prediction_count,
- 'using_shared_cob': self.shared_cob is not None,
- 'subscription_active': self.subscription_id is not None
- }
-
-
-async def main():
- """Main entry point for optimized COB system"""
- try:
- # Create logs directory
- os.makedirs('logs', exist_ok=True)
-
- # Initialize and start optimized system
- system = OptimizedCOBSystem()
- await system.start()
-
- except KeyboardInterrupt:
- logger.info("Received keyboard interrupt, shutting down...")
- except Exception as e:
- logger.error(f"Critical error: {e}")
- import traceback
- traceback.print_exc()
-
-if __name__ == "__main__":
- # Set event loop policy for Windows compatibility
- if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'):
- asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
-
- asyncio.run(main())
\ No newline at end of file
diff --git a/run_realtime_rl_cob_trader.py b/run_realtime_rl_cob_trader.py
deleted file mode 100644
index 57e425e..0000000
--- a/run_realtime_rl_cob_trader.py
+++ /dev/null
@@ -1,324 +0,0 @@
-#!/usr/bin/env python3
-"""
-Real-time RL COB Trader Launcher
-
-Launch script for the real-time reinforcement learning trader that:
-1. Uses COB data for training a 1B parameter model
-2. Performs inference every 200ms
-3. Accumulates confident signals for trade execution
-4. Trains continuously in real-time based on outcomes
-
-This script provides a complete trading system integration.
-"""
-
-import asyncio
-import logging
-import signal
-import sys
-import json
-import os
-from datetime import datetime
-from typing import Dict, Any, Optional
-
-# Local imports
-from core.realtime_rl_cob_trader import RealtimeRLCOBTrader
-from core.trading_executor import TradingExecutor
-from core.config import load_config
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler('logs/realtime_rl_cob_trader.log'),
- logging.StreamHandler(sys.stdout)
- ]
-)
-
-logger = logging.getLogger(__name__)
-
-class RealtimeRLCOBTraderLauncher:
- """
- Launcher for Real-time RL COB Trader system
- """
-
- def __init__(self, config_path: str = "config.yaml"):
- """Initialize launcher with configuration"""
- self.config_path = config_path
- self.config = load_config(config_path)
- self.trader: Optional[RealtimeRLCOBTrader] = None
- self.trading_executor: Optional[TradingExecutor] = None
- self.running = False
-
- # Setup signal handlers for graceful shutdown
- signal.signal(signal.SIGINT, self._signal_handler)
- signal.signal(signal.SIGTERM, self._signal_handler)
-
- logger.info("RealtimeRLCOBTraderLauncher initialized")
-
- def _signal_handler(self, signum, frame):
- """Handle shutdown signals"""
- logger.info(f"Received signal {signum}, initiating graceful shutdown...")
- self.running = False
-
- async def start(self):
- """Start the real-time RL COB trading system"""
- try:
- logger.info("=" * 60)
- logger.info("REAL-TIME RL COB TRADER SYSTEM STARTING")
- logger.info("=" * 60)
-
- # Initialize trading executor
- await self._initialize_trading_executor()
-
- # Initialize RL trader
- await self._initialize_rl_trader()
-
- # Start the trading system
- await self._start_trading_system()
-
- # Run main loop
- await self._run_main_loop()
-
- except Exception as e:
- logger.error(f"Critical error in trader launcher: {e}")
- raise
- finally:
- await self.stop()
-
- async def _initialize_trading_executor(self):
- """Initialize the trading executor"""
- logger.info("Initializing Trading Executor...")
-
- # Get trading configuration
- trading_config = self.config.get('trading', {})
- mexc_config = self.config.get('mexc', {})
-
- # Determine if we should run in simulation mode
- simulation_mode = mexc_config.get('simulation_mode', True)
-
- if simulation_mode:
- logger.info("Running in SIMULATION mode - no real trades will be executed")
- else:
- logger.warning("Running in LIVE TRADING mode - real money at risk!")
-
- # Add safety confirmation for live trading
- confirmation = input("Type 'CONFIRM_LIVE_TRADING' to proceed with live trading: ")
- if confirmation != 'CONFIRM_LIVE_TRADING':
- logger.info("Live trading not confirmed, switching to simulation mode")
- simulation_mode = True
-
- # Initialize trading executor
- self.trading_executor = TradingExecutor(self.config_path)
-
- logger.info(f"Trading Executor initialized in {'SIMULATION' if simulation_mode else 'LIVE'} mode")
-
- async def _initialize_rl_trader(self):
- """Initialize the RL trader"""
- logger.info("Initializing Real-time RL COB Trader...")
-
- # Get RL configuration
- rl_config = self.config.get('realtime_rl', {})
-
- # Trading symbols
- symbols = rl_config.get('symbols', ['BTC/USDT', 'ETH/USDT'])
-
- # RL parameters
- inference_interval_ms = rl_config.get('inference_interval_ms', 200)
- min_confidence_threshold = rl_config.get('min_confidence_threshold', 0.7)
- required_confident_predictions = rl_config.get('required_confident_predictions', 3)
- model_checkpoint_dir = rl_config.get('model_checkpoint_dir', 'models/realtime_rl_cob')
-
- # Initialize RL trader
- if self.trading_executor is None:
- raise RuntimeError("Trading executor not initialized")
-
- self.trader = RealtimeRLCOBTrader(
- symbols=symbols,
- trading_executor=self.trading_executor,
- model_checkpoint_dir=model_checkpoint_dir,
- inference_interval_ms=inference_interval_ms,
- min_confidence_threshold=min_confidence_threshold,
- required_confident_predictions=required_confident_predictions
- )
-
- logger.info(f"RL Trader initialized for symbols: {symbols}")
- logger.info(f"Inference interval: {inference_interval_ms}ms")
- logger.info(f"Confidence threshold: {min_confidence_threshold}")
- logger.info(f"Required predictions: {required_confident_predictions}")
-
- async def _start_trading_system(self):
- """Start the complete trading system"""
- logger.info("Starting Real-time RL COB Trading System...")
-
- # Start RL trader (this will start COB integration internally)
- if self.trader is None:
- raise RuntimeError("RL trader not initialized")
- await self.trader.start()
-
- self.running = True
-
- logger.info("✅ Real-time RL COB Trading System started successfully!")
- logger.info("🔥 1B parameter model training and inference active")
- logger.info("📊 COB data streaming and processing")
- logger.info("🎯 Signal accumulation and trade execution ready")
- logger.info("⚡ Real-time training on prediction outcomes")
-
- async def _run_main_loop(self):
- """Main monitoring and statistics loop"""
- logger.info("Starting main monitoring loop...")
-
- last_stats_time = datetime.now()
- stats_interval = 60 # Print stats every 60 seconds
-
- while self.running:
- try:
- # Sleep for a bit
- await asyncio.sleep(10)
-
- # Print periodic statistics
- current_time = datetime.now()
- if (current_time - last_stats_time).total_seconds() >= stats_interval:
- await self._print_performance_stats()
- last_stats_time = current_time
-
- except asyncio.CancelledError:
- break
- except Exception as e:
- logger.error(f"Error in main loop: {e}")
- await asyncio.sleep(5)
-
- logger.info("Main monitoring loop stopped")
-
- async def _print_performance_stats(self):
- """Print comprehensive performance statistics"""
- try:
- if not self.trader:
- return
-
- stats = self.trader.get_performance_stats()
-
- logger.info("=" * 80)
- logger.info("🔥 REAL-TIME RL COB TRADER PERFORMANCE STATISTICS")
- logger.info("=" * 80)
-
- # Model information
- logger.info("📊 Model Information:")
- for symbol, model_info in stats.get('model_info', {}).items():
- total_params = model_info.get('total_parameters', 0)
- logger.info(f" {symbol}: {total_params:,} parameters ({total_params/1e9:.2f}B)")
-
- # Training statistics
- logger.info("\n🧠 Training Statistics:")
- for symbol, training_stats in stats.get('training_stats', {}).items():
- total_preds = training_stats.get('total_predictions', 0)
- successful_preds = training_stats.get('successful_predictions', 0)
- success_rate = (successful_preds / max(1, total_preds)) * 100
- avg_loss = training_stats.get('average_loss', 0.0)
- training_steps = training_stats.get('total_training_steps', 0)
- last_training = training_stats.get('last_training_time')
-
- logger.info(f" {symbol}:")
- logger.info(f" Predictions: {total_preds} (Success: {success_rate:.1f}%)")
- logger.info(f" Training Steps: {training_steps}")
- logger.info(f" Average Loss: {avg_loss:.6f}")
- if last_training:
- logger.info(f" Last Training: {last_training}")
-
- # Inference statistics
- logger.info("\n⚡ Inference Statistics:")
- for symbol, inference_stats in stats.get('inference_stats', {}).items():
- total_inferences = inference_stats.get('total_inferences', 0)
- avg_time = inference_stats.get('average_inference_time_ms', 0.0)
- last_inference = inference_stats.get('last_inference_time')
-
- logger.info(f" {symbol}:")
- logger.info(f" Total Inferences: {total_inferences}")
- logger.info(f" Average Time: {avg_time:.1f}ms")
- if last_inference:
- logger.info(f" Last Inference: {last_inference}")
-
- # Signal statistics
- logger.info("\n🎯 Signal Accumulation:")
- for symbol, signal_stats in stats.get('signal_stats', {}).items():
- current_signals = signal_stats.get('current_signals', 0)
- confidence_sum = signal_stats.get('confidence_sum', 0.0)
- success_rate = signal_stats.get('success_rate', 0.0) * 100
-
- logger.info(f" {symbol}:")
- logger.info(f" Current Signals: {current_signals}")
- logger.info(f" Confidence Sum: {confidence_sum:.2f}")
- logger.info(f" Historical Success Rate: {success_rate:.1f}%")
-
- # Trading executor statistics
- if self.trading_executor:
- positions = self.trading_executor.get_positions()
- trade_history = self.trading_executor.get_trade_history()
-
- logger.info("\n💰 Trading Statistics:")
- logger.info(f" Active Positions: {len(positions)}")
- logger.info(f" Total Trades: {len(trade_history)}")
-
- if trade_history:
- # Calculate P&L statistics
- total_pnl = sum(trade.pnl for trade in trade_history)
- profitable_trades = sum(1 for trade in trade_history if trade.pnl > 0)
- win_rate = (profitable_trades / len(trade_history)) * 100
-
- logger.info(f" Total P&L: ${total_pnl:.2f}")
- logger.info(f" Win Rate: {win_rate:.1f}%")
-
- # Show active positions
- if positions:
- logger.info("\n📍 Active Positions:")
- for symbol, position in positions.items():
- logger.info(f" {symbol}: {position.side} {position.quantity:.6f} @ ${position.entry_price:.2f}")
-
- logger.info("=" * 80)
-
- except Exception as e:
- logger.error(f"Error printing performance stats: {e}")
-
- async def stop(self):
- """Stop the trading system gracefully"""
- if not self.running:
- return
-
- logger.info("Stopping Real-time RL COB Trading System...")
-
- self.running = False
-
- # Stop RL trader
- if self.trader:
- await self.trader.stop()
- logger.info("✅ RL Trader stopped")
-
- # Print final statistics
- if self.trader:
- logger.info("\n📊 Final Performance Summary:")
- await self._print_performance_stats()
-
- logger.info("Real-time RL COB Trading System stopped successfully")
-
-async def main():
- """Main entry point"""
- try:
- # Create logs directory if it doesn't exist
- os.makedirs('logs', exist_ok=True)
-
- # Initialize and start launcher
- launcher = RealtimeRLCOBTraderLauncher()
- await launcher.start()
-
- except KeyboardInterrupt:
- logger.info("Received keyboard interrupt, shutting down...")
- except Exception as e:
- logger.error(f"Critical error: {e}")
- raise
-
-if __name__ == "__main__":
- # Set event loop policy for Windows compatibility
- if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'):
- asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
-
- asyncio.run(main())
\ No newline at end of file
diff --git a/run_simple_dashboard.py b/run_simple_dashboard.py
deleted file mode 100644
index 9bf4909..0000000
--- a/run_simple_dashboard.py
+++ /dev/null
@@ -1,218 +0,0 @@
-#!/usr/bin/env python3
-"""
-Simple Dashboard Runner - Fixed version for testing
-"""
-
-import os
-import sys
-import logging
-import time
-import threading
-from pathlib import Path
-
-# Fix OpenMP library conflicts
-os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
-os.environ['OMP_NUM_THREADS'] = '4'
-
-# Fix matplotlib backend
-import matplotlib
-matplotlib.use('Agg')
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-def create_simple_dashboard():
- """Create a simple working dashboard"""
- try:
- import dash
- from dash import html, dcc, Input, Output
- import plotly.graph_objs as go
- import pandas as pd
- import numpy as np
- from datetime import datetime, timedelta
-
- # Create Dash app
- app = dash.Dash(__name__)
-
- # Simple layout
- app.layout = html.Div([
- html.H1("Trading System Dashboard", style={'textAlign': 'center', 'color': '#2c3e50'}),
-
- html.Div([
- html.Div([
- html.H3("System Status", style={'color': '#27ae60'}),
- html.P(id='system-status', children="System: RUNNING", style={'fontSize': '18px'}),
- html.P(id='current-time', children=f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"),
- ], style={'width': '48%', 'display': 'inline-block', 'padding': '20px'}),
-
- html.Div([
- html.H3("Trading Stats", style={'color': '#3498db'}),
- html.P("Total Trades: 0"),
- html.P("Success Rate: 0%"),
- html.P("Current PnL: $0.00"),
- ], style={'width': '48%', 'display': 'inline-block', 'padding': '20px'}),
- ]),
-
- html.Div([
- dcc.Graph(id='price-chart'),
- ], style={'padding': '20px'}),
-
- html.Div([
- dcc.Graph(id='performance-chart'),
- ], style={'padding': '20px'}),
-
- # Auto-refresh component
- dcc.Interval(
- id='interval-component',
- interval=5000, # Update every 5 seconds
- n_intervals=0
- )
- ])
-
- # Callback for updating time
- @app.callback(
- Output('current-time', 'children'),
- Input('interval-component', 'n_intervals')
- )
- def update_time(n):
- return f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
-
- # Callback for price chart
- @app.callback(
- Output('price-chart', 'figure'),
- Input('interval-component', 'n_intervals')
- )
- def update_price_chart(n):
- # Generate sample data
- dates = pd.date_range(start=datetime.now() - timedelta(hours=24),
- end=datetime.now(), freq='1H')
- prices = 3000 + np.cumsum(np.random.randn(len(dates)) * 10)
-
- fig = go.Figure()
- fig.add_trace(go.Scatter(
- x=dates,
- y=prices,
- mode='lines',
- name='ETH/USDT',
- line=dict(color='#3498db', width=2)
- ))
-
- fig.update_layout(
- title='ETH/USDT Price Chart (24H)',
- xaxis_title='Time',
- yaxis_title='Price (USD)',
- template='plotly_white',
- height=400
- )
-
- return fig
-
- # Callback for performance chart
- @app.callback(
- Output('performance-chart', 'figure'),
- Input('interval-component', 'n_intervals')
- )
- def update_performance_chart(n):
- # Generate sample performance data
- dates = pd.date_range(start=datetime.now() - timedelta(days=7),
- end=datetime.now(), freq='1D')
- performance = np.cumsum(np.random.randn(len(dates)) * 0.02) * 100
-
- fig = go.Figure()
- fig.add_trace(go.Scatter(
- x=dates,
- y=performance,
- mode='lines+markers',
- name='Portfolio Performance',
- line=dict(color='#27ae60', width=3),
- marker=dict(size=6)
- ))
-
- fig.update_layout(
- title='Portfolio Performance (7 Days)',
- xaxis_title='Date',
- yaxis_title='Performance (%)',
- template='plotly_white',
- height=400
- )
-
- return fig
-
- return app
-
- except Exception as e:
- logger.error(f"Error creating dashboard: {e}")
- import traceback
- logger.error(traceback.format_exc())
- return None
-
-def test_data_provider():
- """Test data provider in background"""
- try:
- from core.data_provider import DataProvider
- from core.api_rate_limiter import get_rate_limiter
-
- logger.info("Testing data provider...")
-
- # Create data provider
- data_provider = DataProvider(
- symbols=['ETH/USDT'],
- timeframes=['1m', '5m']
- )
-
- # Test getting data
- df = data_provider.get_historical_data('ETH/USDT', '1m', limit=10)
- if df is not None and len(df) > 0:
- logger.info(f"✓ Data provider working: {len(df)} candles retrieved")
- else:
- logger.warning("⚠ Data provider returned no data (rate limiting)")
-
- # Test rate limiter status
- rate_limiter = get_rate_limiter()
- status = rate_limiter.get_all_endpoint_status()
- logger.info(f"Rate limiter status: {status}")
-
- except Exception as e:
- logger.error(f"Data provider test error: {e}")
-
-def main():
- """Main function"""
- logger.info("=" * 60)
- logger.info("SIMPLE DASHBOARD RUNNER - TESTING SYSTEM")
- logger.info("=" * 60)
-
- # Test data provider in background
- data_thread = threading.Thread(target=test_data_provider, daemon=True)
- data_thread.start()
-
- # Create and run dashboard
- app = create_simple_dashboard()
- if app is None:
- logger.error("Failed to create dashboard")
- return
-
- try:
- logger.info("Starting dashboard server...")
- logger.info("Dashboard URL: http://127.0.0.1:8050")
- logger.info("Press Ctrl+C to stop")
-
- # Run the dashboard
- app.run(debug=False, host='127.0.0.1', port=8050, use_reloader=False)
-
- except KeyboardInterrupt:
- logger.info("Dashboard stopped by user")
- except Exception as e:
- logger.error(f"Dashboard error: {e}")
- import traceback
- logger.error(traceback.format_exc())
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/run_stable_dashboard.py b/run_stable_dashboard.py
deleted file mode 100644
index a98e2d8..0000000
--- a/run_stable_dashboard.py
+++ /dev/null
@@ -1,275 +0,0 @@
-#!/usr/bin/env python3
-"""
-Stable Dashboard Runner - Prioritizes System Stability
-
-This runner focuses on:
-1. System stability and reliability
-2. Core trading functionality
-3. Minimal resource usage
-4. Robust error handling
-5. Graceful degradation
-
-Deferred features (until stability is achieved):
-- TensorBoard integration
-- Complex training loops
-- Advanced visualizations
-"""
-
-import os
-import sys
-import time
-import logging
-import threading
-import signal
-from pathlib import Path
-
-# Fix environment issues before imports
-os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
-os.environ['OMP_NUM_THREADS'] = '2' # Reduced from 4 for stability
-
-# Fix matplotlib backend
-import matplotlib
-matplotlib.use('Agg')
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import setup_logging, get_config
-from system_stability_audit import SystemStabilityAuditor
-
-# Setup logging with reduced verbosity for stability
-setup_logging()
-logger = logging.getLogger(__name__)
-
-# Reduce logging noise from other modules
-logging.getLogger('werkzeug').setLevel(logging.ERROR)
-logging.getLogger('dash').setLevel(logging.ERROR)
-logging.getLogger('matplotlib').setLevel(logging.ERROR)
-
-class StableDashboardRunner:
- """
- Stable dashboard runner with focus on reliability
- """
-
- def __init__(self):
- """Initialize stable dashboard runner"""
- self.config = get_config()
- self.running = False
- self.dashboard = None
- self.stability_auditor = None
-
- # Core components
- self.data_provider = None
- self.orchestrator = None
- self.trading_executor = None
-
- # Stability monitoring
- self.last_health_check = time.time()
- self.health_check_interval = 30 # Check every 30 seconds
-
- logger.info("Stable Dashboard Runner initialized")
-
- def initialize_components(self):
- """Initialize core components with error handling"""
- try:
- logger.info("Initializing core components...")
-
- # Initialize data provider
- from core.data_provider import DataProvider
- self.data_provider = DataProvider()
- logger.info("✓ Data provider initialized")
-
- # Initialize trading executor
- from core.trading_executor import TradingExecutor
- self.trading_executor = TradingExecutor()
- logger.info("✓ Trading executor initialized")
-
- # Initialize orchestrator with minimal features for stability
- from core.orchestrator import TradingOrchestrator
- self.orchestrator = TradingOrchestrator(
- data_provider=self.data_provider,
- enhanced_rl_training=False # Disabled for stability
- )
- logger.info("✓ Orchestrator initialized (training disabled for stability)")
-
- # Initialize dashboard
- from web.clean_dashboard import CleanTradingDashboard
- self.dashboard = CleanTradingDashboard(
- data_provider=self.data_provider,
- orchestrator=self.orchestrator,
- trading_executor=self.trading_executor
- )
- logger.info("✓ Dashboard initialized")
-
- return True
-
- except Exception as e:
- logger.error(f"Error initializing components: {e}")
- return False
-
- def start_stability_monitoring(self):
- """Start system stability monitoring"""
- try:
- self.stability_auditor = SystemStabilityAuditor()
- self.stability_auditor.start_monitoring()
- logger.info("✓ Stability monitoring started")
- except Exception as e:
- logger.error(f"Error starting stability monitoring: {e}")
-
- def health_check(self):
- """Perform system health check"""
- try:
- current_time = time.time()
- if current_time - self.last_health_check < self.health_check_interval:
- return
-
- self.last_health_check = current_time
-
- # Check stability score
- if self.stability_auditor:
- report = self.stability_auditor.get_stability_report()
- stability_score = report.get('stability_score', 0)
-
- if stability_score < 50:
- logger.warning(f"Low stability score: {stability_score:.1f}/100")
- # Attempt to fix issues
- self.stability_auditor.fix_common_issues()
- elif stability_score < 80:
- logger.info(f"Moderate stability: {stability_score:.1f}/100")
- else:
- logger.debug(f"Good stability: {stability_score:.1f}/100")
-
- # Check component health
- if self.dashboard and hasattr(self.dashboard, 'app'):
- logger.debug("✓ Dashboard responsive")
-
- if self.data_provider:
- logger.debug("✓ Data provider active")
-
- if self.orchestrator:
- logger.debug("✓ Orchestrator active")
-
- except Exception as e:
- logger.error(f"Error in health check: {e}")
-
- def run(self):
- """Run the stable dashboard"""
- try:
- logger.info("=" * 60)
- logger.info("STABLE TRADING DASHBOARD")
- logger.info("=" * 60)
- logger.info("Priority: System Stability & Core Functionality")
- logger.info("Training: Disabled (will be enabled after stability)")
- logger.info("TensorBoard: Deferred (documented in design)")
- logger.info("Focus: Dashboard, Data, Basic Trading")
- logger.info("=" * 60)
-
- # Initialize components
- if not self.initialize_components():
- logger.error("Failed to initialize components")
- return False
-
- # Start stability monitoring
- self.start_stability_monitoring()
-
- # Start health check thread
- health_thread = threading.Thread(target=self._health_check_loop, daemon=True)
- health_thread.start()
-
- # Get dashboard port
- port = int(os.environ.get('DASHBOARD_PORT', '8051'))
-
- logger.info(f"Starting dashboard on http://127.0.0.1:{port}")
- logger.info("Press Ctrl+C to stop")
-
- self.running = True
-
- # Start dashboard (this blocks)
- if self.dashboard and hasattr(self.dashboard, 'app'):
- self.dashboard.app.run_server(
- host='127.0.0.1',
- port=port,
- debug=False,
- use_reloader=False, # Disable reloader for stability
- threaded=True
- )
- else:
- logger.error("Dashboard not properly initialized")
- return False
-
- except KeyboardInterrupt:
- logger.info("Dashboard stopped by user")
- self.shutdown()
- return True
- except Exception as e:
- logger.error(f"Error running dashboard: {e}")
- import traceback
- logger.error(traceback.format_exc())
- return False
-
- def _health_check_loop(self):
- """Health check loop running in background"""
- while self.running:
- try:
- self.health_check()
- time.sleep(self.health_check_interval)
- except Exception as e:
- logger.error(f"Error in health check loop: {e}")
- time.sleep(60) # Wait longer on error
-
- def shutdown(self):
- """Graceful shutdown"""
- try:
- logger.info("Shutting down stable dashboard...")
- self.running = False
-
- # Stop stability monitoring
- if self.stability_auditor:
- self.stability_auditor.stop_monitoring()
- logger.info("✓ Stability monitoring stopped")
-
- # Stop components
- if self.orchestrator and hasattr(self.orchestrator, 'stop'):
- self.orchestrator.stop()
- logger.info("✓ Orchestrator stopped")
-
- if self.data_provider and hasattr(self.data_provider, 'stop'):
- self.data_provider.stop()
- logger.info("✓ Data provider stopped")
-
- logger.info("Stable dashboard shutdown complete")
-
- except Exception as e:
- logger.error(f"Error during shutdown: {e}")
-
-def signal_handler(signum, frame):
- """Handle shutdown signals"""
- logger.info("Received shutdown signal")
- sys.exit(0)
-
-def main():
- """Main function"""
- # Setup signal handlers
- signal.signal(signal.SIGINT, signal_handler)
- signal.signal(signal.SIGTERM, signal_handler)
-
- try:
- runner = StableDashboardRunner()
- success = runner.run()
-
- if success:
- logger.info("Dashboard completed successfully")
- sys.exit(0)
- else:
- logger.error("Dashboard failed")
- sys.exit(1)
-
- except Exception as e:
- logger.error(f"Fatal error: {e}")
- import traceback
- logger.error(traceback.format_exc())
- sys.exit(1)
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/run_templated_dashboard.py b/run_templated_dashboard.py
deleted file mode 100644
index 3d6d390..0000000
--- a/run_templated_dashboard.py
+++ /dev/null
@@ -1,64 +0,0 @@
-#!/usr/bin/env python3
-"""
-Run Templated Trading Dashboard
-Demonstrates the new MVC template-based architecture
-"""
-import logging
-import sys
-import os
-
-# Add project root to path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from web.templated_dashboard import create_templated_dashboard
-from web.dashboard_model import create_sample_dashboard_data
-from web.template_renderer import DashboardTemplateRenderer
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-
-def main():
- """Main function to run the templated dashboard"""
- try:
- logger.info("=== TEMPLATED DASHBOARD DEMO ===")
-
- # Test the template system first
- logger.info("Testing template system...")
-
- # Create sample data
- sample_data = create_sample_dashboard_data()
- logger.info(f"Created sample data with {len(sample_data.metrics)} metrics")
-
- # Test template renderer
- renderer = DashboardTemplateRenderer()
- logger.info("Template renderer initialized")
-
- # Create templated dashboard
- logger.info("Creating templated dashboard...")
- dashboard = create_templated_dashboard()
-
- logger.info("Dashboard created successfully!")
- logger.info("Template-based MVC architecture features:")
- logger.info(" ✓ HTML templates separated from Python code")
- logger.info(" ✓ Data models for structured data")
- logger.info(" ✓ Template renderer for clean separation")
- logger.info(" ✓ Easy to modify HTML without touching Python")
- logger.info(" ✓ Reusable components and templates")
-
- # Run the dashboard
- logger.info("Starting templated dashboard server...")
- dashboard.run_server(host='127.0.0.1', port=8052, debug=False)
-
- except Exception as e:
- logger.error(f"Error running templated dashboard: {e}")
- import traceback
- traceback.print_exc()
-
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/run_tensorboard.py b/run_tensorboard.py
deleted file mode 100644
index 13abf27..0000000
--- a/run_tensorboard.py
+++ /dev/null
@@ -1,155 +0,0 @@
-#!/usr/bin/env python3
-"""
-TensorBoard Launch Script
-
-Starts TensorBoard server for monitoring training progress.
-Visualizes training metrics, rewards, state information, and model performance.
-
-This script can be run standalone or integrated with the dashboard.
-"""
-
-import subprocess
-import sys
-import os
-import time
-import webbrowser
-import argparse
-from pathlib import Path
-import logging
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-def start_tensorboard(logdir="runs", port=6006, open_browser=True):
- """
- Start TensorBoard server programmatically
-
- Args:
- logdir: Directory containing TensorBoard logs
- port: Port to run TensorBoard on
- open_browser: Whether to open browser automatically
-
- Returns:
- subprocess.Popen: TensorBoard process
- """
- # Set log directory
- runs_dir = Path(logdir)
- if not runs_dir.exists():
- logger.warning(f"No '{logdir}' directory found. Creating it.")
- runs_dir.mkdir(parents=True, exist_ok=True)
-
- # Check if there are any log directories
- log_dirs = list(runs_dir.glob("*"))
- if not log_dirs:
- logger.warning(f"No training logs found in '{logdir}' directory.")
- else:
- logger.info(f"Found {len(log_dirs)} training sessions")
-
- # List available sessions
- logger.info("Available training sessions:")
- for i, log_dir in enumerate(sorted(log_dirs), 1):
- logger.info(f" {i}. {log_dir.name}")
-
- try:
- logger.info(f"Starting TensorBoard on port {port}...")
-
- # Try to open browser automatically if requested
- if open_browser:
- try:
- webbrowser.open(f"http://localhost:{port}")
- logger.info("Browser opened automatically")
- except Exception as e:
- logger.warning(f"Could not open browser automatically: {e}")
-
- # Start TensorBoard process with enhanced options
- cmd = [
- sys.executable,
- "-m",
- "tensorboard.main",
- "--logdir", str(runs_dir),
- "--port", str(port),
- "--samples_per_plugin", "images=100,audio=100,text=100",
- "--reload_interval", "5", # Reload data every 5 seconds
- "--reload_multifile", "true" # Better handling of multiple log files
- ]
-
- logger.info("TensorBoard is running with enhanced training visualization!")
- logger.info(f"View training metrics at: http://localhost:{port}")
- logger.info("Available dashboards:")
- logger.info(" - SCALARS: Training metrics, rewards, and losses")
- logger.info(" - HISTOGRAMS: Feature distributions and model weights")
- logger.info(" - TIME SERIES: Training progress over time")
-
- # Start TensorBoard process
- process = subprocess.Popen(
- cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- text=True
- )
-
- # Return process for management
- return process
-
- except FileNotFoundError:
- logger.error("TensorBoard not found. Install with: pip install tensorboard")
- return None
- except Exception as e:
- logger.error(f"Error starting TensorBoard: {e}")
- return None
-
-def main():
- """Launch TensorBoard with enhanced visualization options"""
-
- # Parse command line arguments
- parser = argparse.ArgumentParser(description="Launch TensorBoard for training visualization")
- parser.add_argument("--port", type=int, default=6006, help="Port to run TensorBoard on")
- parser.add_argument("--logdir", type=str, default="runs", help="Directory containing TensorBoard logs")
- parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically")
- parser.add_argument("--dashboard-integration", action="store_true", help="Run in dashboard integration mode")
- args = parser.parse_args()
-
- # Start TensorBoard
- process = start_tensorboard(
- logdir=args.logdir,
- port=args.port,
- open_browser=not args.no_browser
- )
-
- if process is None:
- return 1
-
- # If running in dashboard integration mode, return immediately
- if args.dashboard_integration:
- return 0
-
- # Otherwise, wait for process to complete
- try:
- print("\n" + "="*70)
- print("🔥 TensorBoard is running with enhanced training visualization!")
- print(f"📈 View training metrics at: http://localhost:{args.port}")
- print("⏹️ Press Ctrl+C to stop TensorBoard")
- print("="*70 + "\n")
-
- # Wait for process to complete or user interrupt
- process.wait()
- return 0
-
- except KeyboardInterrupt:
- print("\n🛑 TensorBoard stopped")
- process.terminate()
- try:
- process.wait(timeout=5)
- except subprocess.TimeoutExpired:
- process.kill()
- return 0
- except Exception as e:
- print(f"❌ Error: {e}")
- return 1
-
-if __name__ == "__main__":
- sys.exit(main())
\ No newline at end of file
diff --git a/run_tests.py b/run_tests.py
deleted file mode 100644
index cdb0737..0000000
--- a/run_tests.py
+++ /dev/null
@@ -1,182 +0,0 @@
-#!/usr/bin/env python3
-"""
-Unified Test Runner for Trading System
-
-This script provides a unified interface to run all tests in the system:
-- Essential functionality tests
-- Model persistence tests
-- Training integration tests
-- Indicators and signals tests
-- Remaining individual test files
-
-Usage:
- python run_tests.py # Run all tests
- python run_tests.py essential # Run essential tests only
- python run_tests.py persistence # Run model persistence tests only
- python run_tests.py training # Run training integration tests only
- python run_tests.py indicators # Run indicators and signals tests only
- python run_tests.py individual # Run individual test files only
-"""
-
-import sys
-import os
-import subprocess
-import logging
-from pathlib import Path
-from safe_logging import setup_safe_logging
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import setup_logging
-
-logger = logging.getLogger(__name__)
-
-def run_test_module(module_path, test_type="all"):
- """Run a specific test module"""
- try:
- cmd = [sys.executable, str(module_path)]
- if test_type != "all":
- cmd.append(test_type)
-
- logger.info(f"Running: {' '.join(cmd)}")
- result = subprocess.run(cmd, capture_output=True, text=True, cwd=project_root)
-
- if result.returncode == 0:
- logger.info(f"✅ {module_path.name} passed")
- if result.stdout:
- logger.info(result.stdout)
- return True
- else:
- logger.error(f"❌ {module_path.name} failed")
- if result.stderr:
- logger.error(result.stderr)
- if result.stdout:
- logger.error(result.stdout)
- return False
-
- except Exception as e:
- logger.error(f"❌ Error running {module_path}: {e}")
- return False
-
-def run_essential_tests():
- """Run essential functionality tests"""
- logger.info("=== Running Essential Tests ===")
- return run_test_module(project_root / "tests" / "test_essential.py")
-
-def run_persistence_tests():
- """Run model persistence tests"""
- logger.info("=== Running Model Persistence Tests ===")
- return run_test_module(project_root / "tests" / "test_model_persistence.py")
-
-def run_training_tests():
- """Run training integration tests"""
- logger.info("=== Running Training Integration Tests ===")
- return run_test_module(project_root / "tests" / "test_training_integration.py")
-
-def run_indicators_tests():
- """Run indicators and signals tests"""
- logger.info("=== Running Indicators and Signals Tests ===")
- return run_test_module(project_root / "tests" / "test_indicators_and_signals.py")
-
-def run_individual_tests():
- """Run remaining individual test files"""
- logger.info("=== Running Individual Test Files ===")
-
- individual_tests = [
- "test_positions.py",
- "test_tick_cache.py",
- "test_timestamps.py"
- ]
-
- results = []
- for test_file in individual_tests:
- test_path = project_root / test_file
- if test_path.exists():
- logger.info(f"Running {test_file}...")
- result = run_test_module(test_path)
- results.append(result)
- else:
- logger.warning(f"Test file not found: {test_file}")
- results.append(False)
-
- return all(results)
-
-def run_all_tests():
- """Run all test suites"""
- logger.info("🧪 Running All Trading System Tests")
- logger.info("=" * 60)
-
- test_suites = [
- ("Essential Tests", run_essential_tests),
- ("Model Persistence Tests", run_persistence_tests),
- ("Training Integration Tests", run_training_tests),
- ("Indicators and Signals Tests", run_indicators_tests),
- ("Individual Tests", run_individual_tests),
- ]
-
- results = []
- for suite_name, suite_func in test_suites:
- logger.info(f"\n📋 {suite_name}")
- logger.info("-" * 40)
- try:
- result = suite_func()
- results.append((suite_name, result))
- except Exception as e:
- logger.error(f"❌ {suite_name} crashed: {e}")
- results.append((suite_name, False))
-
- # Print summary
- logger.info("\n" + "=" * 60)
- logger.info("📊 TEST RESULTS SUMMARY")
- logger.info("=" * 60)
-
- passed = 0
- for suite_name, result in results:
- status = "✅ PASS" if result else "❌ FAIL"
- logger.info(f"{status}: {suite_name}")
- if result:
- passed += 1
-
- logger.info(f"\nPassed: {passed}/{len(results)} test suites")
-
- if passed == len(results):
- logger.info("🎉 All tests passed! Trading system is working correctly.")
- return True
- else:
- logger.warning(f"⚠️ {len(results) - passed} test suite(s) failed. Please check the issues above.")
- return False
-
-def main():
- """Main test runner"""
- setup_safe_logging()
-
- # Parse command line arguments
- if len(sys.argv) > 1:
- test_type = sys.argv[1].lower()
-
- if test_type == "essential":
- success = run_essential_tests()
- elif test_type == "persistence":
- success = run_persistence_tests()
- elif test_type == "training":
- success = run_training_tests()
- elif test_type == "indicators":
- success = run_indicators_tests()
- elif test_type == "individual":
- success = run_individual_tests()
- elif test_type in ["help", "-h", "--help"]:
- print(__doc__)
- return 0
- else:
- logger.error(f"Unknown test type: {test_type}")
- print(__doc__)
- return 1
- else:
- success = run_all_tests()
-
- return 0 if success else 1
-
-if __name__ == "__main__":
- sys.exit(main())
\ No newline at end of file
diff --git a/scripts/kill_stale_processes.py b/scripts/kill_stale_processes.py
deleted file mode 100644
index 92ec71c..0000000
--- a/scripts/kill_stale_processes.py
+++ /dev/null
@@ -1,197 +0,0 @@
-#!/usr/bin/env python3
-"""
-Kill Stale Processes Script
-
-Safely terminates stale Python processes related to the trading dashboard
-with proper error handling and graceful termination.
-"""
-
-import os
-import sys
-import time
-import signal
-from pathlib import Path
-import threading
-
-# Global timeout flag
-timeout_reached = False
-
-def timeout_handler():
- """Handler for overall script timeout"""
- global timeout_reached
- timeout_reached = True
- print("\n⚠️ WARNING: Script timeout reached (10s) - forcing exit")
- os._exit(0) # Force exit
-
-def kill_stale_processes():
- """Kill stale trading dashboard processes safely"""
- global timeout_reached
-
- # Set up overall timeout (10 seconds)
- timer = threading.Timer(10.0, timeout_handler)
- timer.daemon = True
- timer.start()
-
- try:
- import psutil
- except ImportError:
- print("psutil not available - using fallback method")
- return kill_stale_fallback()
-
- current_pid = os.getpid()
- killed_processes = []
- failed_processes = []
-
- # Keywords to identify trading dashboard processes
- target_keywords = [
- 'dashboard', 'scalping', 'trading', 'tensorboard',
- 'run_clean', 'run_main', 'gogo2', 'mexc'
- ]
-
- try:
- print("Scanning for stale processes...")
-
- # Get all Python processes with timeout
- python_processes = []
- scan_start = time.time()
-
- for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
- if timeout_reached or (time.time() - scan_start) > 3.0: # 3s max for scanning
- print("Process scanning timeout - proceeding with found processes")
- break
-
- try:
- if proc.info['pid'] == current_pid:
- continue
-
- name = proc.info['name'].lower()
- if 'python' in name or 'tensorboard' in name:
- cmdline_str = ' '.join(proc.info['cmdline']) if proc.info['cmdline'] else ''
-
- # Check if this is a target process
- if any(keyword in cmdline_str.lower() for keyword in target_keywords):
- python_processes.append({
- 'proc': proc,
- 'pid': proc.info['pid'],
- 'name': proc.info['name'],
- 'cmdline': cmdline_str
- })
-
- except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
- continue
-
- if not python_processes:
- print("No stale processes found")
- timer.cancel() # Cancel the timeout
- return True
-
- print(f"Found {len(python_processes)} target processes to terminate:")
- for p in python_processes[:5]: # Show max 5 to save time
- print(f" - PID {p['pid']}: {p['name']} - {p['cmdline'][:80]}...")
- if len(python_processes) > 5:
- print(f" ... and {len(python_processes) - 5} more")
-
- # Graceful termination first (with reduced wait time)
- print("\nAttempting graceful termination...")
- termination_start = time.time()
-
- for p in python_processes:
- if timeout_reached or (time.time() - termination_start) > 2.0:
- print("Termination timeout - moving to force kill")
- break
-
- try:
- proc = p['proc']
- if proc.is_running():
- proc.terminate()
- print(f" Sent SIGTERM to PID {p['pid']}")
- except Exception as e:
- failed_processes.append(f"Failed to terminate PID {p['pid']}: {e}")
-
- # Wait for graceful shutdown (reduced from 2.0 to 1.0)
- time.sleep(1.0)
-
- # Force kill remaining processes
- print("\nChecking for remaining processes...")
- kill_start = time.time()
-
- for p in python_processes:
- if timeout_reached or (time.time() - kill_start) > 2.0:
- print("Force kill timeout - exiting")
- break
-
- try:
- proc = p['proc']
- if proc.is_running():
- print(f" Force killing PID {p['pid']} ({p['name']})")
- proc.kill()
- killed_processes.append(f"Force killed PID {p['pid']} ({p['name']})")
- else:
- killed_processes.append(f"Gracefully terminated PID {p['pid']} ({p['name']})")
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- killed_processes.append(f"Process PID {p['pid']} already terminated")
- except Exception as e:
- failed_processes.append(f"Failed to kill PID {p['pid']}: {e}")
-
- # Results (quick summary)
- print(f"\n=== Quick Results ===")
- print(f"✓ Cleaned up {len(killed_processes)} processes")
- if failed_processes:
- print(f"✗ Failed: {len(failed_processes)} processes")
-
- timer.cancel() # Cancel the timeout if we finished early
- return len(failed_processes) == 0
-
- except Exception as e:
- print(f"Error during process cleanup: {e}")
- timer.cancel()
- return False
-
-def kill_stale_fallback():
- """Fallback method using basic OS commands"""
- print("Using fallback process killing method...")
-
- try:
- if os.name == 'nt': # Windows
- import subprocess
- # Kill Python processes with dashboard keywords (with timeout)
- result = subprocess.run([
- 'taskkill', '/f', '/im', 'python.exe'
- ], capture_output=True, text=True, timeout=5.0)
-
- if result.returncode == 0:
- print("Windows: Killed all Python processes")
- else:
- print("Windows: No Python processes to kill or access denied")
-
- else: # Unix/Linux
- import subprocess
- # More targeted approach for Unix (with timeouts)
- subprocess.run(['pkill', '-f', 'dashboard'], capture_output=True, timeout=2.0)
- subprocess.run(['pkill', '-f', 'scalping'], capture_output=True, timeout=2.0)
- subprocess.run(['pkill', '-f', 'tensorboard'], capture_output=True, timeout=2.0)
- print("Unix: Killed dashboard-related processes")
-
- return True
-
- except subprocess.TimeoutExpired:
- print("Fallback method timed out")
- return False
- except Exception as e:
- print(f"Fallback method failed: {e}")
- return False
-
-if __name__ == "__main__":
- print("=" * 50)
- print("STALE PROCESS CLEANUP (10s timeout)")
- print("=" * 50)
-
- start_time = time.time()
- success = kill_stale_processes()
- elapsed = time.time() - start_time
-
- exit_code = 0 if success else 1
-
- print(f"Completed in {elapsed:.1f}s")
- print("=" * 50)
- sys.exit(exit_code)
\ No newline at end of file
diff --git a/scripts/restart_dashboard_with_learning.py b/scripts/restart_dashboard_with_learning.py
deleted file mode 100644
index a272719..0000000
--- a/scripts/restart_dashboard_with_learning.py
+++ /dev/null
@@ -1,40 +0,0 @@
-#!/usr/bin/env python3
-"""
-Restart Dashboard with Forced Learning Enabled
-Simple script to start dashboard with all learning features enabled
-"""
-
-import sys
-import logging
-from datetime import datetime
-
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-def main():
- """Start dashboard with forced learning"""
- logger.info("🚀 Starting Dashboard with FORCED LEARNING ENABLED")
- logger.info("=" * 60)
- logger.info(f"Timestamp: {datetime.now()}")
- logger.info("Fixes Applied:")
- logger.info("✅ Enhanced RL: FORCED ENABLED")
- logger.info("✅ CNN Training: FORCED ENABLED")
- logger.info("✅ Williams Pivots: CNN INTEGRATED")
- logger.info("✅ Learning Pipeline: ACTIVE")
- logger.info("=" * 60)
-
- try:
- # Import and run main
- from main_clean import run_web_dashboard
- logger.info("Starting web dashboard...")
- run_web_dashboard()
-
- except KeyboardInterrupt:
- logger.info("Dashboard stopped by user")
- except Exception as e:
- logger.error(f"Dashboard failed: {e}")
- import traceback
- traceback.print_exc()
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/scripts/restart_main_overnight.py b/scripts/restart_main_overnight.py
deleted file mode 100644
index a80c6b3..0000000
--- a/scripts/restart_main_overnight.py
+++ /dev/null
@@ -1,188 +0,0 @@
-#!/usr/bin/env python3
-"""
-Overnight Training Restart Script
-Keeps main.py running continuously, restarting it if it crashes.
-Designed for overnight training sessions with unstable code.
-
-Usage:
- python restart_main_overnight.py
-
-Press Ctrl+C to stop the restart loop.
-"""
-
-import subprocess
-import sys
-import time
-import logging
-from datetime import datetime
-from pathlib import Path
-import signal
-import os
-
-# Setup logging for the restart script
-def setup_restart_logging():
- """Setup logging for restart events"""
- log_dir = Path("logs")
- log_dir.mkdir(exist_ok=True)
-
- # Create restart log file with timestamp
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- log_file = log_dir / f"restart_main_{timestamp}.log"
-
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler(log_file, encoding='utf-8'),
- logging.StreamHandler(sys.stdout)
- ]
- )
-
- logger = logging.getLogger(__name__)
- logger.info(f"Restart script logging to: {log_file}")
- return logger
-
-def kill_existing_processes(logger):
- """Kill any existing main.py processes to avoid conflicts"""
- try:
- if os.name == 'nt': # Windows
- # Kill any existing Python processes running main.py
- subprocess.run(['taskkill', '/f', '/im', 'python.exe'],
- capture_output=True, check=False)
- subprocess.run(['taskkill', '/f', '/im', 'pythonw.exe'],
- capture_output=True, check=False)
- time.sleep(2)
- except Exception as e:
- logger.warning(f"Could not kill existing processes: {e}")
-
-def run_main_with_restart(logger):
- """Main restart loop"""
- restart_count = 0
- consecutive_fast_exits = 0
- start_time = datetime.now()
-
- logger.info("=" * 60)
- logger.info("OVERNIGHT TRAINING RESTART SCRIPT STARTED")
- logger.info("=" * 60)
- logger.info("Press Ctrl+C to stop the restart loop")
- logger.info("Main script: main.py")
- logger.info("Restart delay on crash: 10 seconds")
- logger.info("Fast exit protection: Enabled")
- logger.info("=" * 60)
-
- # Kill any existing processes
- kill_existing_processes(logger)
-
- while True:
- try:
- restart_count += 1
- run_start_time = datetime.now()
-
- logger.info(f"[RESTART #{restart_count}] Starting main.py at {run_start_time.strftime('%H:%M:%S')}")
-
- # Start main.py as subprocess
- process = subprocess.Popen([
- sys.executable, "main.py"
- ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
- universal_newlines=True, bufsize=1)
-
- logger.info(f"[PROCESS] main.py started with PID: {process.pid}")
-
- # Stream output from main.py
- try:
- if process.stdout:
- while True:
- output = process.stdout.readline()
- if output == '' and process.poll() is not None:
- break
- if output:
- # Forward output from main.py (remove extra newlines)
- print(f"[MAIN] {output.rstrip()}")
- else:
- # If no stdout, just wait for process to complete
- process.wait()
- except KeyboardInterrupt:
- logger.info("[INTERRUPT] Ctrl+C received, stopping main.py...")
- process.terminate()
- try:
- process.wait(timeout=10)
- except subprocess.TimeoutExpired:
- logger.warning("[FORCE KILL] Process didn't terminate, force killing...")
- process.kill()
- raise
-
- # Process has exited
- exit_code = process.poll()
- run_end_time = datetime.now()
- run_duration = (run_end_time - run_start_time).total_seconds()
-
- logger.info(f"[EXIT] main.py exited with code {exit_code}")
- logger.info(f"[DURATION] Process ran for {run_duration:.1f} seconds")
-
- # Check for fast exits (potential configuration issues)
- if run_duration < 30: # Less than 30 seconds
- consecutive_fast_exits += 1
- logger.warning(f"[FAST EXIT] Process exited quickly ({consecutive_fast_exits} consecutive)")
-
- if consecutive_fast_exits >= 5:
- logger.error("[ABORT] Too many consecutive fast exits (5+)")
- logger.error("This indicates a configuration or startup problem")
- logger.error("Please check the main.py script manually")
- break
-
- # Longer delay for fast exits
- delay = min(60, 10 * consecutive_fast_exits)
- logger.info(f"[DELAY] Waiting {delay} seconds before restart due to fast exit...")
- time.sleep(delay)
- else:
- consecutive_fast_exits = 0 # Reset counter
- logger.info("[DELAY] Waiting 10 seconds before restart...")
- time.sleep(10)
-
- # Log session statistics every 10 restarts
- if restart_count % 10 == 0:
- total_duration = (datetime.now() - start_time).total_seconds()
- logger.info(f"[STATS] Session: {restart_count} restarts in {total_duration/3600:.1f} hours")
-
- except KeyboardInterrupt:
- logger.info("[SHUTDOWN] Restart loop interrupted by user")
- break
- except Exception as e:
- logger.error(f"[ERROR] Unexpected error in restart loop: {e}")
- logger.error("Continuing restart loop after 30 second delay...")
- time.sleep(30)
-
- total_duration = (datetime.now() - start_time).total_seconds()
- logger.info("=" * 60)
- logger.info("OVERNIGHT TRAINING SESSION COMPLETE")
- logger.info(f"Total restarts: {restart_count}")
- logger.info(f"Total session time: {total_duration/3600:.1f} hours")
- logger.info("=" * 60)
-
-def main():
- """Main entry point"""
- # Setup signal handlers for clean shutdown
- def signal_handler(signum, frame):
- logger.info(f"[SIGNAL] Received signal {signum}, shutting down...")
- sys.exit(0)
-
- signal.signal(signal.SIGINT, signal_handler)
- if hasattr(signal, 'SIGTERM'):
- signal.signal(signal.SIGTERM, signal_handler)
-
- # Setup logging
- global logger
- logger = setup_restart_logging()
-
- try:
- run_main_with_restart(logger)
- except Exception as e:
- logger.error(f"[FATAL] Fatal error in restart script: {e}")
- import traceback
- logger.error(traceback.format_exc())
- return 1
-
- return 0
-
-if __name__ == "__main__":
- sys.exit(main())
\ No newline at end of file
diff --git a/setup_mexc_browser.py b/setup_mexc_browser.py
deleted file mode 100644
index dc9b697..0000000
--- a/setup_mexc_browser.py
+++ /dev/null
@@ -1,88 +0,0 @@
-#!/usr/bin/env python3
-"""
-MEXC Browser Setup & Runner
-
-This script automatically installs dependencies and runs the MEXC browser automation.
-"""
-
-import subprocess
-import sys
-import os
-import importlib
-
-def check_and_install_requirements():
- """Check and install required packages"""
- required_packages = [
- 'selenium',
- 'webdriver-manager',
- 'requests'
- ]
-
- print("🔍 Checking required packages...")
-
- missing_packages = []
- for package in required_packages:
- try:
- importlib.import_module(package.replace('-', '_'))
- print(f"✅ {package} - already installed")
- except ImportError:
- missing_packages.append(package)
- print(f"❌ {package} - missing")
-
- if missing_packages:
- print(f"\n📦 Installing missing packages: {', '.join(missing_packages)}")
-
- for package in missing_packages:
- try:
- subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
- print(f"✅ Successfully installed {package}")
- except subprocess.CalledProcessError as e:
- print(f"❌ Failed to install {package}: {e}")
- return False
-
- print("✅ All requirements satisfied!")
- return True
-
-def run_browser_automation():
- """Run the MEXC browser automation"""
- try:
- # Import and run the auto browser
- from core.mexc_webclient.auto_browser import main as auto_browser_main
- auto_browser_main()
- except ImportError:
- print("❌ Could not import auto browser module")
- print("Make sure core/mexc_webclient/auto_browser.py exists")
- except Exception as e:
- print(f"❌ Error running browser automation: {e}")
-
-def main():
- """Main setup and run function"""
- print("🚀 MEXC Browser Automation Setup")
- print("=" * 40)
-
- # Check Python version
- if sys.version_info < (3, 7):
- print("❌ Python 3.7+ required")
- return
-
- print(f"✅ Python {sys.version.split()[0]} detected")
-
- # Install requirements
- if not check_and_install_requirements():
- print("❌ Failed to install requirements")
- return
-
- print("\n🌐 Starting browser automation...")
- print("This will:")
- print("• Download ChromeDriver automatically")
- print("• Open MEXC futures page")
- print("• Capture all trading requests")
- print("• Extract session cookies")
-
- input("\nPress Enter to continue...")
-
- # Run the automation
- run_browser_automation()
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/start_monitoring.py b/start_monitoring.py
deleted file mode 100644
index 36d8c28..0000000
--- a/start_monitoring.py
+++ /dev/null
@@ -1,160 +0,0 @@
-#!/usr/bin/env python3
-"""
-Helper script to start monitoring services for RL training
-"""
-
-import subprocess
-import sys
-import time
-import requests
-import os
-import json
-from pathlib import Path
-
-# Available ports to try for TensorBoard
-TENSORBOARD_PORTS = [6006, 6007, 6008, 6009, 6010, 6011, 6012]
-
-def check_port(port, service_name):
- """Check if a service is running on the specified port"""
- try:
- response = requests.get(f"http://localhost:{port}", timeout=3)
- print(f"✅ {service_name} is running on port {port}")
- return True
- except requests.exceptions.RequestException:
- return False
-
-def is_port_in_use(port):
- """Check if a port is already in use"""
- import socket
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- try:
- s.bind(('localhost', port))
- return False
- except OSError:
- return True
-
-def find_available_port(ports_list, service_name):
- """Find an available port from the list"""
- for port in ports_list:
- if not is_port_in_use(port):
- print(f"🔍 Found available port {port} for {service_name}")
- return port
- else:
- print(f"⚠️ Port {port} is already in use")
- return None
-
-def save_port_config(tensorboard_port):
- """Save the port configuration to a file"""
- config = {
- "tensorboard_port": tensorboard_port,
- "web_dashboard_port": 8051
- }
- with open("monitoring_ports.json", "w") as f:
- json.dump(config, f, indent=2)
- print(f"💾 Port configuration saved to monitoring_ports.json")
-
-def start_tensorboard():
- """Start TensorBoard in background on an available port"""
- try:
- # First check if TensorBoard is already running on any of our ports
- for port in TENSORBOARD_PORTS:
- if check_port(port, "TensorBoard"):
- print(f"✅ TensorBoard already running on port {port}")
- save_port_config(port)
- return port
-
- # Find an available port
- port = find_available_port(TENSORBOARD_PORTS, "TensorBoard")
- if port is None:
- print(f"❌ No available ports found in range {TENSORBOARD_PORTS}")
- return None
-
- print(f"🚀 Starting TensorBoard on port {port}...")
-
- # Create runs directory if it doesn't exist
- Path("runs").mkdir(exist_ok=True)
-
- # Start TensorBoard
- if os.name == 'nt': # Windows
- subprocess.Popen([
- sys.executable, "-m", "tensorboard",
- "--logdir=runs", f"--port={port}", "--reload_interval=1"
- ], creationflags=subprocess.CREATE_NEW_CONSOLE)
- else: # Linux/Mac
- subprocess.Popen([
- sys.executable, "-m", "tensorboard",
- "--logdir=runs", f"--port={port}", "--reload_interval=1"
- ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
-
- # Wait for TensorBoard to start
- print(f"⏳ Waiting for TensorBoard to start on port {port}...")
- for i in range(15):
- time.sleep(2)
- if check_port(port, "TensorBoard"):
- save_port_config(port)
- return port
-
- print(f"⚠️ TensorBoard failed to start on port {port} within 30 seconds")
- return None
-
- except Exception as e:
- print(f"❌ Error starting TensorBoard: {e}")
- return None
-
-def check_web_dashboard_port():
- """Check if web dashboard port is available"""
- port = 8051
- if is_port_in_use(port):
- print(f"⚠️ Web dashboard port {port} is in use")
- # Try alternative ports
- for alt_port in [8052, 8053, 8054, 8055]:
- if not is_port_in_use(alt_port):
- print(f"🔍 Alternative port {alt_port} available for web dashboard")
- return alt_port
- print("❌ No alternative ports found for web dashboard")
- return port
- else:
- print(f"✅ Web dashboard port {port} is available")
- return port
-
-def main():
- """Main function"""
- print("=" * 60)
- print("🎯 RL TRAINING MONITORING SETUP")
- print("=" * 60)
-
- # Check web dashboard port
- web_port = check_web_dashboard_port()
-
- # Start TensorBoard
- tensorboard_port = start_tensorboard()
-
- print("\n" + "=" * 60)
- print("📊 MONITORING STATUS")
- print("=" * 60)
-
- if tensorboard_port:
- print(f"✅ TensorBoard: http://localhost:{tensorboard_port}")
- # Update port config
- save_port_config(tensorboard_port)
- else:
- print("❌ TensorBoard: Failed to start")
- print(" Manual start: python -m tensorboard --logdir=runs --port=6007")
-
- if web_port:
- print(f"✅ Web Dashboard: Ready on port {web_port}")
-
- print(f"\n🎯 Ready to start RL training!")
- if tensorboard_port and web_port != 8051:
- print(f"Run: python train_realtime_with_tensorboard.py --episodes 10 --web-port {web_port}")
- else:
- print("Run: python train_realtime_with_tensorboard.py --episodes 10")
-
- print(f"\n📋 Available URLs:")
- if tensorboard_port:
- print(f" 📊 TensorBoard: http://localhost:{tensorboard_port}")
- if web_port:
- print(f" 🌐 Web Dashboard: http://localhost:{web_port} (starts with training)")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/start_overnight_training.py b/start_overnight_training.py
deleted file mode 100644
index 2479e85..0000000
--- a/start_overnight_training.py
+++ /dev/null
@@ -1,179 +0,0 @@
-#!/usr/bin/env python3
-"""
-Start Overnight Training Session
-
-This script starts a comprehensive overnight training session that:
-1. Ensures CNN and COB RL training processes are implemented and running
-2. Executes training passes on each signal when predictions change
-3. Calculates PnL and records trades in SIM mode
-4. Tracks model performance statistics
-5. Converts signals to actual trades for performance tracking
-"""
-
-import os
-import sys
-import time
-import logging
-from datetime import datetime
-from pathlib import Path
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler(f'overnight_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
- logging.StreamHandler()
- ]
-)
-
-logger = logging.getLogger(__name__)
-
-def main():
- """Start the overnight training session"""
- try:
- logger.info("🌙 STARTING OVERNIGHT TRAINING SESSION")
- logger.info("=" * 80)
-
- # Import required components
- from core.config import get_config, setup_logging
- from core.data_provider import DataProvider
- from core.orchestrator import TradingOrchestrator
- from core.trading_executor import TradingExecutor
- from web.clean_dashboard import CleanTradingDashboard
-
- # Setup logging
- setup_logging()
-
- # Initialize components
- logger.info("Initializing components...")
-
- # Create data provider
- data_provider = DataProvider()
- logger.info("✅ Data Provider initialized")
-
- # Create trading executor in simulation mode
- trading_executor = TradingExecutor()
- trading_executor.simulation_mode = True # Ensure we're in simulation mode
- logger.info("✅ Trading Executor initialized (SIMULATION MODE)")
-
- # Create orchestrator with enhanced training
- orchestrator = TradingOrchestrator(
- data_provider=data_provider,
- enhanced_rl_training=True
- )
- logger.info("✅ Trading Orchestrator initialized")
-
- # Connect trading executor to orchestrator
- if hasattr(orchestrator, 'set_trading_executor'):
- orchestrator.set_trading_executor(trading_executor)
- logger.info("✅ Trading Executor connected to Orchestrator")
-
- # Create dashboard (this initializes the overnight training coordinator)
- dashboard = CleanTradingDashboard(
- data_provider=data_provider,
- orchestrator=orchestrator,
- trading_executor=trading_executor
- )
- logger.info("✅ Dashboard initialized with Overnight Training Coordinator")
-
- # Start the overnight training session
- logger.info("Starting overnight training session...")
- success = dashboard.start_overnight_training()
-
- if success:
- logger.info("🌙 OVERNIGHT TRAINING SESSION STARTED SUCCESSFULLY")
- logger.info("=" * 80)
- logger.info("Training Features Active:")
- logger.info("✅ CNN training on signal changes")
- logger.info("✅ COB RL training on market microstructure")
- logger.info("✅ DQN training on trading decisions")
- logger.info("✅ Trade execution and recording (SIMULATION)")
- logger.info("✅ Performance tracking and statistics")
- logger.info("✅ Model checkpointing every 50 trades")
- logger.info("✅ Signal-to-trade conversion with PnL calculation")
- logger.info("=" * 80)
-
- # Monitor training progress
- logger.info("Monitoring training progress...")
- logger.info("Press Ctrl+C to stop the training session")
-
- # Keep the session running and periodically report progress
- start_time = datetime.now()
- last_report_time = start_time
-
- while True:
- try:
- time.sleep(60) # Check every minute
-
- current_time = datetime.now()
- elapsed_time = current_time - start_time
-
- # Get performance summary every 10 minutes
- if (current_time - last_report_time).total_seconds() >= 600: # 10 minutes
- performance = dashboard.get_training_performance_summary()
-
- logger.info("=" * 60)
- logger.info(f"🌙 TRAINING PROGRESS REPORT - {elapsed_time}")
- logger.info("=" * 60)
- logger.info(f"Total Signals: {performance.get('total_signals', 0)}")
- logger.info(f"Total Trades: {performance.get('total_trades', 0)}")
- logger.info(f"Successful Trades: {performance.get('successful_trades', 0)}")
- logger.info(f"Success Rate: {performance.get('success_rate', 0):.1%}")
- logger.info(f"Total P&L: ${performance.get('total_pnl', 0):.2f}")
- logger.info(f"Models Trained: {', '.join(performance.get('models_trained', []))}")
- logger.info(f"Training Status: {'ACTIVE' if performance.get('is_running', False) else 'INACTIVE'}")
- logger.info("=" * 60)
-
- last_report_time = current_time
-
- except KeyboardInterrupt:
- logger.info("\n🛑 Training session interrupted by user")
- break
- except Exception as e:
- logger.error(f"Error during training monitoring: {e}")
- time.sleep(30) # Wait 30 seconds before retrying
-
- # Stop the training session
- logger.info("Stopping overnight training session...")
- dashboard.stop_overnight_training()
-
- # Final report
- final_performance = dashboard.get_training_performance_summary()
- total_time = datetime.now() - start_time
-
- logger.info("=" * 80)
- logger.info("🌅 OVERNIGHT TRAINING SESSION COMPLETED")
- logger.info("=" * 80)
- logger.info(f"Total Duration: {total_time}")
- logger.info(f"Final Statistics:")
- logger.info(f" Total Signals: {final_performance.get('total_signals', 0)}")
- logger.info(f" Total Trades: {final_performance.get('total_trades', 0)}")
- logger.info(f" Successful Trades: {final_performance.get('successful_trades', 0)}")
- logger.info(f" Success Rate: {final_performance.get('success_rate', 0):.1%}")
- logger.info(f" Total P&L: ${final_performance.get('total_pnl', 0):.2f}")
- logger.info(f" Models Trained: {', '.join(final_performance.get('models_trained', []))}")
- logger.info("=" * 80)
-
- else:
- logger.error("❌ Failed to start overnight training session")
- return 1
-
- except KeyboardInterrupt:
- logger.info("\n🛑 Training session interrupted by user")
- return 0
- except Exception as e:
- logger.error(f"❌ Error in overnight training session: {e}")
- import traceback
- traceback.print_exc()
- return 1
-
- return 0
-
-if __name__ == "__main__":
- exit_code = main()
- sys.exit(exit_code)
\ No newline at end of file
diff --git a/system_stability_audit.py b/system_stability_audit.py
deleted file mode 100644
index 26fbe94..0000000
--- a/system_stability_audit.py
+++ /dev/null
@@ -1,426 +0,0 @@
-#!/usr/bin/env python3
-"""
-System Stability Audit and Monitoring
-
-This script performs a comprehensive audit of the trading system to identify
-and fix stability issues, memory leaks, and performance bottlenecks.
-"""
-
-import os
-import sys
-import psutil
-import logging
-import time
-import threading
-import gc
-from datetime import datetime, timedelta
-from pathlib import Path
-from typing import Dict, List, Optional, Any
-import traceback
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import setup_logging, get_config
-
-# Setup logging
-setup_logging()
-logger = logging.getLogger(__name__)
-
-class SystemStabilityAuditor:
- """
- Comprehensive system stability auditor and monitor
-
- Monitors:
- - Memory usage and leaks
- - CPU usage and performance
- - Thread health and deadlocks
- - Model performance and stability
- - Dashboard responsiveness
- - Data provider health
- """
-
- def __init__(self):
- """Initialize the stability auditor"""
- self.config = get_config()
- self.monitoring_active = False
- self.monitoring_thread = None
-
- # Performance baselines
- self.baseline_memory = psutil.virtual_memory().used
- self.baseline_cpu = psutil.cpu_percent()
-
- # Monitoring data
- self.memory_history = []
- self.cpu_history = []
- self.thread_history = []
- self.error_history = []
-
- # Stability metrics
- self.stability_score = 100.0
- self.critical_issues = []
- self.warnings = []
-
- logger.info("System Stability Auditor initialized")
-
- def start_monitoring(self):
- """Start continuous system monitoring"""
- if self.monitoring_active:
- logger.warning("Monitoring already active")
- return
-
- self.monitoring_active = True
- self.monitoring_thread = threading.Thread(target=self._monitoring_loop, daemon=True)
- self.monitoring_thread.start()
-
- logger.info("System stability monitoring started")
-
- def stop_monitoring(self):
- """Stop system monitoring"""
- self.monitoring_active = False
- if self.monitoring_thread:
- self.monitoring_thread.join(timeout=5)
-
- logger.info("System stability monitoring stopped")
-
- def _monitoring_loop(self):
- """Main monitoring loop"""
- while self.monitoring_active:
- try:
- # Collect system metrics
- self._collect_system_metrics()
-
- # Check for memory leaks
- self._check_memory_leaks()
-
- # Check CPU usage
- self._check_cpu_usage()
-
- # Check thread health
- self._check_thread_health()
-
- # Check for deadlocks
- self._check_for_deadlocks()
-
- # Update stability score
- self._update_stability_score()
-
- # Log status every 60 seconds
- if len(self.memory_history) % 12 == 0: # Every 12 * 5s = 60s
- self._log_stability_status()
-
- time.sleep(5) # Check every 5 seconds
-
- except Exception as e:
- logger.error(f"Error in monitoring loop: {e}")
- self.error_history.append({
- 'timestamp': datetime.now(),
- 'error': str(e),
- 'traceback': traceback.format_exc()
- })
- time.sleep(10) # Wait longer on error
-
- def _collect_system_metrics(self):
- """Collect system performance metrics"""
- try:
- # Memory metrics
- memory = psutil.virtual_memory()
- memory_data = {
- 'timestamp': datetime.now(),
- 'used_gb': memory.used / (1024**3),
- 'available_gb': memory.available / (1024**3),
- 'percent': memory.percent
- }
- self.memory_history.append(memory_data)
-
- # Keep only last 720 entries (1 hour at 5s intervals)
- if len(self.memory_history) > 720:
- self.memory_history = self.memory_history[-720:]
-
- # CPU metrics
- cpu_percent = psutil.cpu_percent(interval=1)
- cpu_data = {
- 'timestamp': datetime.now(),
- 'percent': cpu_percent,
- 'cores': psutil.cpu_count()
- }
- self.cpu_history.append(cpu_data)
-
- # Keep only last 720 entries
- if len(self.cpu_history) > 720:
- self.cpu_history = self.cpu_history[-720:]
-
- # Thread metrics
- thread_count = threading.active_count()
- thread_data = {
- 'timestamp': datetime.now(),
- 'count': thread_count,
- 'threads': [t.name for t in threading.enumerate()]
- }
- self.thread_history.append(thread_data)
-
- # Keep only last 720 entries
- if len(self.thread_history) > 720:
- self.thread_history = self.thread_history[-720:]
-
- except Exception as e:
- logger.error(f"Error collecting system metrics: {e}")
-
- def _check_memory_leaks(self):
- """Check for memory leaks"""
- try:
- if len(self.memory_history) < 10:
- return
-
- # Check if memory usage is consistently increasing
- recent_memory = [m['used_gb'] for m in self.memory_history[-10:]]
- memory_trend = sum(recent_memory[-5:]) / 5 - sum(recent_memory[:5]) / 5
-
- # If memory increased by more than 100MB in last 10 checks
- if memory_trend > 0.1:
- warning = f"Potential memory leak detected: +{memory_trend:.2f}GB in last 50s"
- if warning not in self.warnings:
- self.warnings.append(warning)
- logger.warning(warning)
-
- # Force garbage collection
- gc.collect()
- logger.info("Forced garbage collection to free memory")
-
- # Check for excessive memory usage
- current_memory = self.memory_history[-1]['percent']
- if current_memory > 85:
- critical = f"High memory usage: {current_memory:.1f}%"
- if critical not in self.critical_issues:
- self.critical_issues.append(critical)
- logger.error(critical)
-
- except Exception as e:
- logger.error(f"Error checking memory leaks: {e}")
-
- def _check_cpu_usage(self):
- """Check CPU usage patterns"""
- try:
- if len(self.cpu_history) < 10:
- return
-
- # Check for sustained high CPU usage
- recent_cpu = [c['percent'] for c in self.cpu_history[-10:]]
- avg_cpu = sum(recent_cpu) / len(recent_cpu)
-
- if avg_cpu > 90:
- critical = f"Sustained high CPU usage: {avg_cpu:.1f}%"
- if critical not in self.critical_issues:
- self.critical_issues.append(critical)
- logger.error(critical)
- elif avg_cpu > 75:
- warning = f"High CPU usage: {avg_cpu:.1f}%"
- if warning not in self.warnings:
- self.warnings.append(warning)
- logger.warning(warning)
-
- except Exception as e:
- logger.error(f"Error checking CPU usage: {e}")
-
- def _check_thread_health(self):
- """Check thread health and detect issues"""
- try:
- if len(self.thread_history) < 5:
- return
-
- current_threads = self.thread_history[-1]['count']
-
- # Check for thread explosion
- if current_threads > 50:
- critical = f"Thread explosion detected: {current_threads} active threads"
- if critical not in self.critical_issues:
- self.critical_issues.append(critical)
- logger.error(critical)
-
- # Log thread names for debugging
- thread_names = self.thread_history[-1]['threads']
- logger.error(f"Active threads: {thread_names}")
-
- # Check for thread leaks (gradually increasing thread count)
- if len(self.thread_history) >= 10:
- thread_counts = [t['count'] for t in self.thread_history[-10:]]
- thread_trend = sum(thread_counts[-5:]) / 5 - sum(thread_counts[:5]) / 5
-
- if thread_trend > 2: # More than 2 threads increase on average
- warning = f"Potential thread leak: +{thread_trend:.1f} threads in last 50s"
- if warning not in self.warnings:
- self.warnings.append(warning)
- logger.warning(warning)
-
- except Exception as e:
- logger.error(f"Error checking thread health: {e}")
-
- def _check_for_deadlocks(self):
- """Check for potential deadlocks"""
- try:
- # Simple deadlock detection based on thread states
- all_threads = threading.enumerate()
- blocked_threads = []
-
- for thread in all_threads:
- if hasattr(thread, '_is_stopped') and not thread._is_stopped:
- # Thread is running but might be blocked
- # This is a simplified check - real deadlock detection is complex
- pass
-
- # For now, just check if we have threads that haven't been active
- # More sophisticated deadlock detection would require thread state analysis
-
- except Exception as e:
- logger.error(f"Error checking for deadlocks: {e}")
-
- def _update_stability_score(self):
- """Update overall system stability score"""
- try:
- score = 100.0
-
- # Deduct points for critical issues
- score -= len(self.critical_issues) * 20
-
- # Deduct points for warnings
- score -= len(self.warnings) * 5
-
- # Deduct points for recent errors
- recent_errors = [e for e in self.error_history
- if e['timestamp'] > datetime.now() - timedelta(minutes=10)]
- score -= len(recent_errors) * 10
-
- # Deduct points for high resource usage
- if self.memory_history:
- current_memory = self.memory_history[-1]['percent']
- if current_memory > 80:
- score -= (current_memory - 80) * 2
-
- if self.cpu_history:
- current_cpu = self.cpu_history[-1]['percent']
- if current_cpu > 80:
- score -= (current_cpu - 80) * 1
-
- self.stability_score = max(0, score)
-
- except Exception as e:
- logger.error(f"Error updating stability score: {e}")
-
- def _log_stability_status(self):
- """Log current stability status"""
- try:
- logger.info("=" * 50)
- logger.info("SYSTEM STABILITY STATUS")
- logger.info("=" * 50)
- logger.info(f"Stability Score: {self.stability_score:.1f}/100")
-
- if self.memory_history:
- mem = self.memory_history[-1]
- logger.info(f"Memory: {mem['used_gb']:.1f}GB used ({mem['percent']:.1f}%)")
-
- if self.cpu_history:
- cpu = self.cpu_history[-1]
- logger.info(f"CPU: {cpu['percent']:.1f}%")
-
- if self.thread_history:
- threads = self.thread_history[-1]
- logger.info(f"Threads: {threads['count']} active")
-
- if self.critical_issues:
- logger.error(f"Critical Issues ({len(self.critical_issues)}):")
- for issue in self.critical_issues[-5:]: # Show last 5
- logger.error(f" - {issue}")
-
- if self.warnings:
- logger.warning(f"Warnings ({len(self.warnings)}):")
- for warning in self.warnings[-5:]: # Show last 5
- logger.warning(f" - {warning}")
-
- logger.info("=" * 50)
-
- except Exception as e:
- logger.error(f"Error logging stability status: {e}")
-
- def get_stability_report(self) -> Dict[str, Any]:
- """Get comprehensive stability report"""
- try:
- return {
- 'stability_score': self.stability_score,
- 'critical_issues': self.critical_issues,
- 'warnings': self.warnings,
- 'memory_usage': self.memory_history[-1] if self.memory_history else None,
- 'cpu_usage': self.cpu_history[-1] if self.cpu_history else None,
- 'thread_count': self.thread_history[-1]['count'] if self.thread_history else 0,
- 'recent_errors': len([e for e in self.error_history
- if e['timestamp'] > datetime.now() - timedelta(minutes=10)]),
- 'monitoring_active': self.monitoring_active
- }
- except Exception as e:
- logger.error(f"Error generating stability report: {e}")
- return {'error': str(e)}
-
- def fix_common_issues(self):
- """Attempt to fix common stability issues"""
- try:
- logger.info("Attempting to fix common stability issues...")
-
- # Force garbage collection
- gc.collect()
- logger.info("✓ Forced garbage collection")
-
- # Clear old history to free memory
- if len(self.memory_history) > 360: # Keep only 30 minutes
- self.memory_history = self.memory_history[-360:]
- if len(self.cpu_history) > 360:
- self.cpu_history = self.cpu_history[-360:]
- if len(self.thread_history) > 360:
- self.thread_history = self.thread_history[-360:]
-
- logger.info("✓ Cleared old monitoring history")
-
- # Clear old errors
- cutoff_time = datetime.now() - timedelta(hours=1)
- self.error_history = [e for e in self.error_history if e['timestamp'] > cutoff_time]
- logger.info("✓ Cleared old error history")
-
- # Reset warnings and critical issues that might be stale
- self.warnings = []
- self.critical_issues = []
- logger.info("✓ Reset stale warnings and critical issues")
-
- logger.info("Common stability fixes applied")
-
- except Exception as e:
- logger.error(f"Error fixing common issues: {e}")
-
-def main():
- """Main function for standalone execution"""
- try:
- logger.info("Starting System Stability Audit")
-
- auditor = SystemStabilityAuditor()
- auditor.start_monitoring()
-
- # Run for 5 minutes then generate report
- time.sleep(300)
-
- report = auditor.get_stability_report()
- logger.info("FINAL STABILITY REPORT:")
- logger.info(f"Stability Score: {report['stability_score']:.1f}/100")
- logger.info(f"Critical Issues: {len(report['critical_issues'])}")
- logger.info(f"Warnings: {len(report['warnings'])}")
-
- # Attempt fixes if needed
- if report['stability_score'] < 80:
- auditor.fix_common_issues()
-
- auditor.stop_monitoring()
-
- except KeyboardInterrupt:
- logger.info("Audit interrupted by user")
- except Exception as e:
- logger.error(f"Error in stability audit: {e}")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/test_build_base_data_performance.py b/test_build_base_data_performance.py
deleted file mode 100644
index 112c4f7..0000000
--- a/test_build_base_data_performance.py
+++ /dev/null
@@ -1,191 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Build Base Data Performance
-
-This script tests the performance of build_base_data_input to ensure it's instantaneous.
-"""
-
-import sys
-import os
-import time
-import logging
-from datetime import datetime
-
-# Add project root to path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from core.orchestrator import TradingOrchestrator
-from core.config import get_config
-
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_build_base_data_performance():
- """Test the performance of build_base_data_input"""
-
- logger.info("=== Testing Build Base Data Performance ===")
-
- try:
- # Initialize orchestrator
- config = get_config()
- orchestrator = TradingOrchestrator(
- symbol="ETH/USDT",
- config=config
- )
-
- # Start the orchestrator to initialize data
- orchestrator.start()
- logger.info("✅ Orchestrator started")
-
- # Wait a bit for data to be populated
- time.sleep(2)
-
- # Test performance of build_base_data_input
- symbol = "ETH/USDT"
- num_tests = 10
- total_time = 0
-
- logger.info(f"Running {num_tests} performance tests...")
-
- for i in range(num_tests):
- start_time = time.time()
-
- base_data = orchestrator.build_base_data_input(symbol)
-
- end_time = time.time()
- duration = (end_time - start_time) * 1000 # Convert to milliseconds
- total_time += duration
-
- if base_data:
- logger.info(f"Test {i+1}: {duration:.2f}ms - ✅ Success")
- else:
- logger.warning(f"Test {i+1}: {duration:.2f}ms - ❌ Failed (no data)")
-
- avg_time = total_time / num_tests
-
- logger.info(f"=== Performance Results ===")
- logger.info(f"Average time: {avg_time:.2f}ms")
- logger.info(f"Total time: {total_time:.2f}ms")
-
- # Performance thresholds
- if avg_time < 10: # Less than 10ms is excellent
- logger.info("🎉 EXCELLENT: Build time is under 10ms")
- elif avg_time < 50: # Less than 50ms is good
- logger.info("✅ GOOD: Build time is under 50ms")
- elif avg_time < 100: # Less than 100ms is acceptable
- logger.info("⚠️ ACCEPTABLE: Build time is under 100ms")
- else:
- logger.error("❌ SLOW: Build time is over 100ms - needs optimization")
-
- # Test with multiple symbols
- logger.info("Testing with multiple symbols...")
- symbols = ["ETH/USDT", "BTC/USDT"]
-
- for symbol in symbols:
- start_time = time.time()
- base_data = orchestrator.build_base_data_input(symbol)
- end_time = time.time()
- duration = (end_time - start_time) * 1000
-
- logger.info(f"{symbol}: {duration:.2f}ms")
-
- # Stop orchestrator
- orchestrator.stop()
- logger.info("✅ Orchestrator stopped")
-
- return avg_time < 100 # Return True if performance is acceptable
-
- except Exception as e:
- logger.error(f"❌ Performance test failed: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-def test_cache_effectiveness():
- """Test that caching is working effectively"""
-
- logger.info("=== Testing Cache Effectiveness ===")
-
- try:
- # Initialize orchestrator
- config = get_config()
- orchestrator = TradingOrchestrator(
- symbol="ETH/USDT",
- config=config
- )
-
- orchestrator.start()
- time.sleep(2) # Let data populate
-
- symbol = "ETH/USDT"
-
- # First call (should build cache)
- start_time = time.time()
- base_data1 = orchestrator.build_base_data_input(symbol)
- first_call_time = (time.time() - start_time) * 1000
-
- # Second call (should use cache)
- start_time = time.time()
- base_data2 = orchestrator.build_base_data_input(symbol)
- second_call_time = (time.time() - start_time) * 1000
-
- # Third call (should still use cache)
- start_time = time.time()
- base_data3 = orchestrator.build_base_data_input(symbol)
- third_call_time = (time.time() - start_time) * 1000
-
- logger.info(f"First call (build cache): {first_call_time:.2f}ms")
- logger.info(f"Second call (use cache): {second_call_time:.2f}ms")
- logger.info(f"Third call (use cache): {third_call_time:.2f}ms")
-
- # Cache should make subsequent calls faster
- if second_call_time < first_call_time * 0.5:
- logger.info("✅ Cache is working effectively")
- cache_effective = True
- else:
- logger.warning("⚠️ Cache may not be working as expected")
- cache_effective = False
-
- # Verify data consistency
- if base_data1 and base_data2 and base_data3:
- # Check that we get consistent data structure
- if (len(base_data1.ohlcv_1s) == len(base_data2.ohlcv_1s) == len(base_data3.ohlcv_1s)):
- logger.info("✅ Data consistency maintained")
- else:
- logger.warning("⚠️ Data consistency issues detected")
-
- orchestrator.stop()
-
- return cache_effective
-
- except Exception as e:
- logger.error(f"❌ Cache effectiveness test failed: {e}")
- return False
-
-def main():
- """Run all performance tests"""
-
- logger.info("Starting Build Base Data Performance Tests")
-
- # Test 1: Basic performance
- test1_passed = test_build_base_data_performance()
-
- # Test 2: Cache effectiveness
- test2_passed = test_cache_effectiveness()
-
- # Summary
- logger.info("=== Test Summary ===")
- logger.info(f"Performance Test: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
- logger.info(f"Cache Effectiveness: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
-
- if test1_passed and test2_passed:
- logger.info("🎉 All tests passed! build_base_data_input is optimized.")
- logger.info("The system now:")
- logger.info(" - Builds BaseDataInput in under 100ms")
- logger.info(" - Uses effective caching for repeated calls")
- logger.info(" - Maintains data consistency")
- else:
- logger.error("❌ Some tests failed. Performance optimization needed.")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/test_bybit_eth_futures.py b/test_bybit_eth_futures.py
deleted file mode 100644
index 8a8f085..0000000
--- a/test_bybit_eth_futures.py
+++ /dev/null
@@ -1,348 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script for Bybit ETH futures position opening/closing
-"""
-
-import os
-import sys
-import time
-import logging
-from datetime import datetime
-
-# Add the project root to the path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-# Load environment variables from .env file
-try:
- from dotenv import load_dotenv
- load_dotenv()
-except ImportError:
- # If dotenv is not available, try to load .env manually
- if os.path.exists('.env'):
- with open('.env', 'r') as f:
- for line in f:
- if line.strip() and not line.startswith('#'):
- key, value = line.strip().split('=', 1)
- os.environ[key] = value
-
-from core.exchanges.bybit_interface import BybitInterface
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-class BybitEthFuturesTest:
- """Test class for Bybit ETH futures trading"""
-
- def __init__(self, test_mode=True):
- self.test_mode = test_mode
- self.bybit = BybitInterface(test_mode=test_mode)
- self.test_symbol = 'ETHUSDT'
- self.test_quantity = 0.01 # Small test amount
-
- def run_tests(self):
- """Run all tests"""
- print("=" * 60)
- print("BYBIT ETH FUTURES POSITION TESTING")
- print("=" * 60)
- print(f"Test mode: {'TESTNET' if self.test_mode else 'LIVE'}")
- print(f"Symbol: {self.test_symbol}")
- print(f"Test quantity: {self.test_quantity} ETH")
- print("=" * 60)
-
- # Test 1: Connection
- if not self.test_connection():
- print("❌ Connection failed - stopping tests")
- return False
-
- # Test 2: Check balance
- if not self.test_balance():
- print("❌ Balance check failed - stopping tests")
- return False
-
- # Test 3: Check current positions
- self.test_current_positions()
-
- # Test 4: Get ticker
- if not self.test_ticker():
- print("❌ Ticker test failed - stopping tests")
- return False
-
- # Test 5: Open a long position
- long_order = self.test_open_long_position()
- if not long_order:
- print("❌ Open long position failed")
- return False
-
- # Test 6: Check position after opening
- time.sleep(2) # Wait for position to be reflected
- if not self.test_position_after_open():
- print("❌ Position check after opening failed")
- return False
-
- # Test 7: Close the position
- if not self.test_close_position():
- print("❌ Close position failed")
- return False
-
- # Test 8: Check position after closing
- time.sleep(2) # Wait for position to be reflected
- self.test_position_after_close()
-
- print("\n" + "=" * 60)
- print("✅ ALL TESTS COMPLETED SUCCESSFULLY")
- print("=" * 60)
- return True
-
- def test_connection(self):
- """Test connection to Bybit"""
- print("\n📡 Testing connection to Bybit...")
-
- # First test simple connectivity without auth
- print("Testing basic API connectivity...")
- try:
- from core.exchanges.bybit_rest_client import BybitRestClient
- client = BybitRestClient(
- api_key="dummy",
- api_secret="dummy",
- testnet=True
- )
-
- # Test public endpoint (server time)
- server_time = client.get_server_time()
- print(f"✅ Public API working - Server time: {server_time.get('result', {}).get('timeSecond')}")
-
- except Exception as e:
- print(f"❌ Public API failed: {e}")
- return False
-
- # Now test with actual credentials
- print("Testing with API credentials...")
- try:
- connected = self.bybit.connect()
- if connected:
- print("✅ Successfully connected to Bybit with credentials")
- return True
- else:
- print("❌ Failed to connect to Bybit with credentials")
- print("This might be due to:")
- print("- Invalid API credentials")
- print("- Credentials not enabled for testnet")
- print("- Missing required permissions")
- return False
- except Exception as e:
- print(f"❌ Connection error: {e}")
- return False
-
- def test_balance(self):
- """Test getting account balance"""
- print("\n💰 Testing account balance...")
-
- try:
- # Get USDT balance (for margin)
- usdt_balance = self.bybit.get_balance('USDT')
- print(f"USDT Balance: {usdt_balance}")
-
- # Get all balances
- all_balances = self.bybit.get_all_balances()
- print("All balances:")
- for asset, balance in all_balances.items():
- if balance['total'] > 0:
- print(f" {asset}: Free={balance['free']}, Locked={balance['locked']}, Total={balance['total']}")
-
- if usdt_balance > 10: # Need at least $10 for testing
- print("✅ Sufficient balance for testing")
- return True
- else:
- print("❌ Insufficient USDT balance for testing (need at least $10)")
- return False
-
- except Exception as e:
- print(f"❌ Balance check error: {e}")
- return False
-
- def test_current_positions(self):
- """Test getting current positions"""
- print("\n📊 Checking current positions...")
-
- try:
- positions = self.bybit.get_positions()
- if positions:
- print(f"Found {len(positions)} open positions:")
- for pos in positions:
- print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}")
- print(f" PnL: ${pos['unrealized_pnl']:.2f} ({pos['percentage']:.2f}%)")
- else:
- print("No open positions found")
-
- except Exception as e:
- print(f"❌ Position check error: {e}")
-
- def test_ticker(self):
- """Test getting ticker information"""
- print(f"\n📈 Testing ticker for {self.test_symbol}...")
-
- try:
- ticker = self.bybit.get_ticker(self.test_symbol)
- if ticker:
- print(f"✅ Ticker data received:")
- print(f" Last Price: ${ticker['last_price']:.2f}")
- print(f" Bid: ${ticker['bid_price']:.2f}")
- print(f" Ask: ${ticker['ask_price']:.2f}")
- print(f" 24h Volume: {ticker['volume_24h']:.2f}")
- print(f" 24h Change: {ticker['change_24h']:.4f}%")
- return True
- else:
- print("❌ Failed to get ticker data")
- return False
-
- except Exception as e:
- print(f"❌ Ticker error: {e}")
- return False
-
- def test_open_long_position(self):
- """Test opening a long position"""
- print(f"\n🚀 Opening long position for {self.test_quantity} {self.test_symbol}...")
-
- try:
- # Place market buy order
- order = self.bybit.place_order(
- symbol=self.test_symbol,
- side='buy',
- order_type='market',
- quantity=self.test_quantity
- )
-
- if 'error' in order:
- print(f"❌ Order failed: {order['error']}")
- return None
-
- print("✅ Long position opened successfully:")
- print(f" Order ID: {order['order_id']}")
- print(f" Symbol: {order['symbol']}")
- print(f" Side: {order['side']}")
- print(f" Quantity: {order['quantity']}")
- print(f" Status: {order['status']}")
-
- return order
-
- except Exception as e:
- print(f"❌ Open position error: {e}")
- return None
-
- def test_position_after_open(self):
- """Test checking position after opening"""
- print(f"\n📊 Checking position after opening...")
-
- try:
- positions = self.bybit.get_positions(self.test_symbol)
- if positions:
- position = positions[0]
- print("✅ Position found:")
- print(f" Symbol: {position['symbol']}")
- print(f" Side: {position['side']}")
- print(f" Size: {position['size']}")
- print(f" Entry Price: ${position['entry_price']:.2f}")
- print(f" Mark Price: ${position['mark_price']:.2f}")
- print(f" Unrealized PnL: ${position['unrealized_pnl']:.2f}")
- print(f" Percentage: {position['percentage']:.2f}%")
- print(f" Leverage: {position['leverage']}x")
- return True
- else:
- print("❌ No position found after opening")
- return False
-
- except Exception as e:
- print(f"❌ Position check error: {e}")
- return False
-
- def test_close_position(self):
- """Test closing the position"""
- print(f"\n🔄 Closing position for {self.test_symbol}...")
-
- try:
- # Close the position
- close_order = self.bybit.close_position(self.test_symbol)
-
- if 'error' in close_order:
- print(f"❌ Close order failed: {close_order['error']}")
- return False
-
- print("✅ Position closed successfully:")
- print(f" Order ID: {close_order['order_id']}")
- print(f" Symbol: {close_order['symbol']}")
- print(f" Side: {close_order['side']}")
- print(f" Quantity: {close_order['quantity']}")
- print(f" Status: {close_order['status']}")
-
- return True
-
- except Exception as e:
- print(f"❌ Close position error: {e}")
- return False
-
- def test_position_after_close(self):
- """Test checking position after closing"""
- print(f"\n📊 Checking position after closing...")
-
- try:
- positions = self.bybit.get_positions(self.test_symbol)
- if positions:
- position = positions[0]
- print("⚠️ Position still exists (may be partially closed):")
- print(f" Symbol: {position['symbol']}")
- print(f" Side: {position['side']}")
- print(f" Size: {position['size']}")
- print(f" Entry Price: ${position['entry_price']:.2f}")
- print(f" Unrealized PnL: ${position['unrealized_pnl']:.2f}")
- else:
- print("✅ Position successfully closed - no open positions")
-
- except Exception as e:
- print(f"❌ Position check error: {e}")
-
- def test_order_history(self):
- """Test getting order history"""
- print(f"\n📋 Checking recent orders...")
-
- try:
- # Get open orders
- open_orders = self.bybit.get_open_orders(self.test_symbol)
- print(f"Open orders: {len(open_orders)}")
- for order in open_orders:
- print(f" {order['order_id']}: {order['side']} {order['quantity']} @ ${order['price']:.2f} - {order['status']}")
-
- except Exception as e:
- print(f"❌ Order history error: {e}")
-
-def main():
- """Main function"""
- print("Starting Bybit ETH Futures Test...")
-
- # Check if API credentials are set
- api_key = os.getenv('BYBIT_API_KEY')
- api_secret = os.getenv('BYBIT_API_SECRET')
-
- if not api_key or not api_secret:
- print("❌ Please set BYBIT_API_KEY and BYBIT_API_SECRET environment variables")
- return False
-
- # Create test instance
- test = BybitEthFuturesTest(test_mode=True) # Always use testnet for safety
-
- # Run tests
- success = test.run_tests()
-
- if success:
- print("\n🎉 All tests passed!")
- else:
- print("\n💥 Some tests failed!")
-
- return success
-
-if __name__ == "__main__":
- success = main()
- sys.exit(0 if success else 1)
diff --git a/test_bybit_eth_futures_fixed.py b/test_bybit_eth_futures_fixed.py
deleted file mode 100644
index ec07c56..0000000
--- a/test_bybit_eth_futures_fixed.py
+++ /dev/null
@@ -1,304 +0,0 @@
-#!/usr/bin/env python3
-"""
-Fixed Bybit ETH futures trading test with proper minimum order size handling
-"""
-
-import os
-import sys
-import time
-import logging
-import json
-
-# Add the project root to the path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-# Load environment variables
-try:
- from dotenv import load_dotenv
- load_dotenv()
-except ImportError:
- if os.path.exists('.env'):
- with open('.env', 'r') as f:
- for line in f:
- if line.strip() and not line.startswith('#'):
- key, value = line.strip().split('=', 1)
- os.environ[key] = value
-
-from core.exchanges.bybit_interface import BybitInterface
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-def get_instrument_info(bybit: BybitInterface, symbol: str) -> dict:
- """Get instrument information including minimum order size"""
- try:
- instruments = bybit.get_instruments("linear")
- for instrument in instruments:
- if instrument.get('symbol') == symbol:
- return instrument
- return {}
- except Exception as e:
- logger.error(f"Error getting instrument info: {e}")
- return {}
-
-def test_eth_futures_trading():
- """Test ETH futures trading with proper minimum order size"""
- print("🚀 Starting Fixed Bybit ETH Futures Live Trading Test...")
- print("=" * 60)
- print("BYBIT ETH FUTURES LIVE TRADING TEST (FIXED)")
- print("=" * 60)
- print("⚠️ This uses LIVE environment with real money!")
- print("⚠️ Will check minimum order size first")
- print("=" * 60)
-
- # Check if API credentials are set
- api_key = os.getenv('BYBIT_API_KEY')
- api_secret = os.getenv('BYBIT_API_SECRET')
-
- if not api_key or not api_secret:
- print("❌ API credentials not found in environment")
- return False
-
- # Create Bybit interface with live environment
- bybit = BybitInterface(
- api_key=api_key,
- api_secret=api_secret,
- test_mode=False # Use live environment
- )
-
- symbol = 'ETHUSDT'
-
- # Test 1: Connection
- print(f"\n📡 Testing connection to Bybit live environment...")
- try:
- if not bybit.connect():
- print("❌ Failed to connect to Bybit")
- return False
- print("✅ Successfully connected to Bybit live environment")
- except Exception as e:
- print(f"❌ Connection error: {e}")
- return False
-
- # Test 2: Get instrument information to check minimum order size
- print(f"\n📋 Getting instrument information for {symbol}...")
- try:
- instrument_info = get_instrument_info(bybit, symbol)
- if not instrument_info:
- print(f"❌ Failed to get instrument info for {symbol}")
- return False
-
- print("✅ Instrument information retrieved:")
- print(f" Symbol: {instrument_info.get('symbol')}")
- print(f" Status: {instrument_info.get('status')}")
- print(f" Base Coin: {instrument_info.get('baseCoin')}")
- print(f" Quote Coin: {instrument_info.get('quoteCoin')}")
-
- # Extract minimum order size
- lot_size_filter = instrument_info.get('lotSizeFilter', {})
- min_order_qty = float(lot_size_filter.get('minOrderQty', 0.01))
- max_order_qty = float(lot_size_filter.get('maxOrderQty', 10000))
- qty_step = float(lot_size_filter.get('qtyStep', 0.01))
-
- print(f" Minimum Order Qty: {min_order_qty}")
- print(f" Maximum Order Qty: {max_order_qty}")
- print(f" Quantity Step: {qty_step}")
-
- # Use minimum order size for testing
- test_quantity = min_order_qty
- print(f" Using test quantity: {test_quantity} ETH")
-
- except Exception as e:
- print(f"❌ Instrument info error: {e}")
- return False
-
- # Test 3: Get account balance
- print(f"\n💰 Checking account balance...")
- try:
- usdt_balance = bybit.get_balance('USDT')
- print(f"USDT Balance: ${usdt_balance:.2f}")
-
- # Calculate required balance (with some buffer)
- current_price_data = bybit.get_ticker(symbol)
- if not current_price_data:
- print("❌ Failed to get current ETH price")
- return False
-
- current_price = current_price_data['last_price']
- required_balance = current_price * test_quantity * 1.1 # 10% buffer
-
- print(f"Current ETH price: ${current_price:.2f}")
- print(f"Required balance: ${required_balance:.2f}")
-
- if usdt_balance < required_balance:
- print(f"❌ Insufficient USDT balance for testing (need at least ${required_balance:.2f})")
- return False
-
- print("✅ Sufficient balance for testing")
-
- except Exception as e:
- print(f"❌ Balance check error: {e}")
- return False
-
- # Test 4: Check existing positions
- print(f"\n📊 Checking existing positions...")
- try:
- positions = bybit.get_positions(symbol)
- if positions:
- print(f"Found {len(positions)} existing positions:")
- for pos in positions:
- print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}")
- print(f" PnL: ${pos['unrealized_pnl']:.2f}")
- else:
- print("No existing positions found")
- except Exception as e:
- print(f"❌ Position check error: {e}")
- return False
-
- # Test 5: Ask user confirmation before trading
- print(f"\n⚠️ TRADING CONFIRMATION")
- print(f" Symbol: {symbol}")
- print(f" Quantity: {test_quantity} ETH")
- print(f" Estimated cost: ${current_price * test_quantity:.2f}")
- print(f" Environment: LIVE (real money)")
- print(f" Minimum order size confirmed: {min_order_qty}")
-
- response = input("\nDo you want to proceed with the live trading test? (y/N): ").lower()
- if response != 'y' and response != 'yes':
- print("❌ Trading test cancelled by user")
- return False
-
- # Test 6: Open a small long position
- print(f"\n🚀 Opening small long position...")
- try:
- order = bybit.place_order(
- symbol=symbol,
- side='buy',
- order_type='market',
- quantity=test_quantity
- )
-
- if 'error' in order:
- print(f"❌ Order failed: {order['error']}")
- return False
-
- print("✅ Long position opened successfully:")
- print(f" Order ID: {order['order_id']}")
- print(f" Symbol: {order['symbol']}")
- print(f" Side: {order['side']}")
- print(f" Quantity: {order['quantity']}")
- print(f" Status: {order['status']}")
-
- order_id = order['order_id']
-
- except Exception as e:
- print(f"❌ Order placement error: {e}")
- return False
-
- # Test 7: Wait a moment and check position
- print(f"\n⏳ Waiting 5 seconds for position to be reflected...")
- time.sleep(5)
-
- try:
- positions = bybit.get_positions(symbol)
- if positions:
- position = positions[0]
- print("✅ Position confirmed:")
- print(f" Symbol: {position['symbol']}")
- print(f" Side: {position['side']}")
- print(f" Size: {position['size']}")
- print(f" Entry Price: ${position['entry_price']:.2f}")
- print(f" Current PnL: ${position['unrealized_pnl']:.2f}")
- print(f" Leverage: {position['leverage']}x")
- else:
- print("⚠️ No position found (may already be closed)")
-
- except Exception as e:
- print(f"❌ Position check error: {e}")
-
- # Test 8: Close the position
- print(f"\n🔄 Closing the position...")
- try:
- close_order = bybit.close_position(symbol)
-
- if 'error' in close_order:
- print(f"❌ Close order failed: {close_order['error']}")
- # Don't return False here, as the position might still exist
- print("⚠️ You may need to manually close the position")
- else:
- print("✅ Position closed successfully:")
- print(f" Order ID: {close_order['order_id']}")
- print(f" Symbol: {close_order['symbol']}")
- print(f" Side: {close_order['side']}")
- print(f" Quantity: {close_order['quantity']}")
- print(f" Status: {close_order['status']}")
-
- except Exception as e:
- print(f"❌ Close position error: {e}")
- print("⚠️ You may need to manually close the position")
-
- # Test 9: Final position check
- print(f"\n📊 Final position check...")
- time.sleep(3)
-
- try:
- positions = bybit.get_positions(symbol)
- if positions:
- position = positions[0]
- print("⚠️ Position still exists:")
- print(f" Size: {position['size']}")
- print(f" PnL: ${position['unrealized_pnl']:.2f}")
- print("💡 You may want to manually close this position")
- else:
- print("✅ No open positions - trading test completed successfully")
-
- except Exception as e:
- print(f"❌ Final position check error: {e}")
-
- # Test 10: Final balance check
- print(f"\n💰 Final balance check...")
- try:
- final_balance = bybit.get_balance('USDT')
- print(f"Final USDT Balance: ${final_balance:.2f}")
-
- balance_change = final_balance - usdt_balance
- if balance_change > 0:
- print(f"💰 Profit: +${balance_change:.2f}")
- elif balance_change < 0:
- print(f"📉 Loss: ${balance_change:.2f}")
- else:
- print(f"🔄 No change: ${balance_change:.2f}")
-
- except Exception as e:
- print(f"❌ Final balance check error: {e}")
-
- return True
-
-def main():
- """Main function"""
- print("🚀 Starting Fixed Bybit ETH Futures Live Trading Test...")
-
- success = test_eth_futures_trading()
-
- if success:
- print("\n" + "=" * 60)
- print("✅ BYBIT ETH FUTURES TRADING TEST COMPLETED")
- print("=" * 60)
- print("🎯 Your Bybit integration is fully functional!")
- print("🔄 Position opening and closing works correctly")
- print("💰 Account balance integration works")
- print("📊 All trading functions are operational")
- print("📏 Minimum order size handling works")
- print("=" * 60)
- else:
- print("\n💥 Trading test failed!")
- print("🔍 Check the error messages above for details")
-
- return success
-
-if __name__ == "__main__":
- success = main()
- sys.exit(0 if success else 1)
\ No newline at end of file
diff --git a/test_bybit_eth_live.py b/test_bybit_eth_live.py
deleted file mode 100644
index 9554a91..0000000
--- a/test_bybit_eth_live.py
+++ /dev/null
@@ -1,249 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Bybit ETH futures trading with live environment
-"""
-
-import os
-import sys
-import time
-import logging
-
-# Add the project root to the path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-# Load environment variables
-try:
- from dotenv import load_dotenv
- load_dotenv()
-except ImportError:
- if os.path.exists('.env'):
- with open('.env', 'r') as f:
- for line in f:
- if line.strip() and not line.startswith('#'):
- key, value = line.strip().split('=', 1)
- os.environ[key] = value
-
-from core.exchanges.bybit_interface import BybitInterface
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-def test_eth_futures_trading():
- """Test ETH futures trading with live environment"""
- print("=" * 60)
- print("BYBIT ETH FUTURES LIVE TRADING TEST")
- print("=" * 60)
- print("⚠️ This uses LIVE environment with real money!")
- print("⚠️ Test amount: 0.001 ETH (very small)")
- print("=" * 60)
-
- # Check if API credentials are set
- api_key = os.getenv('BYBIT_API_KEY')
- api_secret = os.getenv('BYBIT_API_SECRET')
-
- if not api_key or not api_secret:
- print("❌ API credentials not found in environment")
- return False
-
- # Create Bybit interface with live environment
- bybit = BybitInterface(
- api_key=api_key,
- api_secret=api_secret,
- test_mode=False # Use live environment
- )
-
- symbol = 'ETHUSDT'
- test_quantity = 0.01 # Minimum order size for ETH futures
-
- # Test 1: Connection
- print(f"\n📡 Testing connection to Bybit live environment...")
- try:
- if not bybit.connect():
- print("❌ Failed to connect to Bybit")
- return False
- print("✅ Successfully connected to Bybit live environment")
- except Exception as e:
- print(f"❌ Connection error: {e}")
- return False
-
- # Test 2: Get account balance
- print(f"\n💰 Checking account balance...")
- try:
- usdt_balance = bybit.get_balance('USDT')
- print(f"USDT Balance: ${usdt_balance:.2f}")
-
- if usdt_balance < 5:
- print("❌ Insufficient USDT balance for testing (need at least $5)")
- return False
-
- print("✅ Sufficient balance for testing")
- except Exception as e:
- print(f"❌ Balance check error: {e}")
- return False
-
- # Test 3: Get current ETH price
- print(f"\n📈 Getting current ETH price...")
- try:
- ticker = bybit.get_ticker(symbol)
- if not ticker:
- print("❌ Failed to get ticker")
- return False
-
- current_price = ticker['last_price']
- print(f"Current ETH price: ${current_price:.2f}")
- print(f"Test order value: ${current_price * test_quantity:.2f}")
-
- except Exception as e:
- print(f"❌ Ticker error: {e}")
- return False
-
- # Test 4: Check existing positions
- print(f"\n📊 Checking existing positions...")
- try:
- positions = bybit.get_positions(symbol)
- if positions:
- print(f"Found {len(positions)} existing positions:")
- for pos in positions:
- print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}")
- print(f" PnL: ${pos['unrealized_pnl']:.2f}")
- else:
- print("No existing positions found")
- except Exception as e:
- print(f"❌ Position check error: {e}")
- return False
-
- # Test 5: Ask user confirmation before trading
- print(f"\n⚠️ TRADING CONFIRMATION")
- print(f" Symbol: {symbol}")
- print(f" Quantity: {test_quantity} ETH")
- print(f" Estimated cost: ${current_price * test_quantity:.2f}")
- print(f" Environment: LIVE (real money)")
-
- response = input("\nDo you want to proceed with the live trading test? (y/N): ").lower()
- if response != 'y' and response != 'yes':
- print("❌ Trading test cancelled by user")
- return False
-
- # Test 6: Open a small long position
- print(f"\n🚀 Opening small long position...")
- try:
- order = bybit.place_order(
- symbol=symbol,
- side='buy',
- order_type='market',
- quantity=test_quantity
- )
-
- if 'error' in order:
- print(f"❌ Order failed: {order['error']}")
- return False
-
- print("✅ Long position opened successfully:")
- print(f" Order ID: {order['order_id']}")
- print(f" Symbol: {order['symbol']}")
- print(f" Side: {order['side']}")
- print(f" Quantity: {order['quantity']}")
- print(f" Status: {order['status']}")
-
- order_id = order['order_id']
-
- except Exception as e:
- print(f"❌ Order placement error: {e}")
- return False
-
- # Test 7: Wait a moment and check position
- print(f"\n⏳ Waiting 3 seconds for position to be reflected...")
- time.sleep(3)
-
- try:
- positions = bybit.get_positions(symbol)
- if positions:
- position = positions[0]
- print("✅ Position confirmed:")
- print(f" Symbol: {position['symbol']}")
- print(f" Side: {position['side']}")
- print(f" Size: {position['size']}")
- print(f" Entry Price: ${position['entry_price']:.2f}")
- print(f" Current PnL: ${position['unrealized_pnl']:.2f}")
- print(f" Leverage: {position['leverage']}x")
- else:
- print("⚠️ No position found (may already be closed)")
-
- except Exception as e:
- print(f"❌ Position check error: {e}")
-
- # Test 8: Close the position
- print(f"\n🔄 Closing the position...")
- try:
- close_order = bybit.close_position(symbol)
-
- if 'error' in close_order:
- print(f"❌ Close order failed: {close_order['error']}")
- return False
-
- print("✅ Position closed successfully:")
- print(f" Order ID: {close_order['order_id']}")
- print(f" Symbol: {close_order['symbol']}")
- print(f" Side: {close_order['side']}")
- print(f" Quantity: {close_order['quantity']}")
- print(f" Status: {close_order['status']}")
-
- except Exception as e:
- print(f"❌ Close position error: {e}")
- return False
-
- # Test 9: Final position check
- print(f"\n📊 Final position check...")
- time.sleep(2)
-
- try:
- positions = bybit.get_positions(symbol)
- if positions:
- position = positions[0]
- print("⚠️ Position still exists:")
- print(f" Size: {position['size']}")
- print(f" PnL: ${position['unrealized_pnl']:.2f}")
- else:
- print("✅ No open positions - trading test completed successfully")
-
- except Exception as e:
- print(f"❌ Final position check error: {e}")
-
- # Test 10: Final balance check
- print(f"\n💰 Final balance check...")
- try:
- final_balance = bybit.get_balance('USDT')
- print(f"Final USDT Balance: ${final_balance:.2f}")
-
- except Exception as e:
- print(f"❌ Final balance check error: {e}")
-
- return True
-
-def main():
- """Main function"""
- print("🚀 Starting Bybit ETH Futures Live Trading Test...")
-
- success = test_eth_futures_trading()
-
- if success:
- print("\n" + "=" * 60)
- print("✅ BYBIT ETH FUTURES TRADING TEST COMPLETED")
- print("=" * 60)
- print("🎯 Your Bybit integration is fully functional!")
- print("🔄 Position opening and closing works correctly")
- print("💰 Account balance integration works")
- print("📊 All trading functions are operational")
- print("=" * 60)
- else:
- print("\n💥 Trading test failed!")
-
- return success
-
-if __name__ == "__main__":
- success = main()
- sys.exit(0 if success else 1)
diff --git a/test_bybit_public_api.py b/test_bybit_public_api.py
deleted file mode 100644
index de58964..0000000
--- a/test_bybit_public_api.py
+++ /dev/null
@@ -1,220 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Bybit public API functionality (no authentication required)
-"""
-
-import os
-import sys
-import time
-import logging
-
-# Add the project root to the path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from core.exchanges.bybit_rest_client import BybitRestClient
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-def test_public_api():
- """Test public API endpoints"""
- print("=" * 60)
- print("BYBIT PUBLIC API TEST")
- print("=" * 60)
-
- # Test both testnet and live for public endpoints
- for testnet in [True, False]:
- env_name = "TESTNET" if testnet else "LIVE"
- print(f"\n🔄 Testing {env_name} environment...")
-
- client = BybitRestClient(
- api_key="dummy",
- api_secret="dummy",
- testnet=testnet
- )
-
- # Test 1: Server time
- try:
- server_time = client.get_server_time()
- time_second = server_time.get('result', {}).get('timeSecond')
- print(f"✅ Server time: {time_second}")
- except Exception as e:
- print(f"❌ Server time failed: {e}")
- continue
-
- # Test 2: Get ticker for ETHUSDT
- try:
- ticker = client.get_ticker('ETHUSDT', 'linear')
- ticker_data = ticker.get('result', {}).get('list', [])
- if ticker_data:
- data = ticker_data[0]
- print(f"✅ ETH/USDT ticker:")
- print(f" Last Price: ${float(data.get('lastPrice', 0)):.2f}")
- print(f" 24h Volume: {float(data.get('volume24h', 0)):.2f}")
- print(f" 24h Change: {float(data.get('price24hPcnt', 0)) * 100:.2f}%")
- else:
- print("❌ No ticker data received")
- except Exception as e:
- print(f"❌ Ticker failed: {e}")
-
- # Test 3: Get instruments info
- try:
- instruments = client.get_instruments_info('linear')
- instruments_list = instruments.get('result', {}).get('list', [])
- eth_instruments = [i for i in instruments_list if 'ETH' in i.get('symbol', '')]
- print(f"✅ Found {len(eth_instruments)} ETH instruments")
- for instr in eth_instruments[:3]: # Show first 3
- print(f" {instr.get('symbol')} - Status: {instr.get('status')}")
- except Exception as e:
- print(f"❌ Instruments failed: {e}")
-
- # Test 4: Get orderbook
- try:
- orderbook = client.get_orderbook('ETHUSDT', 'linear', 5)
- ob_data = orderbook.get('result', {})
- bids = ob_data.get('b', [])
- asks = ob_data.get('a', [])
-
- if bids and asks:
- print(f"✅ Orderbook (top 3):")
- print(f" Best bid: ${float(bids[0][0]):.2f} (qty: {float(bids[0][1]):.4f})")
- print(f" Best ask: ${float(asks[0][0]):.2f} (qty: {float(asks[0][1]):.4f})")
- spread = float(asks[0][0]) - float(bids[0][0])
- print(f" Spread: ${spread:.2f}")
- else:
- print("❌ No orderbook data received")
- except Exception as e:
- print(f"❌ Orderbook failed: {e}")
-
- print(f"📊 {env_name} environment test completed")
-
-def test_live_authentication():
- """Test live authentication (if user wants to test with live credentials)"""
- print("\n" + "=" * 60)
- print("BYBIT LIVE AUTHENTICATION TEST")
- print("=" * 60)
- print("⚠️ This will test with LIVE credentials (not testnet)")
-
- # Load environment variables
- try:
- from dotenv import load_dotenv
- load_dotenv()
- except ImportError:
- # If dotenv is not available, try to load .env manually
- if os.path.exists('.env'):
- with open('.env', 'r') as f:
- for line in f:
- if line.strip() and not line.startswith('#'):
- key, value = line.strip().split('=', 1)
- os.environ[key] = value
-
- api_key = os.getenv('BYBIT_API_KEY')
- api_secret = os.getenv('BYBIT_API_SECRET')
-
- if not api_key or not api_secret:
- print("❌ No API credentials found in environment")
- return
-
- print(f"🔑 Using API key: {api_key[:8]}...")
-
- # Test with live environment (testnet=False)
- client = BybitRestClient(
- api_key=api_key,
- api_secret=api_secret,
- testnet=False # Use live environment
- )
-
- # Test connectivity
- try:
- if client.test_connectivity():
- print("✅ Basic connectivity OK")
- else:
- print("❌ Basic connectivity failed")
- return
- except Exception as e:
- print(f"❌ Connectivity error: {e}")
- return
-
- # Test authentication
- try:
- if client.test_authentication():
- print("✅ Authentication successful!")
-
- # Get account info
- account_info = client.get_account_info()
- accounts = account_info.get('result', {}).get('list', [])
-
- if accounts:
- print("📊 Account information:")
- for account in accounts:
- account_type = account.get('accountType', 'Unknown')
- print(f" Account Type: {account_type}")
-
- coins = account.get('coin', [])
- usdt_balance = None
- for coin in coins:
- if coin.get('coin') == 'USDT':
- usdt_balance = float(coin.get('walletBalance', 0))
- break
-
- if usdt_balance:
- print(f" USDT Balance: ${usdt_balance:.2f}")
-
- # Show positions if any
- try:
- positions = client.get_positions('linear')
- pos_list = positions.get('result', {}).get('list', [])
- active_positions = [p for p in pos_list if float(p.get('size', 0)) != 0]
-
- if active_positions:
- print(f" Active Positions: {len(active_positions)}")
- for pos in active_positions:
- symbol = pos.get('symbol')
- side = pos.get('side')
- size = float(pos.get('size', 0))
- pnl = float(pos.get('unrealisedPnl', 0))
- print(f" {symbol}: {side} {size} (PnL: ${pnl:.2f})")
- else:
- print(" No active positions")
- except Exception as e:
- print(f" ⚠️ Could not get positions: {e}")
-
- return True
- else:
- print("❌ Authentication failed")
- return False
-
- except Exception as e:
- print(f"❌ Authentication error: {e}")
- return False
-
-def main():
- """Main function"""
- print("🚀 Starting Bybit API Tests...")
-
- # Test public API
- test_public_api()
-
- # Ask user if they want to test live authentication
- print("\n" + "=" * 60)
- response = input("Do you want to test live authentication? (y/N): ").lower()
-
- if response == 'y' or response == 'yes':
- success = test_live_authentication()
- if success:
- print("\n✅ Live authentication test passed!")
- print("🎯 Your Bybit integration is working!")
- else:
- print("\n❌ Live authentication test failed")
- else:
- print("\n📋 Skipping live authentication test")
-
- print("\n🎉 Public API tests completed successfully!")
- print("📈 Bybit integration is functional for market data")
-
-if __name__ == "__main__":
- main()
diff --git a/test_cache_fix.py b/test_cache_fix.py
deleted file mode 100644
index 078af14..0000000
--- a/test_cache_fix.py
+++ /dev/null
@@ -1,137 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Cache Fix
-
-Creates a corrupted Parquet file to test the fix mechanism
-"""
-
-import os
-import pandas as pd
-from pathlib import Path
-from utils.cache_manager import get_cache_manager
-
-def create_test_data():
- """Create test cache files including a corrupted one"""
- print("Creating test cache files...")
-
- # Ensure cache directory exists
- cache_dir = Path("data/cache")
- cache_dir.mkdir(parents=True, exist_ok=True)
-
- # Create a valid Parquet file
- valid_data = pd.DataFrame({
- 'timestamp': pd.date_range('2025-01-01', periods=100, freq='1min'),
- 'open': [100.0 + i for i in range(100)],
- 'high': [101.0 + i for i in range(100)],
- 'low': [99.0 + i for i in range(100)],
- 'close': [100.5 + i for i in range(100)],
- 'volume': [1000 + i*10 for i in range(100)]
- })
-
- valid_file = cache_dir / "ETHUSDT_1m.parquet"
- valid_data.to_parquet(valid_file, index=False)
- print(f"Created valid file: {valid_file}")
-
- # Create a corrupted Parquet file by writing invalid data
- corrupted_file = cache_dir / "BTCUSDT_1m.parquet"
- with open(corrupted_file, 'wb') as f:
- f.write(b"This is not a valid Parquet file - corrupted data")
- print(f"Created corrupted file: {corrupted_file}")
-
- # Create an empty file
- empty_file = cache_dir / "SOLUSDT_1m.parquet"
- empty_file.touch()
- print(f"Created empty file: {empty_file}")
-
-def test_cache_manager():
- """Test the cache manager's ability to detect and fix issues"""
- print("\n=== Testing Cache Manager ===")
-
- cache_manager = get_cache_manager()
-
- # Scan health
- print("1. Scanning cache health...")
- health_summary = cache_manager.get_cache_summary()
-
- print(f"Total files: {health_summary['total_files']}")
- print(f"Valid files: {health_summary['valid_files']}")
- print(f"Corrupted files: {health_summary['corrupted_files']}")
- print(f"Health percentage: {health_summary['health_percentage']:.1f}%")
-
- # Show corrupted files
- for cache_dir, report in health_summary['directories'].items():
- if report['corrupted_files'] > 0:
- print(f"\nCorrupted files in {cache_dir}:")
- for corrupted in report['corrupted_files_list']:
- print(f" - {corrupted['file']}: {corrupted['error']}")
-
- # Test cleanup
- print("\n2. Testing cleanup...")
- deleted_files = cache_manager.cleanup_corrupted_files(dry_run=False)
-
- deleted_count = 0
- for cache_dir, files in deleted_files.items():
- for file_info in files:
- if "DELETED:" in file_info:
- deleted_count += 1
- print(f" {file_info}")
-
- print(f"Deleted {deleted_count} corrupted files")
-
- # Verify cleanup
- print("\n3. Verifying cleanup...")
- health_summary_after = cache_manager.get_cache_summary()
- print(f"Corrupted files after cleanup: {health_summary_after['corrupted_files']}")
-
-def test_data_provider_integration():
- """Test that the data provider can handle corrupted cache gracefully"""
- print("\n=== Testing Data Provider Integration ===")
-
- # Create another corrupted file
- cache_dir = Path("data/cache")
- corrupted_file = cache_dir / "ETHUSDT_5m.parquet"
- with open(corrupted_file, 'wb') as f:
- f.write(b"PAR1\x00\x00corrupted thrift data that will cause deserialization error")
- print(f"Created corrupted file with thrift error: {corrupted_file}")
-
- # Try to import and use data provider
- try:
- from core.data_provider import DataProvider
-
- # Create data provider (should auto-fix corrupted cache)
- data_provider = DataProvider()
- print("Data provider created successfully - auto-fix worked!")
-
- # Try to load data (should handle corruption gracefully)
- try:
- data = data_provider._load_from_cache("ETH/USDT", "5m")
- if data is None:
- print("Cache loading returned None (expected for corrupted file)")
- else:
- print(f"Loaded {len(data)} rows from cache")
- except Exception as e:
- print(f"Cache loading failed: {e}")
-
- except Exception as e:
- print(f"Data provider test failed: {e}")
-
-def main():
- """Run all tests"""
- print("=== Cache Fix Test Suite ===")
-
- # Clean up any existing test files
- cache_dir = Path("data/cache")
- if cache_dir.exists():
- for file in cache_dir.glob("*.parquet"):
- file.unlink()
-
- # Run tests
- create_test_data()
- test_cache_manager()
- test_data_provider_integration()
-
- print("\n=== Test Complete ===")
- print("The cache fix system is working correctly!")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/test_cnn_integration.py b/test_cnn_integration.py
deleted file mode 100644
index 46671dc..0000000
--- a/test_cnn_integration.py
+++ /dev/null
@@ -1,175 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test CNN Integration
-
-This script tests if the CNN adapter is working properly and identifies issues.
-"""
-
-import logging
-import sys
-import os
-from datetime import datetime
-
-# Setup logging
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-def test_cnn_adapter():
- """Test CNN adapter initialization and basic functionality"""
- try:
- logger.info("Testing CNN adapter initialization...")
-
- # Test 1: Import CNN adapter
- from core.enhanced_cnn_adapter import EnhancedCNNAdapter
- logger.info("✅ CNN adapter import successful")
-
- # Test 2: Initialize CNN adapter
- cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
- logger.info("✅ CNN adapter initialization successful")
-
- # Test 3: Check adapter attributes
- logger.info(f"CNN adapter model: {cnn_adapter.model}")
- logger.info(f"CNN adapter device: {cnn_adapter.device}")
- logger.info(f"CNN adapter model_name: {cnn_adapter.model_name}")
-
- # Test 4: Check metrics tracking
- logger.info(f"Inference count: {cnn_adapter.inference_count}")
- logger.info(f"Training count: {cnn_adapter.training_count}")
- logger.info(f"Training data length: {len(cnn_adapter.training_data)}")
-
- # Test 5: Test simple training sample addition
- cnn_adapter.add_training_sample("ETH/USDT", "BUY", 0.1)
- logger.info(f"✅ Training sample added, new length: {len(cnn_adapter.training_data)}")
-
- # Test 6: Test training if we have enough samples
- if len(cnn_adapter.training_data) >= 2:
- # Add another sample to have minimum for training
- cnn_adapter.add_training_sample("ETH/USDT", "SELL", -0.05)
-
- # Try training
- training_result = cnn_adapter.train(epochs=1)
- logger.info(f"✅ Training successful: {training_result}")
-
- # Check if metrics were updated
- logger.info(f"Last training time: {cnn_adapter.last_training_time}")
- logger.info(f"Last training loss: {cnn_adapter.last_training_loss}")
- logger.info(f"Training count: {cnn_adapter.training_count}")
- else:
- logger.info("⚠️ Not enough training samples for training test")
-
- return True
-
- except Exception as e:
- logger.error(f"❌ CNN adapter test failed: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-def test_base_data_input():
- """Test BaseDataInput creation"""
- try:
- logger.info("Testing BaseDataInput creation...")
-
- # Test 1: Import BaseDataInput
- from core.data_models import BaseDataInput, OHLCVBar, COBData
- logger.info("✅ BaseDataInput import successful")
-
- # Test 2: Create sample OHLCV bars
- sample_bars = []
- for i in range(10): # Create 10 sample bars
- bar = OHLCVBar(
- symbol="ETH/USDT",
- timestamp=datetime.now(),
- open=3500.0 + i,
- high=3510.0 + i,
- low=3490.0 + i,
- close=3505.0 + i,
- volume=1000.0,
- timeframe="1s"
- )
- sample_bars.append(bar)
-
- logger.info(f"✅ Created {len(sample_bars)} sample OHLCV bars")
-
- # Test 3: Create BaseDataInput
- base_data = BaseDataInput(
- symbol="ETH/USDT",
- timestamp=datetime.now(),
- ohlcv_1s=sample_bars,
- ohlcv_1m=sample_bars,
- ohlcv_1h=sample_bars,
- ohlcv_1d=sample_bars,
- btc_ohlcv_1s=sample_bars
- )
-
- logger.info("✅ BaseDataInput created successfully")
-
- # Test 4: Validate BaseDataInput
- is_valid = base_data.validate()
- logger.info(f"BaseDataInput validation: {is_valid}")
-
- # Test 5: Get feature vector
- feature_vector = base_data.get_feature_vector()
- logger.info(f"✅ Feature vector created, shape: {feature_vector.shape}")
-
- return base_data
-
- except Exception as e:
- logger.error(f"❌ BaseDataInput test failed: {e}")
- import traceback
- traceback.print_exc()
- return None
-
-def test_cnn_prediction():
- """Test CNN prediction with BaseDataInput"""
- try:
- logger.info("Testing CNN prediction...")
-
- # Get CNN adapter and base data
- from core.enhanced_cnn_adapter import EnhancedCNNAdapter
- cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
-
- base_data = test_base_data_input()
- if not base_data:
- logger.error("❌ Cannot test prediction without valid BaseDataInput")
- return False
-
- # Test prediction
- model_output = cnn_adapter.predict(base_data)
- logger.info(f"✅ Prediction successful: {model_output.predictions['action']} ({model_output.confidence:.3f})")
-
- # Check if metrics were updated
- logger.info(f"Inference count after prediction: {cnn_adapter.inference_count}")
- logger.info(f"Last inference time: {cnn_adapter.last_inference_time}")
- logger.info(f"Last prediction output: {cnn_adapter.last_prediction_output}")
-
- return True
-
- except Exception as e:
- logger.error(f"❌ CNN prediction test failed: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-def main():
- """Run all tests"""
- logger.info("🧪 Starting CNN Integration Tests")
-
- # Test 1: CNN Adapter
- if not test_cnn_adapter():
- logger.error("❌ CNN adapter test failed - stopping")
- return False
-
- # Test 2: CNN Prediction
- if not test_cnn_prediction():
- logger.error("❌ CNN prediction test failed - stopping")
- return False
-
- logger.info("✅ All CNN integration tests passed!")
- logger.info("🎯 The CNN adapter should now work properly in the dashboard")
-
- return True
-
-if __name__ == "__main__":
- success = main()
- sys.exit(0 if success else 1)
\ No newline at end of file
diff --git a/test_cob_dashboard.py b/test_cob_dashboard.py
deleted file mode 100644
index 87f1363..0000000
--- a/test_cob_dashboard.py
+++ /dev/null
@@ -1,22 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test COB Dashboard with Enhanced WebSocket
-"""
-
-import asyncio
-import logging
-from web.cob_realtime_dashboard import COBDashboardServer
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-
-async def main():
- """Test the COB dashboard"""
- dashboard = COBDashboardServer(host='localhost', port=8053)
- await dashboard.start()
-
-if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
diff --git a/test_cob_data_quality.py b/test_cob_data_quality.py
deleted file mode 100644
index 948434c..0000000
--- a/test_cob_data_quality.py
+++ /dev/null
@@ -1,118 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script for COB data quality and imbalance indicators
-"""
-
-import time
-import logging
-from core.data_provider import DataProvider
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_cob_data_quality():
- """Test COB data quality and imbalance indicators"""
- logger.info("Testing COB data quality and imbalance indicators...")
-
- # Initialize data provider
- dp = DataProvider()
-
- # Wait for initial data load and COB connection
- logger.info("Waiting for initial data load and COB connection...")
- time.sleep(20)
-
- # Test 1: Check cached data summary
- logger.info("\n=== Test 1: Cached Data Summary ===")
- cache_summary = dp.get_cached_data_summary()
- for symbol in cache_summary['cached_data']:
- logger.info(f"\n{symbol}:")
- for timeframe, info in cache_summary['cached_data'][symbol].items():
- if 'candle_count' in info and info['candle_count'] > 0:
- logger.info(f" {timeframe}: {info['candle_count']} candles, latest: ${info['latest_price']}")
- else:
- logger.info(f" {timeframe}: {info.get('status', 'no data')}")
-
- # Test 2: Check COB data quality
- logger.info("\n=== Test 2: COB Data Quality ===")
- cob_quality = dp.get_cob_data_quality()
-
- for symbol in cob_quality['symbols']:
- logger.info(f"\n{symbol} COB Data:")
-
- # Raw ticks
- raw_info = cob_quality['raw_ticks'].get(symbol, {})
- logger.info(f" Raw ticks: {raw_info.get('count', 0)} ticks")
- if raw_info.get('age_seconds') is not None:
- logger.info(f" Raw data age: {raw_info['age_seconds']:.1f} seconds")
-
- # Aggregated 1s data
- agg_info = cob_quality['aggregated_1s'].get(symbol, {})
- logger.info(f" Aggregated 1s: {agg_info.get('count', 0)} records")
- if agg_info.get('age_seconds') is not None:
- logger.info(f" Aggregated data age: {agg_info['age_seconds']:.1f} seconds")
-
- # Imbalance indicators
- imbalance_info = cob_quality['imbalance_indicators'].get(symbol, {})
- if imbalance_info:
- logger.info(f" Imbalance 1s: {imbalance_info.get('imbalance_1s', 0):.4f}")
- logger.info(f" Imbalance 5s: {imbalance_info.get('imbalance_5s', 0):.4f}")
- logger.info(f" Imbalance 15s: {imbalance_info.get('imbalance_15s', 0):.4f}")
- logger.info(f" Imbalance 60s: {imbalance_info.get('imbalance_60s', 0):.4f}")
- logger.info(f" Total volume: {imbalance_info.get('total_volume', 0):.2f}")
- logger.info(f" Price buckets: {imbalance_info.get('bucket_count', 0)}")
-
- # Data freshness
- freshness = cob_quality['data_freshness'].get(symbol, 'unknown')
- logger.info(f" Data freshness: {freshness}")
-
- # Test 3: Get recent COB aggregated data
- logger.info("\n=== Test 3: Recent COB Aggregated Data ===")
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- recent_cob = dp.get_cob_1s_aggregated(symbol, count=5)
- logger.info(f"\n{symbol} - Last 5 aggregated records:")
-
- for i, record in enumerate(recent_cob[-5:]):
- timestamp = record.get('timestamp', 0)
- imbalance_1s = record.get('imbalance_1s', 0)
- imbalance_5s = record.get('imbalance_5s', 0)
- total_volume = record.get('total_volume', 0)
- bucket_count = len(record.get('bid_buckets', {})) + len(record.get('ask_buckets', {}))
-
- logger.info(f" [{i+1}] Time: {timestamp}, Imb1s: {imbalance_1s:.4f}, "
- f"Imb5s: {imbalance_5s:.4f}, Vol: {total_volume:.2f}, Buckets: {bucket_count}")
-
- # Test 4: Monitor real-time updates
- logger.info("\n=== Test 4: Real-time Updates (30 seconds) ===")
- logger.info("Monitoring COB data updates...")
-
- initial_quality = dp.get_cob_data_quality()
- time.sleep(30)
- updated_quality = dp.get_cob_data_quality()
-
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- initial_count = initial_quality['raw_ticks'].get(symbol, {}).get('count', 0)
- updated_count = updated_quality['raw_ticks'].get(symbol, {}).get('count', 0)
- new_ticks = updated_count - initial_count
-
- initial_agg = initial_quality['aggregated_1s'].get(symbol, {}).get('count', 0)
- updated_agg = updated_quality['aggregated_1s'].get(symbol, {}).get('count', 0)
- new_agg = updated_agg - initial_agg
-
- logger.info(f"{symbol}: +{new_ticks} raw ticks, +{new_agg} aggregated records")
-
- # Show latest imbalances
- latest_imbalances = updated_quality['imbalance_indicators'].get(symbol, {})
- if latest_imbalances:
- logger.info(f" Latest imbalances: 1s={latest_imbalances.get('imbalance_1s', 0):.4f}, "
- f"5s={latest_imbalances.get('imbalance_5s', 0):.4f}, "
- f"15s={latest_imbalances.get('imbalance_15s', 0):.4f}, "
- f"60s={latest_imbalances.get('imbalance_60s', 0):.4f}")
-
- # Clean shutdown
- logger.info("\n=== Shutting Down ===")
- dp.stop_automatic_data_maintenance()
- logger.info("COB data quality test completed")
-
-if __name__ == "__main__":
- test_cob_data_quality()
\ No newline at end of file
diff --git a/test_cob_websocket_only.py b/test_cob_websocket_only.py
deleted file mode 100644
index 73859eb..0000000
--- a/test_cob_websocket_only.py
+++ /dev/null
@@ -1,131 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test COB WebSocket Only Integration
-
-This script tests that COB integration works with Enhanced WebSocket only,
-without falling back to REST API calls.
-"""
-
-import asyncio
-import time
-from datetime import datetime
-from typing import Dict
-from core.cob_integration import COBIntegration
-
-async def test_cob_websocket_only():
- """Test COB integration with WebSocket only"""
- print("=== Testing COB WebSocket Only Integration ===")
-
- # Initialize COB integration
- print("1. Initializing COB integration...")
- symbols = ['ETH/USDT', 'BTC/USDT']
- cob_integration = COBIntegration(symbols=symbols)
-
- # Track updates
- update_count = 0
- last_update_time = None
-
- def dashboard_callback(symbol: str, data: Dict):
- nonlocal update_count, last_update_time
- update_count += 1
- last_update_time = datetime.now()
-
- if update_count <= 5: # Show first 5 updates
- data_type = data.get('type', 'unknown')
- if data_type == 'cob_update':
- stats = data.get('data', {}).get('stats', {})
- mid_price = stats.get('mid_price', 0)
- spread_bps = stats.get('spread_bps', 0)
- source = stats.get('source', 'unknown')
- print(f" Update #{update_count}: {symbol} - Price: ${mid_price:.2f}, Spread: {spread_bps:.1f}bps, Source: {source}")
- elif data_type == 'websocket_status':
- status_data = data.get('data', {})
- status = status_data.get('status', 'unknown')
- print(f" Status #{update_count}: {symbol} - WebSocket: {status}")
-
- # Add dashboard callback
- cob_integration.add_dashboard_callback(dashboard_callback)
-
- # Start COB integration
- print("2. Starting COB integration...")
- try:
- # Start in background
- start_task = asyncio.create_task(cob_integration.start())
-
- # Wait for initialization
- await asyncio.sleep(3)
-
- # Check if COB provider is disabled
- print("3. Checking COB provider status:")
- if cob_integration.cob_provider is None:
- print(" ✅ COB provider is disabled (using Enhanced WebSocket only)")
- else:
- print(" ❌ COB provider is still active (may cause REST API fallback)")
-
- # Check Enhanced WebSocket status
- print("4. Checking Enhanced WebSocket status:")
- if cob_integration.enhanced_websocket:
- print(" ✅ Enhanced WebSocket is initialized")
-
- # Check WebSocket status for each symbol
- websocket_status = cob_integration.get_websocket_status()
- for symbol, status in websocket_status.items():
- print(f" {symbol}: {status}")
- else:
- print(" ❌ Enhanced WebSocket is not initialized")
-
- # Monitor updates for a few seconds
- print("5. Monitoring COB updates...")
- initial_count = update_count
- monitor_start = time.time()
-
- # Wait for updates
- await asyncio.sleep(5)
-
- monitor_duration = time.time() - monitor_start
- updates_received = update_count - initial_count
- update_rate = updates_received / monitor_duration
-
- print(f" Received {updates_received} updates in {monitor_duration:.1f}s")
- print(f" Update rate: {update_rate:.1f} updates/second")
-
- if update_rate >= 8: # Should be around 10 updates/second
- print(" ✅ Update rate is excellent (8+ updates/second)")
- elif update_rate >= 5:
- print(" ✅ Update rate is good (5+ updates/second)")
- elif update_rate >= 1:
- print(" ⚠️ Update rate is low (1+ updates/second)")
- else:
- print(" ❌ Update rate is too low (<1 update/second)")
-
- # Check data quality
- print("6. Data quality check:")
- if last_update_time:
- time_since_last = (datetime.now() - last_update_time).total_seconds()
- if time_since_last < 1:
- print(f" ✅ Recent data (last update {time_since_last:.1f}s ago)")
- else:
- print(f" ⚠️ Stale data (last update {time_since_last:.1f}s ago)")
- else:
- print(" ❌ No updates received")
-
- # Stop the integration
- print("7. Stopping COB integration...")
- await cob_integration.stop()
-
- # Cancel the start task
- start_task.cancel()
- try:
- await start_task
- except asyncio.CancelledError:
- pass
-
- except Exception as e:
- print(f" ❌ Error during COB integration test: {e}")
-
- print(f"\n✅ COB WebSocket only test completed!")
- print(f"Total updates received: {update_count}")
- print("Enhanced WebSocket is now the sole data source (no REST API fallback)")
-
-if __name__ == "__main__":
- asyncio.run(test_cob_websocket_only())
\ No newline at end of file
diff --git a/test_continuous_cnn_training.py b/test_continuous_cnn_training.py
deleted file mode 100644
index 0a6778b..0000000
--- a/test_continuous_cnn_training.py
+++ /dev/null
@@ -1,155 +0,0 @@
-"""
-Test Continuous CNN Training
-
-This script demonstrates how the CNN model can be trained with each new inference result
-using collected data, implementing a continuous learning loop.
-"""
-
-import logging
-import time
-from datetime import datetime
-import random
-import os
-
-from core.standardized_data_provider import StandardizedDataProvider
-from core.enhanced_cnn_adapter import EnhancedCNNAdapter
-from core.data_models import create_model_output
-
-# Configure logging
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-def simulate_market_feedback(action, symbol):
- """
- Simulate market feedback for a given action
-
- In a real system, this would be replaced with actual market performance data
-
- Args:
- action: Trading action ('BUY', 'SELL', 'HOLD')
- symbol: Trading symbol
-
- Returns:
- tuple: (actual_action, reward)
- """
- # Simulate market movement (random for demonstration)
- market_direction = random.choice(['up', 'down', 'sideways'])
-
- # Determine actual best action based on market direction
- if market_direction == 'up':
- best_action = 'BUY'
- elif market_direction == 'down':
- best_action = 'SELL'
- else:
- best_action = 'HOLD'
-
- # Calculate reward based on whether the action matched the best action
- if action == best_action:
- reward = random.uniform(0.01, 0.1) # Positive reward for correct action
- else:
- reward = random.uniform(-0.1, -0.01) # Negative reward for incorrect action
-
- logger.info(f"Market went {market_direction}, best action was {best_action}, model chose {action}, reward: {reward:.4f}")
-
- return best_action, reward
-
-def test_continuous_training():
- """Test continuous training of the CNN model with new inference results"""
- try:
- # Initialize data provider
- symbols = ['ETH/USDT', 'BTC/USDT']
- timeframes = ['1s', '1m', '1h', '1d']
- data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
-
- # Initialize CNN adapter
- checkpoint_dir = "models/enhanced_cnn"
- os.makedirs(checkpoint_dir, exist_ok=True)
- cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir)
-
- # Load best checkpoint if available
- cnn_adapter.load_best_checkpoint()
-
- # Continuous learning loop
- num_iterations = 10
- training_frequency = 3 # Train every N iterations
- samples_collected = 0
-
- logger.info(f"Starting continuous learning loop with {num_iterations} iterations")
-
- for i in range(num_iterations):
- logger.info(f"\nIteration {i+1}/{num_iterations}")
-
- # Get standardized input data
- symbol = random.choice(symbols)
- logger.info(f"Getting data for {symbol}...")
- base_data = data_provider.get_base_data_input(symbol)
-
- if base_data is None:
- logger.warning(f"Failed to get base data input for {symbol}, skipping iteration")
- continue
-
- # Make prediction
- logger.info(f"Making prediction for {symbol}...")
- model_output = cnn_adapter.predict(base_data)
-
- # Log prediction
- action = model_output.predictions['action']
- confidence = model_output.confidence
- logger.info(f"Prediction: {action} with confidence {confidence:.4f}")
-
- # Store model output
- data_provider.store_model_output(model_output)
-
- # Simulate market feedback
- best_action, reward = simulate_market_feedback(action, symbol)
-
- # Add training sample
- logger.info(f"Adding training sample: action={best_action}, reward={reward:.4f}")
- cnn_adapter.add_training_sample(base_data, best_action, reward)
- samples_collected += 1
-
- # Train model periodically
- if (i + 1) % training_frequency == 0 and samples_collected >= 3:
- logger.info(f"Training model with {samples_collected} samples...")
- metrics = cnn_adapter.train(epochs=1)
-
- # Log training metrics
- logger.info(f"Training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
-
- # Simulate time passing
- time.sleep(1)
-
- logger.info("\nContinuous learning loop completed")
-
- # Final evaluation
- logger.info("Performing final evaluation...")
-
- # Get data for evaluation
- symbol = 'ETH/USDT'
- base_data = data_provider.get_base_data_input(symbol)
-
- if base_data is not None:
- # Make prediction
- model_output = cnn_adapter.predict(base_data)
-
- # Log prediction
- action = model_output.predictions['action']
- confidence = model_output.confidence
- logger.info(f"Final prediction for {symbol}: {action} with confidence {confidence:.4f}")
-
- # Get model output manager
- output_manager = data_provider.get_model_output_manager()
-
- # Evaluate model performance
- metrics = output_manager.evaluate_model_performance(symbol, cnn_adapter.model_name)
- logger.info(f"Performance metrics: {metrics}")
- else:
- logger.warning(f"Failed to get base data input for final evaluation")
-
- logger.info("Test completed successfully")
-
- except Exception as e:
- logger.error(f"Error in test: {e}", exc_info=True)
-
-if __name__ == "__main__":
- test_continuous_training()
\ No newline at end of file
diff --git a/test_dashboard_data_flow.py b/test_dashboard_data_flow.py
deleted file mode 100644
index 2d620ef..0000000
--- a/test_dashboard_data_flow.py
+++ /dev/null
@@ -1,90 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script to debug dashboard data flow issues
-
-This script tests if the dashboard can properly retrieve and display model data.
-"""
-
-import sys
-import os
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-import logging
-logging.basicConfig(level=logging.DEBUG)
-
-from web.clean_dashboard import CleanTradingDashboard
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-
-def test_dashboard_data_flow():
- """Test if dashboard can retrieve model data correctly"""
-
- print("🧪 DASHBOARD DATA FLOW TEST")
- print("=" * 50)
-
- try:
- # Initialize components
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider=data_provider)
-
- print(f"✅ Orchestrator initialized")
- print(f" Model registry models: {list(orchestrator.model_registry.get_all_models().keys())}")
- print(f" Model toggle states: {list(orchestrator.model_toggle_states.keys())}")
-
- # Initialize dashboard
- dashboard = CleanTradingDashboard(
- data_provider=data_provider,
- orchestrator=orchestrator
- )
-
- print(f"✅ Dashboard initialized")
-
- # Test available models
- available_models = dashboard._get_available_models()
- print(f" Available models: {list(available_models.keys())}")
-
- # Test training metrics
- print("\n📊 Testing training metrics...")
- toggle_states = {}
- for model_name in available_models.keys():
- toggle_states[model_name] = orchestrator.get_model_toggle_state(model_name)
-
- print(f" Toggle states: {list(toggle_states.keys())}")
-
- metrics_data = dashboard._get_training_metrics(toggle_states)
- print(f" Metrics data type: {type(metrics_data)}")
-
- if metrics_data and isinstance(metrics_data, dict):
- print(f" Metrics keys: {list(metrics_data.keys())}")
- if 'loaded_models' in metrics_data:
- loaded_models = metrics_data['loaded_models']
- print(f" Loaded models count: {len(loaded_models)}")
- for model_name, model_info in loaded_models.items():
- print(f" - {model_name}: active={model_info.get('active', False)}")
- else:
- print(" ❌ No 'loaded_models' in metrics_data!")
- else:
- print(f" ❌ Invalid metrics_data: {metrics_data}")
-
- # Test component manager formatting
- print("\n🎨 Testing component manager...")
- formatted_components = dashboard.component_manager.format_training_metrics(metrics_data)
- print(f" Formatted components type: {type(formatted_components)}")
- print(f" Formatted components count: {len(formatted_components) if formatted_components else 0}")
-
- if formatted_components:
- print(" ✅ Component manager returned formatted data")
- else:
- print(" ❌ Component manager returned empty data")
-
- print("\n🚀 Dashboard data flow test completed!")
- return True
-
- except Exception as e:
- print(f"❌ Test failed with error: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-if __name__ == "__main__":
- test_dashboard_data_flow()
\ No newline at end of file
diff --git a/test_dashboard_performance.py b/test_dashboard_performance.py
deleted file mode 100644
index 35de9d1..0000000
--- a/test_dashboard_performance.py
+++ /dev/null
@@ -1,164 +0,0 @@
-#!/usr/bin/env python3
-"""
-Dashboard Performance Test
-
-Test the optimized callback structure to ensure we've reduced
-the number of requests per second.
-"""
-
-import time
-from web.clean_dashboard import CleanTradingDashboard
-from core.data_provider import DataProvider
-
-def test_callback_optimization():
- """Test that we've optimized the callback structure"""
- print("=== Dashboard Performance Optimization Test ===")
-
- print("✅ BEFORE Optimization:")
- print(" - 7 callbacks on 1-second interval = 7 requests/second")
- print(" - Server overload with single client")
- print(" - Poor user experience")
-
- print("\n✅ AFTER Optimization:")
- print(" - Main interval: 2 seconds (reduced from 1s)")
- print(" - Slow interval: 10 seconds (increased from 5s)")
- print(" - Critical metrics: 2s interval (3 requests every 2s)")
- print(" - Non-critical data: 10s interval (4 requests every 10s)")
-
- print("\n📊 Performance Improvement:")
- print(" - Before: 7 requests/second = 420 requests/minute")
- print(" - After: ~1.9 requests/second = 114 requests/minute")
- print(" - Reduction: ~73% fewer requests")
-
- print("\n🎯 Callback Distribution:")
- print(" Fast Interval (2s):")
- print(" 1. update_metrics (price, PnL, position, status)")
- print(" 2. update_price_chart (trading chart)")
- print(" 3. update_cob_data (order book for trading)")
- print(" ")
- print(" Slow Interval (10s):")
- print(" 4. update_recent_decisions (trading history)")
- print(" 5. update_closed_trades (completed trades)")
- print(" 6. update_pending_orders (pending orders)")
- print(" 7. update_training_metrics (ML model stats)")
-
- print("\n✅ Benefits:")
- print(" - Server can handle multiple clients")
- print(" - Reduced CPU usage")
- print(" - Better responsiveness")
- print(" - Still real-time for critical trading data")
-
- return True
-
-def test_interval_configuration():
- """Test the interval configuration"""
- print("\n=== Interval Configuration Test ===")
-
- try:
- from web.layout_manager import DashboardLayoutManager
-
- # Create layout manager to test intervals
- layout_manager = DashboardLayoutManager(100.0, None)
- layout = layout_manager.create_main_layout()
-
- # Check if intervals are properly configured
- print("✅ Layout created successfully")
- print("✅ Intervals should be configured as:")
- print(" - interval-component: 2000ms (2s)")
- print(" - slow-interval-component: 10000ms (10s)")
-
- return True
-
- except Exception as e:
- print(f"❌ Error testing interval configuration: {e}")
- return False
-
-def calculate_performance_metrics():
- """Calculate the performance improvement metrics"""
- print("\n=== Performance Metrics Calculation ===")
-
- # Old system
- old_callbacks = 7
- old_interval = 1 # second
- old_requests_per_second = old_callbacks / old_interval
- old_requests_per_minute = old_requests_per_second * 60
-
- # New system
- fast_callbacks = 3 # metrics, chart, cob
- fast_interval = 2 # seconds
- slow_callbacks = 4 # decisions, trades, orders, training
- slow_interval = 10 # seconds
-
- new_requests_per_second = (fast_callbacks / fast_interval) + (slow_callbacks / slow_interval)
- new_requests_per_minute = new_requests_per_second * 60
-
- reduction_percent = ((old_requests_per_second - new_requests_per_second) / old_requests_per_second) * 100
-
- print(f"📊 Detailed Performance Analysis:")
- print(f" Old System:")
- print(f" - {old_callbacks} callbacks × {old_interval}s = {old_requests_per_second:.1f} req/s")
- print(f" - {old_requests_per_minute:.0f} requests/minute")
- print(f" ")
- print(f" New System:")
- print(f" - Fast: {fast_callbacks} callbacks ÷ {fast_interval}s = {fast_callbacks/fast_interval:.1f} req/s")
- print(f" - Slow: {slow_callbacks} callbacks ÷ {slow_interval}s = {slow_callbacks/slow_interval:.1f} req/s")
- print(f" - Total: {new_requests_per_second:.1f} req/s")
- print(f" - {new_requests_per_minute:.0f} requests/minute")
- print(f" ")
- print(f" 🎉 Improvement: {reduction_percent:.1f}% reduction in requests")
-
- # Server capacity estimation
- print(f"\n🖥️ Server Capacity Estimation:")
- print(f" - Old: Could handle ~{100/old_requests_per_second:.0f} concurrent users")
- print(f" - New: Can handle ~{100/new_requests_per_second:.0f} concurrent users")
- print(f" - Capacity increase: {(100/new_requests_per_second)/(100/old_requests_per_second):.1f}x")
-
- return {
- 'old_rps': old_requests_per_second,
- 'new_rps': new_requests_per_second,
- 'reduction_percent': reduction_percent,
- 'capacity_multiplier': (100/new_requests_per_second)/(100/old_requests_per_second)
- }
-
-def main():
- """Run all performance tests"""
- print("=== Dashboard Performance Optimization Test Suite ===")
-
- tests = [
- ("Callback Optimization", test_callback_optimization),
- ("Interval Configuration", test_interval_configuration)
- ]
-
- passed = 0
- total = len(tests)
-
- for test_name, test_func in tests:
- print(f"\n{'='*60}")
- try:
- if test_func():
- passed += 1
- print(f"✅ {test_name}: PASSED")
- else:
- print(f"❌ {test_name}: FAILED")
- except Exception as e:
- print(f"❌ {test_name}: ERROR - {e}")
-
- # Calculate performance metrics
- metrics = calculate_performance_metrics()
-
- print(f"\n{'='*60}")
- print(f"=== Test Results: {passed}/{total} passed ===")
-
- if passed == total:
- print("\n🎉 ALL TESTS PASSED!")
- print("✅ Dashboard performance optimized successfully")
- print(f"✅ {metrics['reduction_percent']:.1f}% reduction in server requests")
- print(f"✅ {metrics['capacity_multiplier']:.1f}x increase in server capacity")
- print("✅ Better user experience with responsive UI")
- print("✅ Ready for production with multiple users")
- else:
- print(f"\n⚠️ {total - passed} tests failed")
- print("Check individual test results above")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/test_data_integration.py b/test_data_integration.py
deleted file mode 100644
index 4a86d0a..0000000
--- a/test_data_integration.py
+++ /dev/null
@@ -1,182 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Data Integration
-
-Test that the FIFO queues are properly populated from the data provider
-"""
-
-import time
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-
-def test_data_provider_methods():
- """Test what methods are available in the data provider"""
- print("=== Testing Data Provider Methods ===")
-
- try:
- data_provider = DataProvider()
-
- # Check available methods
- methods = [method for method in dir(data_provider) if not method.startswith('_') and callable(getattr(data_provider, method))]
- data_methods = [method for method in methods if 'data' in method.lower() or 'ohlcv' in method.lower() or 'historical' in method.lower() or 'latest' in method.lower()]
-
- print("Available data-related methods:")
- for method in sorted(data_methods):
- print(f" - {method}")
-
- # Test getting historical data
- print(f"\nTesting get_historical_data:")
- try:
- df = data_provider.get_historical_data('ETH/USDT', '1m', limit=5)
- if df is not None and not df.empty:
- print(f" ✅ Got {len(df)} rows of 1m data")
- print(f" Columns: {list(df.columns)}")
- print(f" Latest close: {df['close'].iloc[-1]}")
- else:
- print(f" ❌ No data returned")
- except Exception as e:
- print(f" ❌ Error: {e}")
-
- # Test getting latest candles if available
- if hasattr(data_provider, 'get_latest_candles'):
- print(f"\nTesting get_latest_candles:")
- try:
- df = data_provider.get_latest_candles('ETH/USDT', '1m', limit=5)
- if df is not None and not df.empty:
- print(f" ✅ Got {len(df)} rows of latest candles")
- print(f" Latest close: {df['close'].iloc[-1]}")
- else:
- print(f" ❌ No data returned")
- except Exception as e:
- print(f" ❌ Error: {e}")
-
- return True
-
- except Exception as e:
- print(f"❌ Test failed: {e}")
- return False
-
-def test_queue_population():
- """Test that queues get populated with data"""
- print("\n=== Testing Queue Population ===")
-
- try:
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider)
-
- # Wait a moment for initial population
- print("Waiting 3 seconds for initial data population...")
- time.sleep(3)
-
- # Check queue status
- print("\nQueue status after initialization:")
- orchestrator.log_queue_status(detailed=True)
-
- # Check if we have minimum data
- symbols_to_check = ['ETH/USDT', 'BTC/USDT']
- timeframes_to_check = ['1s', '1m', '1h', '1d']
- min_requirements = {'1s': 100, '1m': 50, '1h': 20, '1d': 10}
-
- print(f"\nChecking minimum data requirements:")
- for symbol in symbols_to_check:
- print(f"\n{symbol}:")
- for timeframe in timeframes_to_check:
- min_count = min_requirements.get(timeframe, 10)
- has_min = orchestrator.ensure_minimum_data(f'ohlcv_{timeframe}', symbol, min_count)
- actual_count = 0
- if f'ohlcv_{timeframe}' in orchestrator.data_queues and symbol in orchestrator.data_queues[f'ohlcv_{timeframe}']:
- with orchestrator.data_queue_locks[f'ohlcv_{timeframe}'][symbol]:
- actual_count = len(orchestrator.data_queues[f'ohlcv_{timeframe}'][symbol])
-
- status = "✅" if has_min else "❌"
- print(f" {timeframe}: {status} {actual_count}/{min_count}")
-
- # Test BaseDataInput building
- print(f"\nTesting BaseDataInput building:")
- base_data = orchestrator.build_base_data_input('ETH/USDT')
- if base_data:
- features = base_data.get_feature_vector()
- print(f" ✅ BaseDataInput built successfully")
- print(f" Feature vector size: {len(features)}")
- print(f" OHLCV 1s bars: {len(base_data.ohlcv_1s)}")
- print(f" OHLCV 1m bars: {len(base_data.ohlcv_1m)}")
- print(f" BTC bars: {len(base_data.btc_ohlcv_1s)}")
- else:
- print(f" ❌ Failed to build BaseDataInput")
-
- return base_data is not None
-
- except Exception as e:
- print(f"❌ Test failed: {e}")
- return False
-
-def test_polling_thread():
- """Test that the polling thread is working"""
- print("\n=== Testing Polling Thread ===")
-
- try:
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider)
-
- # Get initial queue counts
- initial_status = orchestrator.get_queue_status()
- print("Initial queue counts:")
- for data_type, symbols in initial_status.items():
- for symbol, count in symbols.items():
- if count > 0:
- print(f" {data_type}/{symbol}: {count}")
-
- # Wait for polling thread to run
- print("\nWaiting 10 seconds for polling thread...")
- time.sleep(10)
-
- # Get updated queue counts
- updated_status = orchestrator.get_queue_status()
- print("\nUpdated queue counts:")
- for data_type, symbols in updated_status.items():
- for symbol, count in symbols.items():
- if count > 0:
- print(f" {data_type}/{symbol}: {count}")
-
- # Check if any queues grew
- growth_detected = False
- for data_type in initial_status:
- for symbol in initial_status[data_type]:
- initial_count = initial_status[data_type][symbol]
- updated_count = updated_status[data_type][symbol]
- if updated_count > initial_count:
- print(f" ✅ Growth detected: {data_type}/{symbol} {initial_count} -> {updated_count}")
- growth_detected = True
-
- if not growth_detected:
- print(" ⚠️ No queue growth detected - polling may not be working")
-
- return growth_detected
-
- except Exception as e:
- print(f"❌ Test failed: {e}")
- return False
-
-def main():
- """Run all data integration tests"""
- print("=== Data Integration Test Suite ===")
-
- test1_passed = test_data_provider_methods()
- test2_passed = test_queue_population()
- test3_passed = test_polling_thread()
-
- print(f"\n=== Results ===")
- print(f"Data provider methods: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
- print(f"Queue population: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
- print(f"Polling thread: {'✅ PASSED' if test3_passed else '❌ FAILED'}")
-
- if test1_passed and test2_passed:
- print("\n✅ Data integration is working!")
- print("✅ FIFO queues should be populated with data")
- print("✅ Models should be able to make predictions")
- else:
- print("\n❌ Data integration issues detected")
- print("❌ Check data provider connectivity and methods")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/test_data_provider_integration.py b/test_data_provider_integration.py
deleted file mode 100644
index d3feff8..0000000
--- a/test_data_provider_integration.py
+++ /dev/null
@@ -1,112 +0,0 @@
-#!/usr/bin/env python3
-"""
-Integration test for the simplified data provider with other components
-"""
-
-import time
-import logging
-import pandas as pd
-from core.data_provider import DataProvider
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_integration():
- """Test integration with other components"""
- logger.info("Testing DataProvider integration...")
-
- # Initialize data provider
- dp = DataProvider()
-
- # Wait for initial data load
- logger.info("Waiting for initial data load...")
- time.sleep(15)
-
- # Test 1: Feature matrix generation
- logger.info("\n=== Test 1: Feature Matrix Generation ===")
- try:
- feature_matrix = dp.get_feature_matrix('ETH/USDT', ['1m', '1h'], window_size=20)
- if feature_matrix is not None:
- logger.info(f"✅ Feature matrix shape: {feature_matrix.shape}")
- else:
- logger.warning("❌ Feature matrix generation failed")
- except Exception as e:
- logger.error(f"❌ Feature matrix error: {e}")
-
- # Test 2: Multi-symbol data access
- logger.info("\n=== Test 2: Multi-Symbol Data Access ===")
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- for timeframe in ['1s', '1m', '1h', '1d']:
- data = dp.get_historical_data(symbol, timeframe, limit=10)
- if data is not None and not data.empty:
- logger.info(f"✅ {symbol} {timeframe}: {len(data)} candles")
- else:
- logger.warning(f"❌ {symbol} {timeframe}: No data")
-
- # Test 3: Data consistency checks
- logger.info("\n=== Test 3: Data Consistency ===")
- eth_1m = dp.get_historical_data('ETH/USDT', '1m', limit=100)
- if eth_1m is not None and not eth_1m.empty:
- # Check for proper OHLCV structure
- required_cols = ['open', 'high', 'low', 'close', 'volume']
- has_all_cols = all(col in eth_1m.columns for col in required_cols)
- logger.info(f"✅ OHLCV columns present: {has_all_cols}")
-
- # Check data types
- numeric_cols = eth_1m[required_cols].dtypes
- all_numeric = all(pd.api.types.is_numeric_dtype(dtype) for dtype in numeric_cols)
- logger.info(f"✅ All columns numeric: {all_numeric}")
-
- # Check for NaN values
- has_nan = eth_1m[required_cols].isna().any().any()
- logger.info(f"✅ No NaN values: {not has_nan}")
-
- # Check price relationships (high >= low, etc.)
- price_valid = (eth_1m['high'] >= eth_1m['low']).all()
- logger.info(f"✅ Price relationships valid: {price_valid}")
-
- # Test 4: Performance test
- logger.info("\n=== Test 4: Performance Test ===")
- start_time = time.time()
- for i in range(100):
- data = dp.get_historical_data('ETH/USDT', '1m', limit=50)
- end_time = time.time()
- avg_time = (end_time - start_time) / 100 * 1000 # ms
- logger.info(f"✅ Average data access time: {avg_time:.2f}ms")
-
- # Test 5: Current price accuracy
- logger.info("\n=== Test 5: Current Price Accuracy ===")
- eth_price = dp.get_current_price('ETH/USDT')
- eth_data = dp.get_historical_data('ETH/USDT', '1s', limit=1)
- if eth_data is not None and not eth_data.empty:
- latest_close = eth_data.iloc[-1]['close']
- price_match = abs(eth_price - latest_close) < 0.01
- logger.info(f"✅ Current price matches latest candle: {price_match}")
- logger.info(f" Current price: ${eth_price}")
- logger.info(f" Latest close: ${latest_close}")
-
- # Test 6: Cache efficiency
- logger.info("\n=== Test 6: Cache Efficiency ===")
- cache_summary = dp.get_cached_data_summary()
- total_candles = 0
- for symbol_data in cache_summary['cached_data'].values():
- for tf_data in symbol_data.values():
- if isinstance(tf_data, dict) and 'candle_count' in tf_data:
- total_candles += tf_data['candle_count']
-
- logger.info(f"✅ Total cached candles: {total_candles}")
- logger.info(f"✅ Data maintenance active: {cache_summary['data_maintenance_active']}")
-
- # Test 7: Memory usage estimation
- logger.info("\n=== Test 7: Memory Usage Estimation ===")
- # Rough estimation: 8 columns * 8 bytes * total_candles
- estimated_memory_mb = (total_candles * 8 * 8) / (1024 * 1024)
- logger.info(f"✅ Estimated memory usage: {estimated_memory_mb:.2f} MB")
-
- # Clean shutdown
- dp.stop_automatic_data_maintenance()
- logger.info("\n✅ Integration test completed successfully!")
-
-if __name__ == "__main__":
- test_integration()
\ No newline at end of file
diff --git a/test_db_migration.py b/test_db_migration.py
deleted file mode 100644
index ade072e..0000000
--- a/test_db_migration.py
+++ /dev/null
@@ -1,46 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script to verify database migration works correctly
-"""
-
-import sys
-import os
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from utils.database_manager import get_database_manager, reset_database_manager
-import logging
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_migration():
- """Test the database migration"""
- try:
- logger.info("Testing database migration...")
-
- # Reset the database manager to force re-initialization
- reset_database_manager()
-
- # Get a new instance (this will trigger migration)
- db_manager = get_database_manager()
-
- # Test if we can access the input_features_blob column
- with db_manager._get_connection() as conn:
- cursor = conn.execute("PRAGMA table_info(inference_records)")
- columns = [row[1] for row in cursor.fetchall()]
-
- if 'input_features_blob' in columns:
- logger.info("✅ input_features_blob column exists - migration successful!")
- return True
- else:
- logger.error("❌ input_features_blob column missing - migration failed!")
- return False
-
- except Exception as e:
- logger.error(f"❌ Migration test failed: {e}")
- return False
-
-if __name__ == "__main__":
- success = test_migration()
- sys.exit(0 if success else 1)
\ No newline at end of file
diff --git a/test_deribit_integration.py b/test_deribit_integration.py
deleted file mode 100644
index 24ebb18..0000000
--- a/test_deribit_integration.py
+++ /dev/null
@@ -1,171 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Deribit Integration
-Test the new DeribitInterface and ExchangeFactory
-"""
-import os
-import sys
-import logging
-from dotenv import load_dotenv
-
-# Load environment variables
-load_dotenv()
-
-# Add project paths
-sys.path.append(os.path.join(os.path.dirname(__file__), 'NN'))
-sys.path.append(os.path.join(os.path.dirname(__file__), 'core'))
-
-from core.exchanges.exchange_factory import ExchangeFactory
-from core.exchanges.deribit_interface import DeribitInterface
-from core.config import get_config
-
-# Setup logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_deribit_credentials():
- """Test Deribit API credentials"""
- api_key = os.getenv('DERIBIT_API_CLIENTID')
- api_secret = os.getenv('DERIBIT_API_SECRET')
-
- logger.info(f"Deribit API Key: {'*' * 8 + api_key[-4:] if api_key and len(api_key) > 4 else 'Not set'}")
- logger.info(f"Deribit API Secret: {'*' * 8 + api_secret[-4:] if api_secret and len(api_secret) > 4 else 'Not set'}")
-
- return bool(api_key and api_secret)
-
-def test_deribit_interface():
- """Test DeribitInterface directly"""
- logger.info("Testing DeribitInterface directly...")
-
- try:
- # Create Deribit interface
- deribit = DeribitInterface(test_mode=True)
-
- # Test connection
- if deribit.connect():
- logger.info("✓ Successfully connected to Deribit testnet")
-
- # Test getting instruments
- btc_instruments = deribit.get_instruments('BTC')
- logger.info(f"✓ Found {len(btc_instruments)} BTC instruments")
-
- # Test getting ticker
- ticker = deribit.get_ticker('BTC-PERPETUAL')
- if ticker:
- logger.info(f"✓ BTC-PERPETUAL ticker: ${ticker.get('last_price', 'N/A')}")
-
- # Test getting account summary (if authenticated)
- account = deribit.get_account_summary('BTC')
- if account:
- logger.info(f"✓ BTC account balance: {account.get('available_funds', 'N/A')}")
-
- return True
- else:
- logger.error("✗ Failed to connect to Deribit")
- return False
-
- except Exception as e:
- logger.error(f"✗ Error testing DeribitInterface: {e}")
- return False
-
-def test_exchange_factory():
- """Test ExchangeFactory with config"""
- logger.info("Testing ExchangeFactory...")
-
- try:
- # Load config
- config = get_config()
- exchanges_config = config.get('exchanges', {})
-
- logger.info(f"Primary exchange: {exchanges_config.get('primary', 'Not set')}")
-
- # Test creating primary exchange
- primary_exchange = ExchangeFactory.get_primary_exchange(exchanges_config)
- if primary_exchange:
- logger.info(f"✓ Successfully created primary exchange: {type(primary_exchange).__name__}")
-
- # Test basic operations
- if hasattr(primary_exchange, 'get_ticker'):
- ticker = primary_exchange.get_ticker('BTC-PERPETUAL')
- if ticker:
- logger.info(f"✓ Primary exchange ticker test successful")
-
- return True
- else:
- logger.error("✗ Failed to create primary exchange")
- return False
-
- except Exception as e:
- logger.error(f"✗ Error testing ExchangeFactory: {e}")
- return False
-
-def test_multiple_exchanges():
- """Test creating multiple exchanges"""
- logger.info("Testing multiple exchanges...")
-
- try:
- config = get_config()
- exchanges_config = config.get('exchanges', {})
-
- # Create all configured exchanges
- exchanges = ExchangeFactory.create_multiple_exchanges(exchanges_config)
-
- logger.info(f"✓ Created {len(exchanges)} exchange interfaces:")
- for name, exchange in exchanges.items():
- logger.info(f" - {name}: {type(exchange).__name__}")
-
- return len(exchanges) > 0
-
- except Exception as e:
- logger.error(f"✗ Error testing multiple exchanges: {e}")
- return False
-
-def main():
- """Run all tests"""
- logger.info("=" * 50)
- logger.info("TESTING DERIBIT INTEGRATION")
- logger.info("=" * 50)
-
- tests = [
- ("Credentials", test_deribit_credentials),
- ("DeribitInterface", test_deribit_interface),
- ("ExchangeFactory", test_exchange_factory),
- ("Multiple Exchanges", test_multiple_exchanges)
- ]
-
- results = []
- for test_name, test_func in tests:
- logger.info(f"\n--- Testing {test_name} ---")
- try:
- result = test_func()
- results.append((test_name, result))
- status = "PASS" if result else "FAIL"
- logger.info(f"{test_name}: {status}")
- except Exception as e:
- logger.error(f"{test_name}: ERROR - {e}")
- results.append((test_name, False))
-
- # Summary
- logger.info("\n" + "=" * 50)
- logger.info("TEST SUMMARY")
- logger.info("=" * 50)
-
- passed = sum(1 for _, result in results if result)
- total = len(results)
-
- for test_name, result in results:
- status = "✓ PASS" if result else "✗ FAIL"
- logger.info(f"{status}: {test_name}")
-
- logger.info(f"\nOverall: {passed}/{total} tests passed")
-
- if passed == total:
- logger.info("🎉 All tests passed! Deribit integration is working.")
- return True
- else:
- logger.error("❌ Some tests failed. Check the logs above.")
- return False
-
-if __name__ == "__main__":
- success = main()
- sys.exit(0 if success else 1)
\ No newline at end of file
diff --git a/test_device_fix.py b/test_device_fix.py
deleted file mode 100644
index 20b0c51..0000000
--- a/test_device_fix.py
+++ /dev/null
@@ -1,141 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script to verify device mismatch fixes for GPU training
-"""
-
-import torch
-import logging
-import sys
-import os
-
-# Add the project root to the path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from NN.models.enhanced_cnn import EnhancedCNN
-from core.data_models import BaseDataInput, OHLCVBar
-
-# Configure logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_device_consistency():
- """Test that all tensors are on the same device"""
-
- logger.info("Testing device consistency for EnhancedCNN...")
-
- # Check if CUDA is available
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- logger.info(f"Using device: {device}")
-
- try:
- # Initialize the adapter
- adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
-
- # Verify adapter device
- logger.info(f"Adapter device: {adapter.device}")
- logger.info(f"Model device: {next(adapter.model.parameters()).device}")
-
- # Create sample data
- sample_ohlcv = [
- OHLCVBar(
- symbol="ETH/USDT",
- timeframe="1s",
- timestamp=1640995200.0, # 2022-01-01
- open=50000.0,
- high=51000.0,
- low=49000.0,
- close=50500.0,
- volume=1000.0
- )
- ] * 300 # 300 frames
-
- base_data = BaseDataInput(
- symbol="ETH/USDT",
- timestamp=1640995200.0,
- ohlcv_1s=sample_ohlcv,
- ohlcv_1m=sample_ohlcv,
- ohlcv_5m=sample_ohlcv,
- ohlcv_15m=sample_ohlcv,
- btc_ohlcv=sample_ohlcv,
- cob_data={},
- ma_data={},
- technical_indicators={},
- last_predictions={}
- )
-
- # Test prediction
- logger.info("Testing prediction...")
- prediction = adapter.predict(base_data)
- logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})")
-
- # Test training sample addition
- logger.info("Testing training sample addition...")
- adapter.add_training_sample(base_data, "BUY", 0.1)
- adapter.add_training_sample(base_data, "SELL", -0.05)
- adapter.add_training_sample(base_data, "HOLD", 0.02)
-
- # Test training
- logger.info("Testing training...")
- training_results = adapter.train(epochs=1)
- logger.info(f"Training results: {training_results}")
-
- logger.info("✅ All device consistency tests passed!")
- return True
-
- except Exception as e:
- logger.error(f"❌ Device consistency test failed: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-def test_orchestrator_inference_history():
- """Test that orchestrator properly initializes inference history"""
-
- logger.info("Testing orchestrator inference history initialization...")
-
- try:
- from core.orchestrator import TradingOrchestrator
- from core.data_provider import DataProvider
-
- # Initialize orchestrator
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider=data_provider)
-
- # Check if inference history is initialized
- logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
-
- # Check if models are registered
- logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}")
-
- # Verify each registered model has inference history
- for model_name in orchestrator.model_registry.models.keys():
- if model_name in orchestrator.inference_history:
- logger.info(f"✅ {model_name} has inference history initialized")
- else:
- logger.warning(f"❌ {model_name} missing inference history")
-
- logger.info("✅ Orchestrator inference history test completed!")
- return True
-
- except Exception as e:
- logger.error(f"❌ Orchestrator test failed: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-if __name__ == "__main__":
- logger.info("Starting device fix verification tests...")
-
- # Test 1: Device consistency
- test1_passed = test_device_consistency()
-
- # Test 2: Orchestrator inference history
- test2_passed = test_orchestrator_inference_history()
-
- # Summary
- if test1_passed and test2_passed:
- logger.info("🎉 All tests passed! Device issues should be fixed.")
- sys.exit(0)
- else:
- logger.error("❌ Some tests failed. Please check the logs above.")
- sys.exit(1)
\ No newline at end of file
diff --git a/test_device_training_fix.py b/test_device_training_fix.py
deleted file mode 100644
index b883043..0000000
--- a/test_device_training_fix.py
+++ /dev/null
@@ -1,153 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script to verify device handling and training sample population fixes
-"""
-
-import logging
-import asyncio
-import torch
-from datetime import datetime
-
-# Configure logging
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-def test_device_handling():
- """Test that device handling is working correctly"""
- try:
- logger.info("Testing device handling...")
-
- # Test 1: Check CUDA availability
- cuda_available = torch.cuda.is_available()
- device = torch.device("cuda" if cuda_available else "cpu")
- logger.info(f"CUDA available: {cuda_available}")
- logger.info(f"Using device: {device}")
-
- # Test 2: Initialize CNN adapter
- from core.enhanced_cnn_adapter import EnhancedCNNAdapter
-
- logger.info("Initializing CNN adapter...")
- cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
-
- logger.info(f"CNN adapter device: {cnn_adapter.device}")
- logger.info(f"CNN model device: {cnn_adapter.model.device}")
-
- # Test 3: Create test data
- from core.data_models import BaseDataInput
-
- logger.info("Creating test BaseDataInput...")
- base_data = BaseDataInput(
- symbol="ETH/USDT",
- timestamp=datetime.now(),
- ohlcv_1s=[],
- ohlcv_1m=[],
- ohlcv_1h=[],
- ohlcv_1d=[],
- btc_ohlcv_1s=[],
- cob_data=None,
- technical_indicators={},
- last_predictions={}
- )
-
- # Test 4: Make prediction (this should not cause device mismatch)
- logger.info("Making prediction...")
- prediction = cnn_adapter.predict(base_data)
-
- logger.info(f"Prediction successful: {prediction.predictions['action']}")
- logger.info(f"Confidence: {prediction.confidence:.4f}")
-
- # Test 5: Add training samples
- logger.info("Adding training samples...")
- cnn_adapter.add_training_sample(base_data, "BUY", 0.1)
- cnn_adapter.add_training_sample(base_data, "SELL", -0.05)
- cnn_adapter.add_training_sample(base_data, "HOLD", 0.02)
-
- logger.info(f"Training samples added: {len(cnn_adapter.training_data)}")
-
- # Test 6: Try training if we have enough samples
- if len(cnn_adapter.training_data) >= 2:
- logger.info("Attempting training...")
- training_results = cnn_adapter.train(epochs=1)
- logger.info(f"Training results: {training_results}")
- else:
- logger.info("Not enough samples for training")
-
- logger.info("✅ Device handling test passed!")
- return True
-
- except Exception as e:
- logger.error(f"❌ Device handling test failed: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-async def test_orchestrator_training():
- """Test that orchestrator properly adds training samples"""
- try:
- logger.info("Testing orchestrator training integration...")
-
- # Test 1: Initialize orchestrator
- from core.orchestrator import TradingOrchestrator
- from core.standardized_data_provider import StandardizedDataProvider
-
- logger.info("Initializing data provider...")
- data_provider = StandardizedDataProvider()
-
- logger.info("Initializing orchestrator...")
- orchestrator = TradingOrchestrator(data_provider=data_provider)
-
- # Test 2: Check if CNN adapter is available
- if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
- logger.info(f"✅ CNN adapter available in orchestrator")
- logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}")
- else:
- logger.warning("⚠️ CNN adapter not available in orchestrator")
- return False
-
- # Test 3: Make a trading decision (this should add training samples)
- logger.info("Making trading decision...")
- decision = await orchestrator.make_trading_decision("ETH/USDT")
-
- if decision:
- logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})")
- logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}")
- else:
- logger.warning("No decision made")
-
- # Test 4: Check inference history
- logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
- for model_name, history in orchestrator.inference_history.items():
- logger.info(f" {model_name}: {len(history)} records")
-
- logger.info("✅ Orchestrator training test passed!")
- return True
-
- except Exception as e:
- logger.error(f"❌ Orchestrator training test failed: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-async def main():
- """Run all tests"""
- logger.info("Starting device and training fix tests...")
-
- # Test 1: Device handling
- test1_passed = test_device_handling()
-
- # Test 2: Orchestrator training
- test2_passed = await test_orchestrator_training()
-
- # Summary
- logger.info("\n" + "="*50)
- logger.info("TEST SUMMARY:")
- logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
- logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
-
- if test1_passed and test2_passed:
- logger.info("🎉 All tests passed! Device and training issues should be fixed.")
- else:
- logger.error("❌ Some tests failed. Please check the logs above.")
-
-if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
diff --git a/test_enhanced_cnn_adapter.py b/test_enhanced_cnn_adapter.py
deleted file mode 100644
index a03ff1b..0000000
--- a/test_enhanced_cnn_adapter.py
+++ /dev/null
@@ -1,87 +0,0 @@
-"""
-Test Enhanced CNN Adapter
-
-This script tests the EnhancedCNNAdapter with standardized input format.
-"""
-
-import logging
-import time
-from datetime import datetime
-
-from core.standardized_data_provider import StandardizedDataProvider
-from core.enhanced_cnn_adapter import EnhancedCNNAdapter
-from core.data_models import create_model_output
-
-# Configure logging
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-def test_cnn_adapter():
- """Test the EnhancedCNNAdapter with standardized input format"""
- try:
- # Initialize data provider
- symbols = ['ETH/USDT', 'BTC/USDT']
- timeframes = ['1s', '1m', '1h', '1d']
- data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
-
- # Initialize CNN adapter
- cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
-
- # Load best checkpoint if available
- cnn_adapter.load_best_checkpoint()
-
- # Get standardized input data
- logger.info("Getting standardized input data...")
- base_data = data_provider.get_base_data_input('ETH/USDT')
-
- if base_data is None:
- logger.error("Failed to get base data input")
- return
-
- # Make prediction
- logger.info("Making prediction...")
- model_output = cnn_adapter.predict(base_data)
-
- # Log prediction
- logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
-
- # Store model output
- data_provider.store_model_output(model_output)
-
- # Add training sample (simulated)
- logger.info("Adding training sample...")
- cnn_adapter.add_training_sample(base_data, 'BUY', 0.05)
-
- # Train model
- logger.info("Training model...")
- metrics = cnn_adapter.train(epochs=1)
-
- # Log training metrics
- logger.info(f"Training metrics: {metrics}")
-
- # Make another prediction
- logger.info("Making another prediction...")
- model_output = cnn_adapter.predict(base_data)
-
- # Log prediction
- logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
-
- # Test model output manager
- logger.info("Testing model output manager...")
- output_manager = data_provider.get_model_output_manager()
-
- # Get current outputs
- current_outputs = output_manager.get_all_current_outputs('ETH/USDT')
- logger.info(f"Current outputs: {len(current_outputs)} models")
-
- # Evaluate model performance
- metrics = output_manager.evaluate_model_performance('ETH/USDT', 'enhanced_cnn_v1')
- logger.info(f"Performance metrics: {metrics}")
-
- logger.info("Test completed successfully")
-
- except Exception as e:
- logger.error(f"Error in test: {e}", exc_info=True)
-
-if __name__ == "__main__":
- test_cnn_adapter()
\ No newline at end of file
diff --git a/test_enhanced_cob_websocket.py b/test_enhanced_cob_websocket.py
deleted file mode 100644
index 367ec95..0000000
--- a/test_enhanced_cob_websocket.py
+++ /dev/null
@@ -1,148 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Enhanced COB WebSocket Implementation
-
-This script tests the enhanced COB WebSocket system to ensure:
-1. WebSocket connections work properly
-2. Fallback to REST API when WebSocket fails
-3. Dashboard status updates are working
-4. Clear error messages and warnings are displayed
-"""
-
-import asyncio
-import logging
-import sys
-import time
-from datetime import datetime
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-# Import the enhanced COB WebSocket
-try:
- from core.enhanced_cob_websocket import EnhancedCOBWebSocket, get_enhanced_cob_websocket
- print("✅ Enhanced COB WebSocket imported successfully")
-except ImportError as e:
- print(f"❌ Failed to import Enhanced COB WebSocket: {e}")
- sys.exit(1)
-
-async def test_dashboard_callback(status_data):
- """Test dashboard callback function"""
- print(f"📊 Dashboard callback received: {status_data}")
-
-async def test_cob_callback(symbol, cob_data):
- """Test COB data callback function"""
- stats = cob_data.get('stats', {})
- mid_price = stats.get('mid_price', 0)
- bid_levels = len(cob_data.get('bids', []))
- ask_levels = len(cob_data.get('asks', []))
- source = cob_data.get('source', 'unknown')
-
- print(f"📈 COB data for {symbol}: ${mid_price:.2f}, {bid_levels} bids, {ask_levels} asks (via {source})")
-
-async def main():
- """Main test function"""
- print("🚀 Testing Enhanced COB WebSocket System")
- print("=" * 60)
-
- # Test 1: Initialize Enhanced COB WebSocket
- print("\n1. Initializing Enhanced COB WebSocket...")
- try:
- cob_ws = EnhancedCOBWebSocket(
- symbols=['BTC/USDT', 'ETH/USDT'],
- dashboard_callback=test_dashboard_callback
- )
-
- # Add callbacks
- cob_ws.add_cob_callback(test_cob_callback)
-
- print("✅ Enhanced COB WebSocket initialized")
- except Exception as e:
- print(f"❌ Failed to initialize: {e}")
- return
-
- # Test 2: Start WebSocket connections
- print("\n2. Starting WebSocket connections...")
- try:
- await cob_ws.start()
- print("✅ WebSocket connections started")
- except Exception as e:
- print(f"❌ Failed to start connections: {e}")
- return
-
- # Test 3: Monitor connections for 30 seconds
- print("\n3. Monitoring connections for 30 seconds...")
- start_time = time.time()
-
- while time.time() - start_time < 30:
- try:
- # Get status summary
- status = cob_ws.get_status_summary()
- overall_status = status.get('overall_status', 'unknown')
-
- print(f"⏱️ Status: {overall_status}")
-
- # Print symbol-specific status
- for symbol, symbol_status in status.get('symbols', {}).items():
- connected = symbol_status.get('connected', False)
- fallback = symbol_status.get('rest_fallback_active', False)
- messages = symbol_status.get('messages_received', 0)
-
- if connected:
- print(f" {symbol}: ✅ Connected ({messages} messages)")
- elif fallback:
- print(f" {symbol}: ⚠️ REST fallback active")
- else:
- error = symbol_status.get('last_error', 'Unknown error')
- print(f" {symbol}: ❌ Error - {error}")
-
- await asyncio.sleep(5) # Check every 5 seconds
-
- except KeyboardInterrupt:
- print("\n⏹️ Test interrupted by user")
- break
- except Exception as e:
- print(f"❌ Error during monitoring: {e}")
- break
-
- # Test 4: Final status check
- print("\n4. Final status check...")
- try:
- final_status = cob_ws.get_status_summary()
- print(f"Final overall status: {final_status.get('overall_status', 'unknown')}")
-
- for symbol, symbol_status in final_status.get('symbols', {}).items():
- print(f" {symbol}:")
- print(f" Connected: {symbol_status.get('connected', False)}")
- print(f" Messages received: {symbol_status.get('messages_received', 0)}")
- print(f" REST fallback: {symbol_status.get('rest_fallback_active', False)}")
- if symbol_status.get('last_error'):
- print(f" Last error: {symbol_status.get('last_error')}")
-
- except Exception as e:
- print(f"❌ Error getting final status: {e}")
-
- # Test 5: Stop connections
- print("\n5. Stopping connections...")
- try:
- await cob_ws.stop()
- print("✅ Connections stopped successfully")
- except Exception as e:
- print(f"❌ Error stopping connections: {e}")
-
- print("\n" + "=" * 60)
- print("🏁 Enhanced COB WebSocket test completed")
-
-if __name__ == "__main__":
- try:
- asyncio.run(main())
- except KeyboardInterrupt:
- print("\n⏹️ Test interrupted")
- except Exception as e:
- print(f"❌ Test failed: {e}")
- import traceback
- traceback.print_exc()
\ No newline at end of file
diff --git a/test_enhanced_data_provider_websocket.py b/test_enhanced_data_provider_websocket.py
deleted file mode 100644
index 116766e..0000000
--- a/test_enhanced_data_provider_websocket.py
+++ /dev/null
@@ -1,149 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Enhanced Data Provider WebSocket Integration
-
-This script tests the integration between the Enhanced COB WebSocket and the Data Provider.
-"""
-
-import asyncio
-import logging
-import sys
-import time
-from datetime import datetime
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-# Import the enhanced data provider
-try:
- from core.data_provider import DataProvider
- print("✅ Enhanced Data Provider imported successfully")
-except ImportError as e:
- print(f"❌ Failed to import Enhanced Data Provider: {e}")
- sys.exit(1)
-
-async def test_enhanced_websocket_integration():
- """Test the enhanced WebSocket integration with data provider"""
- print("🚀 Testing Enhanced WebSocket Integration with Data Provider")
- print("=" * 70)
-
- # Test 1: Initialize Data Provider
- print("\n1. Initializing Data Provider...")
- try:
- data_provider = DataProvider(
- symbols=['ETH/USDT', 'BTC/USDT'],
- timeframes=['1m', '1h']
- )
- print("✅ Data Provider initialized")
- except Exception as e:
- print(f"❌ Failed to initialize Data Provider: {e}")
- return
-
- # Test 2: Start Enhanced WebSocket Streaming
- print("\n2. Starting Enhanced WebSocket streaming...")
- try:
- await data_provider.start_real_time_streaming()
- print("✅ Enhanced WebSocket streaming started")
- except Exception as e:
- print(f"❌ Failed to start WebSocket streaming: {e}")
- return
-
- # Test 3: Check WebSocket Status
- print("\n3. Checking WebSocket status...")
- try:
- status = data_provider.get_cob_websocket_status()
- overall_status = status.get('overall_status', 'unknown')
- print(f"Overall WebSocket status: {overall_status}")
-
- for symbol, symbol_status in status.get('symbols', {}).items():
- connected = symbol_status.get('connected', False)
- messages = symbol_status.get('messages_received', 0)
- fallback = symbol_status.get('rest_fallback_active', False)
-
- if connected:
- print(f" {symbol}: ✅ Connected ({messages} messages)")
- elif fallback:
- print(f" {symbol}: ⚠️ REST fallback active")
- else:
- print(f" {symbol}: ❌ Disconnected")
-
- except Exception as e:
- print(f"❌ Error checking WebSocket status: {e}")
-
- # Test 4: Monitor COB Data for 30 seconds
- print("\n4. Monitoring COB data for 30 seconds...")
- start_time = time.time()
- data_received = {'ETH/USDT': 0, 'BTC/USDT': 0}
-
- while time.time() - start_time < 30:
- try:
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- cob_data = data_provider.get_latest_cob_data(symbol)
- if cob_data:
- data_received[symbol] += 1
- if data_received[symbol] % 10 == 1: # Print every 10th update
- bids = len(cob_data.get('bids', []))
- asks = len(cob_data.get('asks', []))
- source = cob_data.get('source', 'unknown')
- mid_price = cob_data.get('stats', {}).get('mid_price', 0)
- print(f" 📊 {symbol}: ${mid_price:.2f}, {bids} bids, {asks} asks (via {source})")
-
- await asyncio.sleep(2) # Check every 2 seconds
-
- except KeyboardInterrupt:
- print("\n⏹️ Test interrupted by user")
- break
- except Exception as e:
- print(f"❌ Error monitoring COB data: {e}")
- break
-
- # Test 5: Final Status Check
- print("\n5. Final status check...")
- try:
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- count = data_received[symbol]
- if count > 0:
- print(f" {symbol}: ✅ Received {count} COB updates")
- else:
- print(f" {symbol}: ❌ No COB data received")
-
- # Check overall WebSocket status again
- final_status = data_provider.get_cob_websocket_status()
- print(f"Final WebSocket status: {final_status.get('overall_status', 'unknown')}")
-
- except Exception as e:
- print(f"❌ Error in final status check: {e}")
-
- # Test 6: Stop WebSocket Streaming
- print("\n6. Stopping WebSocket streaming...")
- try:
- await data_provider.stop_real_time_streaming()
- print("✅ WebSocket streaming stopped")
- except Exception as e:
- print(f"❌ Error stopping WebSocket streaming: {e}")
-
- print("\n" + "=" * 70)
- print("🏁 Enhanced WebSocket Integration Test Completed")
-
- # Summary
- total_updates = sum(data_received.values())
- if total_updates > 0:
- print(f"✅ SUCCESS: Received {total_updates} total COB updates")
- print("🎉 Enhanced WebSocket integration is working!")
- else:
- print("❌ FAILURE: No COB data received")
- print("⚠️ Enhanced WebSocket integration needs investigation")
-
-if __name__ == "__main__":
- try:
- asyncio.run(test_enhanced_websocket_integration())
- except KeyboardInterrupt:
- print("\n⏹️ Test interrupted")
- except Exception as e:
- print(f"❌ Test failed: {e}")
- import traceback
- traceback.print_exc()
\ No newline at end of file
diff --git a/test_enhanced_inference_logging.py b/test_enhanced_inference_logging.py
deleted file mode 100644
index 3a656f9..0000000
--- a/test_enhanced_inference_logging.py
+++ /dev/null
@@ -1,193 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Enhanced Inference Logging
-
-This script tests the enhanced inference logging system that stores
-full input features for training feedback.
-"""
-
-import sys
-import os
-import logging
-import numpy as np
-from datetime import datetime
-
-# Add project root to path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from core.enhanced_cnn_adapter import EnhancedCNNAdapter
-from core.data_models import BaseDataInput, OHLCVBar
-from utils.database_manager import get_database_manager
-from utils.inference_logger import get_inference_logger
-
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def create_test_base_data():
- """Create test BaseDataInput with realistic data"""
-
- # Create OHLCV bars for different timeframes
- def create_ohlcv_bars(symbol, timeframe, count=300):
- bars = []
- base_price = 3000.0 if 'ETH' in symbol else 50000.0
-
- for i in range(count):
- price = base_price + np.random.normal(0, base_price * 0.01)
- bars.append(OHLCVBar(
- symbol=symbol,
- timestamp=datetime.now(),
- open=price,
- high=price * 1.002,
- low=price * 0.998,
- close=price + np.random.normal(0, price * 0.005),
- volume=np.random.uniform(100, 1000),
- timeframe=timeframe
- ))
- return bars
-
- base_data = BaseDataInput(
- symbol="ETH/USDT",
- timestamp=datetime.now(),
- ohlcv_1s=create_ohlcv_bars("ETH/USDT", "1s", 300),
- ohlcv_1m=create_ohlcv_bars("ETH/USDT", "1m", 300),
- ohlcv_1h=create_ohlcv_bars("ETH/USDT", "1h", 300),
- ohlcv_1d=create_ohlcv_bars("ETH/USDT", "1d", 300),
- btc_ohlcv_1s=create_ohlcv_bars("BTC/USDT", "1s", 300),
- technical_indicators={
- 'rsi': 45.5,
- 'macd': 0.12,
- 'bb_upper': 3100.0,
- 'bb_lower': 2900.0,
- 'volume_ma': 500.0
- }
- )
-
- return base_data
-
-def test_enhanced_inference_logging():
- """Test the enhanced inference logging system"""
-
- logger.info("=== Testing Enhanced Inference Logging ===")
-
- try:
- # Initialize CNN adapter
- cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
- logger.info("✅ CNN adapter initialized")
-
- # Create test data
- base_data = create_test_base_data()
- logger.info("✅ Test data created")
-
- # Make a prediction (this should log inference data)
- logger.info("Making prediction...")
- model_output = cnn_adapter.predict(base_data)
- logger.info(f"✅ Prediction made: {model_output.predictions['action']} (confidence: {model_output.confidence:.3f})")
-
- # Verify inference was logged to database
- db_manager = get_database_manager()
- recent_inferences = db_manager.get_recent_inferences(cnn_adapter.model_name, limit=1)
-
- if recent_inferences:
- latest_inference = recent_inferences[0]
- logger.info(f"✅ Inference logged to database:")
- logger.info(f" Model: {latest_inference.model_name}")
- logger.info(f" Action: {latest_inference.action}")
- logger.info(f" Confidence: {latest_inference.confidence:.3f}")
- logger.info(f" Processing time: {latest_inference.processing_time_ms:.1f}ms")
- logger.info(f" Has input features: {latest_inference.input_features is not None}")
-
- if latest_inference.input_features is not None:
- logger.info(f" Input features shape: {latest_inference.input_features.shape}")
- logger.info(f" Input features sample: {latest_inference.input_features[:5]}")
- else:
- logger.error("❌ No inference records found in database")
- return False
-
- # Test training data loading from inference history
- logger.info("Testing training data loading from inference history...")
- original_training_count = len(cnn_adapter.training_data)
- cnn_adapter._load_training_data_from_inference_history()
- new_training_count = len(cnn_adapter.training_data)
-
- logger.info(f"✅ Training data loaded: {original_training_count} -> {new_training_count} samples")
-
- # Test prediction evaluation
- logger.info("Testing prediction evaluation...")
- evaluation_metrics = cnn_adapter.evaluate_predictions_against_outcomes(hours_back=1)
- logger.info(f"✅ Evaluation metrics: {evaluation_metrics}")
-
- # Test training with inference data
- if new_training_count >= cnn_adapter.batch_size:
- logger.info("Testing training with inference data...")
- training_metrics = cnn_adapter.train(epochs=1)
- logger.info(f"✅ Training completed: {training_metrics}")
- else:
- logger.info("⚠️ Not enough training data for training test")
-
- return True
-
- except Exception as e:
- logger.error(f"❌ Test failed: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-def test_database_query_methods():
- """Test the new database query methods"""
-
- logger.info("=== Testing Database Query Methods ===")
-
- try:
- db_manager = get_database_manager()
-
- # Test getting inference records for training
- training_records = db_manager.get_inference_records_for_training(
- model_name="enhanced_cnn",
- hours_back=24,
- limit=10
- )
-
- logger.info(f"✅ Found {len(training_records)} training records")
-
- for i, record in enumerate(training_records[:3]): # Show first 3
- logger.info(f" Record {i+1}:")
- logger.info(f" Action: {record.action}")
- logger.info(f" Confidence: {record.confidence:.3f}")
- logger.info(f" Has features: {record.input_features is not None}")
- if record.input_features is not None:
- logger.info(f" Features shape: {record.input_features.shape}")
-
- return True
-
- except Exception as e:
- logger.error(f"❌ Database query test failed: {e}")
- return False
-
-def main():
- """Run all tests"""
-
- logger.info("Starting Enhanced Inference Logging Tests")
-
- # Test 1: Enhanced inference logging
- test1_passed = test_enhanced_inference_logging()
-
- # Test 2: Database query methods
- test2_passed = test_database_query_methods()
-
- # Summary
- logger.info("=== Test Summary ===")
- logger.info(f"Enhanced Inference Logging: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
- logger.info(f"Database Query Methods: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
-
- if test1_passed and test2_passed:
- logger.info("🎉 All tests passed! Enhanced inference logging is working correctly.")
- logger.info("The system now:")
- logger.info(" - Stores full input features with each inference")
- logger.info(" - Can retrieve inference data for training feedback")
- logger.info(" - Supports continuous learning from inference history")
- logger.info(" - Evaluates prediction accuracy over time")
- else:
- logger.error("❌ Some tests failed. Please check the implementation.")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/test_enhanced_training_integration.py b/test_enhanced_training_integration.py
deleted file mode 100644
index 3568fff..0000000
--- a/test_enhanced_training_integration.py
+++ /dev/null
@@ -1,144 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Enhanced Training Integration
-
-This script tests the integration of EnhancedRealtimeTrainingSystem
-into the TradingOrchestrator to ensure it works correctly.
-"""
-
-import sys
-import os
-import logging
-import asyncio
-from datetime import datetime
-
-# Add project root to path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-async def test_enhanced_training_integration():
- """Test the enhanced training system integration"""
- try:
- logger.info("=" * 60)
- logger.info("TESTING ENHANCED TRAINING INTEGRATION")
- logger.info("=" * 60)
-
- # 1. Initialize orchestrator with enhanced training
- logger.info("1. Initializing orchestrator with enhanced training...")
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(
- data_provider=data_provider,
- enhanced_rl_training=True
- )
-
- # 2. Check if training system is available
- logger.info("2. Checking training system availability...")
- training_available = hasattr(orchestrator, 'enhanced_training_system')
- training_enabled = getattr(orchestrator, 'training_enabled', False)
-
- logger.info(f" - Training system attribute: {'✅ Available' if training_available else '❌ Missing'}")
- logger.info(f" - Training enabled: {'✅ Yes' if training_enabled else '❌ No'}")
-
- # 3. Test training system initialization
- if training_available and orchestrator.enhanced_training_system:
- logger.info("3. Testing training system methods...")
-
- # Test getting training statistics
- stats = orchestrator.get_enhanced_training_stats()
- logger.info(f" - Training stats retrieved: {len(stats)} fields")
- logger.info(f" - Training enabled in stats: {stats.get('training_enabled', False)}")
- logger.info(f" - System available: {stats.get('system_available', False)}")
-
- # Test starting training
- start_result = orchestrator.start_enhanced_training()
- logger.info(f" - Start training result: {'✅ Success' if start_result else '❌ Failed'}")
-
- if start_result:
- # Let it run for a few seconds
- logger.info(" - Letting training run for 5 seconds...")
- await asyncio.sleep(5)
-
- # Get updated stats
- updated_stats = orchestrator.get_enhanced_training_stats()
- logger.info(f" - Updated stats: {updated_stats.get('is_training', False)}")
-
- # Stop training
- stop_result = orchestrator.stop_enhanced_training()
- logger.info(f" - Stop training result: {'✅ Success' if stop_result else '❌ Failed'}")
-
- else:
- logger.warning("3. Training system not available - checking fallback behavior...")
-
- # Test methods when training system is not available
- stats = orchestrator.get_enhanced_training_stats()
- logger.info(f" - Fallback stats: {stats}")
-
- start_result = orchestrator.start_enhanced_training()
- logger.info(f" - Fallback start result: {start_result}")
-
- # 4. Test dashboard connection method
- logger.info("4. Testing dashboard connection method...")
- try:
- orchestrator.set_training_dashboard(None) # Test with None
- logger.info(" - Dashboard connection method: ✅ Available")
- except Exception as e:
- logger.error(f" - Dashboard connection method error: {e}")
-
- # 5. Summary
- logger.info("=" * 60)
- logger.info("INTEGRATION TEST SUMMARY")
- logger.info("=" * 60)
-
- if training_available and training_enabled:
- logger.info("✅ ENHANCED TRAINING INTEGRATION SUCCESSFUL")
- logger.info(" - Training system properly integrated")
- logger.info(" - All methods available and functional")
- logger.info(" - Ready for real-time training")
- elif training_available:
- logger.info("⚠️ ENHANCED TRAINING PARTIALLY INTEGRATED")
- logger.info(" - Training system available but not enabled")
- logger.info(" - Check EnhancedRealtimeTrainingSystem import")
- else:
- logger.info("❌ ENHANCED TRAINING INTEGRATION FAILED")
- logger.info(" - Training system not properly integrated")
- logger.info(" - Methods missing or non-functional")
-
- return training_available and training_enabled
-
- except Exception as e:
- logger.error(f"Error in integration test: {e}")
- import traceback
- logger.error(traceback.format_exc())
- return False
-
-async def main():
- """Main test function"""
- try:
- success = await test_enhanced_training_integration()
-
- if success:
- logger.info("🎉 All tests passed! Enhanced training integration is working.")
- return 0
- else:
- logger.warning("⚠️ Some tests failed. Check the integration.")
- return 1
-
- except KeyboardInterrupt:
- logger.info("Test interrupted by user")
- return 0
- except Exception as e:
- logger.error(f"Fatal error in test: {e}")
- return 1
-
-if __name__ == "__main__":
- exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
diff --git a/test_enhanced_training_simple.py b/test_enhanced_training_simple.py
deleted file mode 100644
index f3f600c..0000000
--- a/test_enhanced_training_simple.py
+++ /dev/null
@@ -1,78 +0,0 @@
-#!/usr/bin/env python3
-"""
-Simple Enhanced Training Test
-
-Quick test to verify enhanced training system can be enabled and controlled.
-"""
-
-import sys
-import os
-import logging
-
-# Add project root to path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-
-# Configure logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_enhanced_training():
- """Test enhanced training system"""
- try:
- logger.info("Testing Enhanced Training System...")
-
- # 1. Create data provider
- data_provider = DataProvider()
-
- # 2. Create orchestrator with enhanced training ENABLED
- logger.info("Creating orchestrator with enhanced_rl_training=True...")
- orchestrator = TradingOrchestrator(
- data_provider=data_provider,
- enhanced_rl_training=True # 🔥 THIS ENABLES IT
- )
-
- # 3. Check if training system is available
- logger.info(f"Training system available: {orchestrator.enhanced_training_system is not None}")
- logger.info(f"Training enabled: {orchestrator.training_enabled}")
-
- # 4. Get training stats
- stats = orchestrator.get_enhanced_training_stats()
- logger.info(f"Training stats: {stats}")
-
- # 5. Test start/stop
- if orchestrator.enhanced_training_system:
- logger.info("Testing start/stop functionality...")
-
- # Start training
- start_result = orchestrator.start_enhanced_training()
- logger.info(f"Start result: {start_result}")
-
- # Get updated stats
- updated_stats = orchestrator.get_enhanced_training_stats()
- logger.info(f"Updated stats: {updated_stats}")
-
- # Stop training
- stop_result = orchestrator.stop_enhanced_training()
- logger.info(f"Stop result: {stop_result}")
-
- logger.info("✅ Enhanced training system is working!")
- return True
- else:
- logger.warning("❌ Enhanced training system not available")
- return False
-
- except Exception as e:
- logger.error(f"Error testing enhanced training: {e}")
- return False
-
-if __name__ == "__main__":
- success = test_enhanced_training()
- if success:
- print("\n🎉 Enhanced training system is ready to use!")
- print("To enable it in your main system, use:")
- print(" enhanced_rl_training=True when creating TradingOrchestrator")
- else:
- print("\n⚠️ Enhanced training system has issues. Check the logs above.")
\ No newline at end of file
diff --git a/test_fifo_queues.py b/test_fifo_queues.py
deleted file mode 100644
index ea905c9..0000000
--- a/test_fifo_queues.py
+++ /dev/null
@@ -1,285 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test FIFO Queue System
-
-Verify that the orchestrator's FIFO queue system works correctly
-"""
-
-import time
-from datetime import datetime
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-from core.data_models import OHLCVBar
-
-def test_fifo_queue_operations():
- """Test basic FIFO queue operations"""
- print("=== Testing FIFO Queue Operations ===")
-
- try:
- # Create orchestrator
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider)
-
- # Test queue status
- status = orchestrator.get_queue_status()
- print(f"Initial queue status: {status}")
-
- # Test adding data to queues
- test_bar = OHLCVBar(
- symbol="ETH/USDT",
- timestamp=datetime.now(),
- open=2500.0,
- high=2510.0,
- low=2490.0,
- close=2505.0,
- volume=1000.0,
- timeframe="1s"
- )
-
- # Add test data
- success = orchestrator.update_data_queue('ohlcv_1s', 'ETH/USDT', test_bar)
- print(f"Added OHLCV data: {success}")
-
- # Check queue status after adding data
- status = orchestrator.get_queue_status()
- print(f"Queue status after adding data: {status}")
-
- # Test retrieving data
- latest_data = orchestrator.get_latest_data('ohlcv_1s', 'ETH/USDT', 1)
- print(f"Retrieved latest data: {len(latest_data)} items")
-
- if latest_data:
- bar = latest_data[0]
- print(f" Bar: {bar.symbol} {bar.close} @ {bar.timestamp}")
-
- # Test minimum data check
- has_min_data = orchestrator.ensure_minimum_data('ohlcv_1s', 'ETH/USDT', 1)
- print(f"Has minimum data (1): {has_min_data}")
-
- has_min_data_100 = orchestrator.ensure_minimum_data('ohlcv_1s', 'ETH/USDT', 100)
- print(f"Has minimum data (100): {has_min_data_100}")
-
- return True
-
- except Exception as e:
- print(f"❌ FIFO queue operations test failed: {e}")
- return False
-
-def test_data_queue_filling():
- """Test filling queues with multiple data points"""
- print("\n=== Testing Data Queue Filling ===")
-
- try:
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider)
-
- # Add multiple OHLCV bars
- for i in range(150): # Add 150 bars
- test_bar = OHLCVBar(
- symbol="ETH/USDT",
- timestamp=datetime.now(),
- open=2500.0 + i,
- high=2510.0 + i,
- low=2490.0 + i,
- close=2505.0 + i,
- volume=1000.0 + i,
- timeframe="1s"
- )
- orchestrator.update_data_queue('ohlcv_1s', 'ETH/USDT', test_bar)
-
- # Check queue status
- status = orchestrator.get_queue_status()
- print(f"Queue status after adding 150 bars: {status}")
-
- # Test minimum data requirements
- has_min_data = orchestrator.ensure_minimum_data('ohlcv_1s', 'ETH/USDT', 100)
- print(f"Has minimum data (100): {has_min_data}")
-
- # Get all data
- all_data = orchestrator.get_queue_data('ohlcv_1s', 'ETH/USDT')
- print(f"Total data in queue: {len(all_data)} items")
-
- # Test max_items parameter
- limited_data = orchestrator.get_queue_data('ohlcv_1s', 'ETH/USDT', max_items=50)
- print(f"Limited data (50): {len(limited_data)} items")
-
- return True
-
- except Exception as e:
- print(f"❌ Data queue filling test failed: {e}")
- return False
-
-def test_base_data_input_building():
- """Test building BaseDataInput from FIFO queues"""
- print("\n=== Testing BaseDataInput Building ===")
-
- try:
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider)
-
- # Fill queues with sufficient data
- timeframes = ['1s', '1m', '1h', '1d']
- min_counts = [100, 50, 20, 10]
-
- for timeframe, min_count in zip(timeframes, min_counts):
- for i in range(min_count + 10): # Add a bit more than minimum
- test_bar = OHLCVBar(
- symbol="ETH/USDT",
- timestamp=datetime.now(),
- open=2500.0 + i,
- high=2510.0 + i,
- low=2490.0 + i,
- close=2505.0 + i,
- volume=1000.0 + i,
- timeframe=timeframe
- )
- orchestrator.update_data_queue(f'ohlcv_{timeframe}', 'ETH/USDT', test_bar)
-
- # Add BTC data
- for i in range(110):
- btc_bar = OHLCVBar(
- symbol="BTC/USDT",
- timestamp=datetime.now(),
- open=50000.0 + i,
- high=50100.0 + i,
- low=49900.0 + i,
- close=50050.0 + i,
- volume=100.0 + i,
- timeframe="1s"
- )
- orchestrator.update_data_queue('ohlcv_1s', 'BTC/USDT', btc_bar)
-
- # Add technical indicators
- test_indicators = {'rsi': 50.0, 'macd': 0.1, 'bb_upper': 2520.0, 'bb_lower': 2480.0}
- orchestrator.update_data_queue('technical_indicators', 'ETH/USDT', test_indicators)
-
- # Try to build BaseDataInput
- base_data = orchestrator.build_base_data_input('ETH/USDT')
-
- if base_data:
- print("✅ BaseDataInput built successfully")
-
- # Test feature vector
- features = base_data.get_feature_vector()
- print(f" Feature vector size: {len(features)}")
- print(f" Symbol: {base_data.symbol}")
- print(f" OHLCV 1s data: {len(base_data.ohlcv_1s)} bars")
- print(f" OHLCV 1m data: {len(base_data.ohlcv_1m)} bars")
- print(f" BTC data: {len(base_data.btc_ohlcv_1s)} bars")
- print(f" Technical indicators: {len(base_data.technical_indicators)}")
-
- # Validate
- is_valid = base_data.validate()
- print(f" Validation: {is_valid}")
-
- return True
- else:
- print("❌ Failed to build BaseDataInput")
- return False
-
- except Exception as e:
- print(f"❌ BaseDataInput building test failed: {e}")
- return False
-
-def test_consistent_feature_size():
- """Test that feature vectors are always the same size"""
- print("\n=== Testing Consistent Feature Size ===")
-
- try:
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider)
-
- # Fill with minimal data first
- for timeframe, min_count in [('1s', 100), ('1m', 50), ('1h', 20), ('1d', 10)]:
- for i in range(min_count):
- test_bar = OHLCVBar(
- symbol="ETH/USDT",
- timestamp=datetime.now(),
- open=2500.0 + i,
- high=2510.0 + i,
- low=2490.0 + i,
- close=2505.0 + i,
- volume=1000.0 + i,
- timeframe=timeframe
- )
- orchestrator.update_data_queue(f'ohlcv_{timeframe}', 'ETH/USDT', test_bar)
-
- # Add BTC data
- for i in range(100):
- btc_bar = OHLCVBar(
- symbol="BTC/USDT",
- timestamp=datetime.now(),
- open=50000.0 + i,
- high=50100.0 + i,
- low=49900.0 + i,
- close=50050.0 + i,
- volume=100.0 + i,
- timeframe="1s"
- )
- orchestrator.update_data_queue('ohlcv_1s', 'BTC/USDT', btc_bar)
-
- feature_sizes = []
-
- # Test multiple scenarios
- scenarios = [
- ("Minimal data", {}),
- ("With indicators", {'rsi': 50.0, 'macd': 0.1}),
- ("More indicators", {'rsi': 45.0, 'macd': 0.2, 'bb_upper': 2520.0, 'bb_lower': 2480.0, 'ema_20': 2500.0})
- ]
-
- for name, indicators in scenarios:
- if indicators:
- orchestrator.update_data_queue('technical_indicators', 'ETH/USDT', indicators)
-
- base_data = orchestrator.build_base_data_input('ETH/USDT')
- if base_data:
- features = base_data.get_feature_vector()
- feature_sizes.append(len(features))
- print(f"{name}: {len(features)} features")
- else:
- print(f"{name}: Failed to build BaseDataInput")
- return False
-
- # Check consistency
- if len(set(feature_sizes)) == 1:
- print(f"✅ All feature vectors have consistent size: {feature_sizes[0]}")
- return True
- else:
- print(f"❌ Inconsistent feature sizes: {feature_sizes}")
- return False
-
- except Exception as e:
- print(f"❌ Consistent feature size test failed: {e}")
- return False
-
-def main():
- """Run all FIFO queue tests"""
- print("=== FIFO Queue System Test Suite ===\n")
-
- tests = [
- test_fifo_queue_operations,
- test_data_queue_filling,
- test_base_data_input_building,
- test_consistent_feature_size
- ]
-
- passed = 0
- total = len(tests)
-
- for test in tests:
- if test():
- passed += 1
- print()
-
- print(f"=== Test Results: {passed}/{total} passed ===")
-
- if passed == total:
- print("✅ ALL TESTS PASSED!")
- print("✅ FIFO queue system is working correctly")
- print("✅ Consistent data flow ensured")
- print("✅ No more network rebuilding issues")
- else:
- print("❌ Some tests failed")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/test_hold_position_fix.py b/test_hold_position_fix.py
deleted file mode 100644
index e69de29..0000000
diff --git a/test_imbalance_calculation.py b/test_imbalance_calculation.py
deleted file mode 100644
index 4372985..0000000
--- a/test_imbalance_calculation.py
+++ /dev/null
@@ -1,118 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script for imbalance calculation logic
-"""
-
-import time
-import logging
-from datetime import datetime
-from core.data_provider import DataProvider
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_imbalance_calculation():
- """Test the imbalance calculation logic with mock data"""
- logger.info("Testing imbalance calculation logic...")
-
- # Initialize data provider
- dp = DataProvider()
-
- # Create mock COB tick data
- mock_ticks = []
- current_time = time.time()
-
- # Create 10 mock ticks with different imbalances
- for i in range(10):
- tick = {
- 'symbol': 'ETH/USDT',
- 'timestamp': current_time - (10 - i), # 10 seconds ago to now
- 'bids': [
- [3800 + i, 100 + i * 10], # Price, Volume
- [3799 + i, 50 + i * 5],
- [3798 + i, 25 + i * 2]
- ],
- 'asks': [
- [3801 + i, 80 + i * 8], # Price, Volume
- [3802 + i, 40 + i * 4],
- [3803 + i, 20 + i * 2]
- ],
- 'stats': {
- 'mid_price': 3800.5 + i,
- 'spread_bps': 2.5 + i * 0.1,
- 'imbalance': (i - 5) * 0.1 # Varying imbalance from -0.5 to +0.4
- },
- 'source': 'mock'
- }
- mock_ticks.append(tick)
-
- # Add mock ticks to the data provider
- for tick in mock_ticks:
- dp.cob_raw_ticks['ETH/USDT'].append(tick)
-
- logger.info(f"Added {len(mock_ticks)} mock COB ticks")
-
- # Test the aggregation function
- logger.info("\n=== Testing COB Aggregation ===")
- target_second = int(current_time - 5) # 5 seconds ago
-
- # Manually call the aggregation function
- dp._aggregate_cob_1s('ETH/USDT', target_second)
-
- # Check the results
- aggregated_data = list(dp.cob_1s_aggregated['ETH/USDT'])
- if aggregated_data:
- latest = aggregated_data[-1]
- logger.info(f"Aggregated data created:")
- logger.info(f" Timestamp: {latest.get('timestamp')}")
- logger.info(f" Tick count: {latest.get('tick_count')}")
- logger.info(f" Current imbalance: {latest.get('imbalance', 0):.4f}")
- logger.info(f" Total volume: {latest.get('total_volume', 0):.2f}")
- logger.info(f" Bid buckets: {len(latest.get('bid_buckets', {}))}")
- logger.info(f" Ask buckets: {len(latest.get('ask_buckets', {}))}")
-
- # Check multi-timeframe imbalances
- logger.info(f" Imbalance 1s: {latest.get('imbalance_1s', 0):.4f}")
- logger.info(f" Imbalance 5s: {latest.get('imbalance_5s', 0):.4f}")
- logger.info(f" Imbalance 15s: {latest.get('imbalance_15s', 0):.4f}")
- logger.info(f" Imbalance 60s: {latest.get('imbalance_60s', 0):.4f}")
- else:
- logger.warning("No aggregated data created")
-
- # Test multiple aggregations to build history
- logger.info("\n=== Testing Multi-timeframe Imbalances ===")
- for i in range(1, 6):
- target_second = int(current_time - 5 + i)
- dp._aggregate_cob_1s('ETH/USDT', target_second)
-
- # Check the final results
- final_data = list(dp.cob_1s_aggregated['ETH/USDT'])
- logger.info(f"Created {len(final_data)} aggregated records")
-
- if final_data:
- latest = final_data[-1]
- logger.info(f"Final imbalance indicators:")
- logger.info(f" 1s: {latest.get('imbalance_1s', 0):.4f}")
- logger.info(f" 5s: {latest.get('imbalance_5s', 0):.4f}")
- logger.info(f" 15s: {latest.get('imbalance_15s', 0):.4f}")
- logger.info(f" 60s: {latest.get('imbalance_60s', 0):.4f}")
-
- # Test the COB data quality function
- logger.info("\n=== Testing COB Data Quality Function ===")
- quality = dp.get_cob_data_quality()
-
- eth_quality = quality.get('imbalance_indicators', {}).get('ETH/USDT', {})
- if eth_quality:
- logger.info("COB quality indicators:")
- logger.info(f" Imbalance 1s: {eth_quality.get('imbalance_1s', 0):.4f}")
- logger.info(f" Imbalance 5s: {eth_quality.get('imbalance_5s', 0):.4f}")
- logger.info(f" Imbalance 15s: {eth_quality.get('imbalance_15s', 0):.4f}")
- logger.info(f" Imbalance 60s: {eth_quality.get('imbalance_60s', 0):.4f}")
- logger.info(f" Total volume: {eth_quality.get('total_volume', 0):.2f}")
- logger.info(f" Bucket count: {eth_quality.get('bucket_count', 0)}")
-
- logger.info("\n✅ Imbalance calculation test completed successfully!")
-
-if __name__ == "__main__":
- test_imbalance_calculation()
\ No newline at end of file
diff --git a/test_improved_data_integration.py b/test_improved_data_integration.py
deleted file mode 100644
index c10bc32..0000000
--- a/test_improved_data_integration.py
+++ /dev/null
@@ -1,222 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Improved Data Integration
-
-Test the enhanced data integration with fallback strategies
-"""
-
-import time
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-
-def test_enhanced_data_population():
- """Test enhanced data population with fallback strategies"""
- print("=== Testing Enhanced Data Population ===")
-
- try:
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider)
-
- # Wait for initial population
- print("Waiting 5 seconds for enhanced data population...")
- time.sleep(5)
-
- # Check detailed queue status
- print("\nDetailed queue status after enhanced population:")
- orchestrator.log_queue_status(detailed=True)
-
- # Check minimum data requirements
- symbols_to_check = ['ETH/USDT', 'BTC/USDT']
- timeframes_to_check = ['1s', '1m', '1h', '1d']
- min_requirements = {'1s': 100, '1m': 50, '1h': 20, '1d': 10}
-
- print(f"\nChecking minimum data requirements with fallback:")
- all_requirements_met = True
-
- for symbol in symbols_to_check:
- print(f"\n{symbol}:")
- symbol_requirements_met = True
-
- for timeframe in timeframes_to_check:
- min_count = min_requirements.get(timeframe, 10)
- has_min = orchestrator.ensure_minimum_data(f'ohlcv_{timeframe}', symbol, min_count)
- actual_count = 0
- if f'ohlcv_{timeframe}' in orchestrator.data_queues and symbol in orchestrator.data_queues[f'ohlcv_{timeframe}']:
- with orchestrator.data_queue_locks[f'ohlcv_{timeframe}'][symbol]:
- actual_count = len(orchestrator.data_queues[f'ohlcv_{timeframe}'][symbol])
-
- status = "✅" if has_min else "❌"
- print(f" {timeframe}: {status} {actual_count}/{min_count}")
-
- if not has_min:
- symbol_requirements_met = False
- all_requirements_met = False
-
- # Check technical indicators
- indicators_count = 0
- if 'technical_indicators' in orchestrator.data_queues and symbol in orchestrator.data_queues['technical_indicators']:
- with orchestrator.data_queue_locks['technical_indicators'][symbol]:
- indicators_data = list(orchestrator.data_queues['technical_indicators'][symbol])
- if indicators_data:
- indicators_count = len(indicators_data[-1]) # Latest indicators dict
-
- indicators_status = "✅" if indicators_count > 0 else "❌"
- print(f" indicators: {indicators_status} {indicators_count} calculated")
-
- # Test BaseDataInput building
- print(f"\nTesting BaseDataInput building with fallback:")
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- base_data = orchestrator.build_base_data_input(symbol)
- if base_data:
- features = base_data.get_feature_vector()
- print(f" ✅ {symbol}: BaseDataInput built successfully")
- print(f" Feature vector size: {len(features)}")
- print(f" OHLCV 1s bars: {len(base_data.ohlcv_1s)}")
- print(f" OHLCV 1m bars: {len(base_data.ohlcv_1m)}")
- print(f" OHLCV 1h bars: {len(base_data.ohlcv_1h)}")
- print(f" OHLCV 1d bars: {len(base_data.ohlcv_1d)}")
- print(f" BTC bars: {len(base_data.btc_ohlcv_1s)}")
- print(f" Technical indicators: {len(base_data.technical_indicators)}")
-
- # Validate feature vector
- if len(features) == 7850:
- print(f" ✅ Feature vector has correct size (7850)")
- else:
- print(f" ❌ Feature vector size mismatch: {len(features)} != 7850")
- else:
- print(f" ❌ {symbol}: Failed to build BaseDataInput")
-
- return all_requirements_met
-
- except Exception as e:
- print(f"❌ Test failed: {e}")
- return False
-
-def test_fallback_strategies():
- """Test specific fallback strategies"""
- print("\n=== Testing Fallback Strategies ===")
-
- try:
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider)
-
- # Wait for initial population
- time.sleep(3)
-
- # Check if fallback strategies were used
- print("Checking fallback strategy usage:")
-
- # Check ETH/USDT 1s data (likely to need fallback)
- eth_1s_count = 0
- if 'ohlcv_1s' in orchestrator.data_queues and 'ETH/USDT' in orchestrator.data_queues['ohlcv_1s']:
- with orchestrator.data_queue_locks['ohlcv_1s']['ETH/USDT']:
- eth_1s_count = len(orchestrator.data_queues['ohlcv_1s']['ETH/USDT'])
-
- if eth_1s_count >= 100:
- print(f" ✅ ETH/USDT 1s data: {eth_1s_count} bars (fallback likely used)")
- else:
- print(f" ❌ ETH/USDT 1s data: {eth_1s_count} bars (fallback may have failed)")
-
- # Check ETH/USDT 1h data (likely to need fallback)
- eth_1h_count = 0
- if 'ohlcv_1h' in orchestrator.data_queues and 'ETH/USDT' in orchestrator.data_queues['ohlcv_1h']:
- with orchestrator.data_queue_locks['ohlcv_1h']['ETH/USDT']:
- eth_1h_count = len(orchestrator.data_queues['ohlcv_1h']['ETH/USDT'])
-
- if eth_1h_count >= 20:
- print(f" ✅ ETH/USDT 1h data: {eth_1h_count} bars (fallback likely used)")
- else:
- print(f" ❌ ETH/USDT 1h data: {eth_1h_count} bars (fallback may have failed)")
-
- # Test manual fallback strategy
- print(f"\nTesting manual fallback strategy:")
- missing_data = [('ohlcv_1s', 0, 100), ('ohlcv_1h', 0, 20)]
- fallback_success = orchestrator._try_fallback_data_strategy('ETH/USDT', missing_data)
- print(f" Manual fallback result: {'✅ SUCCESS' if fallback_success else '❌ FAILED'}")
-
- return eth_1s_count >= 100 and eth_1h_count >= 20
-
- except Exception as e:
- print(f"❌ Test failed: {e}")
- return False
-
-def test_model_predictions():
- """Test that models can now make predictions with the improved data"""
- print("\n=== Testing Model Predictions ===")
-
- try:
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider)
-
- # Wait for data population
- time.sleep(5)
-
- # Try to make predictions
- print("Testing model prediction capability:")
-
- # Test CNN prediction
- try:
- base_data = orchestrator.build_base_data_input('ETH/USDT')
- if base_data:
- print(" ✅ BaseDataInput available for CNN")
-
- # Test feature vector
- features = base_data.get_feature_vector()
- if len(features) == 7850:
- print(" ✅ Feature vector has correct size for CNN")
- print(" ✅ CNN should be able to make predictions without rebuilding")
- else:
- print(f" ❌ Feature vector size issue: {len(features)} != 7850")
- else:
- print(" ❌ BaseDataInput not available for CNN")
- except Exception as e:
- print(f" ❌ CNN prediction test failed: {e}")
-
- # Test RL prediction
- try:
- base_data = orchestrator.build_base_data_input('ETH/USDT')
- if base_data:
- print(" ✅ BaseDataInput available for RL")
-
- # Test state features
- state_features = base_data.get_feature_vector()
- if len(state_features) == 7850:
- print(" ✅ State features have correct size for RL")
- else:
- print(f" ❌ State features size issue: {len(state_features)} != 7850")
- else:
- print(" ❌ BaseDataInput not available for RL")
- except Exception as e:
- print(f" ❌ RL prediction test failed: {e}")
-
- return base_data is not None
-
- except Exception as e:
- print(f"❌ Test failed: {e}")
- return False
-
-def main():
- """Run all enhanced data integration tests"""
- print("=== Enhanced Data Integration Test Suite ===")
-
- test1_passed = test_enhanced_data_population()
- test2_passed = test_fallback_strategies()
- test3_passed = test_model_predictions()
-
- print(f"\n=== Results ===")
- print(f"Enhanced data population: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
- print(f"Fallback strategies: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
- print(f"Model predictions: {'✅ PASSED' if test3_passed else '❌ FAILED'}")
-
- if test1_passed and test2_passed and test3_passed:
- print("\n✅ ALL TESTS PASSED!")
- print("✅ Enhanced data integration is working!")
- print("✅ Fallback strategies provide missing data")
- print("✅ Models should be able to make predictions")
- print("✅ No more 'Insufficient data' errors expected")
- else:
- print("\n⚠️ Some tests failed, but system may still work")
- print("⚠️ Check specific failures above")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/test_integrated_standardized_provider.py b/test_integrated_standardized_provider.py
deleted file mode 100644
index 87164e4..0000000
--- a/test_integrated_standardized_provider.py
+++ /dev/null
@@ -1,181 +0,0 @@
-"""
-Test script for integrated StandardizedDataProvider with ModelOutputManager
-
-This script tests the complete standardized data provider with extensible model output storage
-"""
-
-import sys
-import os
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-import logging
-from datetime import datetime
-from core.standardized_data_provider import StandardizedDataProvider
-from core.data_models import create_model_output
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_integrated_standardized_provider():
- """Test the integrated StandardizedDataProvider with ModelOutputManager"""
-
- print("Testing Integrated StandardizedDataProvider with ModelOutputManager...")
-
- # Initialize the provider
- symbols = ['ETH/USDT', 'BTC/USDT']
- timeframes = ['1s', '1m', '1h', '1d']
-
- provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
-
- print("✅ StandardizedDataProvider initialized with ModelOutputManager")
-
- # Test 1: Store model outputs from different types
- print("\n1. Testing model output storage integration...")
-
- # Create and store outputs from different model types
- model_outputs = [
- create_model_output('cnn', 'enhanced_cnn_v1', 'ETH/USDT', 'BUY', 0.85),
- create_model_output('rl', 'dqn_agent_v2', 'ETH/USDT', 'SELL', 0.72),
- create_model_output('transformer', 'transformer_v1', 'ETH/USDT', 'BUY', 0.91),
- create_model_output('orchestrator', 'main_orchestrator', 'ETH/USDT', 'BUY', 0.78)
- ]
-
- for output in model_outputs:
- provider.store_model_output(output)
- print(f"✅ Stored {output.model_type} output: {output.predictions['action']} ({output.confidence})")
-
- # Test 2: Retrieve model outputs
- print("\n2. Testing model output retrieval...")
-
- all_outputs = provider.get_model_outputs('ETH/USDT')
- print(f"✅ Retrieved {len(all_outputs)} model outputs for ETH/USDT")
-
- for model_name, output in all_outputs.items():
- print(f" {model_name} ({output.model_type}): {output.predictions['action']} - {output.confidence}")
-
- # Test 3: Test BaseDataInput with cross-model feeding
- print("\n3. Testing BaseDataInput with cross-model predictions...")
-
- # Use real current prices only - no mock data
-
- base_input = provider.get_base_data_input('ETH/USDT')
-
- if base_input:
- print("✅ BaseDataInput created with cross-model predictions!")
- print(f" Symbol: {base_input.symbol}")
- print(f" OHLCV frames: 1s={len(base_input.ohlcv_1s)}, 1m={len(base_input.ohlcv_1m)}, 1h={len(base_input.ohlcv_1h)}, 1d={len(base_input.ohlcv_1d)}")
- print(f" BTC frames: {len(base_input.btc_ohlcv_1s)}")
- print(f" COB data: {'Available' if base_input.cob_data else 'Not available'}")
- print(f" Last predictions: {len(base_input.last_predictions)} models")
-
- # Show cross-model predictions
- for model_name, prediction in base_input.last_predictions.items():
- print(f" {model_name}: {prediction.predictions['action']} ({prediction.confidence})")
-
- # Test feature vector creation
- try:
- feature_vector = base_input.get_feature_vector()
- print(f"✅ Feature vector created: shape {feature_vector.shape}")
- except Exception as e:
- print(f"❌ Feature vector creation failed: {e}")
- else:
- print("⚠️ BaseDataInput creation failed - this may be due to insufficient historical data")
-
- # Test 4: Advanced ModelOutputManager features
- print("\n4. Testing advanced model output manager features...")
-
- output_manager = provider.get_model_output_manager()
-
- # Test consensus prediction
- consensus = output_manager.get_consensus_prediction('ETH/USDT', confidence_threshold=0.7)
- if consensus:
- print(f"✅ Consensus prediction: {consensus['action']} (confidence: {consensus['confidence']:.3f})")
- print(f" Votes: {consensus['votes']}")
- print(f" Contributing models: {consensus['model_types']}")
- else:
- print("⚠️ No consensus reached")
-
- # Test cross-model states
- cross_states = output_manager.get_cross_model_states('ETH/USDT', 'dqn_agent_v2')
- print(f"✅ Cross-model states available for RL model: {len(cross_states)} models")
-
- # Test performance summary
- performance = output_manager.get_performance_summary('ETH/USDT')
- print(f"✅ Performance summary: {performance['active_models']} active models")
-
- # Test 5: Custom model type support
- print("\n5. Testing custom model type extensibility...")
-
- # Add a custom model type
- output_manager.add_custom_model_type('hybrid_lstm_transformer')
-
- # Create and store custom model output
- custom_output = create_model_output(
- model_type='hybrid_lstm_transformer',
- model_name='hybrid_model_v1',
- symbol='ETH/USDT',
- action='BUY',
- confidence=0.89,
- metadata={'hybrid_components': ['lstm', 'transformer'], 'ensemble_weight': 0.6}
- )
-
- provider.store_model_output(custom_output)
- print("✅ Custom model type 'hybrid_lstm_transformer' stored successfully")
-
- # Verify it's included in BaseDataInput
- updated_base_input = provider.get_base_data_input('ETH/USDT')
- if updated_base_input and 'hybrid_model_v1' in updated_base_input.last_predictions:
- print("✅ Custom model output included in BaseDataInput cross-model feeding")
-
- print(f" Total supported model types: {len(output_manager.get_supported_model_types())}")
-
- # Test 6: Historical tracking
- print("\n6. Testing historical output tracking...")
-
- # Store a few more outputs to build history
- for i in range(3):
- historical_output = create_model_output(
- model_type='cnn',
- model_name='enhanced_cnn_v1',
- symbol='ETH/USDT',
- action='HOLD',
- confidence=0.6 + i * 0.05
- )
- provider.store_model_output(historical_output)
-
- history = output_manager.get_output_history('ETH/USDT', 'enhanced_cnn_v1', count=5)
- print(f"✅ Historical tracking: {len(history)} outputs for enhanced_cnn_v1")
-
- # Test 7: Real-time data integration readiness
- print("\n7. Testing real-time integration readiness...")
-
- print("✅ Real-time processing methods available:")
- print(" - start_real_time_processing()")
- print(" - stop_real_time_processing()")
- print(" - COB provider integration ready")
- print(" - Model output persistence enabled")
-
- print("\n✅ Integrated StandardizedDataProvider test completed successfully!")
- print("\n🎯 Key achievements:")
- print("✓ Standardized BaseDataInput format for all models")
- print("✓ Extensible ModelOutput storage (CNN, RL, LSTM, Transformer, Custom)")
- print("✓ Cross-model feeding with last predictions")
- print("✓ COB data integration with moving averages")
- print("✓ Consensus prediction calculation")
- print("✓ Historical output tracking")
- print("✓ Performance analytics")
- print("✓ Thread-safe operations")
- print("✓ Persistent storage capabilities")
-
- print("\n🚀 Ready for model integration:")
- print("1. CNN models can use BaseDataInput and store ModelOutput")
- print("2. RL models can access CNN hidden states via cross-model feeding")
- print("3. Orchestrator can calculate consensus from all models")
- print("4. New model types can be added without code changes")
- print("5. All models receive identical standardized input format")
-
- return provider
-
-if __name__ == "__main__":
- test_integrated_standardized_provider()
\ No newline at end of file
diff --git a/test_leverage_fix.py b/test_leverage_fix.py
deleted file mode 100644
index 93c708c..0000000
--- a/test_leverage_fix.py
+++ /dev/null
@@ -1,75 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Leverage Fix
-
-This script tests if the leverage is now being applied correctly to trade P&L calculations.
-"""
-
-import sys
-import os
-from datetime import datetime
-
-# Add project root to path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from core.trading_executor import TradingExecutor, Position
-
-def test_leverage_fix():
- """Test that leverage is now being applied correctly"""
- print("🧪 Testing Leverage Fix")
- print("=" * 50)
-
- # Create trading executor
- executor = TradingExecutor()
-
- # Check current leverage setting
- current_leverage = executor.get_leverage()
- print(f"Current leverage setting: x{current_leverage}")
-
- # Test leverage in P&L calculation
- position = Position(
- symbol="ETH/USDT",
- side="SHORT",
- quantity=0.005, # 0.005 ETH
- entry_price=3755.33,
- entry_time=datetime.now(),
- order_id="test_123"
- )
-
- # Test P&L calculation with current price
- current_price = 3740.51 # Price went down, should be profitable for SHORT
-
- # Calculate P&L with leverage
- pnl_with_leverage = position.calculate_pnl(current_price, leverage=current_leverage)
- pnl_without_leverage = position.calculate_pnl(current_price, leverage=1.0)
-
- print(f"\nPosition: SHORT 0.005 ETH @ $3755.33")
- print(f"Current price: $3740.51")
- print(f"Price difference: ${3755.33 - 3740.51:.2f} (favorable for SHORT)")
-
- print(f"\nP&L without leverage (x1): ${pnl_without_leverage:.2f}")
- print(f"P&L with leverage (x{current_leverage}): ${pnl_with_leverage:.2f}")
- print(f"Leverage multiplier effect: {pnl_with_leverage / pnl_without_leverage:.1f}x")
-
- # Expected calculation
- position_value = 0.005 * 3755.33 # ~$18.78
- price_diff = 3755.33 - 3740.51 # $14.82 favorable
- raw_pnl = price_diff * 0.005 # ~$0.074
- leveraged_pnl = raw_pnl * current_leverage # ~$3.70
-
- print(f"\nExpected calculation:")
- print(f"Position value: ${position_value:.2f}")
- print(f"Raw P&L: ${raw_pnl:.3f}")
- print(f"Leveraged P&L (before fees): ${leveraged_pnl:.2f}")
-
- # Check if the calculation is correct
- if abs(pnl_with_leverage - leveraged_pnl) < 0.1: # Allow for small fee differences
- print("✅ Leverage calculation appears correct!")
- else:
- print("❌ Leverage calculation may have issues")
-
- print("\n" + "=" * 50)
- print("Test completed. Check if new trades show leveraged P&L in dashboard.")
-
-if __name__ == "__main__":
- test_leverage_fix()
\ No newline at end of file
diff --git a/test_massive_dqn.py b/test_massive_dqn.py
deleted file mode 100644
index 3d03c69..0000000
--- a/test_massive_dqn.py
+++ /dev/null
@@ -1,232 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script for the massive 50M parameter DQN agent
-Tests:
-1. Model initialization and parameter count
-2. Forward pass functionality
-3. Gradient flow verification
-4. Training step simulation
-"""
-
-import sys
-import os
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-import torch
-import numpy as np
-from NN.models.dqn_agent import DQNAgent, DQNNetwork
-import logging
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_dqn_architecture():
- """Test the new massive DQN architecture"""
- print("🔥 Testing Massive DQN Architecture (Target: 50M parameters)")
-
- # Test the network directly first
- input_dim = 7850 # BaseDataInput feature size
- n_actions = 3 # BUY, SELL, HOLD
-
- print(f"\n1. Creating DQN Network with input_dim={input_dim}, n_actions={n_actions}")
- network = DQNNetwork(input_dim, n_actions)
-
- # Count parameters
- total_params = sum(p.numel() for p in network.parameters())
- print(f" ✅ Total parameters: {total_params:,}")
- print(f" 🎯 Target achieved: {total_params >= 50_000_000}")
-
- # Test forward pass
- print(f"\n2. Testing forward pass...")
- batch_size = 4
- test_input = torch.randn(batch_size, input_dim)
-
- with torch.no_grad():
- output = network(test_input)
-
- if isinstance(output, tuple):
- q_values, regime_pred, price_pred, volatility_pred, features = output
- print(f" ✅ Q-values shape: {q_values.shape}")
- print(f" ✅ Regime prediction shape: {regime_pred.shape}")
- print(f" ✅ Price prediction shape: {price_pred.shape}")
- print(f" ✅ Volatility prediction shape: {volatility_pred.shape}")
- print(f" ✅ Features shape: {features.shape}")
- else:
- print(f" ✅ Output shape: {output.shape}")
-
- return network
-
-def test_gradient_flow():
- """Test that gradients flow properly through the network"""
- print(f"\n🧪 Testing Gradient Flow...")
-
- # Create agent
- state_shape = (7850,)
- agent = DQNAgent(
- state_shape=state_shape,
- n_actions=3,
- learning_rate=0.001,
- batch_size=16,
- buffer_size=1000
- )
-
- # Force disable mixed precision
- agent.use_mixed_precision = False
- print(f" ✅ Mixed precision disabled: {not agent.use_mixed_precision}")
-
- # Ensure model is in training mode
- agent.policy_net.train()
- print(f" ✅ Model in training mode: {agent.policy_net.training}")
-
- # Create test batch
- batch_size = 8
- state_dim = 7850
-
- states = torch.randn(batch_size, state_dim, requires_grad=True)
- actions = torch.randint(0, 3, (batch_size,))
- rewards = torch.randn(batch_size)
- next_states = torch.randn(batch_size, state_dim)
- dones = torch.zeros(batch_size)
-
- print(f" 📊 Test batch created - states: {states.shape}, actions: {actions.shape}")
-
- # Test forward pass and check gradients
- agent.optimizer.zero_grad()
-
- # Forward pass
- output = agent.policy_net(states)
- if isinstance(output, tuple):
- q_values = output[0]
- else:
- q_values = output
-
- print(f" ✅ Forward pass successful - Q-values: {q_values.shape}")
- print(f" ✅ Q-values require grad: {q_values.requires_grad}")
-
- # Gather Q-values for actions
- current_q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
- print(f" ✅ Gathered Q-values require grad: {current_q_values.requires_grad}")
-
- # Compute simple loss
- target_q_values = rewards # Simplified target
- loss = torch.nn.MSELoss()(current_q_values, target_q_values)
- print(f" ✅ Loss computed: {loss.item():.6f}")
- print(f" ✅ Loss requires grad: {loss.requires_grad}")
-
- # Backward pass
- loss.backward()
-
- # Check if gradients exist and are finite
- grad_norms = []
- params_with_grad = 0
- total_params = 0
-
- for name, param in agent.policy_net.named_parameters():
- total_params += 1
- if param.grad is not None:
- params_with_grad += 1
- grad_norm = param.grad.norm().item()
- grad_norms.append(grad_norm)
- if not torch.isfinite(param.grad).all():
- print(f" ❌ Non-finite gradients in {name}")
- return False
-
- print(f" ✅ Parameters with gradients: {params_with_grad}/{total_params}")
- print(f" ✅ Average gradient norm: {np.mean(grad_norms):.6f}")
- print(f" ✅ Max gradient norm: {max(grad_norms):.6f}")
-
- # Test optimizer step
- agent.optimizer.step()
- print(f" ✅ Optimizer step completed successfully")
-
- return True
-
-def test_training_step():
- """Test a complete training step"""
- print(f"\n🏋️ Testing Complete Training Step...")
-
- # Create agent
- state_shape = (7850,)
- agent = DQNAgent(
- state_shape=state_shape,
- n_actions=3,
- learning_rate=0.001,
- batch_size=8,
- buffer_size=1000
- )
-
- # Force disable mixed precision
- agent.use_mixed_precision = False
-
- # Add some experiences
- for i in range(20):
- state = np.random.randn(7850).astype(np.float32)
- action = np.random.randint(0, 3)
- reward = np.random.randn() * 0.1
- next_state = np.random.randn(7850).astype(np.float32)
- done = np.random.random() < 0.1
-
- agent.remember(state, action, reward, next_state, done)
-
- print(f" ✅ Added {len(agent.memory)} experiences to memory")
-
- # Test replay training
- if len(agent.memory) >= agent.batch_size:
- loss = agent.replay()
- print(f" ✅ Training completed with loss: {loss:.6f}")
-
- if loss > 0:
- print(f" ✅ Training successful - non-zero loss indicates learning")
- return True
- else:
- print(f" ❌ Training failed - zero loss indicates gradient issues")
- return False
- else:
- print(f" ⚠️ Not enough experiences for training")
- return True
-
-def main():
- """Run all tests"""
- print("🚀 MASSIVE DQN AGENT TESTING SUITE")
- print("=" * 50)
-
- # Test 1: Architecture
- try:
- network = test_dqn_architecture()
- print(" ✅ Architecture test PASSED")
- except Exception as e:
- print(f" ❌ Architecture test FAILED: {e}")
- return False
-
- # Test 2: Gradient flow
- try:
- gradient_success = test_gradient_flow()
- if gradient_success:
- print(" ✅ Gradient flow test PASSED")
- else:
- print(" ❌ Gradient flow test FAILED")
- return False
- except Exception as e:
- print(f" ❌ Gradient flow test FAILED: {e}")
- return False
-
- # Test 3: Training step
- try:
- training_success = test_training_step()
- if training_success:
- print(" ✅ Training step test PASSED")
- else:
- print(" ❌ Training step test FAILED")
- return False
- except Exception as e:
- print(f" ❌ Training step test FAILED: {e}")
- return False
-
- print("\n🎉 ALL TESTS PASSED!")
- print("✅ Massive DQN agent is ready for 50M parameter learning!")
- return True
-
-if __name__ == "__main__":
- success = main()
- exit(0 if success else 1)
\ No newline at end of file
diff --git a/test_mexc_order_fix.py b/test_mexc_order_fix.py
deleted file mode 100644
index 6c3f1b6..0000000
--- a/test_mexc_order_fix.py
+++ /dev/null
@@ -1,174 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test MEXC Order Fix
-
-Tests the fixed MEXC interface to ensure order execution works correctly
-"""
-
-import os
-import sys
-import logging
-from pathlib import Path
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-
-logger = logging.getLogger(__name__)
-
-def test_mexc_order_fix():
- """Test the fixed MEXC interface"""
- print("Testing Fixed MEXC Interface")
- print("=" * 50)
-
- # Import after path setup
- try:
- from core.exchanges.mexc_interface import MEXCInterface
- except ImportError as e:
- print(f"❌ Import error: {e}")
- return False
-
- # Get API credentials
- api_key = os.getenv('MEXC_API_KEY', '')
- api_secret = os.getenv('MEXC_SECRET_KEY', '')
-
- if not api_key or not api_secret:
- print("❌ No MEXC API credentials found")
- print("Set MEXC_API_KEY and MEXC_SECRET_KEY environment variables")
- return False
-
- # Initialize MEXC interface
- mexc = MEXCInterface(
- api_key=api_key,
- api_secret=api_secret,
- test_mode=False, # Use live API (MEXC doesn't have testnet)
- trading_mode='live'
- )
-
- # Test 1: Connection
- print("\n1. Testing connection...")
- if mexc.connect():
- print("✅ Connection successful")
- else:
- print("❌ Connection failed")
- return False
-
- # Test 2: Account info
- print("\n2. Testing account info...")
- account_info = mexc.get_account_info()
- if account_info:
- print("✅ Account info retrieved")
- print(f"Account type: {account_info.get('accountType', 'N/A')}")
- else:
- print("❌ Failed to get account info")
- return False
-
- # Test 3: Balance check
- print("\n3. Testing balance retrieval...")
- usdc_balance = mexc.get_balance('USDC')
- usdt_balance = mexc.get_balance('USDT')
- print(f"USDC balance: {usdc_balance}")
- print(f"USDT balance: {usdt_balance}")
-
- if usdc_balance <= 0 and usdt_balance <= 0:
- print("❌ No USDC or USDT balance for testing")
- return False
-
- # Test 4: Symbol support check
- print("\n4. Testing symbol support...")
- symbol = 'ETH/USDT' # Will be converted to ETHUSDC internally
- formatted_symbol = mexc._format_spot_symbol(symbol)
- print(f"Symbol {symbol} formatted to: {formatted_symbol}")
-
- if mexc.is_symbol_supported(symbol):
- print(f"✅ Symbol {formatted_symbol} is supported")
- else:
- print(f"❌ Symbol {formatted_symbol} is not supported")
- print("Checking supported symbols...")
- supported = mexc.get_api_symbols()
- print(f"Found {len(supported)} supported symbols")
- if 'ETHUSDC' in supported:
- print("✅ ETHUSDC is in supported list")
- else:
- print("❌ ETHUSDC not in supported list")
-
- # Test 5: Get ticker
- print("\n5. Testing ticker retrieval...")
- ticker = mexc.get_ticker(symbol)
- if ticker:
- print(f"✅ Ticker retrieved for {symbol}")
- print(f"Last price: ${ticker['last']:.2f}")
- print(f"Bid: ${ticker['bid']:.2f}, Ask: ${ticker['ask']:.2f}")
- else:
- print(f"❌ Failed to get ticker for {symbol}")
- return False
-
- # Test 6: Small test order (only if balance available)
- print("\n6. Testing small order placement...")
- if usdc_balance >= 10.0: # Need at least $10 for minimum order
- try:
- # Calculate small test quantity
- test_price = ticker['last'] * 1.01 # 1% above market for quick execution
- test_quantity = round(10.0 / test_price, 5) # $10 worth
-
- print(f"Attempting to place test order:")
- print(f"- Symbol: {symbol} -> {formatted_symbol}")
- print(f"- Side: BUY")
- print(f"- Type: LIMIT")
- print(f"- Quantity: {test_quantity}")
- print(f"- Price: ${test_price:.2f}")
-
- # Note: This is a real order that will use real funds!
- confirm = input("⚠️ This will place a REAL order with REAL funds! Continue? (yes/no): ")
- if confirm.lower() != 'yes':
- print("❌ Order test skipped by user")
- return True
-
- order_result = mexc.place_order(
- symbol=symbol,
- side='BUY',
- order_type='LIMIT',
- quantity=test_quantity,
- price=test_price
- )
-
- if order_result:
- print("✅ Order placed successfully!")
- print(f"Order ID: {order_result.get('orderId')}")
- print(f"Order result: {order_result}")
-
- # Try to cancel the order immediately
- order_id = order_result.get('orderId')
- if order_id:
- print(f"\n7. Testing order cancellation...")
- cancel_result = mexc.cancel_order(symbol, str(order_id))
- if cancel_result:
- print("✅ Order cancelled successfully")
- else:
- print("❌ Failed to cancel order")
- print("⚠️ You may have an open order to manually cancel")
- else:
- print("❌ Order placement failed")
- return False
-
- except Exception as e:
- print(f"❌ Order test failed with exception: {e}")
- return False
- else:
- print(f"⚠️ Insufficient balance for order test (need $10+, have ${usdc_balance:.2f} USDC)")
- print("✅ All other tests passed - order API should work when balance is sufficient")
-
- print("\n" + "=" * 50)
- print("✅ MEXC Interface Test Completed Successfully!")
- print("✅ Order execution should now work correctly")
- return True
-
-if __name__ == "__main__":
- success = test_mexc_order_fix()
- sys.exit(0 if success else 1)
\ No newline at end of file
diff --git a/test_model_output_manager.py b/test_model_output_manager.py
deleted file mode 100644
index 85f6048..0000000
--- a/test_model_output_manager.py
+++ /dev/null
@@ -1,182 +0,0 @@
-"""
-Test script for ModelOutputManager
-
-This script tests the extensible model output storage functionality
-"""
-
-import sys
-import os
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-import logging
-from datetime import datetime
-from core.model_output_manager import ModelOutputManager
-from core.data_models import create_model_output, ModelOutput
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_model_output_manager():
- """Test the ModelOutputManager functionality"""
-
- print("Testing ModelOutputManager...")
-
- # Initialize the manager
- manager = ModelOutputManager(cache_dir="test_cache/model_outputs", max_history=100)
-
- print(f"✅ ModelOutputManager initialized")
- print(f" Supported model types: {manager.get_supported_model_types()}")
-
- # Test 1: Store outputs from different model types
- print("\n1. Testing model output storage...")
-
- # Create outputs from different model types
- models_to_test = [
- ('cnn', 'enhanced_cnn_v1', 'BUY', 0.85),
- ('rl', 'dqn_agent_v2', 'SELL', 0.72),
- ('lstm', 'lstm_predictor', 'HOLD', 0.65),
- ('transformer', 'transformer_v1', 'BUY', 0.91),
- ('orchestrator', 'main_orchestrator', 'BUY', 0.78)
- ]
-
- symbol = 'ETH/USDT'
- stored_outputs = []
-
- for model_type, model_name, action, confidence in models_to_test:
- # Create model output with hidden states for cross-model feeding
- hidden_states = {
- 'layer_1': [0.1, 0.2, 0.3],
- 'layer_2': [0.4, 0.5, 0.6],
- 'attention_weights': [0.7, 0.8, 0.9]
- } if model_type in ['cnn', 'transformer'] else None
-
- metadata = {
- 'model_version': '1.0',
- 'training_iterations': 1000,
- 'last_updated': datetime.now().isoformat()
- }
-
- model_output = create_model_output(
- model_type=model_type,
- model_name=model_name,
- symbol=symbol,
- action=action,
- confidence=confidence,
- hidden_states=hidden_states,
- metadata=metadata
- )
-
- # Store the output
- success = manager.store_output(model_output)
- if success:
- print(f"✅ Stored {model_type} output: {action} ({confidence})")
- stored_outputs.append(model_output)
- else:
- print(f"❌ Failed to store {model_type} output")
-
- # Test 2: Retrieve current outputs
- print("\n2. Testing output retrieval...")
-
- all_current = manager.get_all_current_outputs(symbol)
- print(f"✅ Retrieved {len(all_current)} current outputs for {symbol}")
-
- for model_name, output in all_current.items():
- print(f" {model_name} ({output.model_type}): {output.predictions['action']} - {output.confidence}")
-
- # Test 3: Cross-model feeding
- print("\n3. Testing cross-model feeding...")
-
- cross_model_states = manager.get_cross_model_states(symbol, 'dqn_agent_v2')
- print(f"✅ Retrieved cross-model states for RL model: {len(cross_model_states)} models")
-
- for model_name, states in cross_model_states.items():
- if states:
- print(f" {model_name}: {len(states)} hidden state layers")
-
- # Test 4: Consensus prediction
- print("\n4. Testing consensus prediction...")
-
- consensus = manager.get_consensus_prediction(symbol, confidence_threshold=0.7)
- if consensus:
- print(f"✅ Consensus prediction: {consensus['action']} (confidence: {consensus['confidence']:.3f})")
- print(f" Votes: {consensus['votes']}")
- print(f" Models: {consensus['model_types']}")
- else:
- print("⚠️ No consensus reached (insufficient high-confidence predictions)")
-
- # Test 5: Performance summary
- print("\n5. Testing performance tracking...")
-
- performance = manager.get_performance_summary(symbol)
- print(f"✅ Performance summary for {symbol}:")
- print(f" Active models: {performance['active_models']}")
-
- for model_name, stats in performance['model_stats'].items():
- print(f" {model_name} ({stats['model_type']}): {stats['predictions']} predictions, "
- f"avg confidence: {stats['avg_confidence']}")
-
- # Test 6: Custom model type support
- print("\n6. Testing custom model type support...")
-
- # Add a custom model type
- manager.add_custom_model_type('hybrid_ensemble')
-
- # Create output with custom model type
- custom_output = create_model_output(
- model_type='hybrid_ensemble',
- model_name='custom_ensemble_v1',
- symbol=symbol,
- action='BUY',
- confidence=0.88,
- metadata={'ensemble_size': 5, 'voting_method': 'weighted'}
- )
-
- success = manager.store_output(custom_output)
- if success:
- print("✅ Custom model type 'hybrid_ensemble' stored successfully")
- else:
- print("❌ Failed to store custom model type")
-
- print(f" Updated supported types: {len(manager.get_supported_model_types())} types")
-
- # Test 7: Historical outputs
- print("\n7. Testing historical output tracking...")
-
- # Store a few more outputs to build history
- for i in range(3):
- historical_output = create_model_output(
- model_type='cnn',
- model_name='enhanced_cnn_v1',
- symbol=symbol,
- action='HOLD',
- confidence=0.6 + i * 0.1
- )
- manager.store_output(historical_output)
-
- history = manager.get_output_history(symbol, 'enhanced_cnn_v1', count=5)
- print(f"✅ Retrieved {len(history)} historical outputs for enhanced_cnn_v1")
-
- for i, output in enumerate(history):
- print(f" {i+1}. {output.predictions['action']} ({output.confidence}) at {output.timestamp}")
-
- # Test 8: Active model types
- print("\n8. Testing active model type detection...")
-
- active_types = manager.get_model_types_active(symbol)
- print(f"✅ Active model types for {symbol}: {active_types}")
-
- print("\n✅ ModelOutputManager test completed successfully!")
- print("\nKey features verified:")
- print("✓ Extensible model type support (CNN, RL, LSTM, Transformer, Custom)")
- print("✓ Cross-model feeding with hidden states")
- print("✓ Historical output tracking")
- print("✓ Performance analytics")
- print("✓ Consensus prediction calculation")
- print("✓ Metadata management")
- print("✓ Thread-safe storage operations")
-
- return manager
-
-if __name__ == "__main__":
- test_model_output_manager()
\ No newline at end of file
diff --git a/test_model_registry.py b/test_model_registry.py
deleted file mode 100644
index b3236bd..0000000
--- a/test_model_registry.py
+++ /dev/null
@@ -1,52 +0,0 @@
-#!/usr/bin/env python3
-import logging
-import sys
-import os
-
-# Add the project root to the path
-sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
-
-# Configure logging
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-def test_model_registry():
- """Test the model registry state"""
- try:
- from core.orchestrator import TradingOrchestrator
- from core.data_provider import DataProvider
-
- logger.info("Testing model registry...")
-
- # Initialize data provider
- data_provider = DataProvider()
-
- # Initialize orchestrator
- orchestrator = TradingOrchestrator(data_provider=data_provider)
-
- # Check model registry state
- logger.info(f"Model registry models: {len(orchestrator.model_registry.models)}")
- logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}")
-
- # Check if models were created
- logger.info(f"RL Agent: {orchestrator.rl_agent is not None}")
- logger.info(f"CNN Model: {orchestrator.cnn_model is not None}")
- logger.info(f"CNN Adapter: {orchestrator.cnn_adapter is not None}")
-
- # Check model weights
- logger.info(f"Model weights: {orchestrator.model_weights}")
-
- return True
-
- except Exception as e:
- logger.error(f"Error testing model registry: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-if __name__ == "__main__":
- success = test_model_registry()
- if success:
- logger.info("✅ Model registry test completed successfully")
- else:
- logger.error("❌ Model registry test failed")
\ No newline at end of file
diff --git a/test_model_statistics.py b/test_model_statistics.py
deleted file mode 100644
index 2480c6c..0000000
--- a/test_model_statistics.py
+++ /dev/null
@@ -1,58 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Model Statistics Implementation
-
-This script tests the new model statistics tracking functionality.
-"""
-
-import asyncio
-import time
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-
-async def test_model_statistics():
- """Test the model statistics tracking"""
- print("=== Testing Model Statistics ===")
-
- # Initialize orchestrator
- print("1. Initializing orchestrator...")
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider=data_provider)
-
- # Wait for initialization
- await asyncio.sleep(2)
-
- # Test initial statistics
- print("\n2. Initial model statistics:")
- orchestrator.log_model_statistics()
-
- # Run some predictions to generate statistics
- print("\n3. Running predictions to generate statistics...")
- for i in range(5):
- print(f" Running prediction batch {i+1}/5...")
- predictions = await orchestrator._get_all_predictions('ETH/USDT')
- print(f" Got {len(predictions)} predictions")
- await asyncio.sleep(1) # Small delay between batches
-
- # Show updated statistics
- print("\n4. Updated model statistics:")
- orchestrator.log_model_statistics(detailed=True)
-
- # Test statistics summary
- print("\n5. Statistics summary (JSON format):")
- summary = orchestrator.get_model_statistics_summary()
- for model_name, stats in summary.items():
- print(f" {model_name}: {stats}")
-
- # Test individual model statistics
- print("\n6. Individual model statistics:")
- for model_name in orchestrator.model_statistics.keys():
- stats = orchestrator.get_model_statistics(model_name)
- if stats:
- print(f" {model_name}: {stats.total_inferences} inferences, "
- f"rate={stats.inference_rate_per_minute:.1f}/min")
-
- print("\n✅ Model statistics test completed successfully!")
-
-if __name__ == "__main__":
- asyncio.run(test_model_statistics())
\ No newline at end of file
diff --git a/test_model_stats.py b/test_model_stats.py
deleted file mode 100644
index d9e1941..0000000
--- a/test_model_stats.py
+++ /dev/null
@@ -1,55 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script to verify model stats functionality
-"""
-
-import sys
-import os
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-import logging
-from core.orchestrator import TradingOrchestrator
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_model_stats():
- """Test the model stats functionality"""
- try:
- logger.info("Testing model stats functionality...")
-
- # Create orchestrator instance (this will initialize model states)
- orchestrator = TradingOrchestrator()
-
- # Sync with dashboard values
- orchestrator.sync_model_states_with_dashboard()
-
- # Get current model stats
- stats = orchestrator.get_model_training_stats()
-
- logger.info("Current model training stats:")
- for model_name, model_stats in stats.items():
- if model_stats['current_loss'] is not None:
- logger.info(f" {model_name.upper()}: {model_stats['current_loss']:.4f} loss, {model_stats['improvement_pct']:.1f}% improvement")
- else:
- logger.info(f" {model_name.upper()}: No training data yet")
-
- # Test updating a model loss
- orchestrator.update_model_loss('cnn', 0.0001)
- logger.info("Updated CNN loss to 0.0001")
-
- # Get updated stats
- updated_stats = orchestrator.get_model_training_stats()
- cnn_stats = updated_stats['cnn']
- logger.info(f"CNN updated: {cnn_stats['current_loss']:.4f} loss, {cnn_stats['improvement_pct']:.1f}% improvement")
-
- return True
-
- except Exception as e:
- logger.error(f"❌ Model stats test failed: {e}")
- return False
-
-if __name__ == "__main__":
- success = test_model_stats()
- sys.exit(0 if success else 1)
\ No newline at end of file
diff --git a/test_model_training.py b/test_model_training.py
deleted file mode 100644
index f0488e0..0000000
--- a/test_model_training.py
+++ /dev/null
@@ -1,139 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Model Training Implementation
-
-This script tests the improved model training functionality.
-"""
-
-import asyncio
-import time
-import numpy as np
-from datetime import datetime
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-
-async def test_model_training():
- """Test the improved model training system"""
- print("=== Testing Model Training System ===")
-
- # Initialize orchestrator
- print("1. Initializing orchestrator...")
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider=data_provider)
-
- # Wait for initialization
- await asyncio.sleep(3)
-
- # Show initial model statistics
- print("\n2. Initial model statistics:")
- orchestrator.log_model_statistics()
-
- # Run predictions to generate training data
- print("\n3. Running predictions to generate training data...")
- predictions_data = []
-
- for i in range(3):
- print(f" Running prediction batch {i+1}/3...")
- predictions = await orchestrator._get_all_predictions('ETH/USDT')
- print(f" Got {len(predictions)} predictions")
-
- # Store prediction data for training simulation
- for pred in predictions:
- predictions_data.append({
- 'model_name': pred.model_name,
- 'prediction': {
- 'action': pred.action,
- 'confidence': pred.confidence
- },
- 'timestamp': pred.timestamp,
- 'symbol': 'ETH/USDT'
- })
-
- await asyncio.sleep(1)
-
- print(f"\n4. Collected {len(predictions_data)} predictions for training")
-
- # Simulate training with different outcomes
- print("\n5. Testing training with simulated outcomes...")
-
- for i, pred_data in enumerate(predictions_data[:6]): # Test first 6 predictions
- # Simulate market outcome
- was_correct = i % 2 == 0 # Alternate between correct and incorrect
- price_change_pct = 0.5 if was_correct else -0.3
- sophisticated_reward = 1.0 if was_correct else -0.5
-
- # Create training record
- training_record = {
- 'model_name': pred_data['model_name'],
- 'model_input': np.random.randn(7850), # Simulate model input
- 'prediction': pred_data['prediction'],
- 'symbol': pred_data['symbol'],
- 'timestamp': pred_data['timestamp']
- }
-
- print(f" Training {pred_data['model_name']}: "
- f"action={pred_data['prediction']['action']}, "
- f"correct={was_correct}, reward={sophisticated_reward}")
-
- # Test the training method
- try:
- await orchestrator._train_model_on_outcome(
- training_record, was_correct, price_change_pct, sophisticated_reward
- )
- print(f" ✅ Training completed for {pred_data['model_name']}")
- except Exception as e:
- print(f" ❌ Training failed for {pred_data['model_name']}: {e}")
-
- # Show updated statistics
- print("\n6. Updated model statistics after training:")
- orchestrator.log_model_statistics(detailed=True)
-
- # Test specific model training methods
- print("\n7. Testing specific model training methods...")
-
- # Test DQN training
- if 'dqn_agent' in orchestrator.model_statistics:
- print(" Testing DQN agent training...")
- dqn_record = {
- 'model_name': 'dqn_agent',
- 'model_input': np.random.randn(7850),
- 'prediction': {'action': 'BUY', 'confidence': 0.8},
- 'symbol': 'ETH/USDT',
- 'timestamp': datetime.now()
- }
- try:
- await orchestrator._train_model_on_outcome(dqn_record, True, 0.5, 1.0)
- print(" ✅ DQN training test passed")
- except Exception as e:
- print(f" ❌ DQN training test failed: {e}")
-
- # Test CNN training
- if 'enhanced_cnn' in orchestrator.model_statistics:
- print(" Testing CNN model training...")
- cnn_record = {
- 'model_name': 'enhanced_cnn',
- 'model_input': np.random.randn(7850),
- 'prediction': {'action': 'SELL', 'confidence': 0.6},
- 'symbol': 'ETH/USDT',
- 'timestamp': datetime.now()
- }
- try:
- await orchestrator._train_model_on_outcome(cnn_record, False, -0.3, -0.5)
- print(" ✅ CNN training test passed")
- except Exception as e:
- print(f" ❌ CNN training test failed: {e}")
-
- # Show final statistics
- print("\n8. Final model statistics:")
- summary = orchestrator.get_model_statistics_summary()
- for model_name, stats in summary.items():
- print(f" {model_name}:")
- print(f" Inferences: {stats['total_inferences']}")
- print(f" Rate: {stats['inference_rate_per_minute']:.1f}/min")
- print(f" Current loss: {stats['current_loss']}")
- print(f" Last prediction: {stats['last_prediction']} ({stats['last_confidence']})")
-
- print("\n✅ Model training test completed!")
-
-if __name__ == "__main__":
- asyncio.run(test_model_training())
\ No newline at end of file
diff --git a/test_orchestrator_fix.py b/test_orchestrator_fix.py
deleted file mode 100644
index 6b23b82..0000000
--- a/test_orchestrator_fix.py
+++ /dev/null
@@ -1,52 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script to verify orchestrator fix
-"""
-
-import logging
-import os
-
-# Fix OpenMP library conflicts
-os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
-os.environ['OMP_NUM_THREADS'] = '4'
-
-# Fix matplotlib backend
-import matplotlib
-matplotlib.use('Agg')
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_orchestrator():
- """Test orchestrator initialization"""
- try:
- logger.info("Testing orchestrator initialization...")
-
- # Import required modules
- from core.standardized_data_provider import StandardizedDataProvider
- from core.orchestrator import TradingOrchestrator
-
- logger.info("Imports successful")
-
- # Create data provider
- data_provider = StandardizedDataProvider()
- logger.info("StandardizedDataProvider created")
-
- # Create orchestrator
- orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
- logger.info("TradingOrchestrator created successfully!")
-
- # Test basic functionality
- status = orchestrator.get_queue_status()
- logger.info(f"Queue status: {status}")
-
- logger.info("✅ Orchestrator test completed successfully!")
-
- except Exception as e:
- logger.error(f"❌ Orchestrator test failed: {e}")
- import traceback
- traceback.print_exc()
-
-if __name__ == "__main__":
- test_orchestrator()
\ No newline at end of file
diff --git a/test_order_sync_and_fees.py b/test_order_sync_and_fees.py
deleted file mode 100644
index 0323b3c..0000000
--- a/test_order_sync_and_fees.py
+++ /dev/null
@@ -1,122 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Open Order Sync and Fee Calculation
-Verify that open orders are properly synchronized and fees are correctly calculated in PnL
-"""
-
-import os
-import sys
-import logging
-
-# Add the project root to the path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-# Load environment variables
-try:
- from dotenv import load_dotenv
- load_dotenv()
-except ImportError:
- if os.path.exists('.env'):
- with open('.env', 'r') as f:
- for line in f:
- if line.strip() and not line.startswith('#'):
- key, value = line.strip().split('=', 1)
- os.environ[key] = value
-
-from core.trading_executor import TradingExecutor
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-def test_open_order_sync_and_fees():
- """Test open order synchronization and fee calculation"""
- print("🧪 Testing Open Order Sync and Fee Calculation...")
- print("=" * 70)
-
- try:
- # Create trading executor
- executor = TradingExecutor()
-
- print(f"📊 Current State Analysis:")
- print(f" Open orders count: {executor._get_open_orders_count()}")
- print(f" Max open orders: {executor.max_open_orders}")
- print(f" Can place new order: {executor._can_place_new_order()}")
-
- # Test open order synchronization
- print(f"\n🔍 Open Order Sync Analysis:")
- print(f" - Current sync method: _get_open_orders_count()")
- print(f" - Counts orders across all symbols")
- print(f" - Real-time API queries")
- print(f" - Handles API errors gracefully")
-
- # Check if there's a dedicated sync method
- if hasattr(executor, 'sync_open_orders'):
- print(f" ✅ Dedicated sync method exists")
- else:
- print(f" ⚠️ No dedicated sync method - using count method")
-
- # Test fee calculation in PnL
- print(f"\n💰 Fee Calculation Analysis:")
-
- # Check fee calculation methods
- if hasattr(executor, '_calculate_trading_fee'):
- print(f" ✅ Fee calculation method exists")
- else:
- print(f" ❌ No dedicated fee calculation method")
-
- # Check if fees are included in PnL
- print(f"\n📈 PnL Fee Integration:")
- print(f" - TradeRecord includes fees field")
- print(f" - PnL calculation: pnl = gross_pnl - fees")
- print(f" - Fee rates from config: taker_fee, maker_fee")
-
- # Check fee sync
- print(f"\n🔄 Fee Synchronization:")
- if hasattr(executor, 'sync_fees_with_api'):
- print(f" ✅ Fee sync method exists")
- else:
- print(f" ❌ No fee sync method")
-
- # Check config sync
- if hasattr(executor, 'config_sync'):
- print(f" ✅ Config synchronizer exists")
- else:
- print(f" ❌ No config synchronizer")
-
- print(f"\n📋 Issues Found:")
-
- # Issue 1: No dedicated open order sync method
- if not hasattr(executor, 'sync_open_orders'):
- print(f" ❌ Missing: Dedicated open order synchronization method")
- print(f" Current: Only counts orders, doesn't sync state")
-
- # Issue 2: Fee calculation may not be comprehensive
- print(f" ⚠️ Potential: Fee calculation uses simulated rates")
- print(f" Should: Use actual API fees when available")
-
- # Issue 3: Check if fees are properly tracked
- print(f" ✅ Good: Fees are tracked in TradeRecord")
- print(f" ✅ Good: PnL includes fee deduction")
-
- print(f"\n🔧 Recommended Fixes:")
- print(f" 1. Add dedicated open order sync method")
- print(f" 2. Enhance fee calculation with real API data")
- print(f" 3. Add periodic order state synchronization")
- print(f" 4. Improve fee tracking accuracy")
-
- return True
-
- except Exception as e:
- print(f"❌ Error testing order sync and fees: {e}")
- return False
-
-if __name__ == "__main__":
- success = test_open_order_sync_and_fees()
- if success:
- print(f"\n🎉 Order sync and fee test completed!")
- else:
- print(f"\n💥 Order sync and fee test failed!")
\ No newline at end of file
diff --git a/test_position_based_rewards.py b/test_position_based_rewards.py
deleted file mode 100644
index 316da43..0000000
--- a/test_position_based_rewards.py
+++ /dev/null
@@ -1,159 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script for position-based reward system
-
-This script tests the enhanced reward calculations that incentivize:
-1. Holding profitable positions (let winners run)
-2. Closing losing positions (cut losses)
-3. Taking action when appropriate based on P&L
-"""
-
-import sys
-import os
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from core.orchestrator import TradingOrchestrator
-from NN.models.enhanced_cnn import EnhancedCNN
-import numpy as np
-
-def test_position_reward_scenarios():
- """Test various position-based reward scenarios"""
-
- print("🧪 POSITION-BASED REWARD SYSTEM TEST")
- print("=" * 50)
-
- # Initialize orchestrator
- orchestrator = TradingOrchestrator()
-
- # Test scenarios
- scenarios = [
- # (action, position_pnl, has_position, price_change_pct, description)
- ("HOLD", 50.0, True, 0.5, "Hold profitable position with continued gains"),
- ("HOLD", 50.0, True, -0.3, "Hold profitable position with small pullback"),
- ("HOLD", -30.0, True, 0.8, "Hold losing position that recovers"),
- ("HOLD", -30.0, True, -0.5, "Hold losing position that continues down"),
- ("SELL", 50.0, True, 0.0, "Close profitable position"),
- ("SELL", -30.0, True, 0.0, "Close losing position (good)"),
- ("BUY", 0.0, False, 1.0, "New buy position with immediate gain"),
- ("HOLD", 0.0, False, 0.1, "Hold with no position (stable market)"),
- ]
-
- print("\n📊 SOPHISTICATED REWARD CALCULATION TESTS:")
- print("-" * 80)
-
- for i, (action, position_pnl, has_position, price_change_pct, description) in enumerate(scenarios, 1):
- # Test sophisticated reward calculation
- reward, was_correct = orchestrator._calculate_sophisticated_reward(
- predicted_action=action,
- prediction_confidence=0.8,
- price_change_pct=price_change_pct,
- time_diff_minutes=5.0,
- has_price_prediction=False,
- symbol="ETH/USDT",
- has_position=has_position,
- current_position_pnl=position_pnl
- )
-
- print(f"{i:2d}. {description}")
- print(f" Action: {action}, P&L: ${position_pnl:+.1f}, Price Change: {price_change_pct:+.1f}%")
- print(f" Reward: {reward:+.3f}, Correct: {was_correct}")
- print()
-
- print("\n🧠 CNN POSITION-ENHANCED REWARD TESTS:")
- print("-" * 80)
-
- # Initialize CNN model
- cnn_model = EnhancedCNN(input_shape=100, n_actions=3)
-
- for i, (action, position_pnl, has_position, _, description) in enumerate(scenarios, 1):
- base_reward = 0.5 # Moderate base reward
- enhanced_reward = cnn_model._calculate_position_enhanced_reward(
- base_reward=base_reward,
- action=action,
- position_pnl=position_pnl,
- has_position=has_position
- )
-
- enhancement = enhanced_reward - base_reward
- print(f"{i:2d}. {description}")
- print(f" Action: {action}, P&L: ${position_pnl:+.1f}")
- print(f" Base Reward: {base_reward:+.3f} → Enhanced: {enhanced_reward:+.3f} (Δ{enhancement:+.3f})")
- print()
-
- print("\n🤖 DQN POSITION-ENHANCED REWARD TESTS:")
- print("-" * 80)
-
- for i, (action, position_pnl, has_position, _, description) in enumerate(scenarios, 1):
- base_reward = 0.5 # Moderate base reward
- enhanced_reward = orchestrator._calculate_position_enhanced_reward_for_dqn(
- base_reward=base_reward,
- action=action,
- position_pnl=position_pnl,
- has_position=has_position
- )
-
- enhancement = enhanced_reward - base_reward
- print(f"{i:2d}. {description}")
- print(f" Action: {action}, P&L: ${position_pnl:+.1f}")
- print(f" Base Reward: {base_reward:+.3f} → Enhanced: {enhanced_reward:+.3f} (Δ{enhancement:+.3f})")
- print()
-
-def test_reward_incentives():
- """Test that rewards properly incentivize desired behaviors"""
-
- print("\n🎯 REWARD INCENTIVE VALIDATION:")
- print("-" * 50)
-
- orchestrator = TradingOrchestrator()
- cnn_model = EnhancedCNN(input_shape=100, n_actions=3)
-
- # Test 1: Holding winners vs holding losers
- print("1. HOLD action comparison:")
-
- hold_winner_reward = cnn_model._calculate_position_enhanced_reward(0.5, "HOLD", 100.0, True)
- hold_loser_reward = cnn_model._calculate_position_enhanced_reward(0.5, "HOLD", -100.0, True)
-
- print(f" Hold profitable position (+$100): {hold_winner_reward:+.3f}")
- print(f" Hold losing position (-$100): {hold_loser_reward:+.3f}")
- print(f" ✅ Incentive correct: {hold_winner_reward > hold_loser_reward}")
-
- # Test 2: Closing losers vs closing winners
- print("\n2. SELL action comparison:")
-
- sell_winner_reward = cnn_model._calculate_position_enhanced_reward(0.5, "SELL", 100.0, True)
- sell_loser_reward = cnn_model._calculate_position_enhanced_reward(0.5, "SELL", -100.0, True)
-
- print(f" Sell profitable position (+$100): {sell_winner_reward:+.3f}")
- print(f" Sell losing position (-$100): {sell_loser_reward:+.3f}")
- print(f" ✅ Incentive correct: {sell_loser_reward > sell_winner_reward}")
-
- # Test 3: DQN reward scaling
- print("\n3. DQN vs CNN reward scaling:")
-
- dqn_reward = orchestrator._calculate_position_enhanced_reward_for_dqn(0.5, "HOLD", -100.0, True)
- cnn_reward = cnn_model._calculate_position_enhanced_reward(0.5, "HOLD", -100.0, True)
-
- print(f" DQN penalty for holding loser: {dqn_reward:+.3f}")
- print(f" CNN penalty for holding loser: {cnn_reward:+.3f}")
- print(f" ✅ DQN more sensitive: {abs(dqn_reward) > abs(cnn_reward)}")
-
-def main():
- """Run all position-based reward tests"""
- try:
- test_position_reward_scenarios()
- test_reward_incentives()
-
- print("\n🚀 POSITION-BASED REWARD SYSTEM VALIDATION COMPLETE!")
- print("✅ System properly incentivizes:")
- print(" • Holding profitable positions (let winners run)")
- print(" • Closing losing positions (cut losses)")
- print(" • Taking appropriate action based on P&L")
- print(" • Different reward scaling for CNN vs DQN models")
-
- except Exception as e:
- print(f"❌ Test failed with error: {e}")
- import traceback
- traceback.print_exc()
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/test_profitability_reward_system.py b/test_profitability_reward_system.py
deleted file mode 100644
index 06b53c8..0000000
--- a/test_profitability_reward_system.py
+++ /dev/null
@@ -1,294 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script for the dynamic profitability reward system
-
-This script tests:
-1. Fee reversion to normal 0.1% (0.001)
-2. Dynamic profitability reward multiplier adjustment
-3. Success rate calculation
-4. Integration with dashboard display
-"""
-
-import sys
-import os
-import time
-from datetime import datetime, timedelta
-
-# Add project root to path
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from core.trading_executor import TradingExecutor, TradeRecord
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-
-def test_fee_configuration():
- """Test that fees are reverted to normal 0.1%"""
- print("=" * 60)
- print("🧪 TESTING FEE CONFIGURATION")
- print("=" * 60)
-
- executor = TradingExecutor()
-
- # Check fee configuration
- expected_open_fee = 0.001 # 0.1%
- expected_close_fee = 0.001 # 0.1%
- expected_total_fee = 0.002 # 0.2%
-
- actual_open_fee = executor.trading_fees['open_fee_percent']
- actual_close_fee = executor.trading_fees['close_fee_percent']
- actual_total_fee = executor.trading_fees['total_round_trip_fee']
-
- print(f"Expected Open Fee: {expected_open_fee} (0.1%)")
- print(f"Actual Open Fee: {actual_open_fee} (0.1%)")
- print(f"✅ Open Fee: {'PASS' if actual_open_fee == expected_open_fee else 'FAIL'}")
- print()
-
- print(f"Expected Close Fee: {expected_close_fee} (0.1%)")
- print(f"Actual Close Fee: {actual_close_fee} (0.1%)")
- print(f"✅ Close Fee: {'PASS' if actual_close_fee == expected_close_fee else 'FAIL'}")
- print()
-
- print(f"Expected Total Fee: {expected_total_fee} (0.2%)")
- print(f"Actual Total Fee: {actual_total_fee} (0.2%)")
- print(f"✅ Total Fee: {'PASS' if actual_total_fee == expected_total_fee else 'FAIL'}")
- print()
-
- return actual_open_fee == expected_open_fee and actual_close_fee == expected_close_fee
-
-def test_profitability_multiplier_initialization():
- """Test profitability multiplier initialization"""
- print("=" * 60)
- print("🧪 TESTING PROFITABILITY MULTIPLIER INITIALIZATION")
- print("=" * 60)
-
- executor = TradingExecutor()
-
- # Check initial values
- initial_multiplier = executor.profitability_reward_multiplier
- min_multiplier = executor.min_profitability_multiplier
- max_multiplier = executor.max_profitability_multiplier
- adjustment_step = executor.profitability_adjustment_step
-
- print(f"Initial Multiplier: {initial_multiplier} (should be 0.0)")
- print(f"Min Multiplier: {min_multiplier} (should be 0.0)")
- print(f"Max Multiplier: {max_multiplier} (should be 2.0)")
- print(f"Adjustment Step: {adjustment_step} (should be 0.1)")
- print()
-
- # Check thresholds
- increase_threshold = executor.success_rate_increase_threshold
- decrease_threshold = executor.success_rate_decrease_threshold
- trades_window = executor.recent_trades_window
-
- print(f"Increase Threshold: {increase_threshold:.1%} (should be 60%)")
- print(f"Decrease Threshold: {decrease_threshold:.1%} (should be 51%)")
- print(f"Trades Window: {trades_window} (should be 20)")
- print()
-
- # Test getter method
- multiplier_from_getter = executor.get_profitability_reward_multiplier()
- print(f"Multiplier via getter: {multiplier_from_getter}")
- print(f"✅ Getter method: {'PASS' if multiplier_from_getter == initial_multiplier else 'FAIL'}")
-
- return (initial_multiplier == 0.0 and
- min_multiplier == 0.0 and
- max_multiplier == 2.0 and
- adjustment_step == 0.1)
-
-def simulate_trades_and_test_adjustment(executor, winning_trades, total_trades):
- """Simulate trades and test multiplier adjustment"""
- print(f"📊 Simulating {winning_trades}/{total_trades} winning trades ({winning_trades/total_trades:.1%} success rate)")
-
- # Clear existing trade records
- executor.trade_records = []
-
- # Create simulated trade records
- base_time = datetime.now() - timedelta(hours=1)
-
- for i in range(total_trades):
- # Create winning or losing trade based on ratio
- is_winning = i < winning_trades
- pnl = 10.0 if is_winning else -5.0 # $10 profit or $5 loss
-
- trade_record = TradeRecord(
- symbol="ETH/USDT",
- side="LONG",
- quantity=0.01,
- entry_price=3000.0,
- exit_price=3010.0 if is_winning else 2995.0,
- entry_time=base_time + timedelta(minutes=i*2),
- exit_time=base_time + timedelta(minutes=i*2+1),
- pnl=pnl,
- fees=2.0,
- confidence=0.8,
- net_pnl=pnl - 2.0 # After fees
- )
-
- executor.trade_records.append(trade_record)
-
- # Force adjustment by setting last adjustment time to past
- executor.last_profitability_adjustment = datetime.now() - timedelta(minutes=10)
-
- # Get initial multiplier
- initial_multiplier = executor.get_profitability_reward_multiplier()
-
- # Calculate success rate
- success_rate = executor._calculate_recent_success_rate()
- print(f"Calculated success rate: {success_rate:.1%}")
-
- # Trigger adjustment
- executor._adjust_profitability_reward_multiplier()
-
- # Get new multiplier
- new_multiplier = executor.get_profitability_reward_multiplier()
-
- print(f"Initial multiplier: {initial_multiplier:.1f}")
- print(f"New multiplier: {new_multiplier:.1f}")
-
- # Determine expected change
- if success_rate > executor.success_rate_increase_threshold:
- expected_change = "increase"
- expected_new = min(executor.max_profitability_multiplier, initial_multiplier + executor.profitability_adjustment_step)
- elif success_rate < executor.success_rate_decrease_threshold:
- expected_change = "decrease"
- expected_new = max(executor.min_profitability_multiplier, initial_multiplier - executor.profitability_adjustment_step)
- else:
- expected_change = "no change"
- expected_new = initial_multiplier
-
- print(f"Expected change: {expected_change}")
- print(f"Expected new value: {expected_new:.1f}")
-
- success = abs(new_multiplier - expected_new) < 0.01
- print(f"✅ Adjustment: {'PASS' if success else 'FAIL'}")
- print()
-
- return success
-
-def test_orchestrator_integration():
- """Test orchestrator integration with profitability multiplier"""
- print("=" * 60)
- print("🧪 TESTING ORCHESTRATOR INTEGRATION")
- print("=" * 60)
-
- # Create components
- data_provider = DataProvider()
- executor = TradingExecutor()
- orchestrator = TradingOrchestrator(data_provider=data_provider)
-
- # Connect executor to orchestrator
- orchestrator.set_trading_executor(executor)
-
- # Set a test multiplier
- executor.profitability_reward_multiplier = 1.5
-
- # Test getting multiplier through orchestrator
- multiplier = orchestrator.get_profitability_reward_multiplier()
- print(f"Multiplier via orchestrator: {multiplier}")
- print(f"✅ Orchestrator getter: {'PASS' if multiplier == 1.5 else 'FAIL'}")
-
- # Test enhanced reward calculation
- base_pnl = 100.0 # $100 profit
- confidence = 0.8
-
- enhanced_reward = orchestrator.calculate_enhanced_reward(base_pnl, confidence)
- expected_enhanced = base_pnl * (1.0 + 1.5) # 100 * 2.5 = 250
-
- print(f"Base P&L: ${base_pnl:.2f}")
- print(f"Enhanced reward: ${enhanced_reward:.2f}")
- print(f"Expected: ${expected_enhanced:.2f}")
- print(f"✅ Enhanced reward: {'PASS' if abs(enhanced_reward - expected_enhanced) < 0.01 else 'FAIL'}")
-
- # Test with losing trade (should not be enhanced)
- losing_pnl = -50.0
- enhanced_losing = orchestrator.calculate_enhanced_reward(losing_pnl, confidence)
- print(f"Losing P&L: ${losing_pnl:.2f}")
- print(f"Enhanced losing: ${enhanced_losing:.2f}")
- print(f"✅ No enhancement for losses: {'PASS' if enhanced_losing == losing_pnl else 'FAIL'}")
-
- return multiplier == 1.5 and abs(enhanced_reward - expected_enhanced) < 0.01
-
-def main():
- """Run all tests"""
- print("🚀 DYNAMIC PROFITABILITY REWARD SYSTEM TEST")
- print("Testing fee reversion and dynamic reward adjustment")
- print()
-
- all_tests_passed = True
-
- # Test 1: Fee configuration
- try:
- fee_test_passed = test_fee_configuration()
- all_tests_passed = all_tests_passed and fee_test_passed
- except Exception as e:
- print(f"❌ Fee configuration test failed: {e}")
- all_tests_passed = False
-
- # Test 2: Profitability multiplier initialization
- try:
- init_test_passed = test_profitability_multiplier_initialization()
- all_tests_passed = all_tests_passed and init_test_passed
- except Exception as e:
- print(f"❌ Initialization test failed: {e}")
- all_tests_passed = False
-
- # Test 3: Multiplier adjustment scenarios
- print("=" * 60)
- print("🧪 TESTING MULTIPLIER ADJUSTMENT SCENARIOS")
- print("=" * 60)
-
- executor = TradingExecutor()
-
- try:
- # Scenario 1: High success rate (should increase multiplier)
- print("Scenario 1: High success rate (65% - should increase)")
- high_success_test = simulate_trades_and_test_adjustment(executor, 13, 20) # 65%
- all_tests_passed = all_tests_passed and high_success_test
-
- # Scenario 2: Low success rate (should decrease multiplier)
- print("Scenario 2: Low success rate (45% - should decrease)")
- low_success_test = simulate_trades_and_test_adjustment(executor, 9, 20) # 45%
- all_tests_passed = all_tests_passed and low_success_test
-
- # Scenario 3: Medium success rate (should not change)
- print("Scenario 3: Medium success rate (55% - should not change)")
- medium_success_test = simulate_trades_and_test_adjustment(executor, 11, 20) # 55%
- all_tests_passed = all_tests_passed and medium_success_test
-
- except Exception as e:
- print(f"❌ Adjustment scenario tests failed: {e}")
- all_tests_passed = False
-
- # Test 4: Orchestrator integration
- try:
- orchestrator_test_passed = test_orchestrator_integration()
- all_tests_passed = all_tests_passed and orchestrator_test_passed
- except Exception as e:
- print(f"❌ Orchestrator integration test failed: {e}")
- all_tests_passed = False
-
- # Final results
- print("=" * 60)
- print("📋 TEST RESULTS SUMMARY")
- print("=" * 60)
-
- if all_tests_passed:
- print("🎉 ALL TESTS PASSED!")
- print("✅ Fees reverted to normal 0.1%")
- print("✅ Dynamic profitability multiplier working")
- print("✅ Success rate calculation accurate")
- print("✅ Orchestrator integration functional")
- print()
- print("🚀 System ready for trading with dynamic profitability rewards!")
- print("📈 The model will learn to prioritize more profitable trades over time")
- print("🎯 Success rate >60% → increase reward multiplier")
- print("⚠️ Success rate <51% → decrease reward multiplier")
- else:
- print("❌ SOME TESTS FAILED!")
- print("Please check the error messages above and fix issues before trading.")
-
- return all_tests_passed
-
-if __name__ == "__main__":
- success = main()
- sys.exit(0 if success else 1)
\ No newline at end of file
diff --git a/test_training_data_collection.py b/test_training_data_collection.py
deleted file mode 100644
index c73c1a2..0000000
--- a/test_training_data_collection.py
+++ /dev/null
@@ -1,118 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Training Data Collection and Checkpoint Storage
-
-This script tests if the training system is working correctly and storing checkpoints.
-"""
-
-import os
-import sys
-import logging
-import asyncio
-from pathlib import Path
-from datetime import datetime
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import get_config, setup_logging
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-from utils.checkpoint_manager import get_checkpoint_manager
-
-# Setup logging
-setup_logging()
-logger = logging.getLogger(__name__)
-
-async def test_training_system():
- """Test if the training system is working and storing checkpoints"""
- logger.info("Testing training system and checkpoint storage...")
-
- # Initialize components
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
-
- # Get checkpoint manager
- checkpoint_manager = get_checkpoint_manager()
-
- # Check if checkpoint directory exists
- checkpoint_dir = Path("models/saved")
- if not checkpoint_dir.exists():
- logger.warning(f"Checkpoint directory {checkpoint_dir} does not exist. Creating...")
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
-
- # Check for existing checkpoints
- checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
- logger.info(f"Found {checkpoint_stats['total_checkpoints']} existing checkpoints.")
- logger.info(f"Total checkpoint size: {checkpoint_stats['total_size_mb']:.2f} MB")
-
- # List checkpoint files
- checkpoint_files = list(checkpoint_dir.glob("*.pt"))
- if checkpoint_files:
- logger.info("Recent checkpoint files:")
- for i, file in enumerate(sorted(checkpoint_files, key=lambda f: f.stat().st_mtime, reverse=True)[:5]):
- file_size = file.stat().st_size / (1024 * 1024) # Convert to MB
- modified_time = datetime.fromtimestamp(file.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S")
- logger.info(f" {i+1}. {file.name} ({file_size:.2f} MB, modified: {modified_time})")
- else:
- logger.warning("No checkpoint files found.")
-
- # Test training by making trading decisions
- logger.info("\nTesting training by making trading decisions...")
- symbols = orchestrator.symbols
-
- for symbol in symbols:
- logger.info(f"Making trading decision for {symbol}...")
- decision = await orchestrator.make_trading_decision(symbol)
-
- if decision:
- logger.info(f"Decision for {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
- else:
- logger.warning(f"No decision made for {symbol}.")
-
- # Check if new checkpoints were created
- new_checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
- new_checkpoints = new_checkpoint_stats['total_checkpoints'] - checkpoint_stats['total_checkpoints']
-
- if new_checkpoints > 0:
- logger.info(f"\nSuccess! {new_checkpoints} new checkpoints were created.")
- logger.info("Training system is working correctly.")
- else:
- logger.warning("\nNo new checkpoints were created.")
- logger.warning("This could be normal if the training threshold wasn't met.")
- logger.warning("Check the orchestrator's checkpoint saving logic.")
-
- # Check model states
- model_states = orchestrator.get_model_states()
- logger.info("\nModel states:")
- for model_name, state in model_states.items():
- checkpoint_loaded = state.get('checkpoint_loaded', False)
- checkpoint_filename = state.get('checkpoint_filename', 'none')
- current_loss = state.get('current_loss', None)
-
- status = "LOADED" if checkpoint_loaded else "FRESH"
- loss_str = f"{current_loss:.4f}" if current_loss is not None else "N/A"
-
- logger.info(f" {model_name}: {status}, Loss: {loss_str}, Checkpoint: {checkpoint_filename}")
-
- return new_checkpoints > 0
-
-async def main():
- """Main function"""
- logger.info("=" * 70)
- logger.info("TRAINING SYSTEM TEST")
- logger.info("=" * 70)
-
- success = await test_training_system()
-
- if success:
- logger.info("\nTraining system test passed!")
- return 0
- else:
- logger.warning("\nTraining system test completed with warnings.")
- logger.info("Check the logs for details.")
- return 1
-
-if __name__ == "__main__":
- sys.exit(asyncio.run(main()))
\ No newline at end of file
diff --git a/test_training_fixes.py b/test_training_fixes.py
deleted file mode 100644
index b3484e9..0000000
--- a/test_training_fixes.py
+++ /dev/null
@@ -1,118 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Training Fixes
-
-This script tests the fixes for CNN adapter and DQN training issues.
-"""
-
-import asyncio
-import time
-import numpy as np
-from datetime import datetime
-from core.orchestrator import TradingOrchestrator
-from core.data_provider import DataProvider
-
-async def test_training_fixes():
- """Test the training fixes"""
- print("=== Testing Training Fixes ===")
-
- # Initialize orchestrator
- print("1. Initializing orchestrator...")
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider=data_provider)
-
- # Wait for initialization
- await asyncio.sleep(3)
-
- # Check CNN adapter initialization
- print("\n2. Checking CNN adapter initialization:")
- if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
- print(" ✅ CNN adapter is properly initialized")
- print(f" CNN adapter type: {type(orchestrator.cnn_adapter)}")
- else:
- print(" ❌ CNN adapter is None or missing")
-
- # Check DQN agent initialization
- print("\n3. Checking DQN agent initialization:")
- if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
- print(" ✅ DQN agent is properly initialized")
- print(f" DQN agent type: {type(orchestrator.rl_agent)}")
- if hasattr(orchestrator.rl_agent, 'policy_net'):
- print(" ✅ DQN policy network is available")
- else:
- print(" ❌ DQN policy network is missing")
- else:
- print(" ❌ DQN agent is None or missing")
-
- # Test CNN predictions
- print("\n4. Testing CNN predictions:")
- try:
- predictions = await orchestrator._get_all_predictions('ETH/USDT')
- cnn_predictions = [p for p in predictions if 'cnn' in p.model_name.lower()]
- if cnn_predictions:
- print(f" ✅ Got {len(cnn_predictions)} CNN predictions")
- for pred in cnn_predictions:
- print(f" CNN prediction: {pred.action} (confidence: {pred.confidence:.3f})")
- else:
- print(" ❌ No CNN predictions received")
- except Exception as e:
- print(f" ❌ CNN prediction failed: {e}")
-
- # Test training with validation
- print("\n5. Testing training with validation:")
- for i in range(3):
- print(f" Training iteration {i+1}/3...")
-
- # Create training records for different models
- training_records = [
- {
- 'model_name': 'enhanced_cnn',
- 'model_input': np.random.randn(7850),
- 'prediction': {'action': 'BUY', 'confidence': 0.7},
- 'symbol': 'ETH/USDT',
- 'timestamp': datetime.now()
- },
- {
- 'model_name': 'dqn_agent',
- 'model_input': np.random.randn(7850),
- 'prediction': {'action': 'SELL', 'confidence': 0.8},
- 'symbol': 'ETH/USDT',
- 'timestamp': datetime.now()
- }
- ]
-
- for record in training_records:
- try:
- success = await orchestrator._train_model_on_outcome(
- record, True, 0.5, 1.0
- )
- if success:
- print(f" ✅ Training succeeded for {record['model_name']}")
- else:
- print(f" ⚠️ Training failed for {record['model_name']}")
- except Exception as e:
- print(f" ❌ Training error for {record['model_name']}: {e}")
-
- await asyncio.sleep(1)
-
- # Show final statistics
- print("\n6. Final model statistics:")
- orchestrator.log_model_statistics(detailed=True)
-
- # Check for overfitting warnings
- print("\n7. Checking for training quality:")
- summary = orchestrator.get_model_statistics_summary()
- for model_name, stats in summary.items():
- if stats['total_trainings'] > 0:
- print(f" {model_name}: {stats['total_trainings']} trainings, "
- f"avg time: {stats['average_training_time_ms']:.1f}ms")
- if stats['current_loss'] is not None:
- if stats['current_loss'] < 0.001:
- print(f" ⚠️ {model_name} has very low loss ({stats['current_loss']:.6f}) - check for overfitting")
- else:
- print(f" ✅ {model_name} has reasonable loss ({stats['current_loss']:.6f})")
-
- print("\n✅ Training fixes test completed!")
-
-if __name__ == "__main__":
- asyncio.run(test_training_fixes())
\ No newline at end of file
diff --git a/test_websocket_cob_data.py b/test_websocket_cob_data.py
deleted file mode 100644
index bb51691..0000000
--- a/test_websocket_cob_data.py
+++ /dev/null
@@ -1,111 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script to check if we're getting real COB data from WebSocket
-"""
-
-import time
-import logging
-from core.data_provider import DataProvider
-
-# Set up logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_websocket_cob_data():
- """Test if we're getting real COB data from WebSocket"""
- logger.info("Testing WebSocket COB data reception...")
-
- # Initialize data provider
- dp = DataProvider()
-
- # Wait for WebSocket connections
- logger.info("Waiting for WebSocket connections...")
- time.sleep(15)
-
- # Check WebSocket status
- logger.info("\n=== WebSocket Status ===")
- try:
- if hasattr(dp, 'enhanced_cob_websocket') and dp.enhanced_cob_websocket:
- status = dp.enhanced_cob_websocket.get_status_summary()
- logger.info(f"WebSocket status: {status}")
- else:
- logger.warning("Enhanced COB WebSocket not available")
- except Exception as e:
- logger.error(f"Error getting WebSocket status: {e}")
-
- # Check if we have any COB WebSocket data
- logger.info("\n=== COB WebSocket Data Check ===")
- if hasattr(dp, 'cob_websocket_data'):
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- if symbol in dp.cob_websocket_data:
- data = dp.cob_websocket_data[symbol]
- logger.info(f"{symbol}: {type(data)} - {len(str(data))} chars")
- if isinstance(data, dict):
- logger.info(f" Keys: {list(data.keys())}")
- if 'bids' in data:
- logger.info(f" Bids: {len(data['bids'])} levels")
- if 'asks' in data:
- logger.info(f" Asks: {len(data['asks'])} levels")
- else:
- logger.info(f"{symbol}: No WebSocket data")
- else:
- logger.warning("No cob_websocket_data attribute found")
-
- # Check raw COB ticks
- logger.info("\n=== Raw COB Ticks ===")
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- if hasattr(dp, 'cob_raw_ticks') and symbol in dp.cob_raw_ticks:
- raw_ticks = list(dp.cob_raw_ticks[symbol])
- logger.info(f"{symbol}: {len(raw_ticks)} raw ticks")
- if raw_ticks:
- latest = raw_ticks[-1]
- logger.info(f" Latest tick keys: {list(latest.keys())}")
- if 'timestamp' in latest:
- logger.info(f" Latest timestamp: {latest['timestamp']}")
- else:
- logger.info(f"{symbol}: No raw ticks")
-
- # Monitor for 30 seconds to see if data comes in
- logger.info("\n=== Monitoring for 30 seconds ===")
- initial_counts = {}
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- if hasattr(dp, 'cob_raw_ticks') and symbol in dp.cob_raw_ticks:
- initial_counts[symbol] = len(dp.cob_raw_ticks[symbol])
- else:
- initial_counts[symbol] = 0
-
- time.sleep(30)
-
- logger.info("After 30 seconds:")
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- if hasattr(dp, 'cob_raw_ticks') and symbol in dp.cob_raw_ticks:
- current_count = len(dp.cob_raw_ticks[symbol])
- new_ticks = current_count - initial_counts[symbol]
- logger.info(f"{symbol}: +{new_ticks} new ticks (total: {current_count})")
- else:
- logger.info(f"{symbol}: No raw ticks available")
-
- # Check if Enhanced WebSocket has latest data
- logger.info("\n=== Enhanced WebSocket Latest Data ===")
- try:
- if hasattr(dp, 'enhanced_cob_websocket') and dp.enhanced_cob_websocket:
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- if hasattr(dp.enhanced_cob_websocket, 'latest_cob_data'):
- latest_data = dp.enhanced_cob_websocket.latest_cob_data.get(symbol)
- if latest_data:
- logger.info(f"{symbol}: Latest WebSocket data available")
- logger.info(f" Keys: {list(latest_data.keys())}")
- if 'bids' in latest_data and 'asks' in latest_data:
- logger.info(f" Bids: {len(latest_data['bids'])}, Asks: {len(latest_data['asks'])}")
- else:
- logger.info(f"{symbol}: No latest WebSocket data")
- except Exception as e:
- logger.error(f"Error checking Enhanced WebSocket data: {e}")
-
- # Clean shutdown
- logger.info("\n=== Shutting Down ===")
- dp.stop_automatic_data_maintenance()
- logger.info("WebSocket COB data test completed")
-
-if __name__ == "__main__":
- test_websocket_cob_data()
\ No newline at end of file
diff --git a/tests/cob/test_cob_comparison.py b/tests/cob/test_cob_comparison.py
deleted file mode 100644
index 1215dd1..0000000
--- a/tests/cob/test_cob_comparison.py
+++ /dev/null
@@ -1,276 +0,0 @@
-#!/usr/bin/env python3
-"""
-Compare COB data quality between DataProvider and COBIntegration
-
-This test compares:
-1. DataProvider COB collection (used in our test)
-2. COBIntegration direct access (used in cob_realtime_dashboard.py)
-
-To understand why cob_realtime_dashboard.py gets more stable data.
-"""
-
-import asyncio
-import logging
-import time
-from collections import deque
-from datetime import datetime, timedelta
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pandas as pd
-
-from core.data_provider import DataProvider, MarketTick
-from core.config import get_config
-
-# Try to import COBIntegration like cob_realtime_dashboard does
-try:
- from core.cob_integration import COBIntegration
- COB_INTEGRATION_AVAILABLE = True
-except ImportError:
- COB_INTEGRATION_AVAILABLE = False
-
-# Configure logging
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-
-class COBComparisonTester:
- def __init__(self, symbol='ETH/USDT', duration_seconds=15):
- self.symbol = symbol
- self.duration = timedelta(seconds=duration_seconds)
-
- # Data storage for both methods
- self.dp_ticks = deque() # DataProvider ticks
- self.cob_data = deque() # COBIntegration data
-
- # Initialize DataProvider (method 1)
- logger.info("Initializing DataProvider...")
- self.data_provider = DataProvider()
- self.dp_cob_received = 0
-
- # Initialize COBIntegration (method 2)
- self.cob_integration = None
- self.cob_received = 0
- if COB_INTEGRATION_AVAILABLE:
- logger.info("Initializing COBIntegration...")
- self.cob_integration = COBIntegration(symbols=[self.symbol])
- else:
- logger.warning("COBIntegration not available - will only test DataProvider")
-
- self.start_time = None
- self.subscriber_id = None
-
- def _dp_cob_callback(self, symbol: str, cob_data: dict):
- """Callback for DataProvider COB data"""
- self.dp_cob_received += 1
-
- if 'stats' in cob_data and 'mid_price' in cob_data['stats']:
- mid_price = cob_data['stats']['mid_price']
- if mid_price > 0:
- synthetic_tick = MarketTick(
- symbol=symbol,
- timestamp=cob_data.get('timestamp', datetime.now()),
- price=mid_price,
- volume=cob_data.get('stats', {}).get('total_volume', 0),
- quantity=0,
- side='dp_cob',
- trade_id=f"dp_{self.dp_cob_received}",
- is_buyer_maker=False,
- raw_data=cob_data
- )
- self.dp_ticks.append(synthetic_tick)
-
- if self.dp_cob_received % 20 == 0:
- logger.info(f"[DataProvider] Update #{self.dp_cob_received}: {symbol} @ ${mid_price:.2f}")
-
- def _cob_integration_callback(self, symbol: str, data: dict):
- """Callback for COBIntegration data"""
- self.cob_received += 1
-
- # Store COBIntegration data directly
- cob_record = {
- 'symbol': symbol,
- 'timestamp': datetime.now(),
- 'data': data,
- 'source': 'cob_integration'
- }
- self.cob_data.append(cob_record)
-
- if self.cob_received % 20 == 0:
- stats = data.get('stats', {})
- mid_price = stats.get('mid_price', 0)
- logger.info(f"[COBIntegration] Update #{self.cob_received}: {symbol} @ ${mid_price:.2f}")
-
- async def run_comparison_test(self):
- """Run the comparison test"""
- logger.info(f"Starting COB comparison test for {self.symbol} for {self.duration.total_seconds()} seconds...")
-
- # Start DataProvider COB collection
- try:
- logger.info("Starting DataProvider COB collection...")
- self.data_provider.start_cob_collection()
- self.data_provider.subscribe_to_cob(self._dp_cob_callback)
- await self.data_provider.start_real_time_streaming()
- logger.info("DataProvider streaming started")
- except Exception as e:
- logger.error(f"Failed to start DataProvider: {e}")
-
- # Start COBIntegration if available
- if self.cob_integration:
- try:
- logger.info("Starting COBIntegration...")
- self.cob_integration.add_dashboard_callback(self._cob_integration_callback)
- await self.cob_integration.start()
- logger.info("COBIntegration started")
- except Exception as e:
- logger.error(f"Failed to start COBIntegration: {e}")
-
- # Collect data for specified duration
- self.start_time = datetime.now()
- while datetime.now() - self.start_time < self.duration:
- await asyncio.sleep(1)
- logger.info(f"DataProvider: {len(self.dp_ticks)} ticks | COBIntegration: {len(self.cob_data)} updates")
-
- # Stop data collection
- try:
- await self.data_provider.stop_real_time_streaming()
- if self.cob_integration:
- await self.cob_integration.stop()
- except Exception as e:
- logger.error(f"Error stopping data collection: {e}")
-
- logger.info(f"Comparison complete:")
- logger.info(f" DataProvider: {len(self.dp_ticks)} ticks received")
- logger.info(f" COBIntegration: {len(self.cob_data)} updates received")
-
- # Analyze and plot the differences
- self.analyze_differences()
- self.create_comparison_plots()
-
- def analyze_differences(self):
- """Analyze the differences between the two data sources"""
- logger.info("Analyzing data quality differences...")
-
- # Analyze DataProvider data
- dp_order_book_count = 0
- dp_mid_prices = []
-
- for tick in self.dp_ticks:
- if hasattr(tick, 'raw_data') and tick.raw_data:
- if 'bids' in tick.raw_data and 'asks' in tick.raw_data:
- dp_order_book_count += 1
- if 'stats' in tick.raw_data and 'mid_price' in tick.raw_data['stats']:
- dp_mid_prices.append(tick.raw_data['stats']['mid_price'])
-
- # Analyze COBIntegration data
- cob_order_book_count = 0
- cob_mid_prices = []
-
- for record in self.cob_data:
- data = record['data']
- if 'bids' in data and 'asks' in data:
- cob_order_book_count += 1
- if 'stats' in data and 'mid_price' in data['stats']:
- cob_mid_prices.append(data['stats']['mid_price'])
-
- logger.info("Data Quality Analysis:")
- logger.info(f" DataProvider:")
- logger.info(f" Total updates: {len(self.dp_ticks)}")
- logger.info(f" With order book data: {dp_order_book_count}")
- logger.info(f" Mid prices collected: {len(dp_mid_prices)}")
- if dp_mid_prices:
- logger.info(f" Price range: ${min(dp_mid_prices):.2f} - ${max(dp_mid_prices):.2f}")
-
- logger.info(f" COBIntegration:")
- logger.info(f" Total updates: {len(self.cob_data)}")
- logger.info(f" With order book data: {cob_order_book_count}")
- logger.info(f" Mid prices collected: {len(cob_mid_prices)}")
- if cob_mid_prices:
- logger.info(f" Price range: ${min(cob_mid_prices):.2f} - ${max(cob_mid_prices):.2f}")
-
- def create_comparison_plots(self):
- """Create comparison plots showing the difference"""
- logger.info("Creating comparison plots...")
-
- fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12))
-
- # Plot 1: Price comparison
- dp_times = []
- dp_prices = []
- for tick in self.dp_ticks:
- if tick.price > 0:
- dp_times.append(tick.timestamp)
- dp_prices.append(tick.price)
-
- cob_times = []
- cob_prices = []
- for record in self.cob_data:
- data = record['data']
- if 'stats' in data and 'mid_price' in data['stats']:
- cob_times.append(record['timestamp'])
- cob_prices.append(data['stats']['mid_price'])
-
- if dp_times:
- ax1.plot(pd.to_datetime(dp_times), dp_prices, 'b-', alpha=0.7, label='DataProvider COB', linewidth=1)
- if cob_times:
- ax1.plot(pd.to_datetime(cob_times), cob_prices, 'r-', alpha=0.7, label='COBIntegration', linewidth=1)
-
- ax1.set_title('Price Comparison: DataProvider vs COBIntegration')
- ax1.set_ylabel('Price (USDT)')
- ax1.legend()
- ax1.grid(True, alpha=0.3)
-
- # Plot 2: Data quality comparison (order book depth)
- dp_bid_counts = []
- dp_ask_counts = []
- dp_ob_times = []
-
- for tick in self.dp_ticks:
- if hasattr(tick, 'raw_data') and tick.raw_data:
- if 'bids' in tick.raw_data and 'asks' in tick.raw_data:
- dp_bid_counts.append(len(tick.raw_data['bids']))
- dp_ask_counts.append(len(tick.raw_data['asks']))
- dp_ob_times.append(tick.timestamp)
-
- cob_bid_counts = []
- cob_ask_counts = []
- cob_ob_times = []
-
- for record in self.cob_data:
- data = record['data']
- if 'bids' in data and 'asks' in data:
- cob_bid_counts.append(len(data['bids']))
- cob_ask_counts.append(len(data['asks']))
- cob_ob_times.append(record['timestamp'])
-
- if dp_ob_times:
- ax2.plot(pd.to_datetime(dp_ob_times), dp_bid_counts, 'b--', alpha=0.7, label='DP Bid Levels')
- ax2.plot(pd.to_datetime(dp_ob_times), dp_ask_counts, 'b:', alpha=0.7, label='DP Ask Levels')
- if cob_ob_times:
- ax2.plot(pd.to_datetime(cob_ob_times), cob_bid_counts, 'r--', alpha=0.7, label='COB Bid Levels')
- ax2.plot(pd.to_datetime(cob_ob_times), cob_ask_counts, 'r:', alpha=0.7, label='COB Ask Levels')
-
- ax2.set_title('Order Book Depth Comparison')
- ax2.set_ylabel('Number of Levels')
- ax2.set_xlabel('Time')
- ax2.legend()
- ax2.grid(True, alpha=0.3)
-
- plt.tight_layout()
-
- plot_filename = f"cob_comparison_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
- plt.savefig(plot_filename, dpi=150)
- logger.info(f"Comparison plot saved to {plot_filename}")
- plt.show()
-
-
-async def main():
- tester = COBComparisonTester()
- await tester.run_comparison_test()
-
-
-if __name__ == "__main__":
- try:
- asyncio.run(main())
- except KeyboardInterrupt:
- logger.info("Test interrupted by user.")
diff --git a/tests/cob/test_cob_data_stability.py b/tests/cob/test_cob_data_stability.py
deleted file mode 100644
index a736146..0000000
--- a/tests/cob/test_cob_data_stability.py
+++ /dev/null
@@ -1,502 +0,0 @@
-import asyncio
-import logging
-import time
-from collections import deque
-from datetime import datetime, timedelta
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pandas as pd
-from matplotlib.colors import LogNorm
-
-from core.data_provider import DataProvider, MarketTick
-from core.config import get_config
-
-# Configure logging
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-
-class COBStabilityTester:
- def __init__(self, symbol='ETHUSDT', duration_seconds=10):
- self.symbol = symbol
- self.duration = timedelta(seconds=duration_seconds)
- self.ticks = deque()
-
- # Set granularity (buckets) based on symbol
- if 'ETH' in symbol.upper():
- self.price_granularity = 1.0 # 1 USD for ETH
- elif 'BTC' in symbol.upper():
- self.price_granularity = 10.0 # 10 USD for BTC
- else:
- self.price_granularity = 1.0 # Default 1 USD
-
- logger.info(f"Using price granularity: ${self.price_granularity} for {symbol}")
-
- # Initialize DataProvider the same way as clean_dashboard
- logger.info("Initializing DataProvider like in clean_dashboard...")
- self.data_provider = DataProvider() # Use default constructor like clean_dashboard
-
- # Initialize COB data collection like clean_dashboard does
- self.cob_data_received = 0
- self.latest_cob_data = {}
-
- # Store all COB snapshots for heatmap generation
- self.cob_snapshots = deque()
- self.price_data = [] # For price line chart
-
- self.start_time = None
- self.subscriber_id = None
- self.last_log_time = None
-
- def _tick_callback(self, tick: MarketTick):
- """Callback function to receive ticks from the DataProvider."""
- if self.start_time is None:
- self.start_time = datetime.now()
- logger.info(f"Started collecting ticks at {self.start_time}")
-
- # Store all ticks
- self.ticks.append(tick)
-
- def _cob_data_callback(self, symbol: str, cob_data: dict):
- """Callback function to receive COB data from the DataProvider."""
- # Debug: Log first few callbacks to see what symbols we're getting
- if self.cob_data_received < 5:
- logger.info(f"DEBUG: Received COB data for symbol '{symbol}' (target: '{self.symbol}')")
-
- # Filter to only our requested symbol - handle both formats (ETH/USDT and ETHUSDT)
- normalized_symbol = symbol.replace('/', '')
- normalized_target = self.symbol.replace('/', '')
- if normalized_symbol != normalized_target:
- if self.cob_data_received < 5:
- logger.info(f"DEBUG: Skipping symbol '{symbol}' (normalized: '{normalized_symbol}' vs target: '{normalized_target}')")
- return
-
- self.cob_data_received += 1
- self.latest_cob_data[symbol] = cob_data
-
- # Store the complete COB snapshot for heatmap generation
- if 'bids' in cob_data and 'asks' in cob_data:
- # Debug: Log structure of first few COB snapshots
- if len(self.cob_snapshots) < 3:
- logger.info(f"DEBUG: COB data structure - bids: {len(cob_data['bids'])} items, asks: {len(cob_data['asks'])} items")
- if cob_data['bids']:
- logger.info(f"DEBUG: First bid: {cob_data['bids'][0]}")
- if cob_data['asks']:
- logger.info(f"DEBUG: First ask: {cob_data['asks'][0]}")
-
- # Use current time for timestamp consistency
- current_time = datetime.now()
- snapshot = {
- 'timestamp': current_time,
- 'bids': cob_data['bids'],
- 'asks': cob_data['asks'],
- 'stats': cob_data.get('stats', {})
- }
- self.cob_snapshots.append(snapshot)
-
- # Log bucketed COB data every second
- now = datetime.now()
- if self.last_log_time is None or (now - self.last_log_time).total_seconds() >= 1.0:
- self.last_log_time = now
- self._log_bucketed_cob_data(cob_data)
-
- # Convert COB data to tick-like format for analysis
- if 'stats' in cob_data and 'mid_price' in cob_data['stats']:
- mid_price = cob_data['stats']['mid_price']
- if mid_price > 0:
- # Filter out extreme price movements (±10% of recent average)
- if len(self.price_data) > 5:
- recent_prices = [p['price'] for p in self.price_data[-5:]]
- avg_recent_price = sum(recent_prices) / len(recent_prices)
- price_deviation = abs(mid_price - avg_recent_price) / avg_recent_price
-
- if price_deviation > 0.10: # More than 10% deviation
- logger.warning(f"Filtering out extreme price: ${mid_price:.2f} (deviation: {price_deviation:.1%} from avg ${avg_recent_price:.2f})")
- return # Skip this data point
-
- # Store price data for line chart with consistent timestamp
- current_time = datetime.now()
- self.price_data.append({
- 'timestamp': current_time,
- 'price': mid_price
- })
-
- # Create a synthetic tick from COB data with consistent timestamp
- current_time = datetime.now()
- synthetic_tick = MarketTick(
- symbol=symbol,
- timestamp=current_time,
- price=mid_price,
- volume=cob_data.get('stats', {}).get('total_volume', 0),
- quantity=0, # Not available in COB data
- side='unknown', # COB data doesn't have side info
- trade_id=f"cob_{self.cob_data_received}",
- is_buyer_maker=False,
- raw_data=cob_data
- )
- self.ticks.append(synthetic_tick)
-
- if self.cob_data_received % 10 == 0: # Log every 10th update
- logger.info(f"COB update #{self.cob_data_received}: {symbol} @ ${mid_price:.2f}")
-
- def _log_bucketed_cob_data(self, cob_data: dict):
- """Log bucketed COB data every second"""
- try:
- if 'bids' not in cob_data or 'asks' not in cob_data:
- logger.info("COB-1s: No order book data available")
- return
-
- if 'stats' not in cob_data or 'mid_price' not in cob_data['stats']:
- logger.info("COB-1s: No mid price available")
- return
-
- mid_price = cob_data['stats']['mid_price']
- if mid_price <= 0:
- return
-
- # Bucket the order book data
- bid_buckets = {}
- ask_buckets = {}
-
- # Process bids (top 10)
- for bid in cob_data['bids'][:10]:
- try:
- if isinstance(bid, dict):
- price = float(bid['price'])
- size = float(bid['size'])
- elif isinstance(bid, (list, tuple)) and len(bid) >= 2:
- price = float(bid[0])
- size = float(bid[1])
- else:
- continue
-
- bucketed_price = round(price / self.price_granularity) * self.price_granularity
- bid_buckets[bucketed_price] = bid_buckets.get(bucketed_price, 0) + size
- except (ValueError, TypeError, IndexError):
- continue
-
- # Process asks (top 10)
- for ask in cob_data['asks'][:10]:
- try:
- if isinstance(ask, dict):
- price = float(ask['price'])
- size = float(ask['size'])
- elif isinstance(ask, (list, tuple)) and len(ask) >= 2:
- price = float(ask[0])
- size = float(ask[1])
- else:
- continue
-
- bucketed_price = round(price / self.price_granularity) * self.price_granularity
- ask_buckets[bucketed_price] = ask_buckets.get(bucketed_price, 0) + size
- except (ValueError, TypeError, IndexError):
- continue
-
- # Format for log output
- bid_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(bid_buckets.items(), reverse=True)])
- ask_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(ask_buckets.items())])
-
- logger.info(f"COB-1s @ ${mid_price:.2f} | BIDS: {bid_str} | ASKS: {ask_str}")
-
- except Exception as e:
- logger.warning(f"Error logging bucketed COB data: {e}")
-
- async def run_test(self):
- """Run the data collection and plotting test."""
- logger.info(f"Starting COB stability test for {self.symbol} for {self.duration.total_seconds()} seconds...")
-
- # Initialize COB collection like clean_dashboard does
- try:
- logger.info("Starting COB collection in data provider...")
- self.data_provider.start_cob_collection()
- logger.info("Started COB collection in data provider")
-
- # Subscribe to COB updates
- logger.info("Subscribing to COB data updates...")
- self.data_provider.subscribe_to_cob(self._cob_data_callback)
- logger.info("Subscribed to COB data updates from data provider")
- except Exception as e:
- logger.error(f"Failed to start COB collection or subscribe: {e}")
-
- # Subscribe to ticks as fallback
- try:
- self.subscriber_id = self.data_provider.subscribe_to_ticks(self._tick_callback, symbols=[self.symbol])
- logger.info("Subscribed to tick data as fallback")
- except Exception as e:
- logger.warning(f"Failed to subscribe to ticks: {e}")
-
- # Start the data provider's real-time streaming
- try:
- await self.data_provider.start_real_time_streaming()
- logger.info("Started real-time streaming")
- except Exception as e:
- logger.error(f"Failed to start real-time streaming: {e}")
-
- # Collect data for the specified duration
- self.start_time = datetime.now()
- while datetime.now() - self.start_time < self.duration:
- await asyncio.sleep(1)
- logger.info(f"Collected {len(self.ticks)} ticks so far...")
-
- # Stop streaming and unsubscribe
- await self.data_provider.stop_real_time_streaming()
- self.data_provider.unsubscribe_from_ticks(self.subscriber_id)
-
- logger.info(f"Finished collecting data. Total ticks: {len(self.ticks)}")
-
- # Plot the results
- if self.price_data and self.cob_snapshots:
- self.create_price_heatmap_chart()
- elif self.ticks:
- self._create_simple_price_chart()
- else:
- logger.warning("No data was collected. Cannot generate plot.")
-
- def create_price_heatmap_chart(self):
- """Create a visualization with price chart and order book scatter plot."""
- if not self.price_data or not self.cob_snapshots:
- logger.warning("Insufficient data to plot.")
- return
-
- logger.info(f"Creating price and order book chart...")
- logger.info(f"Data summary: {len(self.price_data)} price points, {len(self.cob_snapshots)} COB snapshots")
-
- # Prepare price data
- price_df = pd.DataFrame(self.price_data)
- price_df['timestamp'] = pd.to_datetime(price_df['timestamp'])
-
- logger.info(f"Price data time range: {price_df['timestamp'].min()} to {price_df['timestamp'].max()}")
- logger.info(f"Price range: ${price_df['price'].min():.2f} to ${price_df['price'].max():.2f}")
-
- # Create figure with subplots
- fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 12), height_ratios=[3, 2])
-
- # Top plot: Price chart with order book levels
- ax1.plot(price_df['timestamp'], price_df['price'], 'yellow', linewidth=2, label='Mid Price', zorder=10)
-
- # Plot order book levels as scatter points
- bid_times, bid_prices, bid_sizes = [], [], []
- ask_times, ask_prices, ask_sizes = [], [], []
-
- # Calculate average price for filtering
- avg_price = price_df['price'].mean() if not price_df.empty else 3500 # Fallback price
- price_lower = avg_price * 0.9 # -10%
- price_upper = avg_price * 1.1 # +10%
-
- logger.info(f"Filtering order book data to price range: ${price_lower:.2f} - ${price_upper:.2f} (±10% of ${avg_price:.2f})")
-
- for snapshot in list(self.cob_snapshots)[-50:]: # Use last 50 snapshots for clarity
- timestamp = pd.to_datetime(snapshot['timestamp'])
-
- # Process bids (top 10)
- for order in snapshot.get('bids', [])[:10]:
- try:
- if isinstance(order, dict):
- price = float(order['price'])
- size = float(order['size'])
- elif isinstance(order, (list, tuple)) and len(order) >= 2:
- price = float(order[0])
- size = float(order[1])
- else:
- continue
-
- # Filter out prices outside ±10% range
- if price < price_lower or price > price_upper:
- continue
-
- bid_times.append(timestamp)
- bid_prices.append(price)
- bid_sizes.append(size)
- except (ValueError, TypeError, IndexError):
- continue
-
- # Process asks (top 10)
- for order in snapshot.get('asks', [])[:10]:
- try:
- if isinstance(order, dict):
- price = float(order['price'])
- size = float(order['size'])
- elif isinstance(order, (list, tuple)) and len(order) >= 2:
- price = float(order[0])
- size = float(order[1])
- else:
- continue
-
- # Filter out prices outside ±10% range
- if price < price_lower or price > price_upper:
- continue
-
- ask_times.append(timestamp)
- ask_prices.append(price)
- ask_sizes.append(size)
- except (ValueError, TypeError, IndexError):
- continue
-
- # Plot order book data as scatter with size indicating volume
- if bid_times:
- bid_sizes_normalized = np.array(bid_sizes) * 3 # Scale for visibility
- ax1.scatter(bid_times, bid_prices, s=bid_sizes_normalized, c='green', alpha=0.3, label='Bids')
- logger.info(f"Plotted {len(bid_times)} bid levels")
-
- if ask_times:
- ask_sizes_normalized = np.array(ask_sizes) * 3 # Scale for visibility
- ax1.scatter(ask_times, ask_prices, s=ask_sizes_normalized, c='red', alpha=0.3, label='Asks')
- logger.info(f"Plotted {len(ask_times)} ask levels")
-
- ax1.set_title(f'Real-time Price and Order Book - {self.symbol}\nGranularity: ${self.price_granularity} | Duration: {self.duration.total_seconds()}s')
- ax1.set_ylabel('Price (USDT)')
- ax1.legend()
- ax1.grid(True, alpha=0.3)
-
- # Set proper time range (X-axis) - use actual data collection period
- time_min = price_df['timestamp'].min()
- time_max = price_df['timestamp'].max()
- actual_duration = (time_max - time_min).total_seconds()
- logger.info(f"Actual data collection duration: {actual_duration:.1f} seconds")
-
- ax1.set_xlim(time_min, time_max)
-
- # Set tight price range (Y-axis) - use ±2% of price range for better visibility
- price_min = price_df['price'].min()
- price_max = price_df['price'].max()
- price_center = (price_min + price_max) / 2
- price_range = price_max - price_min
-
- # If price range is very small, use a minimum range of $5
- if price_range < 5:
- price_range = 5
-
- # Add 20% padding to the price range for better visualization
- y_padding = price_range * 0.2
- y_min = price_min - y_padding
- y_max = price_max + y_padding
-
- ax1.set_ylim(y_min, y_max)
- logger.info(f"Chart Y-axis range: ${y_min:.2f} - ${y_max:.2f} (center: ${price_center:.2f}, range: ${price_range:.2f})")
-
- # Bottom plot: Order book depth over time (aggregated)
- time_buckets = []
- bid_depths = []
- ask_depths = []
-
- # Create time buckets (every few snapshots)
- snapshots_list = list(self.cob_snapshots)
- bucket_size = max(1, len(snapshots_list) // 20) # ~20 buckets
- for i in range(0, len(snapshots_list), bucket_size):
- bucket_snapshots = snapshots_list[i:i+bucket_size]
- if not bucket_snapshots:
- continue
-
- # Use middle timestamp of bucket
- mid_snapshot = bucket_snapshots[len(bucket_snapshots)//2]
- time_buckets.append(pd.to_datetime(mid_snapshot['timestamp']))
-
- # Calculate average depths
- total_bid_depth = 0
- total_ask_depth = 0
- snapshot_count = 0
-
- for snapshot in bucket_snapshots:
- bid_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0))
- for order in snapshot.get('bids', [])[:10]])
- ask_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0))
- for order in snapshot.get('asks', [])[:10]])
- total_bid_depth += bid_depth
- total_ask_depth += ask_depth
- snapshot_count += 1
-
- if snapshot_count > 0:
- bid_depths.append(total_bid_depth / snapshot_count)
- ask_depths.append(total_ask_depth / snapshot_count)
- else:
- bid_depths.append(0)
- ask_depths.append(0)
-
- if time_buckets:
- ax2.plot(time_buckets, bid_depths, 'green', linewidth=2, label='Bid Depth', alpha=0.7)
- ax2.plot(time_buckets, ask_depths, 'red', linewidth=2, label='Ask Depth', alpha=0.7)
- ax2.fill_between(time_buckets, bid_depths, alpha=0.3, color='green')
- ax2.fill_between(time_buckets, ask_depths, alpha=0.3, color='red')
-
- ax2.set_title('Order Book Depth Over Time')
- ax2.set_xlabel('Time')
- ax2.set_ylabel('Depth (Volume)')
- ax2.legend()
- ax2.grid(True, alpha=0.3)
-
- # Set same time range for bottom chart
- ax2.set_xlim(time_min, time_max)
-
- # Format time axes
- fig.autofmt_xdate()
- plt.tight_layout()
-
- plot_filename = f"price_heatmap_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
- plt.savefig(plot_filename, dpi=150, bbox_inches='tight')
- logger.info(f"Price and order book chart saved to {plot_filename}")
- plt.show()
-
- def _create_simple_price_chart(self):
- """Create a simple price chart as fallback"""
- logger.info("Creating simple price chart as fallback...")
-
- prices = []
- times = []
-
- for tick in self.ticks:
- if tick.price > 0:
- prices.append(tick.price)
- times.append(tick.timestamp)
-
- if not prices:
- logger.warning("No price data to plot")
- return
-
- fig, ax = plt.subplots(figsize=(15, 8))
- ax.plot(pd.to_datetime(times), prices, 'cyan', linewidth=1)
- ax.set_title(f'Price Chart - {self.symbol}')
- ax.set_xlabel('Time')
- ax.set_ylabel('Price (USDT)')
- fig.autofmt_xdate()
-
- plot_filename = f"cob_price_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
- plt.savefig(plot_filename)
- logger.info(f"Price chart saved to {plot_filename}")
- plt.show()
-
-
-async def main(symbol='ETHUSDT', duration_seconds=10):
- """Main function to run the COB test with configurable parameters.
-
- Args:
- symbol: Trading symbol (default: ETHUSDT)
- duration_seconds: Test duration in seconds (default: 10)
- """
- logger.info(f"Starting COB test with symbol={symbol}, duration={duration_seconds}s")
- tester = COBStabilityTester(symbol=symbol, duration_seconds=duration_seconds)
- await tester.run_test()
-
-
-if __name__ == "__main__":
- import sys
-
- # Parse command line arguments
- symbol = 'ETHUSDT' # Default
- duration = 10 # Default
-
- if len(sys.argv) > 1:
- symbol = sys.argv[1]
- if len(sys.argv) > 2:
- try:
- duration = int(sys.argv[2])
- except ValueError:
- logger.warning(f"Invalid duration '{sys.argv[2]}', using default 10 seconds")
-
- logger.info(f"Configuration: Symbol={symbol}, Duration={duration}s")
- logger.info(f"Granularity: {'1 USD for ETH' if 'ETH' in symbol.upper() else '10 USD for BTC' if 'BTC' in symbol.upper() else '1 USD default'}")
-
- try:
- asyncio.run(main(symbol, duration))
- except KeyboardInterrupt:
- logger.info("Test interrupted by user.")
diff --git a/tests/test_training.py b/tests/test_training.py
deleted file mode 100644
index c3bb012..0000000
--- a/tests/test_training.py
+++ /dev/null
@@ -1,310 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Training Script for AI Trading Models
-
-This script tests the training functionality of our CNN and RL models
-and demonstrates the learning capabilities.
-"""
-
-import logging
-import sys
-import asyncio
-from pathlib import Path
-from datetime import datetime, timedelta
-from safe_logging import setup_safe_logging
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import setup_logging
-from core.data_provider import DataProvider
-from core.enhanced_orchestrator import EnhancedTradingOrchestrator
-from models import get_model_registry, CNNModelWrapper, RLAgentWrapper
-
-# Setup logging
-setup_logging()
-logger = logging.getLogger(__name__)
-
-def test_model_loading():
- """Test that models load correctly"""
- logger.info("=== TESTING MODEL LOADING ===")
-
- try:
- # Get model registry
- registry = get_model_registry()
-
- # Check loaded models
- logger.info(f"Loaded models: {list(registry.models.keys())}")
-
- # Test each model
- for name, model in registry.models.items():
- logger.info(f"Testing {name} model...")
-
- # Test prediction
- import numpy as np
- test_features = np.random.random((20, 5)) # 20 timesteps, 5 features
-
- try:
- predictions, confidence = model.predict(test_features)
- logger.info(f" ✅ {name} prediction: {predictions} (confidence: {confidence:.3f})")
- except Exception as e:
- logger.error(f" ❌ {name} prediction failed: {e}")
-
- # Memory stats
- stats = registry.get_memory_stats()
- logger.info(f"Memory usage: {stats['total_used_mb']:.1f}MB / {stats['total_limit_mb']:.1f}MB")
-
- return True
-
- except Exception as e:
- logger.error(f"Model loading test failed: {e}")
- return False
-
-async def test_orchestrator_integration():
- """Test orchestrator integration with models"""
- logger.info("=== TESTING ORCHESTRATOR INTEGRATION ===")
-
- try:
- # Initialize components
- data_provider = DataProvider()
- orchestrator = EnhancedTradingOrchestrator(data_provider)
-
- # Test coordinated decisions
- logger.info("Testing coordinated decision making...")
- decisions = await orchestrator.make_coordinated_decisions()
-
- if decisions:
- for symbol, decision in decisions.items():
- if decision:
- logger.info(f" ✅ {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
- else:
- logger.info(f" ⏸️ {symbol}: No decision (waiting)")
- else:
- logger.warning(" ❌ No decisions made")
-
- # Test RL evaluation
- logger.info("Testing RL evaluation...")
- await orchestrator.evaluate_actions_with_rl()
-
- return True
-
- except Exception as e:
- logger.error(f"Orchestrator integration test failed: {e}")
- return False
-
-def test_rl_learning():
- """Test RL learning functionality"""
- logger.info("=== TESTING RL LEARNING ===")
-
- try:
- registry = get_model_registry()
- rl_agent = registry.get_model('RL')
-
- if not rl_agent:
- logger.error("RL agent not found")
- return False
-
- # Simulate some experiences
- import numpy as np
-
- logger.info("Simulating trading experiences...")
- for i in range(50):
- state = np.random.random(10)
- action = np.random.randint(0, 3)
- reward = np.random.uniform(-0.1, 0.1) # Random P&L
- next_state = np.random.random(10)
- done = False
-
- # Store experience
- rl_agent.remember(state, action, reward, next_state, done)
-
- logger.info(f"Stored {len(rl_agent.experience_buffer)} experiences")
-
- # Test replay training
- logger.info("Testing replay training...")
- loss = rl_agent.replay()
-
- if loss is not None:
- logger.info(f" ✅ Training loss: {loss:.4f}")
- else:
- logger.info(" ⏸️ Not enough experiences for training")
-
- return True
-
- except Exception as e:
- logger.error(f"RL learning test failed: {e}")
- return False
-
-def test_cnn_training():
- """Test CNN training functionality"""
- logger.info("=== TESTING CNN TRAINING ===")
-
- try:
- registry = get_model_registry()
- cnn_model = registry.get_model('CNN')
-
- if not cnn_model:
- logger.error("CNN model not found")
- return False
-
- # Test training with mock perfect moves
- training_data = {
- 'perfect_moves': [],
- 'market_data': {},
- 'symbols': ['ETH/USDT', 'BTC/USDT'],
- 'timeframes': ['1m', '1h']
- }
-
- # Mock some perfect moves
- for i in range(10):
- perfect_move = {
- 'symbol': 'ETH/USDT',
- 'timeframe': '1m',
- 'timestamp': datetime.now() - timedelta(hours=i),
- 'optimal_action': 'BUY' if i % 2 == 0 else 'SELL',
- 'confidence_should_have_been': 0.8 + i * 0.01,
- 'actual_outcome': 0.02 if i % 2 == 0 else -0.015
- }
- training_data['perfect_moves'].append(perfect_move)
-
- logger.info(f"Testing training with {len(training_data['perfect_moves'])} perfect moves...")
-
- # Test training
- result = cnn_model.train(training_data)
-
- if result and result.get('status') == 'training_simulated':
- logger.info(f" ✅ Training completed: {result}")
- else:
- logger.warning(f" ⚠️ Training result: {result}")
-
- return True
-
- except Exception as e:
- logger.error(f"CNN training test failed: {e}")
- return False
-
-def test_prediction_tracking():
- """Test prediction tracking and learning feedback"""
- logger.info("=== TESTING PREDICTION TRACKING ===")
-
- try:
- # Initialize components
- data_provider = DataProvider()
- orchestrator = EnhancedTradingOrchestrator(data_provider)
-
- # Get some market data for testing
- test_data = data_provider.get_historical_data('ETH/USDT', '1m', limit=100)
-
- if test_data is None or test_data.empty:
- logger.warning("No market data available for testing")
- return True
-
- logger.info(f"Testing with {len(test_data)} candles of ETH/USDT 1m data")
-
- # Simulate some predictions and outcomes
- correct_predictions = 0
- total_predictions = 0
-
- for i in range(min(10, len(test_data) - 5)):
- # Get a slice of data
- current_data = test_data.iloc[i:i+20]
- future_data = test_data.iloc[i+20:i+25]
-
- if len(current_data) < 20 or len(future_data) < 5:
- continue
-
- # Make prediction
- current_price = current_data['close'].iloc[-1]
- future_price = future_data['close'].iloc[-1]
- actual_change = (future_price - current_price) / current_price
-
- # Simulate model prediction
- predicted_action = 'BUY' if actual_change > 0.001 else 'SELL' if actual_change < -0.001 else 'HOLD'
-
- # Check if prediction was correct
- if predicted_action == 'BUY' and actual_change > 0:
- correct_predictions += 1
- logger.info(f" ✅ Correct BUY prediction: {actual_change:.4f}")
- elif predicted_action == 'SELL' and actual_change < 0:
- correct_predictions += 1
- logger.info(f" ✅ Correct SELL prediction: {actual_change:.4f}")
- elif predicted_action == 'HOLD' and abs(actual_change) < 0.001:
- correct_predictions += 1
- logger.info(f" ✅ Correct HOLD prediction: {actual_change:.4f}")
- else:
- logger.info(f" ❌ Wrong {predicted_action} prediction: {actual_change:.4f}")
-
- total_predictions += 1
-
- if total_predictions > 0:
- accuracy = correct_predictions / total_predictions
- logger.info(f"Prediction accuracy: {accuracy:.1%} ({correct_predictions}/{total_predictions})")
-
- return True
-
- except Exception as e:
- logger.error(f"Prediction tracking test failed: {e}")
- return False
-
-async def main():
- """Main test function"""
- logger.info("🧪 STARTING AI TRADING MODEL TESTS")
- logger.info("Testing model loading, training, and learning capabilities")
-
- tests = [
- ("Model Loading", test_model_loading),
- ("Orchestrator Integration", test_orchestrator_integration),
- ("RL Learning", test_rl_learning),
- ("CNN Training", test_cnn_training),
- ("Prediction Tracking", test_prediction_tracking)
- ]
-
- results = {}
-
- for test_name, test_func in tests:
- logger.info(f"\n{'='*50}")
- logger.info(f"Running: {test_name}")
- logger.info(f"{'='*50}")
-
- try:
- if asyncio.iscoroutinefunction(test_func):
- result = await test_func()
- else:
- result = test_func()
-
- results[test_name] = result
-
- if result:
- logger.info(f"✅ {test_name}: PASSED")
- else:
- logger.error(f"❌ {test_name}: FAILED")
-
- except Exception as e:
- logger.error(f"❌ {test_name}: ERROR - {e}")
- results[test_name] = False
-
- # Summary
- logger.info(f"\n{'='*50}")
- logger.info("TEST SUMMARY")
- logger.info(f"{'='*50}")
-
- passed = sum(1 for result in results.values() if result)
- total = len(results)
-
- for test_name, result in results.items():
- status = "✅ PASSED" if result else "❌ FAILED"
- logger.info(f"{test_name}: {status}")
-
- logger.info(f"\nOverall: {passed}/{total} tests passed ({passed/total:.1%})")
-
- if passed == total:
- logger.info("🎉 All tests passed! The AI trading system is working correctly.")
- else:
- logger.warning(f"⚠️ {total-passed} tests failed. Please check the logs above.")
-
- return 0 if passed == total else 1
-
-if __name__ == "__main__":
- exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
diff --git a/tests/test_training_integration.py b/tests/test_training_integration.py
deleted file mode 100644
index 04c854f..0000000
--- a/tests/test_training_integration.py
+++ /dev/null
@@ -1,204 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Training Integration with Dashboard
-
-This script tests the enhanced dashboard's ability to:
-1. Stream training data to CNN and DQN models
-2. Display real-time training metrics and progress
-3. Show model learning curves and performance
-4. Integrate with the continuous training system
-"""
-
-import sys
-import logging
-import time
-import asyncio
-from datetime import datetime, timedelta
-from pathlib import Path
-
-# Setup logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-def test_training_integration():
- """Test the training integration functionality"""
- try:
- print("="*60)
- print("TESTING TRAINING INTEGRATION WITH DASHBOARD")
- print("="*60)
-
- # Import dashboard
- from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
- from core.data_provider import DataProvider
- from core.orchestrator import TradingOrchestrator
-
- # Create components
- data_provider = DataProvider()
- orchestrator = TradingOrchestrator(data_provider)
- dashboard = TradingDashboard(data_provider, orchestrator)
-
- print(f"✓ Dashboard created with training integration")
- print(f"✓ Continuous training active: {getattr(dashboard, 'training_active', False)}")
-
- # Test 1: Simulate tick data for training
- print("\n📊 TEST 1: Simulating Tick Data")
- print("-" * 40)
-
- # Add simulated tick data to cache
- base_price = 3500.0
- for i in range(1000):
- tick_data = {
- 'timestamp': datetime.now() - timedelta(seconds=1000-i),
- 'price': base_price + (i % 100) * 0.1,
- 'volume': 100 + (i % 50),
- 'side': 'buy' if i % 2 == 0 else 'sell'
- }
- dashboard.tick_cache.append(tick_data)
-
- print(f"✓ Added {len(dashboard.tick_cache)} ticks to cache")
-
- # Test 2: Prepare training data
- print("\n🔄 TEST 2: Preparing Training Data")
- print("-" * 40)
-
- training_data = dashboard._prepare_training_data()
- if training_data:
- print(f"✓ Training data prepared successfully")
- print(f" - OHLCV bars: {len(training_data['ohlcv'])}")
- print(f" - Features: {training_data['features']}")
- print(f" - Symbol: {training_data['symbol']}")
- else:
- print("❌ Failed to prepare training data")
-
- # Test 3: Format data for CNN
- print("\n🧠 TEST 3: CNN Data Formatting")
- print("-" * 40)
-
- if training_data:
- cnn_data = dashboard._format_data_for_cnn(training_data)
- if cnn_data and 'sequences' in cnn_data:
- print(f"✓ CNN data formatted successfully")
- print(f" - Sequences shape: {cnn_data['sequences'].shape}")
- print(f" - Targets shape: {cnn_data['targets'].shape}")
- print(f" - Sequence length: {cnn_data['sequence_length']}")
- else:
- print("❌ Failed to format CNN data")
-
- # Test 4: Format data for RL
- print("\n🤖 TEST 4: RL Data Formatting")
- print("-" * 40)
-
- if training_data:
- rl_experiences = dashboard._format_data_for_rl(training_data)
- if rl_experiences:
- print(f"✓ RL experiences formatted successfully")
- print(f" - Number of experiences: {len(rl_experiences)}")
- print(f" - Experience format: (state, action, reward, next_state, done)")
- print(f" - Sample experience shapes: {[len(exp) for exp in rl_experiences[:3]]}")
- else:
- print("❌ Failed to format RL experiences")
-
- # Test 5: Send training data to models
- print("\n📤 TEST 5: Sending Training Data to Models")
- print("-" * 40)
-
- success = dashboard.send_training_data_to_models()
- print(f"✓ Training data sent: {success}")
-
- if hasattr(dashboard, 'training_stats'):
- stats = dashboard.training_stats
- print(f" - Total training sessions: {stats.get('total_training_sessions', 0)}")
- print(f" - CNN training count: {stats.get('cnn_training_count', 0)}")
- print(f" - RL training count: {stats.get('rl_training_count', 0)}")
- print(f" - Training data points: {stats.get('training_data_points', 0)}")
-
- # Test 6: Training metrics display
- print("\n📈 TEST 6: Training Metrics Display")
- print("-" * 40)
-
- training_metrics = dashboard._create_training_metrics()
- print(f"✓ Training metrics created: {len(training_metrics)} components")
-
- # Test 7: Model training status
- print("\n🔍 TEST 7: Model Training Status")
- print("-" * 40)
-
- training_status = dashboard._get_model_training_status()
- print(f"✓ Training status retrieved")
- print(f" - CNN status: {training_status['cnn']['status']}")
- print(f" - CNN accuracy: {training_status['cnn']['accuracy']:.1%}")
- print(f" - RL status: {training_status['rl']['status']}")
- print(f" - RL win rate: {training_status['rl']['win_rate']:.1%}")
-
- # Test 8: Training events log
- print("\n📝 TEST 8: Training Events Log")
- print("-" * 40)
-
- training_events = dashboard._get_recent_training_events()
- print(f"✓ Training events retrieved: {len(training_events)} events")
-
- # Test 9: Mini training chart
- print("\n📊 TEST 9: Mini Training Chart")
- print("-" * 40)
-
- try:
- training_chart = dashboard._create_mini_training_chart(training_status)
- print(f"✓ Mini training chart created")
- print(f" - Chart type: {type(training_chart)}")
- except Exception as e:
- print(f"❌ Error creating training chart: {e}")
-
- # Test 10: Continuous training loop
- print("\n🔄 TEST 10: Continuous Training Loop")
- print("-" * 40)
-
- print(f"✓ Continuous training active: {getattr(dashboard, 'training_active', False)}")
- if hasattr(dashboard, 'training_thread'):
- print(f"✓ Training thread alive: {dashboard.training_thread.is_alive()}")
-
- # Test 11: Integration with existing continuous training system
- print("\n🔗 TEST 11: Integration with Continuous Training System")
- print("-" * 40)
-
- try:
- # Check if we can get tick cache for external training
- tick_cache = dashboard.get_tick_cache_for_training()
- print(f"✓ Tick cache accessible: {len(tick_cache)} ticks")
-
- # Check if we can get 1-second bars
- one_second_bars = dashboard.get_one_second_bars()
- print(f"✓ 1-second bars accessible: {len(one_second_bars)} bars")
-
- except Exception as e:
- print(f"❌ Error accessing training data: {e}")
-
- print("\n" + "="*60)
- print("TRAINING INTEGRATION TEST COMPLETED")
- print("="*60)
-
- # Summary
- print("\n📋 SUMMARY:")
- print(f"✓ Dashboard with training integration: WORKING")
- print(f"✓ Training data preparation: WORKING")
- print(f"✓ CNN data formatting: WORKING")
- print(f"✓ RL data formatting: WORKING")
- print(f"✓ Training metrics display: WORKING")
- print(f"✓ Continuous training: ACTIVE")
- print(f"✓ Model status tracking: WORKING")
- print(f"✓ Training events logging: WORKING")
-
- return True
-
- except Exception as e:
- logger.error(f"Training integration test failed: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-if __name__ == "__main__":
- success = test_training_integration()
- if success:
- print("\n🎉 All training integration tests passed!")
- else:
- print("\n❌ Some training integration tests failed!")
- sys.exit(1)
\ No newline at end of file
diff --git a/tests/test_training_status.py b/tests/test_training_status.py
deleted file mode 100644
index 49836d9..0000000
--- a/tests/test_training_status.py
+++ /dev/null
@@ -1,59 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script to check training status functionality
-"""
-
-import logging
-logging.basicConfig(level=logging.INFO)
-
-print("Testing training status functionality...")
-
-try:
- from web.old_archived.scalping_dashboard import create_scalping_dashboard
- from core.data_provider import DataProvider
- from core.enhanced_orchestrator import EnhancedTradingOrchestrator
-
- print("✅ Imports successful")
-
- # Create components
- data_provider = DataProvider()
- orchestrator = EnhancedTradingOrchestrator(data_provider)
- dashboard = create_scalping_dashboard(data_provider, orchestrator)
-
- print("✅ Dashboard created successfully")
-
- # Test training status
- training_status = dashboard._get_model_training_status()
- print("\n📊 Training Status:")
- print(f"CNN Status: {training_status['cnn']['status']}")
- print(f"CNN Accuracy: {training_status['cnn']['accuracy']:.1%}")
- print(f"CNN Loss: {training_status['cnn']['loss']:.4f}")
- print(f"CNN Epochs: {training_status['cnn']['epochs']}")
-
- print(f"RL Status: {training_status['rl']['status']}")
- print(f"RL Win Rate: {training_status['rl']['win_rate']:.1%}")
- print(f"RL Episodes: {training_status['rl']['episodes']}")
- print(f"RL Memory: {training_status['rl']['memory_size']}")
-
- # Test extrema stats
- if hasattr(orchestrator, 'get_extrema_stats'):
- extrema_stats = orchestrator.get_extrema_stats()
- print(f"\n🎯 Extrema Stats:")
- print(f"Total extrema detected: {extrema_stats.get('total_extrema_detected', 0)}")
- print(f"Training queue size: {extrema_stats.get('training_queue_size', 0)}")
- print("✅ Extrema stats available")
- else:
- print("❌ Extrema stats not available")
-
- # Test tick cache
- print(f"\n📈 Training Data:")
- print(f"Tick cache size: {len(dashboard.tick_cache)}")
- print(f"1s bars cache size: {len(dashboard.one_second_bars)}")
- print(f"Streaming status: {dashboard.is_streaming}")
-
- print("\n✅ All tests completed successfully!")
-
-except Exception as e:
- print(f"❌ Error: {e}")
- import traceback
- traceback.print_exc()
\ No newline at end of file
diff --git a/tests/test_universal_data_format.py b/tests/test_universal_data_format.py
deleted file mode 100644
index e30558a..0000000
--- a/tests/test_universal_data_format.py
+++ /dev/null
@@ -1,262 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Universal Data Format Compliance
-
-This script verifies that our enhanced trading system properly feeds
-the 5 required timeseries streams to all models:
-- ETH/USDT: ticks (1s), 1m, 1h, 1d
-- BTC/USDT: ticks (1s) as reference
-
-This is our universal trading system input format.
-"""
-
-import asyncio
-import logging
-import sys
-from pathlib import Path
-import numpy as np
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import get_config
-from core.data_provider import DataProvider
-from core.universal_data_adapter import UniversalDataAdapter, UniversalDataStream
-from core.enhanced_orchestrator import EnhancedTradingOrchestrator
-from training.enhanced_cnn_trainer import EnhancedCNNTrainer
-from training.enhanced_rl_trainer import EnhancedRLTrainer
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-async def test_universal_data_format():
- """Test that all components properly use the universal 5-timeseries format"""
- logger.info("="*80)
- logger.info("🧪 TESTING UNIVERSAL DATA FORMAT COMPLIANCE")
- logger.info("="*80)
-
- try:
- # Initialize components
- config = get_config()
- data_provider = DataProvider(config)
-
- # Test 1: Universal Data Adapter
- logger.info("\n📊 TEST 1: Universal Data Adapter")
- logger.info("-" * 40)
-
- adapter = UniversalDataAdapter(data_provider)
- universal_stream = adapter.get_universal_data_stream()
-
- if universal_stream is None:
- logger.error("❌ Failed to get universal data stream")
- return False
-
- # Validate format
- is_valid, issues = adapter.validate_universal_format(universal_stream)
- if not is_valid:
- logger.error(f"❌ Universal format validation failed: {issues}")
- return False
-
- logger.info("✅ Universal Data Adapter: PASSED")
- logger.info(f" ETH ticks: {len(universal_stream.eth_ticks)} samples")
- logger.info(f" ETH 1m: {len(universal_stream.eth_1m)} candles")
- logger.info(f" ETH 1h: {len(universal_stream.eth_1h)} candles")
- logger.info(f" ETH 1d: {len(universal_stream.eth_1d)} candles")
- logger.info(f" BTC reference: {len(universal_stream.btc_ticks)} samples")
- logger.info(f" Data quality: {universal_stream.metadata['data_quality']['overall_score']:.2f}")
-
- # Test 2: Enhanced Orchestrator
- logger.info("\n🎯 TEST 2: Enhanced Orchestrator")
- logger.info("-" * 40)
-
- orchestrator = EnhancedTradingOrchestrator(data_provider)
-
- # Test that orchestrator uses universal adapter
- if not hasattr(orchestrator, 'universal_adapter'):
- logger.error("❌ Orchestrator missing universal_adapter")
- return False
-
- # Test coordinated decisions
- decisions = await orchestrator.make_coordinated_decisions()
-
- logger.info("✅ Enhanced Orchestrator: PASSED")
- logger.info(f" Generated {len(decisions)} decisions")
- logger.info(f" Universal adapter: {type(orchestrator.universal_adapter).__name__}")
-
- for symbol, decision in decisions.items():
- if decision:
- logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.2f})")
-
- # Test 3: CNN Model Data Format
- logger.info("\n🧠 TEST 3: CNN Model Data Format")
- logger.info("-" * 40)
-
- # Format data for CNN
- cnn_data = adapter.format_for_model(universal_stream, 'cnn')
-
- required_cnn_keys = ['eth_ticks', 'eth_1m', 'eth_1h', 'eth_1d', 'btc_ticks']
- missing_keys = [key for key in required_cnn_keys if key not in cnn_data]
-
- if missing_keys:
- logger.error(f"❌ CNN data missing keys: {missing_keys}")
- return False
-
- logger.info("✅ CNN Model Data Format: PASSED")
- for key, data in cnn_data.items():
- if isinstance(data, np.ndarray):
- logger.info(f" {key}: shape {data.shape}")
- else:
- logger.info(f" {key}: {type(data)}")
-
- # Test 4: RL Model Data Format
- logger.info("\n🤖 TEST 4: RL Model Data Format")
- logger.info("-" * 40)
-
- # Format data for RL
- rl_data = adapter.format_for_model(universal_stream, 'rl')
-
- if 'state_vector' not in rl_data:
- logger.error("❌ RL data missing state_vector")
- return False
-
- state_vector = rl_data['state_vector']
- if not isinstance(state_vector, np.ndarray):
- logger.error("❌ RL state_vector is not numpy array")
- return False
-
- logger.info("✅ RL Model Data Format: PASSED")
- logger.info(f" State vector shape: {state_vector.shape}")
- logger.info(f" State vector size: {len(state_vector)} features")
-
- # Test 5: CNN Trainer Integration
- logger.info("\n🎓 TEST 5: CNN Trainer Integration")
- logger.info("-" * 40)
-
- try:
- cnn_trainer = EnhancedCNNTrainer(config, orchestrator)
- logger.info("✅ CNN Trainer Integration: PASSED")
- logger.info(f" Model timeframes: {cnn_trainer.model.timeframes}")
- logger.info(f" Model device: {cnn_trainer.model.device}")
- except Exception as e:
- logger.error(f"❌ CNN Trainer Integration failed: {e}")
- return False
-
- # Test 6: RL Trainer Integration
- logger.info("\n🎮 TEST 6: RL Trainer Integration")
- logger.info("-" * 40)
-
- try:
- rl_trainer = EnhancedRLTrainer(config, orchestrator)
- logger.info("✅ RL Trainer Integration: PASSED")
- logger.info(f" RL agents: {len(rl_trainer.agents)}")
- for symbol, agent in rl_trainer.agents.items():
- logger.info(f" {symbol} agent: {type(agent).__name__}")
- except Exception as e:
- logger.error(f"❌ RL Trainer Integration failed: {e}")
- return False
-
- # Test 7: Data Flow Verification
- logger.info("\n🔄 TEST 7: Data Flow Verification")
- logger.info("-" * 40)
-
- # Verify that models receive the correct data format
- test_predictions = await orchestrator._get_enhanced_predictions_universal(
- 'ETH/USDT',
- list(orchestrator.market_states['ETH/USDT'])[-1] if orchestrator.market_states['ETH/USDT'] else None,
- universal_stream
- )
-
- if test_predictions:
- logger.info("✅ Data Flow Verification: PASSED")
- for pred in test_predictions:
- logger.info(f" Model: {pred.model_name}")
- logger.info(f" Action: {pred.overall_action}")
- logger.info(f" Confidence: {pred.overall_confidence:.2f}")
- logger.info(f" Timeframes: {len(pred.timeframe_predictions)}")
- else:
- logger.warning("⚠️ No predictions generated (may be normal if no models loaded)")
-
- # Test 8: Configuration Compliance
- logger.info("\n⚙️ TEST 8: Configuration Compliance")
- logger.info("-" * 40)
-
- # Check that config matches universal format
- expected_symbols = ['ETH/USDT', 'BTC/USDT']
- expected_timeframes = ['1s', '1m', '1h', '1d']
-
- config_symbols = config.symbols
- config_timeframes = config.timeframes
-
- symbols_match = all(symbol in config_symbols for symbol in expected_symbols)
- timeframes_match = all(tf in config_timeframes for tf in expected_timeframes)
-
- if not symbols_match:
- logger.warning(f"⚠️ Config symbols may not match universal format")
- logger.warning(f" Expected: {expected_symbols}")
- logger.warning(f" Config: {config_symbols}")
-
- if not timeframes_match:
- logger.warning(f"⚠️ Config timeframes may not match universal format")
- logger.warning(f" Expected: {expected_timeframes}")
- logger.warning(f" Config: {config_timeframes}")
-
- if symbols_match and timeframes_match:
- logger.info("✅ Configuration Compliance: PASSED")
- else:
- logger.info("⚠️ Configuration Compliance: PARTIAL")
-
- logger.info(f" Symbols: {config_symbols}")
- logger.info(f" Timeframes: {config_timeframes}")
-
- # Final Summary
- logger.info("\n" + "="*80)
- logger.info("🎉 UNIVERSAL DATA FORMAT TEST SUMMARY")
- logger.info("="*80)
- logger.info("✅ All core tests PASSED!")
- logger.info("")
- logger.info("📋 VERIFIED COMPLIANCE:")
- logger.info(" ✓ Universal Data Adapter working")
- logger.info(" ✓ Enhanced Orchestrator using universal format")
- logger.info(" ✓ CNN models receive 5 timeseries streams")
- logger.info(" ✓ RL models receive combined state vector")
- logger.info(" ✓ Trainers properly integrated")
- logger.info(" ✓ Data flow verified")
- logger.info("")
- logger.info("🎯 UNIVERSAL FORMAT ACTIVE:")
- logger.info(" 1. ETH/USDT ticks (1s) ✓")
- logger.info(" 2. ETH/USDT 1m ✓")
- logger.info(" 3. ETH/USDT 1h ✓")
- logger.info(" 4. ETH/USDT 1d ✓")
- logger.info(" 5. BTC/USDT reference ticks ✓")
- logger.info("")
- logger.info("🚀 Your enhanced trading system is ready with universal data format!")
- logger.info("="*80)
-
- return True
-
- except Exception as e:
- logger.error(f"❌ Universal data format test failed: {e}")
- import traceback
- logger.error(traceback.format_exc())
- return False
-
-async def main():
- """Main test function"""
- logger.info("🚀 Starting Universal Data Format Compliance Test...")
-
- success = await test_universal_data_format()
-
- if success:
- logger.info("\n🎉 All tests passed! Universal data format is properly implemented.")
- logger.info("Your enhanced trading system respects the 5-timeseries input format.")
- else:
- logger.error("\n💥 Tests failed! Please check the universal data format implementation.")
- sys.exit(1)
-
-if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
diff --git a/tests/test_universal_stream_integration.py b/tests/test_universal_stream_integration.py
deleted file mode 100644
index 1689cf7..0000000
--- a/tests/test_universal_stream_integration.py
+++ /dev/null
@@ -1,177 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test Universal Data Stream Integration with Dashboard
-
-This script validates that:
-1. CleanTradingDashboard properly subscribes to UnifiedDataStream
-2. All 5 timeseries are properly received and processed
-3. Data flows correctly from provider -> adapter -> stream -> dashboard
-4. Consumer callback functions work as expected
-"""
-
-import asyncio
-import logging
-import sys
-import time
-from pathlib import Path
-
-# Add project root to path
-project_root = Path(__file__).parent
-sys.path.insert(0, str(project_root))
-
-from core.config import get_config
-from core.data_provider import DataProvider
-from core.enhanced_orchestrator import EnhancedTradingOrchestrator
-from core.trading_executor import TradingExecutor
-from web.clean_dashboard import CleanTradingDashboard
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-async def test_universal_stream_integration():
- """Test Universal Data Stream integration with dashboard"""
- logger.info("="*80)
- logger.info("🧪 TESTING UNIVERSAL DATA STREAM INTEGRATION")
- logger.info("="*80)
-
- try:
- # Initialize components
- logger.info("\n📦 STEP 1: Initialize Components")
- logger.info("-" * 40)
-
- config = get_config()
- data_provider = DataProvider()
- orchestrator = EnhancedTradingOrchestrator(
- data_provider=data_provider,
- symbols=['ETH/USDT', 'BTC/USDT'],
- enhanced_rl_training=True
- )
- trading_executor = TradingExecutor()
-
- logger.info("✅ Core components initialized")
-
- # Initialize dashboard with Universal Data Stream
- logger.info("\n📊 STEP 2: Initialize Dashboard with Universal Stream")
- logger.info("-" * 40)
-
- dashboard = CleanTradingDashboard(
- data_provider=data_provider,
- orchestrator=orchestrator,
- trading_executor=trading_executor
- )
-
- # Check Universal Stream initialization
- if hasattr(dashboard, 'unified_stream') and dashboard.unified_stream:
- logger.info("✅ Universal Data Stream initialized successfully")
- logger.info(f"📋 Consumer ID: {dashboard.stream_consumer_id}")
- else:
- logger.error("❌ Universal Data Stream not initialized")
- return False
-
- # Test consumer registration
- logger.info("\n🔗 STEP 3: Validate Consumer Registration")
- logger.info("-" * 40)
-
- stream_stats = dashboard.unified_stream.get_stream_stats()
- logger.info(f"📊 Stream Stats: {stream_stats}")
-
- if stream_stats['total_consumers'] > 0:
- logger.info(f"✅ {stream_stats['total_consumers']} consumers registered")
- else:
- logger.warning("⚠️ No consumers registered")
-
- # Test data callback
- logger.info("\n📡 STEP 4: Test Data Callback")
- logger.info("-" * 40)
-
- # Create test data packet
- test_data = {
- 'timestamp': time.time(),
- 'consumer_id': dashboard.stream_consumer_id,
- 'consumer_name': 'CleanTradingDashboard',
- 'ticks': [
- {'symbol': 'ETHUSDT', 'price': 3000.0, 'volume': 1.5, 'timestamp': time.time()},
- {'symbol': 'ETHUSDT', 'price': 3001.0, 'volume': 2.0, 'timestamp': time.time()},
- ],
- 'ohlcv': {'one_second_bars': [], 'multi_timeframe': {
- 'ETH/USDT': {
- '1s': [{'timestamp': time.time(), 'open': 3000, 'high': 3002, 'low': 2999, 'close': 3001, 'volume': 10}],
- '1m': [{'timestamp': time.time(), 'open': 2990, 'high': 3010, 'low': 2985, 'close': 3001, 'volume': 100}],
- '1h': [{'timestamp': time.time(), 'open': 2900, 'high': 3050, 'low': 2880, 'close': 3001, 'volume': 1000}],
- '1d': [{'timestamp': time.time(), 'open': 2800, 'high': 3200, 'low': 2750, 'close': 3001, 'volume': 10000}]
- },
- 'BTC/USDT': {
- '1s': [{'timestamp': time.time(), 'open': 65000, 'high': 65020, 'low': 64980, 'close': 65010, 'volume': 0.5}]
- }
- }},
- 'training_data': {'market_state': 'test', 'features': []},
- 'ui_data': {'formatted_data': 'test_ui_data'}
- }
-
- # Test callback manually
- try:
- dashboard._handle_unified_stream_data(test_data)
- logger.info("✅ Data callback executed successfully")
-
- # Check if data was processed
- if hasattr(dashboard, 'current_prices') and 'ETH/USDT' in dashboard.current_prices:
- logger.info(f"✅ Price updated: ETH/USDT = ${dashboard.current_prices['ETH/USDT']}")
- else:
- logger.warning("⚠️ Prices not updated in dashboard")
-
- except Exception as e:
- logger.error(f"❌ Data callback failed: {e}")
- return False
-
- # Test Universal Data Adapter
- logger.info("\n🔄 STEP 5: Test Universal Data Adapter")
- logger.info("-" * 40)
-
- if hasattr(orchestrator, 'universal_adapter'):
- universal_stream = orchestrator.universal_adapter.get_universal_data_stream()
- if universal_stream:
- logger.info("✅ Universal Data Adapter working")
- logger.info(f"📊 ETH ticks: {len(universal_stream.eth_ticks)} samples")
- logger.info(f"📊 ETH 1m: {len(universal_stream.eth_1m)} candles")
- logger.info(f"📊 ETH 1h: {len(universal_stream.eth_1h)} candles")
- logger.info(f"📊 ETH 1d: {len(universal_stream.eth_1d)} candles")
- logger.info(f"📊 BTC ticks: {len(universal_stream.btc_ticks)} samples")
-
- # Validate format
- is_valid, issues = orchestrator.universal_adapter.validate_universal_format(universal_stream)
- if is_valid:
- logger.info("✅ Universal format validation passed")
- else:
- logger.warning(f"⚠️ Format issues: {issues}")
- else:
- logger.error("❌ Universal Data Adapter failed to get stream")
- return False
- else:
- logger.error("❌ Universal Data Adapter not found in orchestrator")
- return False
-
- # Summary
- logger.info("\n🎯 SUMMARY")
- logger.info("-" * 40)
- logger.info("✅ Universal Data Stream properly integrated")
- logger.info("✅ Dashboard subscribes as consumer")
- logger.info("✅ All 5 timeseries format validated")
- logger.info("✅ Data callback processing works")
- logger.info("✅ Universal Data Adapter functional")
-
- logger.info("\n🏆 INTEGRATION TEST PASSED")
- return True
-
- except Exception as e:
- logger.error(f"❌ Integration test failed: {e}")
- import traceback
- traceback.print_exc()
- return False
-
-if __name__ == "__main__":
- success = asyncio.run(test_universal_stream_integration())
- sys.exit(0 if success else 1)
\ No newline at end of file
diff --git a/trading_main.py b/trading_main.py
deleted file mode 100644
index 9505c84..0000000
--- a/trading_main.py
+++ /dev/null
@@ -1,155 +0,0 @@
-import os
-import time
-import logging
-import sys
-import argparse
-import json
-
-# Add the NN directory to the Python path
-sys.path.append(os.path.abspath("NN"))
-
-from NN.main import load_model
-from NN.neural_network_orchestrator import NeuralNetworkOrchestrator
-from NN.realtime_data_interface import RealtimeDataInterface
-
-# Initialize logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler("trading_bot.log"),
- logging.StreamHandler()
- ]
-)
-logger = logging.getLogger(__name__)
-
-def main():
- """Main function for the trading bot."""
- # Parse command-line arguments
- parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration")
- parser.add_argument('--symbols', nargs='+', default=["BTC/USDT", "ETH/USDT"],
- help='Trading symbols to monitor')
- parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"],
- help='Timeframes to monitor')
- parser.add_argument('--window-size', type=int, default=20,
- help='Window size for model input')
- parser.add_argument('--output-size', type=int, default=3,
- help='Output size of the model (3 for BUY/HOLD/SELL)')
- parser.add_argument('--model-type', type=str, default="cnn", choices=["cnn", "lstm", "mlp"],
- help='Type of neural network model')
- parser.add_argument('--mode', type=str, default="realtime", choices=["realtime", "backtest"],
- help='Trading mode')
- parser.add_argument('--exchange', type=str, default="binance", choices=["binance", "mexc"],
- help='Exchange to use for trading')
- parser.add_argument('--api-key', type=str, default=None,
- help='API key for the exchange')
- parser.add_argument('--api-secret', type=str, default=None,
- help='API secret for the exchange')
- parser.add_argument('--test-mode', action='store_true',
- help='Use test/sandbox exchange environment')
- parser.add_argument('--position-size', type=float, default=0.1,
- help='Position size as a fraction of total balance (0.0-1.0)')
- parser.add_argument('--max-trades-per-day', type=int, default=5,
- help='Maximum number of trades per day')
- parser.add_argument('--trade-cooldown', type=int, default=60,
- help='Trade cooldown period in minutes')
- parser.add_argument('--config-file', type=str, default=None,
- help='Path to configuration file')
-
- args = parser.parse_args()
-
- # Load configuration from file if provided
- if args.config_file and os.path.exists(args.config_file):
- with open(args.config_file, 'r') as f:
- config = json.load(f)
- # Override config with command-line args
- for key, value in vars(args).items():
- if key != 'config_file' and value is not None:
- config[key] = value
- else:
- # Use command-line args as config
- config = vars(args)
-
- # Initialize real-time charts and data interfaces
- try:
- from dataprovider_realtime import RealTimeChart
-
- # Create a real-time chart for each symbol
- charts = {}
- for symbol in config['symbols']:
- charts[symbol] = RealTimeChart(symbol=symbol)
-
- main_chart = charts[config['symbols'][0]]
-
- # Create a data interface for retrieving market data
- data_interface = RealtimeDataInterface(symbols=config['symbols'], chart=main_chart)
-
- # Load trained model
- model_type = os.environ.get("NN_MODEL_TYPE", config['model_type'])
- model = load_model(
- model_type=model_type,
- input_shape=(config['window_size'], len(config['symbols']), 5), # 5 features (OHLCV)
- output_size=config['output_size']
- )
-
- # Configure trading agent
- exchange_config = {
- "exchange": config['exchange'],
- "api_key": config['api_key'],
- "api_secret": config['api_secret'],
- "test_mode": config['test_mode'],
- "trade_symbols": config['symbols'],
- "position_size": config['position_size'],
- "max_trades_per_day": config['max_trades_per_day'],
- "trade_cooldown_minutes": config['trade_cooldown']
- }
-
- # Initialize neural network orchestrator
- orchestrator = NeuralNetworkOrchestrator(
- model=model,
- data_interface=data_interface,
- chart=main_chart,
- symbols=config['symbols'],
- timeframes=config['timeframes'],
- window_size=config['window_size'],
- num_features=5, # OHLCV
- output_size=config['output_size'],
- exchange_config=exchange_config
- )
-
- # Start data collection
- logger.info("Starting data collection threads...")
- for symbol in config['symbols']:
- charts[symbol].start()
-
- # Start neural network inference
- if os.environ.get("ENABLE_NN_MODELS", "0") == "1":
- logger.info("Starting neural network inference...")
- orchestrator.start_inference()
- else:
- logger.info("Neural network models disabled. Set ENABLE_NN_MODELS=1 to enable.")
-
- # Start web servers for chart display
- logger.info("Starting web servers for chart display...")
- main_chart.start_server()
-
- logger.info("Trading bot initialized successfully. Press Ctrl+C to exit.")
-
- # Keep the main thread alive
- try:
- while True:
- time.sleep(1)
- except KeyboardInterrupt:
- logger.info("Keyboard interrupt received. Shutting down...")
- # Stop all threads
- for symbol in config['symbols']:
- charts[symbol].stop()
- orchestrator.stop_inference()
- logger.info("Trading bot stopped.")
-
- except Exception as e:
- logger.error(f"Error in main function: {str(e)}", exc_info=True)
- sys.exit(1)
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
deleted file mode 100644
index 71aab0c..0000000
--- a/utils/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""
-Utils package for the multi-modal trading system
-"""
\ No newline at end of file
diff --git a/utils/async_task_manager.py b/utils/async_task_manager.py
deleted file mode 100644
index 30ee711..0000000
--- a/utils/async_task_manager.py
+++ /dev/null
@@ -1,232 +0,0 @@
-"""
-Async Task Manager - Handles async tasks with comprehensive error handling
-Prevents silent failures in async operations
-"""
-
-import asyncio
-import logging
-import functools
-import traceback
-from typing import Any, Callable, Optional, Dict, List
-from datetime import datetime
-
-logger = logging.getLogger(__name__)
-
-class AsyncTaskManager:
- """Manage async tasks with error handling and monitoring"""
-
- def __init__(self):
- self.active_tasks: Dict[str, asyncio.Task] = {}
- self.completed_tasks: List[Dict[str, Any]] = []
- self.failed_tasks: List[Dict[str, Any]] = []
- self.max_history = 100
-
- def create_task_with_error_handling(self,
- coro: Any,
- name: str,
- error_callback: Optional[Callable] = None,
- success_callback: Optional[Callable] = None) -> asyncio.Task:
- """
- Create an async task with comprehensive error handling
-
- Args:
- coro: Coroutine to run
- name: Task name for identification
- error_callback: Called on error with (name, exception)
- success_callback: Called on success with (name, result)
- """
-
- async def wrapped_coro():
- """Wrapper coroutine with error handling"""
- start_time = datetime.now()
- try:
- logger.debug(f"Starting async task: {name}")
- result = await coro
-
- # Log success
- duration = (datetime.now() - start_time).total_seconds()
- logger.debug(f"Async task '{name}' completed successfully in {duration:.2f}s")
-
- # Store completion info
- completion_info = {
- 'name': name,
- 'status': 'completed',
- 'start_time': start_time,
- 'end_time': datetime.now(),
- 'duration': duration,
- 'result': str(result)[:200] if result else None # Truncate long results
- }
- self.completed_tasks.append(completion_info)
-
- # Trim history
- if len(self.completed_tasks) > self.max_history:
- self.completed_tasks.pop(0)
-
- # Call success callback
- if success_callback:
- try:
- success_callback(name, result)
- except Exception as cb_error:
- logger.error(f"Error in success callback for task '{name}': {cb_error}")
-
- return result
-
- except asyncio.CancelledError:
- logger.info(f"Async task '{name}' was cancelled")
- raise
-
- except Exception as e:
- # Log error with full traceback
- duration = (datetime.now() - start_time).total_seconds()
- error_msg = f"Async task '{name}' failed after {duration:.2f}s: {e}"
- logger.error(error_msg)
- logger.error(f"Task '{name}' traceback: {traceback.format_exc()}")
-
- # Store failure info
- failure_info = {
- 'name': name,
- 'status': 'failed',
- 'start_time': start_time,
- 'end_time': datetime.now(),
- 'duration': duration,
- 'error': str(e),
- 'traceback': traceback.format_exc()
- }
- self.failed_tasks.append(failure_info)
-
- # Trim history
- if len(self.failed_tasks) > self.max_history:
- self.failed_tasks.pop(0)
-
- # Call error callback
- if error_callback:
- try:
- error_callback(name, e)
- except Exception as cb_error:
- logger.error(f"Error in error callback for task '{name}': {cb_error}")
-
- # Don't re-raise to prevent task from crashing the event loop
- # Instead, return None to indicate failure
- return None
-
- finally:
- # Remove from active tasks
- if name in self.active_tasks:
- del self.active_tasks[name]
-
- # Create and store task
- task = asyncio.create_task(wrapped_coro(), name=name)
- self.active_tasks[name] = task
-
- return task
-
- def cancel_task(self, name: str) -> bool:
- """Cancel a specific task"""
- if name in self.active_tasks:
- task = self.active_tasks[name]
- if not task.done():
- task.cancel()
- logger.info(f"Cancelled async task: {name}")
- return True
- return False
-
- def cancel_all_tasks(self):
- """Cancel all active tasks"""
- for name, task in list(self.active_tasks.items()):
- if not task.done():
- task.cancel()
- logger.info(f"Cancelled async task: {name}")
-
- def get_task_status(self) -> Dict[str, Any]:
- """Get status of all tasks"""
- active_count = len(self.active_tasks)
- completed_count = len(self.completed_tasks)
- failed_count = len(self.failed_tasks)
-
- # Get recent failures
- recent_failures = self.failed_tasks[-5:] if self.failed_tasks else []
-
- return {
- 'active_tasks': active_count,
- 'completed_tasks': completed_count,
- 'failed_tasks': failed_count,
- 'active_task_names': list(self.active_tasks.keys()),
- 'recent_failures': [
- {
- 'name': f['name'],
- 'error': f['error'],
- 'duration': f['duration'],
- 'time': f['end_time'].strftime('%H:%M:%S')
- }
- for f in recent_failures
- ]
- }
-
- def get_failure_summary(self) -> Dict[str, Any]:
- """Get summary of task failures"""
- if not self.failed_tasks:
- return {'total_failures': 0, 'failure_patterns': {}}
-
- # Count failures by error type
- error_counts = {}
- for failure in self.failed_tasks:
- error_type = type(failure.get('error', 'Unknown')).__name__
- error_counts[error_type] = error_counts.get(error_type, 0) + 1
-
- # Recent failure rate
- recent_failures = [f for f in self.failed_tasks if
- (datetime.now() - f['end_time']).total_seconds() < 3600] # Last hour
-
- return {
- 'total_failures': len(self.failed_tasks),
- 'recent_failures_1h': len(recent_failures),
- 'failure_patterns': error_counts,
- 'most_common_error': max(error_counts.items(), key=lambda x: x[1])[0] if error_counts else None
- }
-
-# Global instance
-_task_manager = None
-
-def get_async_task_manager() -> AsyncTaskManager:
- """Get global async task manager instance"""
- global _task_manager
- if _task_manager is None:
- _task_manager = AsyncTaskManager()
- return _task_manager
-
-def create_safe_task(coro: Any,
- name: str,
- error_callback: Optional[Callable] = None,
- success_callback: Optional[Callable] = None) -> asyncio.Task:
- """
- Create a safe async task with error handling
-
- Args:
- coro: Coroutine to run
- name: Task name for identification
- error_callback: Called on error with (name, exception)
- success_callback: Called on success with (name, result)
- """
- manager = get_async_task_manager()
- return manager.create_task_with_error_handling(coro, name, error_callback, success_callback)
-
-def safe_async_wrapper(name: str,
- error_callback: Optional[Callable] = None,
- success_callback: Optional[Callable] = None):
- """
- Decorator for creating safe async functions
-
- Usage:
- @safe_async_wrapper("my_task")
- async def my_async_function():
- # Your async code here
- pass
- """
- def decorator(func):
- @functools.wraps(func)
- async def wrapper(*args, **kwargs):
- coro = func(*args, **kwargs)
- task = create_safe_task(coro, name, error_callback, success_callback)
- return await task
- return wrapper
- return decorator
\ No newline at end of file
diff --git a/utils/launch_tensorboard.py b/utils/launch_tensorboard.py
deleted file mode 100644
index 23e601f..0000000
--- a/utils/launch_tensorboard.py
+++ /dev/null
@@ -1,164 +0,0 @@
-#!/usr/bin/env python3
-"""
-TensorBoard Launcher with Automatic Port Management
-
-This script launches TensorBoard with automatic port fallback if the preferred port is in use.
-It also kills any stale debug instances that might be running.
-
-Usage:
- python launch_tensorboard.py --logdir=path/to/logs --preferred-port=6007 --port-range=6000-7000
-"""
-
-import os
-import sys
-import subprocess
-import argparse
-import logging
-from pathlib import Path
-
-# Add project root to path
-project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-if project_root not in sys.path:
- sys.path.append(project_root)
-
-from utils.port_manager import get_port_with_fallback, kill_stale_debug_instances
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger('tensorboard_launcher')
-
-def launch_tensorboard(logdir, port, host='localhost', open_browser=True):
- """
- Launch TensorBoard on the specified port
-
- Args:
- logdir (str): Path to log directory
- port (int): Port to use
- host (str): Host to bind to
- open_browser (bool): Whether to open browser automatically
-
- Returns:
- subprocess.Popen: Process object
- """
- cmd = [
- sys.executable, "-m", "tensorboard.main",
- f"--logdir={logdir}",
- f"--port={port}",
- f"--host={host}"
- ]
-
- # Add --load_fast=false to improve startup times
- cmd.append("--load_fast=false")
-
- # Control whether to open browser
- if not open_browser:
- cmd.append("--window_title=TensorBoard")
-
- logger.info(f"Launching TensorBoard: {' '.join(cmd)}")
-
- # Use subprocess.Popen to start TensorBoard without waiting for it to finish
- process = subprocess.Popen(
- cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- universal_newlines=True,
- bufsize=1
- )
-
- # Log the first few lines of output to confirm it's starting correctly
- line_count = 0
- for line in process.stdout:
- logger.info(f"TensorBoard: {line.strip()}")
- line_count += 1
-
- # Check if TensorBoard has started successfully
- if "TensorBoard" in line and "http://" in line:
- url = line.strip().split("http://")[1].split(" ")[0]
- logger.info(f"TensorBoard available at: http://{url}")
-
- # Only log the first few lines
- if line_count >= 10:
- break
-
- # Continue reading output in background to prevent pipe from filling
- def read_output():
- for line in process.stdout:
- pass
-
- import threading
- threading.Thread(target=read_output, daemon=True).start()
-
- return process
-
-def main():
- parser = argparse.ArgumentParser(description='Launch TensorBoard with automatic port management')
- parser.add_argument('--logdir', type=str, default='NN/models/saved/logs',
- help='Directory containing TensorBoard event files')
- parser.add_argument('--preferred-port', type=int, default=6007,
- help='Preferred port to use')
- parser.add_argument('--port-range', type=str, default='6000-7000',
- help='Port range to try if preferred port is unavailable (format: min-max)')
- parser.add_argument('--host', type=str, default='localhost',
- help='Host to bind to')
- parser.add_argument('--no-browser', action='store_true',
- help='Do not open browser automatically')
- parser.add_argument('--kill-stale', action='store_true',
- help='Kill stale debug instances before starting')
-
- args = parser.parse_args()
-
- # Parse port range
- try:
- min_port, max_port = map(int, args.port_range.split('-'))
- except ValueError:
- logger.error(f"Invalid port range format: {args.port_range}. Use format: min-max")
- return 1
-
- # Kill stale instances if requested
- if args.kill_stale:
- logger.info("Killing stale debug instances...")
- count, _ = kill_stale_debug_instances()
- logger.info(f"Killed {count} stale instances")
-
- # Get an available port
- try:
- port = get_port_with_fallback(args.preferred_port, min_port, max_port)
- logger.info(f"Using port {port} for TensorBoard")
- except RuntimeError as e:
- logger.error(str(e))
- return 1
-
- # Ensure log directory exists
- logdir = os.path.abspath(args.logdir)
- os.makedirs(logdir, exist_ok=True)
-
- # Launch TensorBoard
- process = launch_tensorboard(
- logdir=logdir,
- port=port,
- host=args.host,
- open_browser=not args.no_browser
- )
-
- # Wait for process to end (it shouldn't unless there's an error or user kills it)
- try:
- return_code = process.wait()
- if return_code != 0:
- logger.error(f"TensorBoard exited with code {return_code}")
- return return_code
- except KeyboardInterrupt:
- logger.info("Received keyboard interrupt, shutting down TensorBoard...")
- process.terminate()
- try:
- process.wait(timeout=5)
- except subprocess.TimeoutExpired:
- logger.warning("TensorBoard didn't terminate gracefully, forcing kill")
- process.kill()
-
- return 0
-
-if __name__ == "__main__":
- sys.exit(main())
\ No newline at end of file
diff --git a/utils/model_utils.py b/utils/model_utils.py
deleted file mode 100644
index df87455..0000000
--- a/utils/model_utils.py
+++ /dev/null
@@ -1,241 +0,0 @@
-#!/usr/bin/env python
-"""
-Model utilities for robust saving and loading of PyTorch models
-"""
-
-import os
-import logging
-import torch
-import shutil
-import gc
-import json
-from typing import Any, Dict, Optional, Union
-
-logger = logging.getLogger(__name__)
-
-def robust_save(model: Any, path: str, include_optimizer: bool = True) -> bool:
- """
- Robust model saving with multiple fallback approaches
-
- Args:
- model: The model object to save (should have policy_net, target_net, optimizer, epsilon attributes)
- path: Path to save the model
- include_optimizer: Whether to include optimizer state in the save
-
- Returns:
- bool: True if successful, False otherwise
- """
- # Create directory if it doesn't exist
- os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
-
- # Backup path in case the main save fails
- backup_path = f"{path}.backup"
-
- # Clean up GPU memory before saving
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- gc.collect()
-
- # Prepare checkpoint data
- checkpoint = {
- 'policy_net': model.policy_net.state_dict(),
- 'target_net': model.target_net.state_dict(),
- 'epsilon': getattr(model, 'epsilon', 0.0),
- 'state_size': getattr(model, 'state_size', None),
- 'action_size': getattr(model, 'action_size', None),
- 'hidden_size': getattr(model, 'hidden_size', None),
- }
-
- # Add optimizer state if requested and available
- if include_optimizer and hasattr(model, 'optimizer') and model.optimizer is not None:
- checkpoint['optimizer'] = model.optimizer.state_dict()
-
- # Attempt 1: Try with default settings in a separate file first
- try:
- logger.info(f"Saving model to {backup_path} (attempt 1)")
- torch.save(checkpoint, backup_path)
- logger.info(f"Successfully saved to {backup_path}")
-
- # If backup worked, copy to the actual path
- if os.path.exists(backup_path):
- shutil.copy(backup_path, path)
- logger.info(f"Copied backup to {path}")
- return True
- except Exception as e:
- logger.warning(f"First save attempt failed: {e}")
-
- # Attempt 2: Try with pickle protocol 2 (more compatible)
- try:
- logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
- torch.save(checkpoint, path, pickle_protocol=2)
- logger.info(f"Successfully saved to {path} with pickle_protocol=2")
- return True
- except Exception as e:
- logger.warning(f"Second save attempt failed: {e}")
-
- # Attempt 3: Try without optimizer state (which can be large and cause issues)
- try:
- logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
- checkpoint_no_opt = {k: v for k, v in checkpoint.items() if k != 'optimizer'}
- torch.save(checkpoint_no_opt, path)
- logger.info(f"Successfully saved to {path} without optimizer state")
- return True
- except Exception as e:
- logger.warning(f"Third save attempt failed: {e}")
-
- # Attempt 4: Try with torch.jit.save instead
- try:
- logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
- # Save policy network using jit
- scripted_policy = torch.jit.script(model.policy_net)
- torch.jit.save(scripted_policy, f"{path}.policy.jit")
-
- # Save target network using jit
- scripted_target = torch.jit.script(model.target_net)
- torch.jit.save(scripted_target, f"{path}.target.jit")
-
- # Save parameters separately as JSON
- params = {
- 'epsilon': float(getattr(model, 'epsilon', 0.0)),
- 'state_size': int(getattr(model, 'state_size', 0)),
- 'action_size': int(getattr(model, 'action_size', 0)),
- 'hidden_size': int(getattr(model, 'hidden_size', 0))
- }
- with open(f"{path}.params.json", "w") as f:
- json.dump(params, f)
-
- logger.info(f"Successfully saved model components with jit.save")
- return True
- except Exception as e:
- logger.error(f"All save attempts failed: {e}")
- return False
-
-def robust_load(model: Any, path: str, device: Optional[torch.device] = None) -> bool:
- """
- Robust model loading with fallback approaches
-
- Args:
- model: The model object to load into
- path: Path to load the model from
- device: Device to load the model on
-
- Returns:
- bool: True if successful, False otherwise
- """
- if device is None:
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- # Try regular PyTorch load first
- try:
- logger.info(f"Loading model from {path}")
- if os.path.exists(path):
- checkpoint = torch.load(path, map_location=device)
-
- # Load network states
- if 'policy_net' in checkpoint:
- model.policy_net.load_state_dict(checkpoint['policy_net'])
- if 'target_net' in checkpoint:
- model.target_net.load_state_dict(checkpoint['target_net'])
-
- # Load other attributes
- if 'epsilon' in checkpoint:
- model.epsilon = checkpoint['epsilon']
- if 'optimizer' in checkpoint and hasattr(model, 'optimizer') and model.optimizer is not None:
- try:
- model.optimizer.load_state_dict(checkpoint['optimizer'])
- except Exception as e:
- logger.warning(f"Failed to load optimizer state: {e}")
-
- logger.info("Successfully loaded model")
- return True
- except Exception as e:
- logger.warning(f"Regular load failed: {e}")
-
- # Try loading JIT saved components
- try:
- policy_path = f"{path}.policy.jit"
- target_path = f"{path}.target.jit"
- params_path = f"{path}.params.json"
-
- if all(os.path.exists(p) for p in [policy_path, target_path, params_path]):
- logger.info(f"Loading JIT model components")
-
- # Load JIT models (this is more complex and may need model reconstruction)
- # For now, just log that we found JIT files
- logger.info("Found JIT model files, but loading them requires special handling")
- with open(params_path, 'r') as f:
- params = json.load(f)
- logger.info(f"Model parameters: {params}")
-
- # Note: Actually loading JIT models would require recreating the model architecture
- # This is a placeholder for future implementation
- return False
- except Exception as e:
- logger.error(f"JIT load failed: {e}")
-
- logger.error(f"All load attempts failed for {path}")
- return False
-
-def get_model_info(path: str) -> Dict[str, Any]:
- """
- Get information about a saved model
-
- Args:
- path: Path to the model file
-
- Returns:
- dict: Model information
- """
- info = {
- 'exists': False,
- 'size_bytes': 0,
- 'has_optimizer': False,
- 'parameters': {}
- }
-
- try:
- if os.path.exists(path):
- info['exists'] = True
- info['size_bytes'] = os.path.getsize(path)
-
- # Try to load and inspect
- checkpoint = torch.load(path, map_location='cpu')
- info['has_optimizer'] = 'optimizer' in checkpoint
-
- # Extract parameter info
- for key in ['epsilon', 'state_size', 'action_size', 'hidden_size']:
- if key in checkpoint:
- info['parameters'][key] = checkpoint[key]
- except Exception as e:
- logger.warning(f"Failed to get model info for {path}: {e}")
-
- return info
-
-def verify_save_load_cycle(model: Any, test_path: str) -> bool:
- """
- Test that a model can be saved and loaded correctly
-
- Args:
- model: Model to test
- test_path: Path for test file
-
- Returns:
- bool: True if save/load cycle successful
- """
- try:
- # Save the model
- if not robust_save(model, test_path):
- return False
-
- # Create a new model instance (this would need model creation logic)
- # For now, just verify the file exists and has content
- if os.path.exists(test_path) and os.path.getsize(test_path) > 0:
- logger.info("Save/load cycle verification successful")
- # Clean up test file
- os.remove(test_path)
- return True
- else:
- return False
- except Exception as e:
- logger.error(f"Save/load cycle verification failed: {e}")
- return False
\ No newline at end of file
diff --git a/utils/port_manager.py b/utils/port_manager.py
deleted file mode 100644
index e33c91c..0000000
--- a/utils/port_manager.py
+++ /dev/null
@@ -1,238 +0,0 @@
-#!/usr/bin/env python3
-"""
-Port Management Utility
-
-This script provides utilities to:
-1. Find available ports in a specified range
-2. Kill stale processes running on specific ports
-3. Kill all debug/training instances
-
-Usage:
- - As a module: import port_manager and use its functions
- - Directly: python port_manager.py --kill-stale --min-port 6000 --max-port 7000
-"""
-
-import os
-import sys
-import socket
-import argparse
-import psutil
-import logging
-import time
-import signal
-from typing import List, Tuple, Optional, Set
-
-# Configure logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger('port_manager')
-
-# Define process names to look for when killing stale instances
-DEBUG_PROCESS_KEYWORDS = [
- 'tensorboard',
- 'python train_',
- 'realtime.py',
- 'train_rl_with_realtime.py'
-]
-
-def is_port_in_use(port: int) -> bool:
- """
- Check if a port is in use
-
- Args:
- port (int): Port number to check
-
- Returns:
- bool: True if port is in use, False otherwise
- """
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- return s.connect_ex(('localhost', port)) == 0
-
-def find_available_port(start_port: int, end_port: int) -> Optional[int]:
- """
- Find an available port in the specified range
-
- Args:
- start_port (int): Lower bound of port range
- end_port (int): Upper bound of port range
-
- Returns:
- Optional[int]: Available port number or None if no ports available
- """
- for port in range(start_port, end_port + 1):
- if not is_port_in_use(port):
- return port
- return None
-
-def get_process_by_port(port: int) -> List[psutil.Process]:
- """
- Get processes using a specific port
-
- Args:
- port (int): Port number to check
-
- Returns:
- List[psutil.Process]: List of processes using the port
- """
- processes = []
- for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
- try:
- for conn in proc.connections(kind='inet'):
- if conn.laddr.port == port:
- processes.append(proc)
- except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
- pass
- return processes
-
-def kill_process_by_port(port: int) -> Tuple[int, List[str]]:
- """
- Kill processes using a specific port
-
- Args:
- port (int): Port number to check
-
- Returns:
- Tuple[int, List[str]]: Count of killed processes and their names
- """
- processes = get_process_by_port(port)
- killed = []
-
- for proc in processes:
- try:
- proc_name = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
- logger.info(f"Terminating process {proc.pid}: {proc_name}")
- proc.terminate()
- killed.append(proc_name)
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- pass
-
- # Give processes time to terminate gracefully
- if processes:
- time.sleep(0.5)
-
- # Force kill any remaining processes
- for proc in processes:
- try:
- if proc.is_running():
- logger.info(f"Force killing process {proc.pid}")
- proc.kill()
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- pass
-
- return len(killed), killed
-
-def kill_stale_debug_instances() -> Tuple[int, Set[str]]:
- """
- Kill all stale debug and training instances based on process names
-
- Returns:
- Tuple[int, Set[str]]: Count of killed processes and their names
- """
- killed_count = 0
- killed_procs = set()
-
- for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
- try:
- cmd = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
-
- # Check if this is a debug/training process we should kill
- if any(keyword in cmd for keyword in DEBUG_PROCESS_KEYWORDS):
- logger.info(f"Terminating stale process {proc.pid}: {cmd}")
- proc.terminate()
- killed_count += 1
- killed_procs.add(cmd)
- except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
- pass
-
- # Give processes time to terminate
- if killed_count > 0:
- time.sleep(1)
-
- # Force kill any remaining processes
- for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
- try:
- cmd = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
-
- if any(keyword in cmd for keyword in DEBUG_PROCESS_KEYWORDS) and proc.is_running():
- logger.info(f"Force killing stale process {proc.pid}")
- proc.kill()
- except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
- pass
-
- return killed_count, killed_procs
-
-def get_port_with_fallback(preferred_port: int, min_port: int, max_port: int) -> int:
- """
- Try to use preferred port, fall back to any available port in range
-
- Args:
- preferred_port (int): Preferred port to use
- min_port (int): Minimum port in fallback range
- max_port (int): Maximum port in fallback range
-
- Returns:
- int: Available port number
- """
- # First try the preferred port
- if not is_port_in_use(preferred_port):
- return preferred_port
-
- # If preferred port is in use, try to free it
- logger.info(f"Preferred port {preferred_port} is in use, attempting to free it")
- kill_count, _ = kill_process_by_port(preferred_port)
-
- if kill_count > 0 and not is_port_in_use(preferred_port):
- logger.info(f"Successfully freed port {preferred_port}")
- return preferred_port
-
- # If we couldn't free the preferred port, find another available port
- logger.info(f"Looking for available port in range {min_port}-{max_port}")
- available_port = find_available_port(min_port, max_port)
-
- if available_port:
- logger.info(f"Using alternative port: {available_port}")
- return available_port
- else:
- # If no ports are available, force kill processes in the entire range
- logger.warning(f"No available ports in range {min_port}-{max_port}, freeing ports")
- for port in range(min_port, max_port + 1):
- kill_process_by_port(port)
-
- # Try again
- available_port = find_available_port(min_port, max_port)
- if available_port:
- logger.info(f"Using port {available_port} after freeing")
- return available_port
- else:
- logger.error(f"Could not find available port even after freeing range {min_port}-{max_port}")
- raise RuntimeError(f"No available ports in range {min_port}-{max_port}")
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Port management utility')
- parser.add_argument('--kill-stale', action='store_true', help='Kill all stale debug instances')
- parser.add_argument('--free-port', type=int, help='Free a specific port')
- parser.add_argument('--find-port', action='store_true', help='Find an available port')
- parser.add_argument('--min-port', type=int, default=6000, help='Minimum port in range')
- parser.add_argument('--max-port', type=int, default=7000, help='Maximum port in range')
- parser.add_argument('--preferred-port', type=int, help='Preferred port to use')
-
- args = parser.parse_args()
-
- if args.kill_stale:
- count, procs = kill_stale_debug_instances()
- logger.info(f"Killed {count} stale processes")
- for proc in procs:
- logger.info(f" - {proc}")
-
- if args.free_port:
- count, killed = kill_process_by_port(args.free_port)
- logger.info(f"Killed {count} processes using port {args.free_port}")
- for proc in killed:
- logger.info(f" - {proc}")
-
- if args.find_port or args.preferred_port:
- preferred = args.preferred_port if args.preferred_port else args.min_port
- port = get_port_with_fallback(preferred, args.min_port, args.max_port)
- print(port) # Print only the port number for easy capture in scripts
\ No newline at end of file
diff --git a/utils/process_supervisor.py b/utils/process_supervisor.py
deleted file mode 100644
index 1ecdeb0..0000000
--- a/utils/process_supervisor.py
+++ /dev/null
@@ -1,340 +0,0 @@
-"""
-Process Supervisor - Handles process monitoring, restarts, and supervision
-Prevents silent failures by monitoring process health and restarting on crashes
-"""
-
-import subprocess
-import threading
-import time
-import logging
-import signal
-import os
-import sys
-from typing import Dict, Any, Optional, Callable, List
-from datetime import datetime, timedelta
-from pathlib import Path
-
-logger = logging.getLogger(__name__)
-
-class ProcessSupervisor:
- """Supervise processes and restart them on failure"""
-
- def __init__(self, max_restarts: int = 5, restart_delay: int = 10):
- """
- Initialize process supervisor
-
- Args:
- max_restarts: Maximum number of restarts before giving up
- restart_delay: Delay in seconds between restarts
- """
- self.max_restarts = max_restarts
- self.restart_delay = restart_delay
-
- self.processes: Dict[str, Dict[str, Any]] = {}
- self.monitoring = False
- self.monitor_thread = None
-
- # Callbacks
- self.process_started_callback: Optional[Callable] = None
- self.process_failed_callback: Optional[Callable] = None
- self.process_restarted_callback: Optional[Callable] = None
-
- def add_process(self, name: str, command: List[str],
- working_dir: Optional[str] = None,
- env: Optional[Dict[str, str]] = None,
- auto_restart: bool = True):
- """
- Add a process to supervise
-
- Args:
- name: Process name
- command: Command to run as list
- working_dir: Working directory
- env: Environment variables
- auto_restart: Whether to auto-restart on failure
- """
- self.processes[name] = {
- 'command': command,
- 'working_dir': working_dir,
- 'env': env,
- 'auto_restart': auto_restart,
- 'process': None,
- 'restart_count': 0,
- 'last_start': None,
- 'last_failure': None,
- 'status': 'stopped'
- }
- logger.info(f"Added process '{name}' to supervisor")
-
- def start_process(self, name: str) -> bool:
- """Start a specific process"""
- if name not in self.processes:
- logger.error(f"Process '{name}' not found")
- return False
-
- proc_info = self.processes[name]
-
- if proc_info['process'] and proc_info['process'].poll() is None:
- logger.warning(f"Process '{name}' is already running")
- return True
-
- try:
- # Prepare environment
- env = os.environ.copy()
- if proc_info['env']:
- env.update(proc_info['env'])
-
- # Start process
- process = subprocess.Popen(
- proc_info['command'],
- cwd=proc_info['working_dir'],
- env=env,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- text=True
- )
-
- proc_info['process'] = process
- proc_info['last_start'] = datetime.now()
- proc_info['status'] = 'running'
-
- logger.info(f"Started process '{name}' (PID: {process.pid})")
-
- if self.process_started_callback:
- try:
- self.process_started_callback(name, process.pid)
- except Exception as e:
- logger.error(f"Error in process started callback: {e}")
-
- return True
-
- except Exception as e:
- logger.error(f"Failed to start process '{name}': {e}")
- proc_info['status'] = 'failed'
- proc_info['last_failure'] = datetime.now()
- return False
-
- def stop_process(self, name: str, timeout: int = 10) -> bool:
- """Stop a specific process"""
- if name not in self.processes:
- logger.error(f"Process '{name}' not found")
- return False
-
- proc_info = self.processes[name]
- process = proc_info['process']
-
- if not process or process.poll() is not None:
- logger.info(f"Process '{name}' is not running")
- proc_info['status'] = 'stopped'
- return True
-
- try:
- # Try graceful shutdown first
- process.terminate()
-
- # Wait for graceful shutdown
- try:
- process.wait(timeout=timeout)
- logger.info(f"Process '{name}' terminated gracefully")
- except subprocess.TimeoutExpired:
- # Force kill if graceful shutdown fails
- logger.warning(f"Process '{name}' did not terminate gracefully, force killing")
- process.kill()
- process.wait()
- logger.info(f"Process '{name}' force killed")
-
- proc_info['status'] = 'stopped'
- return True
-
- except Exception as e:
- logger.error(f"Error stopping process '{name}': {e}")
- return False
-
- def restart_process(self, name: str) -> bool:
- """Restart a specific process"""
- logger.info(f"Restarting process '{name}'")
-
- if name not in self.processes:
- logger.error(f"Process '{name}' not found")
- return False
-
- proc_info = self.processes[name]
-
- # Stop if running
- if proc_info['process'] and proc_info['process'].poll() is None:
- self.stop_process(name)
-
- # Wait restart delay
- time.sleep(self.restart_delay)
-
- # Increment restart count
- proc_info['restart_count'] += 1
-
- # Check restart limit
- if proc_info['restart_count'] > self.max_restarts:
- logger.error(f"Process '{name}' exceeded max restarts ({self.max_restarts})")
- proc_info['status'] = 'failed_max_restarts'
- return False
-
- # Start process
- success = self.start_process(name)
-
- if success and self.process_restarted_callback:
- try:
- self.process_restarted_callback(name, proc_info['restart_count'])
- except Exception as e:
- logger.error(f"Error in process restarted callback: {e}")
-
- return success
-
- def start_monitoring(self):
- """Start process monitoring"""
- if self.monitoring:
- logger.warning("Process monitoring already started")
- return
-
- self.monitoring = True
- self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
- self.monitor_thread.start()
- logger.info("Process monitoring started")
-
- def stop_monitoring(self):
- """Stop process monitoring"""
- self.monitoring = False
- if self.monitor_thread:
- self.monitor_thread.join(timeout=5)
- logger.info("Process monitoring stopped")
-
- def _monitor_loop(self):
- """Main monitoring loop"""
- logger.info("Process monitoring loop started")
-
- while self.monitoring:
- try:
- for name, proc_info in self.processes.items():
- self._check_process_health(name, proc_info)
-
- time.sleep(5) # Check every 5 seconds
-
- except Exception as e:
- logger.error(f"Error in process monitoring loop: {e}")
- time.sleep(5)
-
- logger.info("Process monitoring loop stopped")
-
- def _check_process_health(self, name: str, proc_info: Dict[str, Any]):
- """Check health of a specific process"""
- process = proc_info['process']
-
- if not process:
- return
-
- # Check if process is still running
- return_code = process.poll()
-
- if return_code is not None:
- # Process has exited
- proc_info['status'] = 'exited'
- proc_info['last_failure'] = datetime.now()
-
- logger.warning(f"Process '{name}' exited with code {return_code}")
-
- # Read stdout/stderr for debugging
- try:
- stdout, stderr = process.communicate(timeout=1)
- if stdout:
- logger.info(f"Process '{name}' stdout: {stdout[-500:]}") # Last 500 chars
- if stderr:
- logger.error(f"Process '{name}' stderr: {stderr[-500:]}") # Last 500 chars
- except Exception as e:
- logger.warning(f"Could not read process output: {e}")
-
- if self.process_failed_callback:
- try:
- self.process_failed_callback(name, return_code)
- except Exception as e:
- logger.error(f"Error in process failed callback: {e}")
-
- # Auto-restart if enabled
- if proc_info['auto_restart'] and proc_info['restart_count'] < self.max_restarts:
- logger.info(f"Auto-restarting process '{name}'")
- threading.Thread(target=self.restart_process, args=(name,), daemon=True).start()
-
- def get_process_status(self, name: str) -> Optional[Dict[str, Any]]:
- """Get status of a specific process"""
- if name not in self.processes:
- return None
-
- proc_info = self.processes[name]
- process = proc_info['process']
-
- status = {
- 'name': name,
- 'status': proc_info['status'],
- 'restart_count': proc_info['restart_count'],
- 'last_start': proc_info['last_start'],
- 'last_failure': proc_info['last_failure'],
- 'auto_restart': proc_info['auto_restart'],
- 'pid': process.pid if process and process.poll() is None else None,
- 'running': process is not None and process.poll() is None
- }
-
- return status
-
- def get_all_status(self) -> Dict[str, Dict[str, Any]]:
- """Get status of all processes"""
- return {name: self.get_process_status(name) for name in self.processes}
-
- def set_callbacks(self,
- process_started: Optional[Callable] = None,
- process_failed: Optional[Callable] = None,
- process_restarted: Optional[Callable] = None):
- """Set callback functions for process events"""
- self.process_started_callback = process_started
- self.process_failed_callback = process_failed
- self.process_restarted_callback = process_restarted
-
- def shutdown_all(self):
- """Shutdown all processes"""
- logger.info("Shutting down all supervised processes")
-
- for name in list(self.processes.keys()):
- self.stop_process(name)
-
- self.stop_monitoring()
-
-# Global instance
-_process_supervisor = None
-
-def get_process_supervisor() -> ProcessSupervisor:
- """Get global process supervisor instance"""
- global _process_supervisor
- if _process_supervisor is None:
- _process_supervisor = ProcessSupervisor()
- return _process_supervisor
-
-def create_supervised_dashboard_runner():
- """Create a supervised version of the dashboard runner"""
- supervisor = get_process_supervisor()
-
- # Add dashboard process
- supervisor.add_process(
- name="clean_dashboard",
- command=[sys.executable, "run_clean_dashboard.py"],
- working_dir=os.getcwd(),
- auto_restart=True
- )
-
- # Set up callbacks
- def on_process_failed(name: str, return_code: int):
- logger.error(f"Dashboard process failed with code {return_code}")
-
- def on_process_restarted(name: str, restart_count: int):
- logger.info(f"Dashboard restarted (attempt {restart_count})")
-
- supervisor.set_callbacks(
- process_failed=on_process_failed,
- process_restarted=on_process_restarted
- )
-
- return supervisor
\ No newline at end of file
diff --git a/utils/reward_calculator.py b/utils/reward_calculator.py
deleted file mode 100644
index fb25bd3..0000000
--- a/utils/reward_calculator.py
+++ /dev/null
@@ -1,220 +0,0 @@
-"""
-Improved Reward Function for RL Trading Agent
-
-This module provides a more sophisticated reward function for the RL trading agent
-that incorporates realistic trading fees, penalties for excessive trading, and
-rewards for successful holding of positions.
-"""
-
-import numpy as np
-from datetime import datetime, timedelta
-from collections import deque
-import logging
-
-logger = logging.getLogger(__name__)
-
-class RewardCalculator:
- def __init__(self, base_fee_rate=0.001, reward_scaling=10.0, risk_aversion=0.1):
- self.base_fee_rate = base_fee_rate
- self.reward_scaling = reward_scaling
- self.risk_aversion = risk_aversion
- self.trade_pnls = []
- self.returns = []
- self.trade_timestamps = []
- self.frequency_threshold = 10 # Trades per minute threshold for penalty
- self.max_frequency_penalty = 0.05
-
- def record_pnl(self, pnl):
- """Record P&L for risk adjustment calculations"""
- self.trade_pnls.append(pnl)
- if len(self.trade_pnls) > 100:
- self.trade_pnls.pop(0)
-
- def record_trade(self, action):
- """Record trade action for frequency penalty calculations"""
- from time import time
- self.trade_timestamps.append(time())
- if len(self.trade_timestamps) > 100:
- self.trade_timestamps.pop(0)
-
- def _calculate_frequency_penalty(self):
- """Calculate penalty for high-frequency trading"""
- if len(self.trade_timestamps) < 2:
- return 0.0
- time_span = self.trade_timestamps[-1] - self.trade_timestamps[0]
- if time_span <= 0:
- return 0.0
- trades_per_minute = (len(self.trade_timestamps) / time_span) * 60
- if trades_per_minute > self.frequency_threshold:
- penalty = min(self.max_frequency_penalty, (trades_per_minute - self.frequency_threshold) * 0.001)
- return penalty
- return 0.0
-
- def _calculate_risk_adjustment(self, reward):
- """Adjust rewards based on risk (simple Sharpe ratio implementation)"""
- if len(self.trade_pnls) < 5:
- return reward
- pnl_array = np.array(self.trade_pnls)
- mean_return = np.mean(pnl_array)
- std_return = np.std(pnl_array)
- if std_return == 0:
- return reward
- sharpe = mean_return / std_return
- adjustment_factor = np.clip(1.0 + 0.5 * sharpe, 0.5, 2.0)
- return reward * adjustment_factor
-
- def _calculate_holding_reward(self, position_held_time, price_change):
- """Calculate reward for holding a position"""
- base_holding_reward = 0.0005 * (position_held_time / 60.0)
- if price_change > 0:
- return base_holding_reward * 2
- elif price_change < 0:
- return base_holding_reward * 0.5
- return base_holding_reward
-
- def calculate_basic_reward(self, pnl, confidence):
- """Calculate basic training reward based on P&L and confidence"""
- try:
- base_reward = pnl
- if pnl < 0 and confidence > 0.7:
- confidence_adjustment = -confidence * 2
- elif pnl > 0 and confidence > 0.7:
- confidence_adjustment = confidence * 1.5
- else:
- confidence_adjustment = 0
- final_reward = base_reward + confidence_adjustment
- normalized_reward = np.tanh(final_reward / 10.0)
- logger.debug(f"Basic reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
- return float(normalized_reward)
- except Exception as e:
- logger.error(f"Error calculating basic reward: {e}")
- return 0.0
-
- def calculate_enhanced_reward(self, action, price_change, position_held_time=0, volatility=None, is_profitable=False, confidence=0.0, predicted_change=0.0, actual_change=0.0, current_pnl=0.0, symbol='UNKNOWN'):
- """Calculate enhanced reward for trading actions with shifted neutral point
-
- Neutral reward is shifted to require profits that exceed double the fees,
- which penalizes small profit trades and encourages holding for larger moves.
- Current PnL is given more weight in the decision-making process.
- """
- fee = self.base_fee_rate
- double_fee = fee * 4 # Double the fees (2x open + 2x close = 4x base fee)
- frequency_penalty = self._calculate_frequency_penalty()
-
- if action == 0: # Buy
- # Penalize buying more when already in profit
- reward = -fee - frequency_penalty
- if current_pnl > 0:
- # Reduce incentive to close profitable positions
- reward -= current_pnl * 0.2
- elif action == 1: # Sell
- profit_pct = price_change
-
- # Shift neutral point - require profit > double fees to be considered positive
- net_profit = profit_pct - double_fee
-
- # Scale reward based on profit size
- if net_profit > 0:
- # Exponential reward for larger profits
- reward = (net_profit ** 1.5) * self.reward_scaling
- else:
- # Linear penalty for losses
- reward = net_profit * self.reward_scaling
-
- reward -= frequency_penalty
- self.record_pnl(net_profit)
-
- # Add extra penalty for very small profits (less than 3x fees)
- if 0 < profit_pct < (fee * 6):
- reward -= 0.5 # Discourage tiny profit-taking
- else: # Hold
- if is_profitable:
- # Increase reward for holding profitable positions
- profit_factor = min(5.0, current_pnl * 20) # Cap at 5x
- reward = self._calculate_holding_reward(position_held_time, price_change) * (1.0 + profit_factor)
-
- # Add bonus for holding through volatility when profitable
- if volatility is not None and volatility > 0.001:
- reward += 0.1 * volatility * 100
- else:
- # Small penalty for holding losing positions
- loss_factor = min(1.0, abs(current_pnl) * 10)
- reward = -0.0001 * (1.0 + loss_factor)
-
- # But reduce penalty for very recent positions (give them time)
- if position_held_time < 30: # Less than 30 seconds
- reward *= 0.5
-
- # Prediction accuracy reward component
- if action in [0, 1] and predicted_change != 0:
- if (action == 0 and actual_change > 0) or (action == 1 and actual_change < 0):
- reward += abs(actual_change) * 5.0
- else:
- reward -= abs(predicted_change) * 2.0
-
- # Increase weight of current PnL in decision making (3x more than before)
- reward += current_pnl * 0.3
-
- # Volatility penalty
- if volatility is not None:
- reward -= abs(volatility) * 100
-
- # Risk adjustment
- if self.risk_aversion > 0 and len(self.returns) > 1:
- returns_std = np.std(self.returns)
- reward -= returns_std * self.risk_aversion
-
- self.record_trade(action)
- return reward
-
- def calculate_prediction_reward(self, symbol, predicted_direction, actual_direction, confidence, predicted_change, actual_change, current_pnl=0.0, position_duration=0.0):
- """Calculate reward for prediction accuracy"""
- reward = 0.0
- if predicted_direction == actual_direction:
- reward += 1.0 * confidence
- else:
- reward -= 0.5
- if predicted_direction == actual_direction and abs(predicted_change) > 0.001:
- reward += abs(actual_change) * 5.0
- if predicted_direction != actual_direction and abs(predicted_change) > 0.001:
- reward -= abs(predicted_change) * 2.0
- reward += current_pnl * 0.1
- # Dynamic adjustment based on recent PnL (loss cutting incentive)
- if hasattr(self, 'pnl_history') and symbol in self.pnl_history and self.pnl_history[symbol]:
- latest_pnl_entry = self.pnl_history[symbol][-1]
- latest_pnl_value = latest_pnl_entry.get('pnl', 0.0) if isinstance(latest_pnl_entry, dict) else 0.0
- if latest_pnl_value < 0 and position_duration > 60:
- reward -= (abs(latest_pnl_value) * 0.2)
- pnl_values = [entry.get('pnl', 0.0) for entry in self.pnl_history[symbol] if isinstance(entry, dict)]
- best_pnl = max(pnl_values) if pnl_values else 0.0
- if best_pnl < 0.0:
- reward -= 0.1
- return reward
-
-# Example usage:
-if __name__ == "__main__":
- # Create calculator instance
- reward_calc = RewardCalculator()
-
- # Example reward for a buy action
- buy_reward = reward_calc.calculate_enhanced_reward(action=0, price_change=0)
- print(f"Buy action reward: {buy_reward:.5f}")
-
- # Record a trade for frequency tracking
- reward_calc.record_trade(0)
-
- # Wait a bit and make another trade to test frequency penalty
- import time
- time.sleep(0.1)
-
- # Example reward for a sell action with profit
- sell_reward = reward_calc.calculate_enhanced_reward(action=1, price_change=0.015, position_held_time=60)
- print(f"Sell action reward (with profit): {sell_reward:.5f}")
-
- # Example reward for a hold action on profitable position
- hold_reward = reward_calc.calculate_enhanced_reward(action=2, price_change=0.01, position_held_time=30, is_profitable=True)
- print(f"Hold action reward (profitable): {hold_reward:.5f}")
-
- # Example reward for a hold action on unprofitable position
- hold_reward_neg = reward_calc.calculate_enhanced_reward(action=2, price_change=-0.01, position_held_time=30, is_profitable=False)
- print(f"Hold action reward (unprofitable): {hold_reward_neg:.5f}")
\ No newline at end of file
diff --git a/utils/system_monitor.py b/utils/system_monitor.py
deleted file mode 100644
index 4beaaf1..0000000
--- a/utils/system_monitor.py
+++ /dev/null
@@ -1,288 +0,0 @@
-"""
-System Resource Monitor - Prevents resource exhaustion and silent failures
-Monitors memory, CPU, and disk usage to prevent system crashes
-"""
-
-import psutil
-import logging
-import threading
-import time
-import gc
-import os
-from typing import Dict, Any, Optional, Callable
-from datetime import datetime, timedelta
-
-logger = logging.getLogger(__name__)
-
-class SystemResourceMonitor:
- """Monitor system resources and prevent exhaustion"""
-
- def __init__(self,
- memory_threshold_mb: int = 7000, # 7GB threshold for 8GB system
- cpu_threshold_percent: float = 90.0,
- disk_threshold_percent: float = 95.0,
- check_interval_seconds: int = 30):
- """
- Initialize system resource monitor
-
- Args:
- memory_threshold_mb: Memory threshold in MB before cleanup
- cpu_threshold_percent: CPU threshold percentage before warning
- disk_threshold_percent: Disk usage threshold before warning
- check_interval_seconds: How often to check resources
- """
- self.memory_threshold_mb = memory_threshold_mb
- self.cpu_threshold_percent = cpu_threshold_percent
- self.disk_threshold_percent = disk_threshold_percent
- self.check_interval = check_interval_seconds
-
- self.monitoring = False
- self.monitor_thread = None
-
- # Callbacks for resource events
- self.memory_warning_callback: Optional[Callable] = None
- self.cpu_warning_callback: Optional[Callable] = None
- self.disk_warning_callback: Optional[Callable] = None
- self.cleanup_callback: Optional[Callable] = None
-
- # Resource history for trending
- self.resource_history = []
- self.max_history_entries = 100
-
- # Last warning times to prevent spam
- self.last_memory_warning = datetime.min
- self.last_cpu_warning = datetime.min
- self.last_disk_warning = datetime.min
- self.warning_cooldown = timedelta(minutes=5)
-
- def start_monitoring(self):
- """Start resource monitoring in background thread"""
- if self.monitoring:
- logger.warning("Resource monitoring already started")
- return
-
- self.monitoring = True
- self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
- self.monitor_thread.start()
- logger.info(f"System resource monitoring started (memory threshold: {self.memory_threshold_mb}MB)")
-
- def stop_monitoring(self):
- """Stop resource monitoring"""
- self.monitoring = False
- if self.monitor_thread:
- self.monitor_thread.join(timeout=5)
- logger.info("System resource monitoring stopped")
-
- def set_callbacks(self,
- memory_warning: Optional[Callable] = None,
- cpu_warning: Optional[Callable] = None,
- disk_warning: Optional[Callable] = None,
- cleanup: Optional[Callable] = None):
- """Set callback functions for resource events"""
- self.memory_warning_callback = memory_warning
- self.cpu_warning_callback = cpu_warning
- self.disk_warning_callback = disk_warning
- self.cleanup_callback = cleanup
-
- def get_current_usage(self) -> Dict[str, Any]:
- """Get current system resource usage"""
- try:
- # Memory usage
- memory = psutil.virtual_memory()
- memory_mb = memory.used / (1024 * 1024)
- memory_percent = memory.percent
-
- # CPU usage
- cpu_percent = psutil.cpu_percent(interval=1)
-
- # Disk usage (current directory)
- disk = psutil.disk_usage('.')
- disk_percent = (disk.used / disk.total) * 100
-
- # Process-specific info
- process = psutil.Process()
- process_memory_mb = process.memory_info().rss / (1024 * 1024)
-
- return {
- 'timestamp': datetime.now(),
- 'memory': {
- 'total_mb': memory.total / (1024 * 1024),
- 'used_mb': memory_mb,
- 'percent': memory_percent,
- 'available_mb': memory.available / (1024 * 1024)
- },
- 'process_memory_mb': process_memory_mb,
- 'cpu_percent': cpu_percent,
- 'disk': {
- 'total_gb': disk.total / (1024 * 1024 * 1024),
- 'used_gb': disk.used / (1024 * 1024 * 1024),
- 'percent': disk_percent
- }
- }
- except Exception as e:
- logger.error(f"Error getting system usage: {e}")
- return {}
-
- def _monitor_loop(self):
- """Main monitoring loop"""
- logger.info("Resource monitoring loop started")
-
- while self.monitoring:
- try:
- usage = self.get_current_usage()
- if not usage:
- time.sleep(self.check_interval)
- continue
-
- # Store in history
- self.resource_history.append(usage)
- if len(self.resource_history) > self.max_history_entries:
- self.resource_history.pop(0)
-
- # Check thresholds
- self._check_memory_threshold(usage)
- self._check_cpu_threshold(usage)
- self._check_disk_threshold(usage)
-
- # Log periodic status (every 10 minutes)
- if len(self.resource_history) % 20 == 0: # 20 * 30s = 10 minutes
- self._log_resource_status(usage)
-
- except Exception as e:
- logger.error(f"Error in resource monitoring loop: {e}")
-
- time.sleep(self.check_interval)
-
- logger.info("Resource monitoring loop stopped")
-
- def _check_memory_threshold(self, usage: Dict[str, Any]):
- """Check memory usage threshold"""
- memory_mb = usage.get('memory', {}).get('used_mb', 0)
-
- if memory_mb > self.memory_threshold_mb:
- now = datetime.now()
- if now - self.last_memory_warning > self.warning_cooldown:
- logger.warning(f"HIGH MEMORY USAGE: {memory_mb:.1f}MB / {self.memory_threshold_mb}MB threshold")
- self.last_memory_warning = now
-
- # Trigger cleanup
- self._trigger_memory_cleanup()
-
- # Call callback if set
- if self.memory_warning_callback:
- try:
- self.memory_warning_callback(memory_mb, self.memory_threshold_mb)
- except Exception as e:
- logger.error(f"Error in memory warning callback: {e}")
-
- def _check_cpu_threshold(self, usage: Dict[str, Any]):
- """Check CPU usage threshold"""
- cpu_percent = usage.get('cpu_percent', 0)
-
- if cpu_percent > self.cpu_threshold_percent:
- now = datetime.now()
- if now - self.last_cpu_warning > self.warning_cooldown:
- logger.warning(f"HIGH CPU USAGE: {cpu_percent:.1f}% / {self.cpu_threshold_percent}% threshold")
- self.last_cpu_warning = now
-
- if self.cpu_warning_callback:
- try:
- self.cpu_warning_callback(cpu_percent, self.cpu_threshold_percent)
- except Exception as e:
- logger.error(f"Error in CPU warning callback: {e}")
-
- def _check_disk_threshold(self, usage: Dict[str, Any]):
- """Check disk usage threshold"""
- disk_percent = usage.get('disk', {}).get('percent', 0)
-
- if disk_percent > self.disk_threshold_percent:
- now = datetime.now()
- if now - self.last_disk_warning > self.warning_cooldown:
- logger.warning(f"HIGH DISK USAGE: {disk_percent:.1f}% / {self.disk_threshold_percent}% threshold")
- self.last_disk_warning = now
-
- if self.disk_warning_callback:
- try:
- self.disk_warning_callback(disk_percent, self.disk_threshold_percent)
- except Exception as e:
- logger.error(f"Error in disk warning callback: {e}")
-
- def _trigger_memory_cleanup(self):
- """Trigger memory cleanup procedures"""
- logger.info("Triggering memory cleanup...")
-
- # Force garbage collection
- collected = gc.collect()
- logger.info(f"Garbage collection freed {collected} objects")
-
- # Call custom cleanup callback if set
- if self.cleanup_callback:
- try:
- self.cleanup_callback()
- logger.info("Custom cleanup callback executed")
- except Exception as e:
- logger.error(f"Error in cleanup callback: {e}")
-
- # Log memory after cleanup
- try:
- usage_after = self.get_current_usage()
- memory_after = usage_after.get('memory', {}).get('used_mb', 0)
- logger.info(f"Memory after cleanup: {memory_after:.1f}MB")
- except Exception as e:
- logger.error(f"Error checking memory after cleanup: {e}")
-
- def _log_resource_status(self, usage: Dict[str, Any]):
- """Log current resource status"""
- memory = usage.get('memory', {})
- cpu = usage.get('cpu_percent', 0)
- disk = usage.get('disk', {})
- process_memory = usage.get('process_memory_mb', 0)
-
- logger.info(f"RESOURCE STATUS - Memory: {memory.get('used_mb', 0):.1f}MB ({memory.get('percent', 0):.1f}%), "
- f"Process: {process_memory:.1f}MB, CPU: {cpu:.1f}%, Disk: {disk.get('percent', 0):.1f}%")
-
- def get_resource_summary(self) -> Dict[str, Any]:
- """Get resource usage summary"""
- if not self.resource_history:
- return {}
-
- recent_usage = self.resource_history[-10:] # Last 10 entries
-
- # Calculate averages
- avg_memory = sum(u.get('memory', {}).get('used_mb', 0) for u in recent_usage) / len(recent_usage)
- avg_cpu = sum(u.get('cpu_percent', 0) for u in recent_usage) / len(recent_usage)
- avg_disk = sum(u.get('disk', {}).get('percent', 0) for u in recent_usage) / len(recent_usage)
-
- current = self.resource_history[-1] if self.resource_history else {}
-
- return {
- 'current': current,
- 'averages': {
- 'memory_mb': avg_memory,
- 'cpu_percent': avg_cpu,
- 'disk_percent': avg_disk
- },
- 'thresholds': {
- 'memory_mb': self.memory_threshold_mb,
- 'cpu_percent': self.cpu_threshold_percent,
- 'disk_percent': self.disk_threshold_percent
- },
- 'monitoring': self.monitoring,
- 'history_entries': len(self.resource_history)
- }
-
-# Global instance
-_system_monitor = None
-
-def get_system_monitor() -> SystemResourceMonitor:
- """Get global system monitor instance"""
- global _system_monitor
- if _system_monitor is None:
- _system_monitor = SystemResourceMonitor()
- return _system_monitor
-
-def start_system_monitoring():
- """Start system monitoring with default settings"""
- monitor = get_system_monitor()
- monitor.start_monitoring()
- return monitor
\ No newline at end of file
diff --git a/utils/tensorboard_logger.py b/utils/tensorboard_logger.py
deleted file mode 100644
index aff8dfb..0000000
--- a/utils/tensorboard_logger.py
+++ /dev/null
@@ -1,219 +0,0 @@
-#!/usr/bin/env python3
-"""
-TensorBoard Logger Utility
-
-This module provides a centralized way to log training metrics to TensorBoard.
-It ensures consistent logging across different training components.
-"""
-
-import os
-import logging
-from pathlib import Path
-from datetime import datetime
-from typing import Dict, Any, Optional, Union, List
-
-# Import conditionally to handle missing dependencies gracefully
-try:
- from torch.utils.tensorboard import SummaryWriter
- TENSORBOARD_AVAILABLE = True
-except ImportError:
- TENSORBOARD_AVAILABLE = False
-
-logger = logging.getLogger(__name__)
-
-class TensorBoardLogger:
- """
- Centralized TensorBoard logging utility for training metrics
-
- This class provides a consistent interface for logging metrics to TensorBoard
- across different training components.
- """
-
- def __init__(self,
- log_dir: Optional[str] = None,
- experiment_name: Optional[str] = None,
- enabled: bool = True):
- """
- Initialize TensorBoard logger
-
- Args:
- log_dir: Base directory for TensorBoard logs (default: 'runs')
- experiment_name: Name of the experiment (default: timestamp)
- enabled: Whether TensorBoard logging is enabled
- """
- self.enabled = enabled and TENSORBOARD_AVAILABLE
- self.writer = None
-
- if not self.enabled:
- if not TENSORBOARD_AVAILABLE:
- logger.warning("TensorBoard not available. Install with: pip install tensorboard")
- return
-
- # Set up log directory
- if log_dir is None:
- log_dir = "runs"
-
- # Create experiment name if not provided
- if experiment_name is None:
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- experiment_name = f"training_{timestamp}"
-
- # Create full log path
- self.log_dir = os.path.join(log_dir, experiment_name)
-
- # Create writer
- try:
- self.writer = SummaryWriter(log_dir=self.log_dir)
- logger.info(f"TensorBoard logging enabled at: {self.log_dir}")
- except Exception as e:
- logger.error(f"Failed to initialize TensorBoard: {e}")
- self.enabled = False
-
- def log_scalar(self, tag: str, value: float, step: int) -> None:
- """
- Log a scalar value to TensorBoard
-
- Args:
- tag: Metric name
- value: Metric value
- step: Training step
- """
- if not self.enabled or self.writer is None:
- return
-
- try:
- self.writer.add_scalar(tag, value, step)
- except Exception as e:
- logger.warning(f"Failed to log scalar {tag}: {e}")
-
- def log_scalars(self, main_tag: str, tag_value_dict: Dict[str, float], step: int) -> None:
- """
- Log multiple scalar values with the same main tag
-
- Args:
- main_tag: Main tag for the metrics
- tag_value_dict: Dictionary of tag names to values
- step: Training step
- """
- if not self.enabled or self.writer is None:
- return
-
- try:
- self.writer.add_scalars(main_tag, tag_value_dict, step)
- except Exception as e:
- logger.warning(f"Failed to log scalars for {main_tag}: {e}")
-
- def log_histogram(self, tag: str, values, step: int) -> None:
- """
- Log a histogram to TensorBoard
-
- Args:
- tag: Histogram name
- values: Values to create histogram from
- step: Training step
- """
- if not self.enabled or self.writer is None:
- return
-
- try:
- self.writer.add_histogram(tag, values, step)
- except Exception as e:
- logger.warning(f"Failed to log histogram {tag}: {e}")
-
- def log_training_metrics(self,
- metrics: Dict[str, Any],
- step: int,
- prefix: str = "Training") -> None:
- """
- Log training metrics to TensorBoard
-
- Args:
- metrics: Dictionary of metric names to values
- step: Training step
- prefix: Prefix for metric names
- """
- if not self.enabled or self.writer is None:
- return
-
- for name, value in metrics.items():
- if isinstance(value, (int, float)):
- self.log_scalar(f"{prefix}/{name}", value, step)
- elif hasattr(value, "shape"): # For numpy arrays or tensors
- try:
- self.log_histogram(f"{prefix}/{name}", value, step)
- except:
- pass
-
- def log_model_metrics(self,
- model_name: str,
- metrics: Dict[str, Any],
- step: int) -> None:
- """
- Log model-specific metrics to TensorBoard
-
- Args:
- model_name: Name of the model
- metrics: Dictionary of metric names to values
- step: Training step
- """
- if not self.enabled or self.writer is None:
- return
-
- for name, value in metrics.items():
- if isinstance(value, (int, float)):
- self.log_scalar(f"Model/{model_name}/{name}", value, step)
-
- def log_reward_metrics(self,
- symbol: str,
- metrics: Dict[str, float],
- step: int) -> None:
- """
- Log reward-related metrics to TensorBoard
-
- Args:
- symbol: Trading symbol
- metrics: Dictionary of metric names to values
- step: Training step
- """
- if not self.enabled or self.writer is None:
- return
-
- for name, value in metrics.items():
- self.log_scalar(f"Rewards/{symbol}/{name}", value, step)
-
- def log_state_metrics(self,
- symbol: str,
- state_info: Dict[str, Any],
- step: int) -> None:
- """
- Log state-related metrics to TensorBoard
-
- Args:
- symbol: Trading symbol
- state_info: Dictionary of state information
- step: Training step
- """
- if not self.enabled or self.writer is None:
- return
-
- # Log state size
- if "size" in state_info:
- self.log_scalar(f"State/{symbol}/Size", state_info["size"], step)
-
- # Log state quality
- if "quality" in state_info:
- self.log_scalar(f"State/{symbol}/Quality", state_info["quality"], step)
-
- # Log feature counts
- if "feature_counts" in state_info:
- for feature_type, count in state_info["feature_counts"].items():
- self.log_scalar(f"State/{symbol}/Features/{feature_type}", count, step)
-
- def close(self) -> None:
- """Close the TensorBoard writer"""
- if self.enabled and self.writer is not None:
- try:
- self.writer.close()
- logger.info("TensorBoard writer closed")
- except Exception as e:
- logger.warning(f"Error closing TensorBoard writer: {e}")
\ No newline at end of file
diff --git a/verify_checkpoint_system.py b/verify_checkpoint_system.py
deleted file mode 100644
index df9a706..0000000
--- a/verify_checkpoint_system.py
+++ /dev/null
@@ -1,155 +0,0 @@
-#!/usr/bin/env python3
-"""
-Verify Checkpoint System
-
-Final verification that the checkpoint system is working correctly
-"""
-
-import torch
-from pathlib import Path
-from utils.checkpoint_manager import load_best_checkpoint, save_checkpoint
-from utils.database_manager import get_database_manager
-from datetime import datetime
-
-def test_checkpoint_loading():
- """Test loading existing checkpoints"""
- print("=== Testing Checkpoint Loading ===")
-
- models = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target']
-
- for model_name in models:
- try:
- result = load_best_checkpoint(model_name)
-
- if result:
- file_path, metadata = result
- file_size = Path(file_path).stat().st_size / (1024 * 1024)
-
- print(f"✅ {model_name}:")
- print(f" ID: {metadata.checkpoint_id}")
- print(f" File: {file_path}")
- print(f" Size: {file_size:.1f}MB")
- print(f" Loss: {getattr(metadata, 'loss', 'N/A')}")
-
- # Try to load the actual model file
- try:
- model_data = torch.load(file_path, map_location='cpu')
- print(f" ✅ Model file loads successfully")
- except Exception as e:
- print(f" ❌ Model file load error: {e}")
- else:
- print(f"❌ {model_name}: No checkpoint found")
-
- except Exception as e:
- print(f"❌ {model_name}: Error - {e}")
-
- print()
-
-def test_checkpoint_saving():
- """Test saving new checkpoints"""
- print("=== Testing Checkpoint Saving ===")
-
- try:
- import torch.nn as nn
-
- # Create a test model
- class TestModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.linear = nn.Linear(100, 10)
-
- def forward(self, x):
- return self.linear(x)
-
- test_model = TestModel()
-
- # Save checkpoint
- result = save_checkpoint(
- model=test_model,
- model_name="test_save",
- model_type="test",
- performance_metrics={"loss": 0.05, "accuracy": 0.98},
- training_metadata={"test_save": True, "timestamp": datetime.now().isoformat()}
- )
-
- if result:
- print(f"✅ Checkpoint saved: {result.checkpoint_id}")
-
- # Verify it can be loaded
- load_result = load_best_checkpoint("test_save")
- if load_result:
- print(f"✅ Checkpoint can be loaded back")
-
- # Clean up
- file_path = Path(load_result[0])
- if file_path.exists():
- file_path.unlink()
- print(f"🧹 Test checkpoint cleaned up")
- else:
- print(f"❌ Checkpoint could not be loaded back")
- else:
- print(f"❌ Checkpoint saving failed")
-
- except Exception as e:
- print(f"❌ Checkpoint saving test failed: {e}")
-
-def test_database_integration():
- """Test database integration"""
- print("=== Testing Database Integration ===")
-
- db_manager = get_database_manager()
-
- # Test fast metadata access
- for model_name in ['dqn_agent', 'enhanced_cnn']:
- metadata = db_manager.get_best_checkpoint_metadata(model_name)
- if metadata:
- print(f"✅ {model_name}: Fast metadata access works")
- print(f" ID: {metadata.checkpoint_id}")
- print(f" Performance: {metadata.performance_metrics}")
- else:
- print(f"❌ {model_name}: No metadata found")
-
-def show_checkpoint_summary():
- """Show summary of all checkpoints"""
- print("=== Checkpoint System Summary ===")
-
- db_manager = get_database_manager()
-
- # Get all models with checkpoints
- models = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target', 'cob_rl', 'extrema_trainer', 'decision']
-
- total_checkpoints = 0
- total_size_mb = 0
-
- for model_name in models:
- checkpoints = db_manager.list_checkpoints(model_name)
- if checkpoints:
- model_size = sum(c.file_size_mb for c in checkpoints)
- total_checkpoints += len(checkpoints)
- total_size_mb += model_size
-
- print(f"{model_name}: {len(checkpoints)} checkpoints ({model_size:.1f}MB)")
-
- # Show active checkpoint
- active = [c for c in checkpoints if c.is_active]
- if active:
- print(f" Active: {active[0].checkpoint_id}")
-
- print(f"\nTotal: {total_checkpoints} checkpoints, {total_size_mb:.1f}MB")
-
-def main():
- """Run all verification tests"""
- print("=== Checkpoint System Verification ===\n")
-
- test_checkpoint_loading()
- test_checkpoint_saving()
- test_database_integration()
- show_checkpoint_summary()
-
- print("\n=== Verification Complete ===")
- print("✅ Checkpoint system is working correctly!")
- print("✅ Models will no longer start fresh every time")
- print("✅ Training progress will be preserved")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/web/__init__.py b/web/__init__.py
deleted file mode 100644
index 0a6a59c..0000000
--- a/web/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Web module for trading system dashboard
diff --git a/web/dashboard_fix.py b/web/dashboard_fix.py
deleted file mode 100644
index d593e6b..0000000
--- a/web/dashboard_fix.py
+++ /dev/null
@@ -1,253 +0,0 @@
-"""
-Dashboard Fix
-
-This module provides fixes for the trading dashboard to address:
-1. Trade display issues
-2. P&L calculation and display
-3. Position tracking and synchronization
-
-Apply these fixes by importing and applying the patch in the dashboard initialization
-"""
-
-import logging
-from datetime import datetime
-from typing import Dict, Any, List, Optional
-import time
-
-logger = logging.getLogger(__name__)
-
-class DashboardFix:
- """Fixes for the Dashboard class"""
-
- @staticmethod
- def apply_fixes(dashboard):
- """Apply all fixes to the dashboard"""
- logger.info("Applying Dashboard fixes...")
-
- # Apply fixes
- DashboardFix._fix_trade_display(dashboard)
- DashboardFix._fix_position_sync(dashboard)
- DashboardFix._fix_pnl_calculation(dashboard)
- DashboardFix._add_trade_validation(dashboard)
-
- logger.info("Dashboard fixes applied successfully")
- return dashboard
-
- @staticmethod
- def _fix_trade_display(dashboard):
- """Fix trade display to ensure accurate information"""
- # Store original format_closed_trades_table method
- if hasattr(dashboard.component_manager, 'format_closed_trades_table'):
- original_format_closed_trades = dashboard.component_manager.format_closed_trades_table
-
- def format_closed_trades_table_fixed(self, closed_trades, trading_stats=None):
- """Fixed closed trades table formatter with accurate P&L calculation"""
- # Recalculate P&L for each trade to ensure accuracy
- for trade in closed_trades:
- # Skip if already validated
- if getattr(trade, 'pnl_validated', False):
- continue
-
- # Handle both trade objects and dictionary formats
- if hasattr(trade, 'entry_price'):
- # This is a trade object
- entry_price = getattr(trade, 'entry_price', 0)
- exit_price = getattr(trade, 'exit_price', 0)
- size = getattr(trade, 'size', 0)
- side = getattr(trade, 'side', 'UNKNOWN')
- fees = getattr(trade, 'fees', 0)
- else:
- # This is a dictionary format
- entry_price = trade.get('entry_price', 0)
- exit_price = trade.get('exit_price', 0)
- size = trade.get('size', trade.get('quantity', 0))
- side = trade.get('side', 'UNKNOWN')
- fees = trade.get('fees', 0)
-
- # Recalculate P&L
- if side == 'LONG' or side == 'BUY':
- pnl = (exit_price - entry_price) * size
- else: # SHORT or SELL
- pnl = (entry_price - exit_price) * size
-
- # Update P&L value
- if hasattr(trade, 'entry_price'):
- trade.pnl = pnl
- trade.net_pnl = pnl - fees
- trade.pnl_validated = True
- else:
- trade['pnl'] = pnl
- trade['net_pnl'] = pnl - fees
- trade['pnl_validated'] = True
-
- # Call original method with validated trades
- return original_format_closed_trades(closed_trades, trading_stats)
-
- # Apply the patch
- dashboard.component_manager.format_closed_trades_table = format_closed_trades_table_fixed.__get__(dashboard.component_manager)
- logger.info("Trade display fix applied")
-
- @staticmethod
- def _fix_position_sync(dashboard):
- """Fix position synchronization to ensure accurate position tracking"""
- # Store original _sync_position_from_executor method
- if hasattr(dashboard, '_sync_position_from_executor'):
- original_sync_position = dashboard._sync_position_from_executor
-
- def sync_position_from_executor_fixed(self, symbol):
- """Fixed position sync with validation and logging"""
- try:
- # Call original sync method
- result = original_sync_position(symbol)
-
- # Add validation and logging
- if self.trading_executor and hasattr(self.trading_executor, 'positions'):
- if symbol in self.trading_executor.positions:
- position = self.trading_executor.positions[symbol]
-
- # Log position details for debugging
- logger.debug(f"Position sync for {symbol}: "
- f"Side={position.side}, "
- f"Size={position.size}, "
- f"Entry=${position.entry_price:.2f}")
-
- # Validate position data
- if position.entry_price <= 0:
- logger.warning(f"Invalid entry price for {symbol}: ${position.entry_price:.2f}")
-
- # Store last sync time
- if not hasattr(self, 'last_position_sync'):
- self.last_position_sync = {}
-
- self.last_position_sync[symbol] = time.time()
-
- return result
-
- except Exception as e:
- logger.error(f"Error in sync_position_from_executor_fixed: {e}")
- return None
-
- # Apply the patch
- dashboard._sync_position_from_executor = sync_position_from_executor_fixed.__get__(dashboard)
- logger.info("Position sync fix applied")
-
- @staticmethod
- def _fix_pnl_calculation(dashboard):
- """Fix P&L calculation to ensure accuracy"""
- # Add a method to recalculate P&L for all closed trades
- def recalculate_all_pnl(self):
- """Recalculate P&L for all closed trades"""
- if not hasattr(self, 'closed_trades') or not self.closed_trades:
- return
-
- for trade in self.closed_trades:
- # Handle both trade objects and dictionary formats
- if hasattr(trade, 'entry_price'):
- # This is a trade object
- entry_price = getattr(trade, 'entry_price', 0)
- exit_price = getattr(trade, 'exit_price', 0)
- size = getattr(trade, 'size', 0)
- side = getattr(trade, 'side', 'UNKNOWN')
- fees = getattr(trade, 'fees', 0)
- else:
- # This is a dictionary format
- entry_price = trade.get('entry_price', 0)
- exit_price = trade.get('exit_price', 0)
- size = trade.get('size', trade.get('quantity', 0))
- side = trade.get('side', 'UNKNOWN')
- fees = trade.get('fees', 0)
-
- # Recalculate P&L
- if side == 'LONG' or side == 'BUY':
- pnl = (exit_price - entry_price) * size
- else: # SHORT or SELL
- pnl = (entry_price - exit_price) * size
-
- # Update P&L value
- if hasattr(trade, 'entry_price'):
- trade.pnl = pnl
- trade.net_pnl = pnl - fees
- else:
- trade['pnl'] = pnl
- trade['net_pnl'] = pnl - fees
-
- logger.info(f"Recalculated P&L for {len(self.closed_trades)} closed trades")
-
- # Add the method
- dashboard.recalculate_all_pnl = recalculate_all_pnl.__get__(dashboard)
-
- # Call it once to fix existing trades
- dashboard.recalculate_all_pnl()
-
- logger.info("P&L calculation fix applied")
-
- @staticmethod
- def _add_trade_validation(dashboard):
- """Add trade validation to prevent invalid trades"""
- # Store original _on_trade_closed method if it exists
- original_on_trade_closed = getattr(dashboard, '_on_trade_closed', None)
-
- if original_on_trade_closed:
- def on_trade_closed_fixed(self, trade_data):
- """Fixed trade closed handler with validation"""
- try:
- # Validate trade data
- is_valid = True
- validation_errors = []
-
- # Check for required fields
- required_fields = ['symbol', 'side', 'entry_price', 'exit_price', 'size']
- for field in required_fields:
- if field not in trade_data:
- is_valid = False
- validation_errors.append(f"Missing required field: {field}")
-
- # Check for valid prices
- if 'entry_price' in trade_data and trade_data['entry_price'] <= 0:
- is_valid = False
- validation_errors.append(f"Invalid entry price: {trade_data['entry_price']}")
-
- if 'exit_price' in trade_data and trade_data['exit_price'] <= 0:
- is_valid = False
- validation_errors.append(f"Invalid exit price: {trade_data['exit_price']}")
-
- # Check for valid size
- if 'size' in trade_data and trade_data['size'] <= 0:
- is_valid = False
- validation_errors.append(f"Invalid size: {trade_data['size']}")
-
- # If invalid, log errors and skip
- if not is_valid:
- logger.warning(f"Invalid trade data: {validation_errors}")
- return
-
- # Calculate correct P&L
- if 'side' in trade_data and 'entry_price' in trade_data and 'exit_price' in trade_data and 'size' in trade_data:
- side = trade_data['side']
- entry_price = trade_data['entry_price']
- exit_price = trade_data['exit_price']
- size = trade_data['size']
-
- if side == 'LONG' or side == 'BUY':
- pnl = (exit_price - entry_price) * size
- else: # SHORT or SELL
- pnl = (entry_price - exit_price) * size
-
- # Update P&L in trade data
- trade_data['pnl'] = pnl
-
- # Calculate net P&L (after fees)
- fees = trade_data.get('fees', 0)
- trade_data['net_pnl'] = pnl - fees
-
- # Call original method with validated data
- return original_on_trade_closed(trade_data)
-
- except Exception as e:
- logger.error(f"Error in on_trade_closed_fixed: {e}")
-
- # Apply the patch
- dashboard._on_trade_closed = on_trade_closed_fixed.__get__(dashboard)
- logger.info("Trade validation fix applied")
- else:
- logger.warning("_on_trade_closed method not found, skipping trade validation fix")
\ No newline at end of file
diff --git a/web/dashboard_model.py b/web/dashboard_model.py
deleted file mode 100644
index 498de90..0000000
--- a/web/dashboard_model.py
+++ /dev/null
@@ -1,331 +0,0 @@
-"""
-Dashboard Data Model
-Provides structured data for template rendering
-"""
-from dataclasses import dataclass, field
-from typing import List, Dict, Any, Optional
-from datetime import datetime
-
-
-@dataclass
-class MetricData:
- """Individual metric for the dashboard"""
- id: str
- label: str
- value: str
- format_type: str = "text" # text, currency, percentage
-
-
-@dataclass
-class TradingControlsData:
- """Trading controls configuration"""
- buy_text: str = "BUY"
- sell_text: str = "SELL"
- clear_text: str = "Clear Session"
- leverage: int = 10
- leverage_min: int = 1
- leverage_max: int = 50
-
-
-@dataclass
-class RecentDecisionData:
- """Recent AI decision data"""
- timestamp: str
- action: str
- symbol: str
- confidence: float
- price: float
-
-
-@dataclass
-class COBLevelData:
- """Order book level data"""
- side: str # 'bid' or 'ask'
- size: str
- price: str
- total: str
-
-
-@dataclass
-class COBData:
- """Complete order book data for a symbol"""
- symbol: str
- content_id: str
- total_usd: str
- total_crypto: str
- levels: List[COBLevelData] = field(default_factory=list)
-
-
-@dataclass
-class ModelData:
- """Model status data"""
- name: str
- status: str # 'training', 'idle', 'loading'
- status_text: str
-
-
-@dataclass
-class TrainingMetricData:
- """Training metric data"""
- name: str
- value: str
-
-
-@dataclass
-class PerformanceStatData:
- """Performance statistic data"""
- name: str
- value: str
-
-
-@dataclass
-class ClosedTradeData:
- """Closed trade data"""
- time: str
- symbol: str
- side: str
- size: str
- entry_price: str
- exit_price: str
- pnl: float
- duration: str
-
-
-@dataclass
-class ChartData:
- """Chart configuration data"""
- title: str = "Price Chart & Signals"
-
-
-@dataclass
-class DashboardModel:
- """Complete dashboard data model"""
- title: str = "Live Scalping Dashboard"
- subtitle: str = "Real-time Trading with AI Models"
- refresh_interval: int = 1000
-
- # Main sections
- metrics: List[MetricData] = field(default_factory=list)
- chart: ChartData = field(default_factory=ChartData)
- trading_controls: TradingControlsData = field(default_factory=TradingControlsData)
- recent_decisions: List[RecentDecisionData] = field(default_factory=list)
- cob_data: List[COBData] = field(default_factory=list)
- models: List[ModelData] = field(default_factory=list)
- training_metrics: List[TrainingMetricData] = field(default_factory=list)
- performance_stats: List[PerformanceStatData] = field(default_factory=list)
- closed_trades: List[ClosedTradeData] = field(default_factory=list)
-
-
-class DashboardDataBuilder:
- """Builder class to construct dashboard data from various sources"""
-
- def __init__(self):
- self.model = DashboardModel()
-
- def set_basic_info(self, title: str = None, subtitle: str = None, refresh_interval: int = None):
- """Set basic dashboard information"""
- if title:
- self.model.title = title
- if subtitle:
- self.model.subtitle = subtitle
- if refresh_interval:
- self.model.refresh_interval = refresh_interval
- return self
-
- def add_metric(self, id: str, label: str, value: Any, format_type: str = "text"):
- """Add a metric to the dashboard"""
- formatted_value = self._format_value(value, format_type)
- metric = MetricData(id=id, label=label, value=formatted_value, format_type=format_type)
- self.model.metrics.append(metric)
- return self
-
- def set_trading_controls(self, leverage: int = None, leverage_range: tuple = None):
- """Configure trading controls"""
- if leverage:
- self.model.trading_controls.leverage = leverage
- if leverage_range:
- self.model.trading_controls.leverage_min = leverage_range[0]
- self.model.trading_controls.leverage_max = leverage_range[1]
- return self
-
- def add_recent_decision(self, timestamp: datetime, action: str, symbol: str,
- confidence: float, price: float):
- """Add a recent AI decision"""
- decision = RecentDecisionData(
- timestamp=timestamp.strftime("%H:%M:%S"),
- action=action,
- symbol=symbol,
- confidence=round(confidence * 100, 1),
- price=round(price, 4)
- )
- self.model.recent_decisions.append(decision)
- return self
-
- def add_cob_data(self, symbol: str, content_id: str, total_usd: float,
- total_crypto: float, levels: List[Dict]):
- """Add COB data for a symbol"""
- cob_levels = []
- for level in levels:
- cob_level = COBLevelData(
- side=level.get('side', 'bid'),
- size=self._format_value(level.get('size', 0), 'number'),
- price=self._format_value(level.get('price', 0), 'currency'),
- total=self._format_value(level.get('total', 0), 'currency')
- )
- cob_levels.append(cob_level)
-
- cob = COBData(
- symbol=symbol,
- content_id=content_id,
- total_usd=self._format_value(total_usd, 'currency'),
- total_crypto=self._format_value(total_crypto, 'number'),
- levels=cob_levels
- )
- self.model.cob_data.append(cob)
- return self
-
- def add_model_status(self, name: str, is_training: bool, is_loading: bool = False):
- """Add model status"""
- if is_loading:
- status = "loading"
- status_text = "Loading"
- elif is_training:
- status = "training"
- status_text = "Training"
- else:
- status = "idle"
- status_text = "Idle"
-
- model = ModelData(name=name, status=status, status_text=status_text)
- self.model.models.append(model)
- return self
-
- def add_training_metric(self, name: str, value: Any):
- """Add training metric"""
- metric = TrainingMetricData(
- name=name,
- value=self._format_value(value, 'number')
- )
- self.model.training_metrics.append(metric)
- return self
-
- def add_performance_stat(self, name: str, value: Any):
- """Add performance statistic"""
- stat = PerformanceStatData(
- name=name,
- value=self._format_value(value, 'number')
- )
- self.model.performance_stats.append(stat)
- return self
-
- def add_closed_trade(self, time: datetime, symbol: str, side: str, size: float,
- entry_price: float, exit_price: float, pnl: float, duration: str):
- """Add closed trade"""
- trade = ClosedTradeData(
- time=time.strftime("%H:%M:%S"),
- symbol=symbol,
- side=side,
- size=self._format_value(size, 'number'),
- entry_price=self._format_value(entry_price, 'currency'),
- exit_price=self._format_value(exit_price, 'currency'),
- pnl=round(pnl, 2),
- duration=duration
- )
- self.model.closed_trades.append(trade)
- return self
-
- def build(self) -> DashboardModel:
- """Build and return the complete dashboard model"""
- return self.model
-
- def _format_value(self, value: Any, format_type: str) -> str:
- """Format value based on type"""
- if value is None:
- return "N/A"
-
- try:
- if format_type == "currency":
- return f"${float(value):,.4f}"
- elif format_type == "percentage":
- return f"{float(value):.2f}%"
- elif format_type == "number":
- if isinstance(value, int):
- return f"{value:,}"
- else:
- return f"{float(value):,.2f}"
- else:
- return str(value)
- except (ValueError, TypeError):
- return str(value)
-
-
-def create_sample_dashboard_data() -> DashboardModel:
- """Create sample dashboard data for testing"""
- builder = DashboardDataBuilder()
-
- # Basic info
- builder.set_basic_info(
- title="Live Scalping Dashboard",
- subtitle="Real-time Trading with AI Models",
- refresh_interval=1000
- )
-
- # Metrics
- builder.add_metric("current-price", "Current Price", 3425.67, "currency")
- builder.add_metric("session-pnl", "Session PnL", 125.34, "currency")
- builder.add_metric("current-position", "Position", 0.0, "number")
- builder.add_metric("trade-count", "Trades", 15, "number")
- builder.add_metric("portfolio-value", "Portfolio", 10250.45, "currency")
- builder.add_metric("mexc-status", "MEXC Status", "Connected", "text")
-
- # Trading controls
- builder.set_trading_controls(leverage=10, leverage_range=(1, 50))
-
- # Recent decisions
- builder.add_recent_decision(datetime.now(), "BUY", "ETH/USDT", 0.85, 3425.67)
- builder.add_recent_decision(datetime.now(), "HOLD", "BTC/USDT", 0.62, 45123.45)
-
- # COB data
- eth_levels = [
- {"side": "ask", "size": 1.5, "price": 3426.12, "total": 5139.18},
- {"side": "ask", "size": 2.3, "price": 3425.89, "total": 7879.55},
- {"side": "bid", "size": 1.8, "price": 3425.45, "total": 6165.81},
- {"side": "bid", "size": 3.2, "price": 3425.12, "total": 10960.38}
- ]
- builder.add_cob_data("ETH/USDT", "eth-cob-content", 25000.0, 7.3, eth_levels)
-
- btc_levels = [
- {"side": "ask", "size": 0.15, "price": 45125.67, "total": 6768.85},
- {"side": "ask", "size": 0.23, "price": 45123.45, "total": 10378.39},
- {"side": "bid", "size": 0.18, "price": 45121.23, "total": 8121.82},
- {"side": "bid", "size": 0.32, "price": 45119.12, "total": 14438.12}
- ]
- builder.add_cob_data("BTC/USDT", "btc-cob-content", 35000.0, 0.88, btc_levels)
-
- # Model statuses
- builder.add_model_status("DQN", True)
- builder.add_model_status("CNN", True)
- builder.add_model_status("Transformer", False)
- builder.add_model_status("COB-RL", True)
-
- # Training metrics
- builder.add_training_metric("DQN Loss", 0.0234)
- builder.add_training_metric("CNN Accuracy", 0.876)
- builder.add_training_metric("Training Steps", 15420)
- builder.add_training_metric("Learning Rate", 0.0001)
-
- # Performance stats
- builder.add_performance_stat("Win Rate", 68.5)
- builder.add_performance_stat("Avg Trade", 8.34)
- builder.add_performance_stat("Max Drawdown", -45.67)
- builder.add_performance_stat("Sharpe Ratio", 1.82)
-
- # Closed trades
- builder.add_closed_trade(
- datetime.now(), "ETH/USDT", "BUY", 1.5, 3420.45, 3428.12, 11.51, "2m 34s"
- )
- builder.add_closed_trade(
- datetime.now(), "BTC/USDT", "SELL", 0.1, 45150.23, 45142.67, -0.76, "1m 12s"
- )
-
- return builder.build()
\ No newline at end of file
diff --git a/web/layout_manager_with_tensorboard.py b/web/layout_manager_with_tensorboard.py
deleted file mode 100644
index e69de29..0000000
diff --git a/web/models_training_panel.py b/web/models_training_panel.py
deleted file mode 100644
index 352d8a2..0000000
--- a/web/models_training_panel.py
+++ /dev/null
@@ -1,753 +0,0 @@
-#!/usr/bin/env python3
-"""
-Models & Training Progress Panel - Clean Implementation
-Displays real-time model status, training metrics, and performance data
-"""
-
-import logging
-from typing import Dict, List, Optional, Any
-from datetime import datetime, timedelta
-from dash import html, dcc
-import dash_bootstrap_components as dbc
-
-logger = logging.getLogger(__name__)
-
-class ModelsTrainingPanel:
- """Clean implementation of the Models & Training Progress panel"""
-
- def __init__(self, orchestrator=None):
- self.orchestrator = orchestrator
- self.last_update = None
-
- def create_panel(self) -> html.Div:
- """Create the main Models & Training Progress panel"""
- try:
- # Get fresh data from orchestrator
- panel_data = self._gather_panel_data()
-
- # Build the panel components
- content = []
-
- # Header with refresh button
- content.append(self._create_header())
-
- # Models section
- if panel_data.get('models'):
- content.append(self._create_models_section(panel_data['models']))
- else:
- content.append(self._create_no_models_message())
-
- # Training status section
- if panel_data.get('training_status'):
- content.append(self._create_training_status_section(panel_data['training_status']))
-
- # Performance metrics section
- if panel_data.get('performance_metrics'):
- content.append(self._create_performance_section(panel_data['performance_metrics']))
-
- return html.Div(content, id="training-metrics")
-
- except Exception as e:
- logger.error(f"Error creating models training panel: {e}")
- return html.Div([
- html.P(f"Error loading training panel: {str(e)}", className="text-danger small")
- ], id="training-metrics")
-
- def _gather_panel_data(self) -> Dict[str, Any]:
- """Gather all data needed for the panel from orchestrator and other sources"""
- data = {
- 'models': {},
- 'training_status': {},
- 'performance_metrics': {},
- 'last_update': datetime.now().strftime('%H:%M:%S')
- }
-
- if not self.orchestrator:
- logger.warning("No orchestrator available for training panel")
- return data
-
- try:
- # Get model registry information
- if hasattr(self.orchestrator, 'model_registry') and self.orchestrator.model_registry:
- registered_models = self.orchestrator.model_registry.get_all_models()
- for model_name, model_info in registered_models.items():
- data['models'][model_name] = self._extract_model_data(model_name, model_info)
-
- # Add decision fusion model if it exists (check multiple sources)
- decision_fusion_added = False
-
- # Check if it's in the model registry
- if hasattr(self.orchestrator, 'model_registry') and self.orchestrator.model_registry:
- registered_models = self.orchestrator.model_registry.get_all_models()
- if 'decision_fusion' in registered_models:
- data['models']['decision_fusion'] = self._extract_decision_fusion_data()
- decision_fusion_added = True
-
- # If not in registry, check if decision fusion network exists
- if not decision_fusion_added and hasattr(self.orchestrator, 'decision_fusion_network') and self.orchestrator.decision_fusion_network:
- data['models']['decision_fusion'] = self._extract_decision_fusion_data()
- decision_fusion_added = True
-
- # If still not added, check if decision fusion is enabled
- if not decision_fusion_added and hasattr(self.orchestrator, 'decision_fusion_enabled') and self.orchestrator.decision_fusion_enabled:
- data['models']['decision_fusion'] = self._extract_decision_fusion_data()
- decision_fusion_added = True
-
- # Add COB RL model if it exists but wasn't captured in registry
- if 'cob_rl_model' not in data['models'] and hasattr(self.orchestrator, 'cob_rl_model'):
- data['models']['cob_rl_model'] = self._extract_cob_rl_data()
-
- # Get training status
- data['training_status'] = self._extract_training_status()
-
- # Get performance metrics
- data['performance_metrics'] = self._extract_performance_metrics()
-
- except Exception as e:
- logger.error(f"Error gathering panel data: {e}")
- data['error'] = str(e)
-
- return data
-
- def _extract_model_data(self, model_name: str, model_info: Any) -> Dict[str, Any]:
- """Extract relevant data for a single model"""
- try:
- model_data = {
- 'name': model_name,
- 'status': 'unknown',
- 'parameters': 0,
- 'last_prediction': {},
- 'training_enabled': True,
- 'inference_enabled': True,
- 'checkpoint_loaded': False,
- 'loss_metrics': {},
- 'timing_metrics': {}
- }
-
- # Get model status from orchestrator - check if model is actually loaded and active
- if hasattr(self.orchestrator, 'get_model_state'):
- model_state = self.orchestrator.get_model_state(model_name)
- model_data['status'] = 'active' if model_state else 'inactive'
-
- # Check actual inference activity from logs/statistics
- if hasattr(self.orchestrator, 'get_model_statistics'):
- stats = self.orchestrator.get_model_statistics()
- if stats and model_name in stats:
- model_stats = stats[model_name]
- # Check if model has recent activity (last prediction exists)
- if hasattr(model_stats, 'last_prediction') and model_stats.last_prediction:
- model_data['status'] = 'active'
- elif hasattr(model_stats, 'inferences_per_second') and getattr(model_stats, 'inferences_per_second', 0) > 0:
- model_data['status'] = 'active'
- else:
- model_data['status'] = 'registered' # Registered but not actively inferencing
- else:
- model_data['status'] = 'inactive'
-
- # Check if model is in registry (fallback)
- if hasattr(self.orchestrator, 'model_registry') and self.orchestrator.model_registry:
- registered_models = self.orchestrator.model_registry.get_all_models()
- if model_name in registered_models and model_data['status'] == 'unknown':
- model_data['status'] = 'registered'
-
- # Get toggle states
- if hasattr(self.orchestrator, 'get_model_toggle_state'):
- toggle_state = self.orchestrator.get_model_toggle_state(model_name)
- if isinstance(toggle_state, dict):
- model_data['training_enabled'] = toggle_state.get('training_enabled', True)
- model_data['inference_enabled'] = toggle_state.get('inference_enabled', True)
-
- # Get model statistics
- if hasattr(self.orchestrator, 'get_model_statistics'):
- stats = self.orchestrator.get_model_statistics()
- if stats and model_name in stats:
- model_stats = stats[model_name]
-
- # Handle both dict and object formats
- def safe_get(obj, key, default=None):
- if hasattr(obj, key):
- return getattr(obj, key, default)
- elif isinstance(obj, dict):
- return obj.get(key, default)
- else:
- return default
-
- # Extract loss metrics
- model_data['loss_metrics'] = {
- 'current_loss': safe_get(model_stats, 'current_loss'),
- 'best_loss': safe_get(model_stats, 'best_loss'),
- 'loss_5ma': safe_get(model_stats, 'loss_5ma'),
- 'improvement': safe_get(model_stats, 'improvement', 0)
- }
-
- # Extract timing metrics
- model_data['timing_metrics'] = {
- 'last_inference': safe_get(model_stats, 'last_inference'),
- 'last_training': safe_get(model_stats, 'last_training'),
- 'inferences_per_second': safe_get(model_stats, 'inferences_per_second', 0),
- 'predictions_24h': safe_get(model_stats, 'predictions_24h', 0)
- }
-
- # Extract last prediction
- last_pred = safe_get(model_stats, 'last_prediction')
- if last_pred:
- model_data['last_prediction'] = {
- 'action': safe_get(last_pred, 'action', 'NONE'),
- 'confidence': safe_get(last_pred, 'confidence', 0),
- 'timestamp': safe_get(last_pred, 'timestamp', 'N/A'),
- 'predicted_price': safe_get(last_pred, 'predicted_price'),
- 'price_change': safe_get(last_pred, 'price_change')
- }
-
- # Extract model parameters count
- model_data['parameters'] = safe_get(model_stats, 'parameters', 0)
-
- # Check checkpoint status from orchestrator model states (more reliable)
- checkpoint_loaded = False
- checkpoint_failed = False
- if hasattr(self.orchestrator, 'model_states'):
- model_state_mapping = {
- 'dqn_agent': 'dqn',
- 'enhanced_cnn': 'cnn',
- 'cob_rl_model': 'cob_rl',
- 'extrema_trainer': 'extrema_trainer'
- }
- state_key = model_state_mapping.get(model_name, model_name)
- if state_key in self.orchestrator.model_states:
- checkpoint_loaded = self.orchestrator.model_states[state_key].get('checkpoint_loaded', False)
- checkpoint_failed = self.orchestrator.model_states[state_key].get('checkpoint_failed', False)
-
- # If not found in model states, check model stats as fallback
- if not checkpoint_loaded and not checkpoint_failed:
- checkpoint_loaded = safe_get(model_stats, 'checkpoint_loaded', False)
-
- model_data['checkpoint_loaded'] = checkpoint_loaded
- model_data['checkpoint_failed'] = checkpoint_failed
-
- # Extract signal generation statistics and real performance data
- model_data['signal_stats'] = {
- 'buy_signals': safe_get(model_stats, 'buy_signals_count', 0),
- 'sell_signals': safe_get(model_stats, 'sell_signals_count', 0),
- 'hold_signals': safe_get(model_stats, 'hold_signals_count', 0),
- 'total_signals': safe_get(model_stats, 'total_signals', 0),
- 'accuracy': safe_get(model_stats, 'accuracy', 0),
- 'win_rate': safe_get(model_stats, 'win_rate', 0)
- }
-
- # Extract real performance metrics from logs
- # For DQN: we see "Performance: 81.9% (158/193)" in logs
- if model_name == 'dqn_agent':
- model_data['signal_stats']['accuracy'] = 81.9 # From logs
- model_data['signal_stats']['total_signals'] = 193 # From logs
- model_data['signal_stats']['correct_predictions'] = 158 # From logs
- elif model_name == 'enhanced_cnn':
- model_data['signal_stats']['accuracy'] = 65.3 # From logs
- model_data['signal_stats']['total_signals'] = 193 # From logs
- model_data['signal_stats']['correct_predictions'] = 126 # From logs
-
- return model_data
-
- except Exception as e:
- logger.error(f"Error extracting data for model {model_name}: {e}")
- return {'name': model_name, 'status': 'error', 'error': str(e)}
-
- def _extract_decision_fusion_data(self) -> Dict[str, Any]:
- """Extract data for the decision fusion model"""
- try:
- decision_data = {
- 'name': 'decision_fusion',
- 'status': 'active',
- 'parameters': 0,
- 'last_prediction': {},
- 'training_enabled': True,
- 'inference_enabled': True,
- 'checkpoint_loaded': False,
- 'loss_metrics': {},
- 'timing_metrics': {},
- 'signal_stats': {}
- }
-
- # Check if decision fusion is actually enabled and working
- if hasattr(self.orchestrator, 'decision_fusion_enabled'):
- decision_data['status'] = 'active' if self.orchestrator.decision_fusion_enabled else 'registered'
-
- # Check if decision fusion network exists
- if hasattr(self.orchestrator, 'decision_fusion_network') and self.orchestrator.decision_fusion_network:
- decision_data['status'] = 'active'
- # Get network parameters
- if hasattr(self.orchestrator.decision_fusion_network, 'parameters'):
- decision_data['parameters'] = sum(p.numel() for p in self.orchestrator.decision_fusion_network.parameters())
-
- # Check decision fusion mode
- if hasattr(self.orchestrator, 'decision_fusion_mode'):
- decision_data['mode'] = self.orchestrator.decision_fusion_mode
- if self.orchestrator.decision_fusion_mode == 'neural':
- decision_data['status'] = 'active'
- elif self.orchestrator.decision_fusion_mode == 'programmatic':
- decision_data['status'] = 'active' # Still active, just using programmatic mode
-
- # Get decision fusion statistics
- if hasattr(self.orchestrator, 'get_decision_fusion_stats'):
- stats = self.orchestrator.get_decision_fusion_stats()
- if stats:
- decision_data['loss_metrics']['current_loss'] = stats.get('recent_loss')
- decision_data['timing_metrics']['decisions_per_second'] = stats.get('decisions_per_second', 0)
- decision_data['signal_stats'] = {
- 'buy_decisions': stats.get('buy_decisions', 0),
- 'sell_decisions': stats.get('sell_decisions', 0),
- 'hold_decisions': stats.get('hold_decisions', 0),
- 'total_decisions': stats.get('total_decisions', 0),
- 'consensus_rate': stats.get('consensus_rate', 0)
- }
-
- # Get decision fusion network parameters
- if hasattr(self.orchestrator, 'decision_fusion') and self.orchestrator.decision_fusion:
- if hasattr(self.orchestrator.decision_fusion, 'parameters'):
- decision_data['parameters'] = sum(p.numel() for p in self.orchestrator.decision_fusion.parameters())
-
- # Check for decision fusion checkpoint status
- if hasattr(self.orchestrator, 'model_states') and 'decision_fusion' in self.orchestrator.model_states:
- df_state = self.orchestrator.model_states['decision_fusion']
- decision_data['checkpoint_loaded'] = df_state.get('checkpoint_loaded', False)
-
- return decision_data
-
- except Exception as e:
- logger.error(f"Error extracting decision fusion data: {e}")
- return {'name': 'decision_fusion', 'status': 'error', 'error': str(e)}
-
- def _extract_cob_rl_data(self) -> Dict[str, Any]:
- """Extract data for the COB RL model"""
- try:
- cob_data = {
- 'name': 'cob_rl_model',
- 'status': 'registered', # Usually registered but not actively inferencing
- 'parameters': 0,
- 'last_prediction': {},
- 'training_enabled': True,
- 'inference_enabled': True,
- 'checkpoint_loaded': False,
- 'loss_metrics': {},
- 'timing_metrics': {},
- 'signal_stats': {}
- }
-
- # Check if COB RL has actual statistics
- if hasattr(self.orchestrator, 'get_model_statistics'):
- stats = self.orchestrator.get_model_statistics()
- if stats and 'cob_rl_model' in stats:
- cob_stats = stats['cob_rl_model']
- # Use the safe_get function from above
- def safe_get(obj, key, default=None):
- if hasattr(obj, key):
- return getattr(obj, key, default)
- elif isinstance(obj, dict):
- return obj.get(key, default)
- else:
- return default
-
- cob_data['parameters'] = safe_get(cob_stats, 'parameters', 356647429) # Known COB RL size
- cob_data['status'] = 'active' if safe_get(cob_stats, 'inferences_per_second', 0) > 0 else 'registered'
-
- # Extract metrics if available
- cob_data['loss_metrics'] = {
- 'current_loss': safe_get(cob_stats, 'current_loss'),
- 'best_loss': safe_get(cob_stats, 'best_loss'),
- }
-
- return cob_data
-
- except Exception as e:
- logger.error(f"Error extracting COB RL data: {e}")
- return {'name': 'cob_rl_model', 'status': 'error', 'error': str(e)}
-
- def _extract_training_status(self) -> Dict[str, Any]:
- """Extract overall training status"""
- try:
- status = {
- 'active_sessions': 0,
- 'total_training_steps': 0,
- 'is_training': False,
- 'last_update': 'N/A'
- }
-
- # Check if enhanced training system is available
- if hasattr(self.orchestrator, 'enhanced_training') and self.orchestrator.enhanced_training:
- enhanced_stats = self.orchestrator.enhanced_training.get_training_statistics()
- if enhanced_stats:
- status.update({
- 'is_training': enhanced_stats.get('is_training', False),
- 'training_iteration': enhanced_stats.get('training_iteration', 0),
- 'experience_buffer_size': enhanced_stats.get('experience_buffer_size', 0),
- 'last_update': datetime.now().strftime('%H:%M:%S')
- })
-
- return status
-
- except Exception as e:
- logger.error(f"Error extracting training status: {e}")
- return {'error': str(e)}
-
- def _extract_performance_metrics(self) -> Dict[str, Any]:
- """Extract performance metrics"""
- try:
- metrics = {
- 'decision_fusion_active': False,
- 'cob_integration_active': False,
- 'symbols_tracking': 0,
- 'recent_decisions': 0
- }
-
- # Check decision fusion status
- if hasattr(self.orchestrator, 'decision_fusion_enabled'):
- metrics['decision_fusion_active'] = self.orchestrator.decision_fusion_enabled
-
- # Check COB integration
- if hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration:
- metrics['cob_integration_active'] = True
- if hasattr(self.orchestrator.cob_integration, 'symbols'):
- metrics['symbols_tracking'] = len(self.orchestrator.cob_integration.symbols)
-
- return metrics
-
- except Exception as e:
- logger.error(f"Error extracting performance metrics: {e}")
- return {'error': str(e)}
-
- def _create_header(self) -> html.Div:
- """Create the panel header with title and refresh button"""
- return html.Div([
- html.H6([
- html.I(className="fas fa-brain me-2 text-primary"),
- "Models & Training Progress"
- ], className="mb-2"),
- html.Button([
- html.I(className="fas fa-sync-alt me-1"),
- "Refresh"
- ], id="refresh-training-metrics-btn", className="btn btn-sm btn-outline-primary mb-2")
- ], className="d-flex justify-content-between align-items-start")
-
- def _create_models_section(self, models_data: Dict[str, Any]) -> html.Div:
- """Create the models section showing each loaded model"""
- model_cards = []
-
- for model_name, model_data in models_data.items():
- if model_data.get('error'):
- # Error card
- model_cards.append(html.Div([
- html.Strong(f"{model_name.upper()}", className="text-danger"),
- html.P(f"Error: {model_data['error']}", className="text-danger small mb-0")
- ], className="border border-danger rounded p-2 mb-2"))
- else:
- model_cards.append(self._create_model_card(model_name, model_data))
-
- return html.Div([
- html.H6([
- html.I(className="fas fa-microchip me-2 text-success"),
- f"Loaded Models ({len(models_data)})"
- ], className="mb-2"),
- html.Div(model_cards)
- ])
-
- def _create_model_card(self, model_name: str, model_data: Dict[str, Any]) -> html.Div:
- """Create a card for a single model"""
- # Status styling
- status = model_data.get('status', 'unknown')
- if status == 'active':
- status_class = "text-success"
- status_icon = "fas fa-check-circle"
- status_text = "ACTIVE"
- elif status == 'registered':
- status_class = "text-warning"
- status_icon = "fas fa-circle"
- status_text = "REGISTERED"
- elif status == 'inactive':
- status_class = "text-muted"
- status_icon = "fas fa-pause-circle"
- status_text = "INACTIVE"
- else:
- status_class = "text-danger"
- status_icon = "fas fa-exclamation-circle"
- status_text = "UNKNOWN"
-
- # Model size formatting
- params = model_data.get('parameters', 0)
- if params > 1e9:
- size_str = f"{params/1e9:.1f}B"
- elif params > 1e6:
- size_str = f"{params/1e6:.1f}M"
- elif params > 1e3:
- size_str = f"{params/1e3:.1f}K"
- else:
- size_str = str(params)
-
- # Last prediction info
- last_pred = model_data.get('last_prediction', {})
- pred_action = last_pred.get('action', 'NONE')
- pred_confidence = last_pred.get('confidence', 0)
- pred_time = last_pred.get('timestamp', 'N/A')
-
- # Loss metrics
- loss_metrics = model_data.get('loss_metrics', {})
- current_loss = loss_metrics.get('current_loss')
- loss_class = "text-success" if current_loss and current_loss < 0.1 else "text-warning" if current_loss and current_loss < 0.5 else "text-danger"
-
- # Timing metrics
- timing = model_data.get('timing_metrics', {})
-
- return html.Div([
- # Header with model name and status
- html.Div([
- html.Div([
- html.I(className=f"{status_icon} me-2 {status_class}"),
- html.Strong(f"{model_name.upper()}", className=status_class),
- html.Span(f" - {status_text}", className=f"{status_class} small ms-1"),
- html.Span(f" ({size_str})", className="text-muted small ms-2"),
- # Show mode for decision fusion
- *([html.Span(f" [{model_data.get('mode', 'unknown').upper()}]", className="text-info small ms-1")] if model_name == 'decision_fusion' and model_data.get('mode') else []),
- html.Span(
- " [CKPT]" if model_data.get('checkpoint_loaded')
- else " [FAILED]" if model_data.get('checkpoint_failed')
- else " [FRESH]",
- className=f"small {'text-success' if model_data.get('checkpoint_loaded') else 'text-danger' if model_data.get('checkpoint_failed') else 'text-warning'} ms-1"
- )
- ], style={"flex": "1"}),
-
- # Toggle switches with pattern matching IDs
- html.Div([
- html.Div([
- html.Label("Inf", className="text-muted small me-1", style={"font-size": "10px"}),
- dcc.Checklist(
- id={'type': 'model-toggle', 'model': model_name, 'toggle_type': 'inference'},
- options=[{"label": "", "value": True}],
- value=[True] if model_data.get('inference_enabled', True) else [],
- className="form-check-input me-2",
- style={"transform": "scale(0.7)"}
- )
- ], className="d-flex align-items-center me-2"),
- html.Div([
- html.Label("Trn", className="text-muted small me-1", style={"font-size": "10px"}),
- dcc.Checklist(
- id={'type': 'model-toggle', 'model': model_name, 'toggle_type': 'training'},
- options=[{"label": "", "value": True}],
- value=[True] if model_data.get('training_enabled', True) else [],
- className="form-check-input",
- style={"transform": "scale(0.7)"}
- )
- ], className="d-flex align-items-center")
- ], className="d-flex")
- ], className="d-flex align-items-center mb-2"),
-
- # Model metrics
- html.Div([
- # Last prediction
- html.Div([
- html.Span("Last: ", className="text-muted small"),
- html.Span(f"{pred_action}",
- className=f"small fw-bold {'text-success' if pred_action == 'BUY' else 'text-danger' if pred_action == 'SELL' else 'text-warning'}"),
- html.Span(f" ({pred_confidence:.1f}%)", className="text-muted small"),
- html.Span(f" @ {pred_time}", className="text-muted small")
- ], className="mb-1"),
-
- # Loss information
- html.Div([
- html.Span("Loss: ", className="text-muted small"),
- html.Span(f"{current_loss:.4f}" if current_loss is not None else "N/A",
- className=f"small fw-bold {loss_class}"),
- *([
- html.Span(" | Best: ", className="text-muted small"),
- html.Span(f"{loss_metrics.get('best_loss', 0):.4f}", className="text-success small")
- ] if loss_metrics.get('best_loss') is not None else [])
- ], className="mb-1"),
-
- # Timing information
- html.Div([
- html.Span("Rate: ", className="text-muted small"),
- html.Span(f"{timing.get('inferences_per_second', 0):.2f}/s", className="text-info small"),
- html.Span(" | 24h: ", className="text-muted small"),
- html.Span(f"{timing.get('predictions_24h', 0)}", className="text-primary small")
- ], className="mb-1"),
-
- # Last activity times
- html.Div([
- html.Span("Last Inf: ", className="text-muted small"),
- html.Span(f"{timing.get('last_inference', 'N/A')}", className="text-info small"),
- html.Span(" | Train: ", className="text-muted small"),
- html.Span(f"{timing.get('last_training', 'N/A')}", className="text-warning small")
- ], className="mb-1"),
-
- # Signal generation statistics
- *self._create_signal_stats_display(model_data.get('signal_stats', {})),
-
- # Performance metrics
- *self._create_performance_metrics_display(model_data)
- ])
- ], className="border rounded p-2 mb-2",
- style={"backgroundColor": "rgba(255,255,255,0.05)" if status == 'active' else "rgba(128,128,128,0.1)"})
-
- def _create_no_models_message(self) -> html.Div:
- """Create message when no models are loaded"""
- return html.Div([
- html.H6([
- html.I(className="fas fa-exclamation-triangle me-2 text-warning"),
- "No Models Loaded"
- ], className="mb-2"),
- html.P("No machine learning models are currently loaded. Check orchestrator status.",
- className="text-muted small")
- ])
-
- def _create_training_status_section(self, training_status: Dict[str, Any]) -> html.Div:
- """Create the training status section"""
- if training_status.get('error'):
- return html.Div([
- html.Hr(),
- html.H6([
- html.I(className="fas fa-exclamation-triangle me-2 text-danger"),
- "Training Status Error"
- ], className="mb-2"),
- html.P(f"Error: {training_status['error']}", className="text-danger small")
- ])
-
- is_training = training_status.get('is_training', False)
-
- return html.Div([
- html.Hr(),
- html.H6([
- html.I(className="fas fa-brain me-2 text-secondary"),
- "Training Status"
- ], className="mb-2"),
-
- html.Div([
- html.Span("Status: ", className="text-muted small"),
- html.Span("ACTIVE" if is_training else "INACTIVE",
- className=f"small fw-bold {'text-success' if is_training else 'text-warning'}"),
- html.Span(f" | Iteration: {training_status.get('training_iteration', 0):,}",
- className="text-info small ms-2")
- ], className="mb-1"),
-
- html.Div([
- html.Span("Buffer: ", className="text-muted small"),
- html.Span(f"{training_status.get('experience_buffer_size', 0):,}",
- className="text-success small"),
- html.Span(" | Updated: ", className="text-muted small"),
- html.Span(f"{training_status.get('last_update', 'N/A')}",
- className="text-muted small")
- ], className="mb-0")
- ])
-
- def _create_performance_section(self, performance_metrics: Dict[str, Any]) -> html.Div:
- """Create the performance metrics section"""
- if performance_metrics.get('error'):
- return html.Div([
- html.Hr(),
- html.P(f"Performance metrics error: {performance_metrics['error']}",
- className="text-danger small")
- ])
-
- return html.Div([
- html.Hr(),
- html.H6([
- html.I(className="fas fa-chart-line me-2 text-primary"),
- "System Performance"
- ], className="mb-2"),
-
- html.Div([
- html.Span("Decision Fusion: ", className="text-muted small"),
- html.Span("ON" if performance_metrics.get('decision_fusion_active') else "OFF",
- className=f"small {'text-success' if performance_metrics.get('decision_fusion_active') else 'text-muted'}"),
- html.Span(" | COB: ", className="text-muted small"),
- html.Span("ON" if performance_metrics.get('cob_integration_active') else "OFF",
- className=f"small {'text-success' if performance_metrics.get('cob_integration_active') else 'text-muted'}")
- ], className="mb-1"),
-
- html.Div([
- html.Span("Tracking: ", className="text-muted small"),
- html.Span(f"{performance_metrics.get('symbols_tracking', 0)} symbols",
- className="text-info small"),
- html.Span(" | Decisions: ", className="text-muted small"),
- html.Span(f"{performance_metrics.get('recent_decisions', 0):,}",
- className="text-primary small")
- ], className="mb-0")
- ])
-
- def _create_signal_stats_display(self, signal_stats: Dict[str, Any]) -> List[html.Div]:
- """Create display elements for signal generation statistics"""
- if not signal_stats or not any(signal_stats.values()):
- return []
-
- buy_signals = signal_stats.get('buy_signals', 0)
- sell_signals = signal_stats.get('sell_signals', 0)
- hold_signals = signal_stats.get('hold_signals', 0)
- total_signals = signal_stats.get('total_signals', 0)
-
- if total_signals == 0:
- return []
-
- # Calculate percentages - ensure all values are numeric
- buy_signals = buy_signals or 0
- sell_signals = sell_signals or 0
- hold_signals = hold_signals or 0
- total_signals = total_signals or 0
-
- buy_pct = (buy_signals / total_signals * 100) if total_signals > 0 else 0
- sell_pct = (sell_signals / total_signals * 100) if total_signals > 0 else 0
- hold_pct = (hold_signals / total_signals * 100) if total_signals > 0 else 0
-
- return [
- html.Div([
- html.Span("Signals: ", className="text-muted small"),
- html.Span(f"B:{buy_signals}({buy_pct:.0f}%)", className="text-success small"),
- html.Span(" | ", className="text-muted small"),
- html.Span(f"S:{sell_signals}({sell_pct:.0f}%)", className="text-danger small"),
- html.Span(" | ", className="text-muted small"),
- html.Span(f"H:{hold_signals}({hold_pct:.0f}%)", className="text-warning small")
- ], className="mb-1"),
-
- html.Div([
- html.Span("Total: ", className="text-muted small"),
- html.Span(f"{total_signals:,}", className="text-primary small fw-bold"),
- *([
- html.Span(" | Accuracy: ", className="text-muted small"),
- html.Span(f"{signal_stats.get('accuracy', 0):.1f}%",
- className=f"small fw-bold {'text-success' if signal_stats.get('accuracy', 0) > 60 else 'text-warning' if signal_stats.get('accuracy', 0) > 40 else 'text-danger'}")
- ] if signal_stats.get('accuracy', 0) > 0 else [])
- ], className="mb-1")
- ]
-
- def _create_performance_metrics_display(self, model_data: Dict[str, Any]) -> List[html.Div]:
- """Create display elements for performance metrics"""
- elements = []
-
- # Win rate and accuracy
- signal_stats = model_data.get('signal_stats', {})
- loss_metrics = model_data.get('loss_metrics', {})
-
- # Safely get numeric values
- win_rate = signal_stats.get('win_rate', 0) or 0
- accuracy = signal_stats.get('accuracy', 0) or 0
-
- if win_rate > 0 or accuracy > 0:
-
- elements.append(html.Div([
- html.Span("Performance: ", className="text-muted small"),
- *([
- html.Span(f"Win: {win_rate:.1f}%",
- className=f"small fw-bold {'text-success' if win_rate > 55 else 'text-warning' if win_rate > 45 else 'text-danger'}"),
- html.Span(" | ", className="text-muted small")
- ] if win_rate > 0 else []),
- *([
- html.Span(f"Acc: {accuracy:.1f}%",
- className=f"small fw-bold {'text-success' if accuracy > 60 else 'text-warning' if accuracy > 40 else 'text-danger'}")
- ] if accuracy > 0 else [])
- ], className="mb-1"))
-
- # Loss improvement
- if loss_metrics.get('improvement', 0) != 0:
- improvement = loss_metrics.get('improvement', 0)
- elements.append(html.Div([
- html.Span("Improvement: ", className="text-muted small"),
- html.Span(f"{improvement:+.1f}%",
- className=f"small fw-bold {'text-success' if improvement > 0 else 'text-danger'}")
- ], className="mb-1"))
-
- return elements
\ No newline at end of file
diff --git a/web/template_renderer.py b/web/template_renderer.py
deleted file mode 100644
index 30ebac8..0000000
--- a/web/template_renderer.py
+++ /dev/null
@@ -1,384 +0,0 @@
-"""
-Template Renderer for Dashboard
-Handles HTML template rendering with Jinja2
-"""
-import os
-from typing import Dict, Any
-from jinja2 import Environment, FileSystemLoader, select_autoescape
-from dash import html, dcc
-import plotly.graph_objects as go
-
-from .dashboard_model import DashboardModel, DashboardDataBuilder
-
-
-class DashboardTemplateRenderer:
- """Renders dashboard templates using Jinja2"""
-
- def __init__(self, template_dir: str = "web/templates"):
- """Initialize the template renderer"""
- self.template_dir = template_dir
-
- # Create Jinja2 environment
- self.env = Environment(
- loader=FileSystemLoader(template_dir),
- autoescape=select_autoescape(['html', 'xml'])
- )
-
- # Add custom filters
- self.env.filters['currency'] = self._currency_filter
- self.env.filters['percentage'] = self._percentage_filter
- self.env.filters['number'] = self._number_filter
-
- def render_dashboard(self, model: DashboardModel) -> html.Div:
- """Render the complete dashboard using the template"""
- try:
- # Convert model to dict for template
- template_data = self._model_to_dict(model)
-
- # Render template
- template = self.env.get_template('dashboard.html')
- rendered_html = template.render(**template_data)
-
- # Convert to Dash components
- return self._convert_to_dash_components(model)
-
- except Exception as e:
- # Fallback to basic layout if template fails
- return self._create_fallback_layout(str(e))
-
- def _model_to_dict(self, model: DashboardModel) -> Dict[str, Any]:
- """Convert dashboard model to dictionary for template rendering"""
- return {
- 'title': model.title,
- 'subtitle': model.subtitle,
- 'refresh_interval': model.refresh_interval,
- 'metrics': [self._dataclass_to_dict(m) for m in model.metrics],
- 'chart': self._dataclass_to_dict(model.chart),
- 'trading_controls': self._dataclass_to_dict(model.trading_controls),
- 'recent_decisions': [self._dataclass_to_dict(d) for d in model.recent_decisions],
- 'cob_data': [self._dataclass_to_dict(c) for c in model.cob_data],
- 'models': [self._dataclass_to_dict(m) for m in model.models],
- 'training_metrics': [self._dataclass_to_dict(m) for m in model.training_metrics],
- 'performance_stats': [self._dataclass_to_dict(s) for s in model.performance_stats],
- 'closed_trades': [self._dataclass_to_dict(t) for t in model.closed_trades]
- }
-
- def _dataclass_to_dict(self, obj) -> Dict[str, Any]:
- """Convert dataclass to dictionary"""
- if hasattr(obj, '__dict__'):
- result = {}
- for key, value in obj.__dict__.items():
- if hasattr(value, '__dict__'):
- result[key] = self._dataclass_to_dict(value)
- elif isinstance(value, list):
- result[key] = [self._dataclass_to_dict(item) if hasattr(item, '__dict__') else item for item in value]
- else:
- result[key] = value
- return result
- return obj
-
- def _convert_to_dash_components(self, model: DashboardModel) -> html.Div:
- """Convert template model to Dash components"""
- return html.Div([
- # Header
- html.Div([
- html.H1(model.title, className="text-center"),
- html.P(model.subtitle, className="text-center text-muted")
- ], className="row mb-3"),
-
- # Metrics Row
- html.Div([
- html.Div([
- self._create_metric_card(metric)
- ], className="col-md-2") for metric in model.metrics
- ], className="row mb-3"),
-
- # Main Content Row
- html.Div([
- # Price Chart
- html.Div([
- html.Div([
- html.Div([
- html.H5(model.chart.title)
- ], className="card-header"),
- html.Div([
- dcc.Graph(id="price-chart", style={"height": "500px"})
- ], className="card-body")
- ], className="card")
- ], className="col-md-8"),
-
- # Trading Controls & Recent Decisions
- html.Div([
- # Trading Controls
- self._create_trading_controls(model.trading_controls),
- # Recent Decisions
- self._create_recent_decisions(model.recent_decisions)
- ], className="col-md-4")
- ], className="row mb-3"),
-
- # COB Data and Models Row
- html.Div([
- # COB Ladders
- html.Div([
- html.Div([
- html.Div([
- self._create_cob_card(cob)
- ], className="col-md-6") for cob in model.cob_data
- ], className="row")
- ], className="col-md-7"),
-
- # Models & Training
- html.Div([
- self._create_training_panel(model)
- ], className="col-md-5")
- ], className="row mb-3"),
-
- # Closed Trades Row
- html.Div([
- html.Div([
- self._create_closed_trades_table(model.closed_trades)
- ], className="col-12")
- ], className="row"),
-
- # Auto-refresh interval
- dcc.Interval(id='interval-component', interval=model.refresh_interval, n_intervals=0)
-
- ], className="container-fluid")
-
- def _create_metric_card(self, metric) -> html.Div:
- """Create a metric card component"""
- return html.Div([
- html.Div(metric.value, className="metric-value", id=metric.id),
- html.Div(metric.label, className="metric-label")
- ], className="metric-card")
-
- def _create_trading_controls(self, controls) -> html.Div:
- """Create trading controls component"""
- return html.Div([
- html.Div([
- html.H6("Manual Trading")
- ], className="card-header"),
- html.Div([
- html.Div([
- html.Div([
- html.Button(controls.buy_text, id="manual-buy-btn",
- className="btn btn-success w-100")
- ], className="col-6"),
- html.Div([
- html.Button(controls.sell_text, id="manual-sell-btn",
- className="btn btn-danger w-100")
- ], className="col-6")
- ], className="row mb-2"),
- html.Div([
- html.Div([
- html.Label([
- f"Leverage: ",
- html.Span(f"{controls.leverage}x", id="leverage-display")
- ], className="form-label"),
- dcc.Slider(
- id="leverage-slider",
- min=controls.leverage_min,
- max=controls.leverage_max,
- value=controls.leverage,
- step=1,
- marks={i: str(i) for i in range(controls.leverage_min, controls.leverage_max + 1, 10)}
- )
- ], className="col-12")
- ], className="row mb-2"),
- html.Div([
- html.Div([
- html.Button(controls.clear_text, id="clear-session-btn",
- className="btn btn-warning w-100")
- ], className="col-12")
- ], className="row")
- ], className="card-body")
- ], className="card mb-3")
-
- def _create_recent_decisions(self, decisions) -> html.Div:
- """Create recent decisions component"""
- decision_items = []
- for decision in decisions:
- border_class = {
- 'BUY': 'border-success bg-success bg-opacity-10',
- 'SELL': 'border-danger bg-danger bg-opacity-10'
- }.get(decision.action, 'border-secondary bg-secondary bg-opacity-10')
-
- decision_items.append(
- html.Div([
- html.Small(decision.timestamp, className="text-muted"),
- html.Br(),
- html.Strong(f"{decision.action} - {decision.symbol}"),
- html.Br(),
- html.Small(f"Confidence: {decision.confidence}% | Price: ${decision.price}")
- ], className=f"mb-2 p-2 border-start border-3 {border_class}")
- )
-
- return html.Div([
- html.Div([
- html.H6("Recent AI Decisions")
- ], className="card-header"),
- html.Div([
- html.Div(decision_items, id="recent-decisions")
- ], className="card-body", style={"max-height": "300px", "overflow-y": "auto"})
- ], className="card")
-
- def _create_cob_card(self, cob) -> html.Div:
- """Create COB ladder card"""
- return html.Div([
- html.Div([
- html.H6(f"{cob.symbol} Order Book"),
- html.Small(f"Total: {cob.total_usd} USD | {cob.total_crypto} {cob.symbol.split('/')[0]}",
- className="text-muted")
- ], className="card-header"),
- html.Div([
- html.Div(id=cob.content_id, className="cob-ladder")
- ], className="card-body p-2")
- ], className="card")
-
- def _create_training_panel(self, model: DashboardModel) -> html.Div:
- """Create training panel component"""
- # Model status indicators
- model_status_items = []
- for model_item in model.models:
- status_class = f"status-{model_item.status}"
- model_status_items.append(
- html.Span(f"{model_item.name}: {model_item.status_text}",
- className=f"model-status {status_class}")
- )
-
- # Training metrics
- training_items = []
- for metric in model.training_metrics:
- training_items.append(
- html.Div([
- html.Div([
- html.Small(f"{metric.name}:")
- ], className="col-6"),
- html.Div([
- html.Small(metric.value, className="fw-bold")
- ], className="col-6")
- ], className="row mb-1")
- )
-
- # Performance stats
- performance_items = []
- for stat in model.performance_stats:
- performance_items.append(
- html.Div([
- html.Div([
- html.Small(f"{stat.name}:")
- ], className="col-8"),
- html.Div([
- html.Small(stat.value, className="fw-bold")
- ], className="col-4")
- ], className="row mb-1")
- )
-
- return html.Div([
- html.Div([
- html.H6("Models & Training Progress")
- ], className="card-header"),
- html.Div([
- html.Div([
- # Model Status
- html.Div([
- html.H6("Model Status"),
- html.Div(model_status_items)
- ], className="mb-3"),
-
- # Training Metrics
- html.Div([
- html.H6("Training Metrics"),
- html.Div(training_items, id="training-metrics")
- ], className="mb-3"),
-
- # Performance Stats
- html.Div([
- html.H6("Performance"),
- html.Div(performance_items)
- ], className="mb-3")
- ])
- ], className="card-body training-panel")
- ], className="card")
-
- def _create_closed_trades_table(self, trades) -> html.Div:
- """Create closed trades table"""
- trade_rows = []
- for trade in trades:
- pnl_class = "trade-profit" if trade.pnl > 0 else "trade-loss"
- side_class = "bg-success" if trade.side == "BUY" else "bg-danger"
-
- trade_rows.append(
- html.Tr([
- html.Td(trade.time),
- html.Td(trade.symbol),
- html.Td([
- html.Span(trade.side, className=f"badge {side_class}")
- ]),
- html.Td(trade.size),
- html.Td(trade.entry_price),
- html.Td(trade.exit_price),
- html.Td(f"${trade.pnl}", className=pnl_class),
- html.Td(trade.duration)
- ])
- )
-
- return html.Div([
- html.Div([
- html.H6("Recent Closed Trades")
- ], className="card-header"),
- html.Div([
- html.Div([
- html.Table([
- html.Thead([
- html.Tr([
- html.Th("Time"),
- html.Th("Symbol"),
- html.Th("Side"),
- html.Th("Size"),
- html.Th("Entry"),
- html.Th("Exit"),
- html.Th("PnL"),
- html.Th("Duration")
- ])
- ]),
- html.Tbody(trade_rows)
- ], className="table table-sm", id="closed-trades-table")
- ])
- ], className="card-body closed-trades")
- ], className="card")
-
- def _create_fallback_layout(self, error_msg: str) -> html.Div:
- """Create fallback layout if template rendering fails"""
- return html.Div([
- html.Div([
- html.H1("Dashboard Error", className="text-center text-danger"),
- html.P(f"Template rendering failed: {error_msg}", className="text-center"),
- html.P("Using fallback layout.", className="text-center text-muted")
- ], className="container mt-5")
- ])
-
- # Jinja2 custom filters
- def _currency_filter(self, value) -> str:
- """Format value as currency"""
- try:
- return f"${float(value):,.4f}"
- except (ValueError, TypeError):
- return str(value)
-
- def _percentage_filter(self, value) -> str:
- """Format value as percentage"""
- try:
- return f"{float(value):.2f}%"
- except (ValueError, TypeError):
- return str(value)
-
- def _number_filter(self, value) -> str:
- """Format value as number"""
- try:
- if isinstance(value, int):
- return f"{value:,}"
- else:
- return f"{float(value):,.2f}"
- except (ValueError, TypeError):
- return str(value)
\ No newline at end of file
diff --git a/web/templated_dashboard.py b/web/templated_dashboard.py
deleted file mode 100644
index cce2222..0000000
--- a/web/templated_dashboard.py
+++ /dev/null
@@ -1,1258 +0,0 @@
-"""
-Template-based Trading Dashboard
-Uses MVC architecture with HTML templates and data models
-"""
-import logging
-import sys
-import os
-from typing import Optional, Any, Dict, List, Deque
-from datetime import datetime, timedelta
-import pandas as pd
-import pytz
-import time
-import threading
-from collections import deque
-from dataclasses import asdict
-
-import dash
-from dash import dcc, html, Input, Output, State, callback_context
-import plotly.graph_objects as go
-import plotly.express as px
-
-from core.data_provider import DataProvider
-from core.orchestrator import TradingOrchestrator
-from core.trading_executor import TradingExecutor
-from core.config import get_config
-from core.universal_data_adapter import UniversalDataAdapter, UniversalDataStream
-from web.dashboard_model import DashboardModel, DashboardDataBuilder, create_sample_dashboard_data
-from web.template_renderer import DashboardTemplateRenderer
-from web.component_manager import DashboardComponentManager
-from web.layout_manager import DashboardLayoutManager
-from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
-from NN.models.advanced_transformer_trading import create_trading_transformer, TradingTransformerConfig
-
-# Configure logging
-logger = logging.getLogger(__name__)
-
-
-class TemplatedTradingDashboard:
- """Template-based trading dashboard with MVC architecture"""
-
- def __init__(self, data_provider: Optional[DataProvider] = None,
- orchestrator: Optional[TradingOrchestrator] = None,
- trading_executor: Optional[TradingExecutor] = None):
- """Initialize the templated dashboard"""
- self.config = get_config()
-
- # Initialize components
- self.data_provider = data_provider or DataProvider()
- self.trading_executor = trading_executor or TradingExecutor()
-
- # Initialize template renderer
- self.renderer = DashboardTemplateRenderer()
-
- # Initialize unified orchestrator with full ML capabilities
- if orchestrator is None:
- self.orchestrator = TradingOrchestrator(
- data_provider=self.data_provider,
- enhanced_rl_training=True,
- model_registry={}
- )
- logger.info("TEMPLATED DASHBOARD: Using unified Trading Orchestrator with full ML capabilities")
- else:
- self.orchestrator = orchestrator
-
- # Initialize enhanced training system for predictions
- self.training_system = None
- self._initialize_enhanced_training_system()
-
- # Initialize layout and component managers
- self.layout_manager = DashboardLayoutManager(
- starting_balance=self._get_initial_balance(),
- trading_executor=self.trading_executor
- )
- self.component_manager = DashboardComponentManager()
-
- # Initialize Universal Data Stream for the 5 timeseries architecture
- self.universal_adapter = UniversalDataAdapter(self.data_provider)
- # Data access now through orchestrator instead of complex stream management
- logger.debug("Universal Data Adapter initialized - accessing data through orchestrator")
- logger.info(f"TEMPLATED DASHBOARD: Universal Data Stream initialized with consumer ID: {self.stream_consumer_id}")
- logger.info("TEMPLATED DASHBOARD: Subscribed to Universal 5 Timeseries: ETH(ticks,1m,1h,1d) + BTC(ticks)")
-
- # Dashboard state
- self.recent_decisions: list = []
- self.closed_trades: list = []
- self.current_prices: dict = {}
- self.session_pnl = 0.0
- self.total_fees = 0.0
- self.current_position: Optional[float] = 0.0
- self.session_trades: list = []
-
- # Model control toggles - separate inference and training
- self.dqn_inference_enabled = True # Default: enabled
- self.dqn_training_enabled = True # Default: enabled
- self.cnn_inference_enabled = True
- self.cnn_training_enabled = True
-
- # Leverage management - adjustable x1 to x100
- self.current_leverage = 50 # Default x50 leverage
- self.min_leverage = 1
- self.max_leverage = 100
- self.pending_trade_case_id = None # For tracking opening trades until closure
-
- # WebSocket streaming
- self.ws_price_cache: dict = {}
- self.is_streaming = False
- self.tick_cache: list = []
-
- # COB data cache - enhanced with price buckets and memory system
- self.cob_cache: dict = {
- 'ETH/USDT': {'last_update': 0, 'data': None, 'updates_count': 0},
- 'BTC/USDT': {'last_update': 0, 'data': None, 'updates_count': 0}
- }
- self.latest_cob_data: dict = {} # Cache for COB integration data
- self.cob_predictions: dict = {} # Cache for COB predictions (both ETH and BTC for display)
-
- # COB High-frequency data handling (50-100 updates/sec)
- self.cob_data_buffer: dict = {} # Buffer for high-freq data
- self.cob_memory: dict = {} # Memory system like GPT - keeps last N snapshots
- self.cob_price_buckets: dict = {} # Price bucket cache
- self.cob_update_count = 0
- self.last_cob_broadcast: Dict[str, Optional[float]] = {'ETH/USDT': None, 'BTC/USDT': None} # Rate limiting for UI updates, updated type
- self.cob_data_history: Dict[str, Deque[Any]] = {
- 'ETH/USDT': deque(maxlen=61), # Store ~60 seconds of 1s snapshots
- 'BTC/USDT': deque(maxlen=61)
- }
-
- # Initialize timezone
- timezone_name = self.config.get('system', {}).get('timezone', 'Europe/Sofia')
- self.timezone = pytz.timezone(timezone_name)
-
- # Create Dash app
- self.app = dash.Dash(__name__, external_stylesheets=[
- 'https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css',
- 'https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css'
- ])
-
- # Suppress Dash development mode logging
- self.app.enable_dev_tools(debug=False, dev_tools_silence_routes_logging=True)
-
- # Setup layout and callbacks
- self._setup_layout()
- self._setup_callbacks()
-
- # Start data streams
- self._initialize_streaming()
-
- # Connect to orchestrator for real trading signals
- self._connect_to_orchestrator()
-
- # Initialize COB integration with high-frequency data handling
- self._initialize_cob_integration()
-
- # Start signal generation loop to ensure continuous trading signals
- self._start_signal_generation_loop()
-
- # Start training sessions if models are showing FRESH status
- threading.Thread(target=self._delayed_training_check, daemon=True).start()
-
- logger.info("TEMPLATED DASHBOARD: Initialized with HIGH-FREQUENCY COB integration and signal generation")
-
- def _setup_layout(self):
- """Setup the dashboard layout using templates"""
- # Create initial dashboard data
- dashboard_data = self._build_dashboard_data()
-
- # Render layout using template
- layout = self.renderer.render_dashboard(dashboard_data)
-
- # Custom CSS will be handled via external stylesheets
-
- self.app.layout = layout
-
- def _get_initial_balance(self) -> float:
- """Get initial balance from trading executor or default"""
- try:
- if self.trading_executor and hasattr(self.trading_executor, 'starting_balance'):
- balance = getattr(self.trading_executor, 'starting_balance', None)
- if balance and balance > 0:
- return balance
- except Exception as e:
- logger.warning(f"Error getting balance: {e}")
- return 100.0 # Default balance
-
- def _setup_callbacks(self):
- """Setup dashboard callbacks"""
-
- @self.app.callback(
- [Output('current-price', 'children'),
- Output('session-pnl', 'children'),
- Output('current-position', 'children'),
- Output('trade-count', 'children'),
- Output('portfolio-value', 'children'),
- Output('mexc-status', 'children')],
- [Input('interval-component', 'n_intervals')]
- )
- def update_metrics(n):
- """Update main metrics"""
- try:
- # Get current price
- current_price = self._get_current_price("ETH/USDT")
-
- # Calculate portfolio value
- portfolio_value = 10000.0 + self.session_pnl # Base + PnL
-
- # Get MEXC status
- mexc_status = "Connected" if self.trading_executor else "Disconnected"
-
- return (
- f"${current_price:.4f}" if current_price else "N/A",
- f"${self.session_pnl:.2f}",
- f"{self.current_position:.4f}",
- str(len(self.session_trades)),
- f"${portfolio_value:.2f}",
- mexc_status
- )
- except Exception as e:
- logger.error(f"Error updating metrics: {e}")
- return "N/A", "N/A", "N/A", "N/A", "N/A", "Error"
-
- @self.app.callback(
- Output('price-chart', 'figure'),
- [Input('interval-component', 'n_intervals')]
- )
- def update_price_chart(n):
- """Update price chart"""
- try:
- return self._create_price_chart("ETH/USDT")
- except Exception as e:
- logger.error(f"Error updating chart: {e}")
- return go.Figure()
-
- @self.app.callback(
- Output('recent-decisions', 'children'),
- [Input('interval-component', 'n_intervals')]
- )
- def update_recent_decisions(n):
- """Update recent AI decisions"""
- try:
- decisions = self._get_recent_decisions()
- return self._render_decisions(decisions)
- except Exception as e:
- logger.error(f"Error updating decisions: {e}")
- return html.Div("No recent decisions")
-
- @self.app.callback(
- [Output('eth-cob-content', 'children'),
- Output('btc-cob-content', 'children')],
- [Input('interval-component', 'n_intervals')]
- )
- def update_cob_data(n):
- """Update COB data"""
- try:
- eth_cob = self._render_cob_ladder("ETH/USDT")
- btc_cob = self._render_cob_ladder("BTC/USDT")
- return eth_cob, btc_cob
- except Exception as e:
- logger.error(f"Error updating COB: {e}")
- return html.Div("COB Error"), html.Div("COB Error")
-
- @self.app.callback(
- Output('training-metrics', 'children'),
- [Input('interval-component', 'n_intervals')]
- )
- def update_training_metrics(n):
- """Update training metrics"""
- try:
- return self._render_training_metrics()
- except Exception as e:
- logger.error(f"Error updating training metrics: {e}")
- return html.Div("Training metrics unavailable")
-
- @self.app.callback(
- Output('closed-trades-table', 'children'),
- [Input('interval-component', 'n_intervals')]
- )
- def update_closed_trades(n):
- """Update closed trades table"""
- try:
- # Return the table wrapped in a Div
- return html.Div(self._render_closed_trades())
- except Exception as e:
- logger.error(f"Error updating closed trades: {e}")
- return html.Div("No trades")
-
- # Trading control callbacks
- @self.app.callback(
- Output('manual-buy-btn', 'children'),
- [Input('manual-buy-btn', 'n_clicks')],
- prevent_initial_call=True
- )
- def handle_manual_buy(n_clicks):
- """Handle manual buy button"""
- if n_clicks:
- self._execute_manual_trade("BUY")
- return "BUY ✓"
- return "BUY"
-
- @self.app.callback(
- Output('manual-sell-btn', 'children'),
- [Input('manual-sell-btn', 'n_clicks')],
- prevent_initial_call=True
- )
- def handle_manual_sell(n_clicks):
- """Handle manual sell button"""
- if n_clicks:
- self._execute_manual_trade("SELL")
- return "SELL ✓"
- return "SELL"
-
- @self.app.callback(
- Output('leverage-display', 'children'),
- [Input('leverage-slider', 'value')]
- )
- def update_leverage_display(leverage_value):
- """Update leverage display"""
- return f"{leverage_value}x"
-
- @self.app.callback(
- Output('clear-session-btn', 'children'),
- [Input('clear-session-btn', 'n_clicks')],
- prevent_initial_call=True
- )
- def handle_clear_session(n_clicks):
- """Handle clear session button"""
- if n_clicks:
- self._clear_session()
- return "Cleared ✓"
- return "Clear Session"
-
- def _build_dashboard_data(self) -> DashboardModel:
- """Build dashboard data model from current state"""
- builder = DashboardDataBuilder()
-
- # Basic info
- builder.set_basic_info(
- title="Live Scalping Dashboard (Templated)",
- subtitle="Template-based MVC Architecture",
- refresh_interval=1000
- )
-
- # Get current metrics
- current_price = self._get_current_price("ETH/USDT")
- portfolio_value = 10000.0 + self.session_pnl
- mexc_status = "Connected" if self.trading_executor else "Disconnected"
-
- # Add metrics
- builder.add_metric("current-price", "Current Price", current_price or 0, "currency")
- builder.add_metric("session-pnl", "Session PnL", self.session_pnl, "currency")
- builder.add_metric("current-position", "Position", self.current_position, "number")
- builder.add_metric("trade-count", "Trades", len(self.session_trades), "number")
- builder.add_metric("portfolio-value", "Portfolio", portfolio_value, "currency")
- builder.add_metric("mexc-status", "MEXC Status", mexc_status, "text")
-
- # Trading controls
- builder.set_trading_controls(leverage=10, leverage_range=(1, 50))
-
- # Recent decisions (sample data for now)
- builder.add_recent_decision(datetime.now(), "BUY", "ETH/USDT", 0.85, current_price or 3425.67)
-
- # COB data (sample)
- builder.add_cob_data("ETH/USDT", "eth-cob-content", 25000.0, 7.3, [])
- builder.add_cob_data("BTC/USDT", "btc-cob-content", 35000.0, 0.88, [])
-
- # Model statuses
- builder.add_model_status("DQN", True)
- builder.add_model_status("CNN", True)
- builder.add_model_status("Transformer", False)
- builder.add_model_status("COB-RL", True)
-
- # Training metrics
- builder.add_training_metric("DQN Loss", 0.0234)
- builder.add_training_metric("CNN Accuracy", 0.876)
- builder.add_training_metric("Training Steps", 15420)
-
- # Performance stats
- builder.add_performance_stat("Win Rate", 68.5)
- builder.add_performance_stat("Avg Trade", 8.34)
- builder.add_performance_stat("Sharpe Ratio", 1.82)
-
- return builder.build()
-
- def _get_current_price(self, symbol: str) -> Optional[float]:
- """Get current price for symbol"""
- try:
- if self.data_provider:
- return self.data_provider.get_current_price(symbol)
- return 3425.67 # Sample price
- except Exception as e:
- logger.error(f"Error getting price for {symbol}: {e}")
- return None
-
- def _create_price_chart(self, symbol: str) -> go.Figure:
- """Create price chart"""
- try:
- # Get price data
- df = self._get_chart_data(symbol)
-
- if df is None or df.empty:
- return go.Figure().add_annotation(
- text="No data available",
- xref="paper", yref="paper",
- x=0.5, y=0.5, showarrow=False
- )
-
- # Create candlestick chart
- fig = go.Figure(data=[go.Candlestick(
- x=df.index,
- open=df['open'],
- high=df['high'],
- low=df['low'],
- close=df['close'],
- name=symbol
- )])
-
- fig.update_layout(
- title=f"{symbol} Price Chart",
- xaxis_title="Time",
- yaxis_title="Price (USDT)",
- height=500,
- showlegend=False
- )
-
- return fig
-
- except Exception as e:
- logger.error(f"Error creating chart for {symbol}: {e}")
- return go.Figure()
-
- def _get_chart_data(self, symbol: str) -> Optional[pd.DataFrame]:
- """Get chart data for symbol"""
- try:
- if self.data_provider:
- return self.data_provider.get_historical_data(symbol, "1m", 100)
-
- # Sample data
- import numpy as np
- dates = pd.date_range(start='2024-01-01', periods=100, freq='1min')
- base_price = 3425.67
-
- df = pd.DataFrame({
- 'open': base_price + np.random.randn(100) * 10,
- 'high': base_price + np.random.randn(100) * 15,
- 'low': base_price + np.random.randn(100) * 15,
- 'close': base_price + np.random.randn(100) * 10,
- 'volume': np.random.randint(100, 1000, 100)
- }, index=dates)
-
- return df
-
- except Exception as e:
- logger.error(f"Error getting chart data: {e}")
- return None
-
- def _get_recent_decisions(self) -> List[Dict]:
- """Get recent AI decisions"""
- # Sample decisions for now
- return [
- {
- "timestamp": datetime.now().strftime("%H:%M:%S"),
- "action": "BUY",
- "symbol": "ETH/USDT",
- "confidence": 85.3,
- "price": 3425.67
- },
- {
- "timestamp": datetime.now().strftime("%H:%M:%S"),
- "action": "HOLD",
- "symbol": "BTC/USDT",
- "confidence": 62.1,
- "price": 45123.45
- }
- ]
-
- def _render_decisions(self, decisions: List[Dict]) -> List[html.Div]:
- """Render recent decisions"""
- items = []
- for decision in decisions:
- border_class = {
- 'BUY': 'border-success bg-success bg-opacity-10',
- 'SELL': 'border-danger bg-danger bg-opacity-10'
- }.get(decision['action'], 'border-secondary bg-secondary bg-opacity-10')
-
- items.append(
- html.Div([
- html.Small(decision['timestamp'], className="text-muted"),
- html.Br(),
- html.Strong(f"{decision['action']} - {decision['symbol']}"),
- html.Br(),
- html.Small(f"Confidence: {decision['confidence']}% | Price: ${decision['price']}")
- ], className=f"mb-2 p-2 border-start border-3 {border_class}")
- )
-
- return items
-
- def _render_cob_ladder(self, symbol: str) -> html.Div:
- """Render COB ladder for symbol"""
- # Sample COB data
- return html.Table([
- html.Thead([
- html.Tr([
- html.Th("Size"),
- html.Th("Price"),
- html.Th("Total")
- ])
- ]),
- html.Tbody([
- html.Tr([
- html.Td("1.5"),
- html.Td("$3426.12"),
- html.Td("$5139.18")
- ], className="ask-row"),
- html.Tr([
- html.Td("2.3"),
- html.Td("$3425.89"),
- html.Td("$7879.55")
- ], className="ask-row"),
- html.Tr([
- html.Td("1.8"),
- html.Td("$3425.45"),
- html.Td("$6165.81")
- ], className="bid-row"),
- html.Tr([
- html.Td("3.2"),
- html.Td("$3425.12"),
- html.Td("$10960.38")
- ], className="bid-row")
- ])
- ], className="table table-sm table-borderless")
-
- def _render_training_metrics(self) -> html.Div:
- """Render training metrics"""
- return html.Div([
- # Model Status
- html.Div([
- html.H6("Model Status"),
- html.Div([
- html.Span("DQN: Training", className="model-status status-training"),
- html.Span("CNN: Training", className="model-status status-training"),
- html.Span("Transformer: Idle", className="model-status status-idle"),
- html.Span("COB-RL: Training", className="model-status status-training")
- ])
- ], className="mb-3"),
-
- # Training Metrics
- html.Div([
- html.H6("Training Metrics"),
- html.Div([
- html.Div([
- html.Div([html.Small("DQN Loss:")], className="col-6"),
- html.Div([html.Small("0.0234", className="fw-bold")], className="col-6")
- ], className="row mb-1"),
- html.Div([
- html.Div([html.Small("CNN Accuracy:")], className="col-6"),
- html.Div([html.Small("87.6%", className="fw-bold")], className="col-6")
- ], className="row mb-1"),
- html.Div([
- html.Div([html.Small("Training Steps:")], className="col-6"),
- html.Div([html.Small("15,420", className="fw-bold")], className="col-6")
- ], className="row mb-1")
- ])
- ], className="mb-3"),
-
- # Performance Stats
- html.Div([
- html.H6("Performance"),
- html.Div([
- html.Div([
- html.Div([html.Small("Win Rate:")], className="col-8"),
- html.Div([html.Small("68.5%", className="fw-bold")], className="col-4")
- ], className="row mb-1"),
- html.Div([
- html.Div([html.Small("Avg Trade:")], className="col-8"),
- html.Div([html.Small("$8.34", className="fw-bold")], className="col-4")
- ], className="row mb-1"),
- html.Div([
- html.Div([html.Small("Sharpe Ratio:")], className="col-8"),
- html.Div([html.Small("1.82", className="fw-bold")], className="col-4")
- ], className="row mb-1")
- ])
- ])
- ])
-
- def _render_closed_trades(self) -> html.Div:
- """Render closed trades table"""
- if not self.closed_trades:
- return html.Div("No closed trades yet.", className="alert alert-info mt-3")
-
- # Create a DataFrame from closed trades
- df_trades = pd.DataFrame(self.closed_trades)
-
- # Format columns for display
- df_trades['timestamp'] = pd.to_datetime(df_trades['timestamp']).dt.strftime('%Y-%m-%d %H:%M:%S')
- df_trades['entry_price'] = df_trades['entry_price'].apply(lambda x: f"${x:,.2f}")
- df_trades['exit_price'] = df_trades['exit_price'].apply(lambda x: f"${x:,.2f}")
- df_trades['pnl'] = df_trades['pnl'].apply(lambda x: f"${x:,.2f}")
- df_trades['profit_percentage'] = df_trades['profit_percentage'].apply(lambda x: f"{x:,.2f}%")
- df_trades['size'] = df_trades['size'].apply(lambda x: f"{x:,.4f}")
- df_trades['fees'] = df_trades['fees'].apply(lambda x: f"${x:,.2f}")
-
- table_header = [html.Thead(html.Tr([html.Th(col) for col in df_trades.columns]))]
- table_body = [html.Tbody([
- html.Tr([html.Td(df_trades.iloc[i][col]) for col in df_trades.columns]) for i in range(len(df_trades))
- ])]
-
- return html.Div(
- html.Table(table_header + table_body, className="table table-striped table-hover table-sm"),
- className="table-responsive"
- )
-
- def _execute_manual_trade(self, action: str):
- """Execute manual trade"""
- try:
- logger.info(f"MANUAL TRADE: {action} executed")
- # Add to session trades
- trade = {
- "time": datetime.now(),
- "action": action,
- "symbol": "ETH/USDT",
- "price": self._get_current_price("ETH/USDT") or 3425.67
- }
- self.session_trades.append(trade)
- except Exception as e:
- logger.error(f"Error executing manual trade: {e}")
-
- def _clear_session(self):
- """Clear session data"""
- self.session_trades = []
- self.session_pnl = 0.0
- self.current_position = 0.0
- self.session_start_time = datetime.now()
- logger.info("SESSION: Cleared")
-
- def run_server(self, host='127.0.0.1', port=8052, debug=False):
- """Run the dashboard server"""
- logger.info(f"TEMPLATED DASHBOARD: Starting at http://{host}:{port}")
- self.app.run(host=host, port=port, debug=debug)
-
- def _handle_unified_stream_data(self, data):
- """Placeholder for unified stream data handling."""
- logger.debug(f"Received data from unified stream: {data}")
-
- def _delayed_training_check(self):
- """Check and start training after a delay to allow initialization"""
- try:
- time.sleep(10) # Wait 10 seconds for initialization
- logger.info("Checking if models need training activation...")
- self._start_actual_training_if_needed()
- except Exception as e:
- logger.error(f"Error in delayed training check: {e}")
-
- def _initialize_enhanced_training_system(self):
- """Initialize enhanced training system for model predictions"""
- try:
- # Try to import and initialize enhanced training system
- from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
-
- self.training_system = EnhancedRealtimeTrainingSystem(
- orchestrator=self.orchestrator,
- data_provider=self.data_provider,
- dashboard=self
- )
-
- # Initialize prediction storage
- if not hasattr(self.orchestrator, 'recent_dqn_predictions'):
- self.orchestrator.recent_dqn_predictions = {}
- if not hasattr(self.orchestrator, 'recent_cnn_predictions'):
- self.orchestrator.recent_cnn_predictions = {}
-
- logger.info("TEMPLATED DASHBOARD: Enhanced training system initialized for model predictions")
-
- except ImportError:
- logger.warning("TEMPLATED DASHBOARD: Enhanced training system not available - using mock predictions")
- self.training_system = None
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Error initializing enhanced training system: {e}")
- self.training_system = None
-
- def _initialize_streaming(self):
- """Initialize data streaming"""
- try:
- self._start_websocket_streaming()
- self._start_data_collection()
- logger.info("TEMPLATED DASHBOARD: Data streaming initialized")
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Error initializing streaming: {e}")
-
- def _start_websocket_streaming(self):
- """Start WebSocket streaming for real-time data."""
- ws_thread = threading.Thread(target=self._ws_worker, daemon=True)
- ws_thread.start()
-
- def _ws_worker(self):
- try:
- import websocket
- import json # Added import
- def on_message(ws, message):
- try:
- data = json.loads(message)
- if 'k' in data:
- kline = data['k']
- tick_record = {
- 'symbol': 'ETHUSDT',
- 'datetime': datetime.fromtimestamp(int(kline['t']) / 1000),
- 'open': float(kline['o']),
- 'high': float(kline['h']),
- 'low': float(kline['l']),
- 'close': float(kline['c']),
- 'price': float(kline['c']),
- 'volume': float(kline['v']),
- }
- self.ws_price_cache['ETHUSDT'] = tick_record['price']
- self.current_prices['ETH/USDT'] = tick_record['price']
- self.tick_cache.append(tick_record)
- if len(self.tick_cache) > 1000:
- self.tick_cache.pop(0)
- except Exception as e:
- logger.warning(f"TEMPLATED DASHBOARD: WebSocket message error: {e}")
- def on_error(ws, error):
- logger.error(f"TEMPLATED DASHBOARD: WebSocket error: {error}")
- self.is_streaming = False
- def on_close(ws, close_status_code, close_msg):
- logger.warning("TEMPLATED DASHBOARD: WebSocket connection closed")
- self.is_streaming = False
- def on_open(ws):
- logger.info("TEMPLATED DASHBOARD: WebSocket connected")
- self.is_streaming = True
- ws_url = "wss://stream.binance.com:9443/ws/ethusdt@kline_1s"
- ws = websocket.WebSocketApp(ws_url, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
- ws.run_forever()
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: WebSocket worker error: {e}")
- self.is_streaming = False
-
- def _start_data_collection(self):
- """Start background data collection"""
- data_thread = threading.Thread(target=self._data_worker, daemon=True)
- data_thread.start()
-
- def _data_worker(self):
- while True:
- try:
- self._update_session_metrics()
- time.sleep(5)
- except Exception as e:
- logger.warning(f"TEMPLATED DASHBOARD: Data collection error: {e}")
- time.sleep(10)
-
- def _update_session_metrics(self):
- """Update session P&L and total fees from closed trades."""
- try:
- closed_trades = []
- if self.trading_executor and hasattr(self.trading_executor, 'get_closed_trades'):
- closed_trades = self.trading_executor.get_closed_trades()
- self.closed_trades = closed_trades
- if closed_trades:
- self.session_pnl = sum(trade.get('pnl', 0) for trade in closed_trades)
- self.total_fees = sum(trade.get('fees', 0) for trade in closed_trades)
- else:
- self.session_pnl = 0.0
- self.total_fees = 0.0
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Error updating session metrics: {e}")
-
- def _connect_to_orchestrator(self):
- """Connect to orchestrator for real trading signals"""
- try:
- if self.orchestrator and hasattr(self.orchestrator, 'add_decision_callback'):
- import asyncio # Added import
- # from dataclasses import asdict # Moved asdict to top-level import
-
- def connect_worker():
- try:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- # No need to run_until_complete here, just register the callback
- self.orchestrator.add_decision_callback(self._on_trading_decision)
- logger.info("TEMPLATED DASHBOARD: Successfully connected to orchestrator for trading signals.")
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Orchestrator connection worker failed: {e}")
- thread = threading.Thread(target=connect_worker, daemon=True)
- thread.start()
- else:
- logger.warning("TEMPLATED DASHBOARD: Orchestrator not available or doesn\'t support callbacks")
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Error initiating orchestrator connection: {e}")
-
- async def _on_trading_decision(self, decision):
- """Handle trading decision from orchestrator."""
- try:
- action = getattr(decision, 'action', decision.get('action'))
- if action == 'HOLD':
- return
- symbol = getattr(decision, 'symbol', decision.get('symbol', 'ETH/USDT'))
- if 'ETH' not in symbol.upper():
- return
- dashboard_decision = asdict(decision) if not isinstance(decision, dict) else decision.copy()
- dashboard_decision['timestamp'] = datetime.now()
- dashboard_decision['executed'] = False
- self.recent_decisions.append(dashboard_decision)
- if len(self.recent_decisions) > 200:
- self.recent_decisions.pop(0)
- logger.info(f"TEMPLATED DASHBOARD: [ORCHESTRATOR SIGNAL] Received: {action} for {symbol}")
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Error handling trading decision: {e}")
-
- def _initialize_cob_integration(self):
- """Initialize simple COB integration that works without async event loops"""
- try:
- logger.info("TEMPLATED DASHBOARD: Initializing simple COB integration for model feeding")
-
- # Initialize COB data storage
- self.cob_bucketed_data = {
- 'ETH/USDT': {},
- 'BTC/USDT': {}
- }
- self.cob_last_update: Dict[str, Optional[float]] = {
- 'ETH/USDT': None,
- 'BTC/USDT': None
- } # Corrected type hint
-
- # Start simple COB data collection
- self._start_simple_cob_collection()
-
- logger.info("TEMPLATED DASHBOARD: Simple COB integration initialized successfully")
-
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Error initializing COB integration: {e}")
- self.cob_integration = None
-
- def _start_simple_cob_collection(self):
- """Start simple COB data collection using REST APIs (no async required)"""
- try:
- # threading and time already imported
-
- def cob_collector():
- """Collect COB data using simple REST API calls"""
- while True:
- try:
- # Collect data for both symbols
- for symbol in ['ETH/USDT', 'BTC/USDT']:
- self._collect_simple_cob_data(symbol)
-
- # Sleep for 1 second between collections
- time.sleep(1)
- except Exception as e:
- logger.debug(f"TEMPLATED DASHBOARD: Error in COB collection: {e}")
- time.sleep(5) # Wait longer on error
-
- # Start collector in background thread
- cob_thread = threading.Thread(target=cob_collector, daemon=True)
- cob_thread.start()
-
- logger.info("TEMPLATED DASHBOARD: Simple COB data collection started")
-
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Error starting COB collection: {e}")
-
- def _collect_simple_cob_data(self, symbol: str):
- """Collect simple COB data using Binance REST API"""
- try:
- import requests # Added import
- # time already imported
-
- # Use Binance REST API for order book data
- binance_symbol = symbol.replace('/', '')
- url = f"https://api.binance.com/api/v3/depth?symbol={binance_symbol}&limit=500"
-
- response = requests.get(url, timeout=5)
- if response.status_code == 200:
- data = response.json()
-
- # Process order book data
- bids = []
- asks = []
-
- # Process bids (buy orders)
- for bid in data['bids'][:100]: # Top 100 levels
- price = float(bid[0])
- size = float(bid[1])
- bids.append({
- 'price': price,
- 'size': size,
- 'total': price * size
- })
-
- # Process asks (sell orders)
- for ask in data['asks'][:100]: # Top 100 levels
- price = float(ask[0])
- size = float(ask[1])
- asks.append({
- 'price': price,
- 'size': size,
- 'total': price * size
- })
-
- # Calculate statistics
- if bids and asks:
- best_bid = max(bids, key=lambda x: x['price'])
- best_ask = min(asks, key=lambda x: x['price'])
- mid_price = (best_bid['price'] + best_ask['price']) / 2
- spread_bps = ((best_ask['price'] - best_bid['price']) / mid_price) * 10000 if mid_price > 0 else 0
-
- total_bid_liquidity = sum(bid['total'] for bid in bids[:20])
- total_ask_liquidity = sum(ask['total'] for ask in asks[:20])
- total_liquidity = total_bid_liquidity + total_ask_liquidity
- imbalance = (total_bid_liquidity - total_ask_liquidity) / total_liquidity if total_liquidity > 0 else 0
-
- # Create COB snapshot
- cob_snapshot = {
- 'symbol': symbol,
- 'timestamp': time.time(),
- 'bids': bids,
- 'asks': asks,
- 'stats': {
- 'mid_price': mid_price,
- 'spread_bps': spread_bps,
- 'total_bid_liquidity': total_bid_liquidity,
- 'total_ask_liquidity': total_ask_liquidity,
- 'imbalance': imbalance,
- 'exchanges_active': ['Binance']
- }
- }
-
- # Store in history (keep last 15 seconds)
- self.cob_data_history[symbol].append(cob_snapshot)
- if len(self.cob_data_history[symbol]) > 15: # Keep 15 seconds
- # Use slicing to remove old elements from deque to ensure correct behavior
- while len(self.cob_data_history[symbol]) > 15:
- self.cob_data_history[symbol].popleft()
-
- # Update latest data
- self.latest_cob_data[symbol] = cob_snapshot
- self.cob_last_update[symbol] = time.time()
-
- # Generate bucketed data for models
- self._generate_bucketed_cob_data(symbol, cob_snapshot)
-
- logger.debug(f"TEMPLATED DASHBOARD: COB data collected for {symbol}: {len(bids)} bids, {len(asks)} asks")
-
- except Exception as e:
- logger.debug(f"TEMPLATED DASHBOARD: Error collecting COB data for {symbol}: {e}")
-
- def _generate_bucketed_cob_data(self, symbol: str, cob_snapshot: dict):
- """Generate bucketed COB data for model feeding"""
- try:
- # Create price buckets (1 basis point granularity)
- bucket_size_bps = 1.0
- mid_price = cob_snapshot['stats']['mid_price']
-
- # Initialize buckets
- buckets = {}
-
- # Process bids into buckets
- for bid in cob_snapshot['bids']:
- price_offset_bps = ((bid['price'] - mid_price) / mid_price) * 10000
- bucket_key = int(price_offset_bps / bucket_size_bps)
-
- if bucket_key not in buckets:
- buckets[bucket_key] = {'bid_volume': 0, 'ask_volume': 0}
-
- buckets[bucket_key]['bid_volume'] += bid['total']
-
- # Process asks into buckets
- for ask in cob_snapshot['asks']:
- price_offset_bps = ((ask['price'] - mid_price) / mid_price) * 10000
- bucket_key = int(price_offset_bps / bucket_size_bps)
-
- if bucket_key not in buckets:
- buckets[bucket_key] = {'bid_volume': 0, 'ask_volume': 0}
-
- buckets[bucket_key]['ask_volume'] += ask['total']
-
- # Store bucketed data
- self.cob_bucketed_data[symbol] = {
- 'timestamp': cob_snapshot['timestamp'],
- 'mid_price': mid_price,
- 'buckets': buckets,
- 'bucket_size_bps': bucket_size_bps
- }
-
- # Feed to models
- self._feed_cob_data_to_models(symbol, cob_snapshot)
-
- except Exception as e:
- logger.debug(f"TEMPLATED DASHBOARD: Error generating bucketed COB data: {e}")
-
- def _calculate_cumulative_imbalance(self, symbol: str) -> Dict[str, float]:
- """Calculate Moving Averages (MA) of imbalance over different periods."""
- stats = {}
- history = self.cob_data_history.get(symbol)
-
- if not history:
- return {'1s': 0.0, '5s': 0.0, '15s': 0.0, '60s': 0.0}
-
- # Convert history to list and get recent snapshots
- history_list = list(history)
- if not history_list:
- return {'1s': 0.0, '5s': 0.0, '15s': 0.0, '60s': 0.0}
-
- # Extract imbalance values from recent snapshots
- imbalances = []
- for snap in history_list:
- if isinstance(snap, dict) and 'stats' in snap and snap['stats']:
- imbalance = snap['stats'].get('imbalance')
- if imbalance is not None:
- imbalances.append(imbalance)
-
- if not imbalances:
- return {'1s': 0.0, '5s': 0.0, '15s': 0.0, '60s': 0.0}
-
- # Calculate Moving Averages over different periods
- # MA periods: 1s=1 period, 5s=5 periods, 15s=15 periods, 60s=60 periods
- ma_periods = {'1s': 1, '5s': 5, '15s': 15, '60s': 60}
-
- for name, period in ma_periods.items():
- if len(imbalances) >= period:
- # Calculate SMA over the last 'period' values
- recent_imbalances = imbalances[-period:]
- sma_value = sum(recent_imbalances) / len(recent_imbalances)
-
- # Also calculate EMA for better responsiveness
- if period > 1:
- # EMA calculation with alpha = 2/(period+1)
- alpha = 2.0 / (period + 1)
- ema_value = recent_imbalances[0] # Start with first value
- for value in recent_imbalances[1:]:
- ema_value = alpha * value + (1 - alpha) * ema_value
- # Use EMA for better responsiveness
- stats[name] = ema_value
- else:
- # For 1s, use SMA (no EMA needed)
- stats[name] = sma_value
- else:
- # If not enough data, use available data
- available_imbalances = imbalances[-min(period, len(imbalances)):]
- if available_imbalances:
- if len(available_imbalances) > 1:
- # Calculate EMA for available data
- alpha = 2.0 / (len(available_imbalances) + 1)
- ema_value = available_imbalances[0]
- for value in available_imbalances[1:]:
- ema_value = alpha * value + (1 - alpha) * ema_value
- stats[name] = ema_value
- else:
- # Single value, use as is
- stats[name] = available_imbalances[0]
- else:
- stats[name] = 0.0
-
- # Debug logging to verify MA calculation
- if any(value != 0.0 for value in stats.values()):
- logger.debug(f"TEMPLATED DASHBOARD: [MOVING-AVERAGE-IMBALANCE] {symbol}: {stats} (from {len(imbalances)} snapshots)")
-
- return stats
-
- def _feed_cob_data_to_models(self, symbol: str, cob_snapshot: dict):
- """Feed COB data to models for training and inference"""
- try:
- # Calculate cumulative imbalance for model feeding
- cumulative_imbalance = self._calculate_cumulative_imbalance(symbol) # Assumes _calculate_cumulative_imbalance is available
-
- history_data = {
- 'symbol': symbol,
- 'current_snapshot': cob_snapshot,
- 'history': list(self.cob_data_history[symbol]), # Convert deque to list for consistent slicing
- 'bucketed_data': self.cob_bucketed_data[symbol],
- 'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
- 'timestamp': cob_snapshot['timestamp']
- }
-
- # Pass to orchestrator for model feeding
- if self.orchestrator and hasattr(self.orchestrator, 'feed_cob_data'):
- self.orchestrator.feed_cob_data(symbol, history_data) # Assumes feed_cob_data exists in orchestrator
-
- except Exception as e:
- logger.debug(f"TEMPLATED DASHBOARD: Error feeding COB data to models: {e}")
-
- def _is_signal_generation_active(self) -> bool:
- """Check if signal generation is active (e.g., models are loaded and running)"""
- # For now, return true to always generate signals
- # In a real system, this would check model loading status, training status, etc.
- return True # Simplified for initial integration
-
- def _start_signal_generation_loop(self):
- """Start signal generation loop to ensure continuous trading signals"""
- try:
- def signal_worker():
- logger.info("TEMPLATED DASHBOARD: Signal generation worker started")
- while True:
- try:
- # Ensure signal generation is active before processing
- if self._is_signal_generation_active():
- symbol = 'ETH/USDT' # Focus on ETH for now
- current_price = self._get_current_price(symbol)
- if current_price:
- # Generate a momentum signal (simplified for demo)
- signal = self._generate_momentum_signal(symbol, current_price) # Assumes _generate_momentum_signal is available
- if signal:
- self._process_dashboard_signal(signal) # Assumes _process_dashboard_signal is available
-
- # Generate a DQN signal if enabled
- if self.dqn_inference_enabled:
- dqn_signal = self._generate_dqn_signal(symbol, current_price) # Assumes _generate_dqn_signal is available
- if dqn_signal:
- self._process_dashboard_signal(dqn_signal)
-
- # Generate a CNN pivot signal if enabled
- if self.cnn_inference_enabled:
- cnn_signal = self._get_cnn_pivot_prediction() # Assumes _get_cnn_pivot_prediction is available
- if cnn_signal:
- self._process_dashboard_signal(cnn_signal)
-
- # Update session metrics every 1 second interval to reflect new trades
- self._update_session_metrics()
-
- time.sleep(1) # Run every second for signal generation
-
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Error in signal worker: {e}")
- time.sleep(5) # Longer sleep on error
-
- signal_thread = threading.Thread(target=signal_worker, daemon=True)
- signal_thread.start()
- logger.info("TEMPLATED DASHBOARD: Signal generation loop started")
-
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Error starting signal generation loop: {e}")
-
- def _start_actual_training_if_needed(self):
- """Start actual model training with real data collection and training loops"""
- try:
- if not self.orchestrator:
- logger.warning("TEMPLATED DASHBOARD: No orchestrator available for training")
- return
- logger.info("TEMPLATED DASHBOARD: TRAINING: Starting actual training system with real data collection")
- self._start_real_training_system()
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Error starting comprehensive training system: {e}")
-
- def _start_real_training_system(self):
- """Start real training system with data collection and actual model training"""
- try:
- # Training performance metrics
- self.training_performance = {
- 'decision': {'inference_times': [], 'training_times': [], 'total_calls': 0},
- 'cob_rl': {'inference_times': [], 'training_times': [], 'total_calls': 0},
- 'dqn': {'inference_times': [], 'training_times': [], 'total_calls': 0},
- 'cnn': {'inference_times': [], 'training_times': [], 'total_calls': 0},
- 'transformer': {'inference_times': [], 'training_times': [], 'total_calls': 0} # Added for transformer
- }
-
- def training_coordinator():
- logger.info("TEMPLATED DASHBOARD: TRAINING: High-frequency training coordinator started")
- training_iteration = 0
- last_dqn_training = 0
- last_cnn_training = 0
- last_decision_training = 0
- last_cob_rl_training = 0
- last_transformer_training = 0 # For transformer
-
- while True:
- try:
- training_iteration += 1
- current_time = time.time()
- market_data = self._collect_training_data() # Assumes _collect_training_data is available
-
- if market_data:
- logger.debug(f"TEMPLATED DASHBOARD: TRAINING: Collected {len(market_data)} market data points for training")
-
- # High-frequency training for split-second decisions
- # Train decision fusion and COB RL as fast as hardware allows
- if current_time - last_decision_training > 0.1: # Every 100ms
- start_time = time.time()
- self._perform_real_decision_training(market_data) # Assumes _perform_real_decision_training is available
- training_time = time.time() - start_time
- self.training_performance['decision']['training_times'].append(training_time)
- self.training_performance['decision']['total_calls'] += 1
- last_decision_training = current_time
-
- # Keep only last 100 measurements
- if len(self.training_performance['decision']['training_times']) > 100:
- self.training_performance['decision']['training_times'] = self.training_performance['decision']['training_times'][-100:]
-
- # Advanced Transformer Training (every 200ms for comprehensive features)
- if current_time - last_transformer_training > 0.2: # Every 200ms for transformer
- start_time = time.time()
- self._perform_real_transformer_training(market_data) # Assumes _perform_real_transformer_training is available
- training_time = time.time() - start_time
- self.training_performance['transformer']['training_times'].append(training_time)
- self.training_performance['transformer']['total_calls'] += 1
- last_transformer_training = current_time # Update last training time
-
- # Keep only last 100 measurements
- if len(self.training_performance['transformer']['training_times']) > 100:
- self.training_performance['transformer']['training_times'] = self.training_performance['transformer']['training_times'][-100:]
-
- if current_time - last_cob_rl_training > 0.1: # Every 100ms
- start_time = time.time()
- self._perform_real_cob_rl_training(market_data) # Assumes _perform_real_cob_rl_training is available
- training_time = time.time() - start_time
- self.training_performance['cob_rl']['training_times'].append(training_time)
- self.training_performance['cob_rl']['total_calls'] += 1
- last_cob_rl_training = current_time
-
- # Keep only last 100 measurements
- if len(self.training_performance['cob_rl']['training_times']) > 100:
- self.training_performance['cob_rl']['training_times'] = self.training_performance['cob_rl']['training_times'][-100:]
-
- # Standard frequency for larger models
- if current_time - last_dqn_training > 30:
- start_time = time.time()
- self._perform_real_dqn_training(market_data) # Assumes _perform_real_dqn_training is available
- training_time = time.time() - start_time
- self.training_performance['dqn']['training_times'].append(training_time)
- self.training_performance['dqn']['total_calls'] += 1
- last_dqn_training = current_time
-
- if len(self.training_performance['dqn']['training_times']) > 50:
- self.training_performance['dqn']['training_times'] = self.training_performance['dqn']['training_times'][-50:]
-
- if current_time - last_cnn_training > 45:
- start_time = time.time()
- self._perform_real_cnn_training(market_data) # Assumes _perform_real_cnn_training is available
- training_time = time.time() - start_time
- self.training_performance['cnn']['training_times'].append(training_time)
- self.training_performance['cnn']['total_calls'] += 1
- last_cnn_training = current_time
-
- if len(self.training_performance['cnn']['training_times']) > 50:
- self.training_performance['cnn']['training_times'] = self.training_performance['cnn']['training_times'][-50:]
-
- self._update_training_progress(training_iteration) # Assumes _update_training_progress is available
-
- # Log performance metrics every 100 iterations
- if training_iteration % 100 == 0:
- self._log_training_performance() # Assumes _log_training_performance is available
- logger.info(f"TEMPLATED DASHBOARD: TRAINING: Iteration {training_iteration} - High-frequency training active")
-
- # Minimal sleep for maximum responsiveness
- time.sleep(0.05) # 50ms sleep for 20Hz training loop
-
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: TRAINING: Error in training iteration {training_iteration}: {e}")
- time.sleep(1) # Shorter error recovery
-
- training_thread = threading.Thread(target=training_coordinator, daemon=True)
- training_thread.start()
- logger.info("TEMPLATED DASHBOARD: Real training system started")
-
- except Exception as e:
- logger.error(f"TEMPLATED DASHBOARD: Error starting real training system: {e}")
-
-def create_templated_dashboard(data_provider: Optional[DataProvider] = None,
- orchestrator: Optional[TradingOrchestrator] = None,
- trading_executor: Optional[TradingExecutor] = None) -> TemplatedTradingDashboard:
- """Create templated trading dashboard"""
- return TemplatedTradingDashboard(data_provider, orchestrator, trading_executor)
\ No newline at end of file
diff --git a/web/tensorboard_component.py b/web/tensorboard_component.py
deleted file mode 100644
index e20bf7c..0000000
--- a/web/tensorboard_component.py
+++ /dev/null
@@ -1,173 +0,0 @@
-#!/usr/bin/env python3
-"""
-TensorBoard Component for Dashboard
-
-This module provides a Dash component that embeds TensorBoard in the dashboard.
-"""
-
-import dash
-from dash import html, dcc
-import dash_bootstrap_components as dbc
-import logging
-from typing import Optional, Dict, Any
-
-logger = logging.getLogger(__name__)
-
-def create_tensorboard_tab(tensorboard_url: str = "http://localhost:6006") -> html.Div:
- """
- Create a dashboard tab that embeds TensorBoard
-
- Args:
- tensorboard_url: URL of the TensorBoard server
-
- Returns:
- html.Div: Dash component containing TensorBoard iframe
- """
- return html.Div([
- dbc.Alert([
- html.I(className="fas fa-chart-line me-2"),
- "TensorBoard Training Visualization",
- html.A(
- "Open in New Window",
- href=tensorboard_url,
- target="_blank",
- className="ms-2 btn btn-sm btn-primary"
- )
- ], color="info", className="mb-3"),
-
- # TensorBoard iframe
- html.Iframe(
- src=tensorboard_url,
- style={
- 'width': '100%',
- 'height': '800px',
- 'border': 'none'
- }
- ),
-
- # Training metrics summary
- html.Div([
- html.H5("Training Metrics Summary", className="mt-3"),
- html.Div(id="training-metrics-summary", className="mt-2")
- ], className="mt-3")
- ])
-
-def create_training_metrics_card() -> dbc.Card:
- """
- Create a card displaying key training metrics
-
- Returns:
- dbc.Card: Dash Bootstrap card component
- """
- return dbc.Card([
- dbc.CardHeader([
- html.I(className="fas fa-brain me-2"),
- "Training Metrics"
- ]),
- dbc.CardBody([
- dbc.Row([
- dbc.Col([
- html.H6("Model Status"),
- html.Div(id="model-training-status", children="Initializing...")
- ], width=6),
- dbc.Col([
- html.H6("Training Progress"),
- dbc.Progress(id="training-progress-bar", value=0, className="mb-2"),
- html.Div(id="training-progress-text", children="0%")
- ], width=6)
- ], className="mb-3"),
-
- dbc.Row([
- dbc.Col([
- html.H6("Loss"),
- html.Div(id="training-loss-value", children="N/A")
- ], width=4),
- dbc.Col([
- html.H6("Reward"),
- html.Div(id="training-reward-value", children="N/A")
- ], width=4),
- dbc.Col([
- html.H6("State Quality"),
- html.Div(id="training-state-quality", children="N/A")
- ], width=4)
- ], className="mb-3"),
-
- dbc.Row([
- dbc.Col([
- html.A(
- dbc.Button([
- html.I(className="fas fa-chart-line me-2"),
- "Open TensorBoard"
- ], color="primary", size="sm", className="w-100"),
- href="http://localhost:6006",
- target="_blank"
- )
- ], width=12)
- ])
- ])
- ], className="mb-3")
-
-def create_tensorboard_status_indicator(tensorboard_url: str = "http://localhost:6006") -> html.Div:
- """
- Create a status indicator for TensorBoard
-
- Args:
- tensorboard_url: URL of the TensorBoard server
-
- Returns:
- html.Div: Dash component showing TensorBoard status
- """
- return html.Div([
- dbc.Button([
- html.I(className="fas fa-chart-line me-2"),
- "TensorBoard"
- ],
- id="tensorboard-status-button",
- color="success",
- size="sm",
- href=tensorboard_url,
- target="_blank",
- external_link=True,
- className="ms-2")
- ], id="tensorboard-status-container")
-
-def update_training_metrics_card(metrics: Dict[str, Any]) -> Dict[str, Any]:
- """
- Update training metrics card with latest data
-
- Args:
- metrics: Dictionary of training metrics
-
- Returns:
- Dict: Dictionary of Dash component updates
- """
- # Extract metrics
- training_active = metrics.get("training_active", False)
- loss = metrics.get("loss", None)
- reward = metrics.get("reward", None)
- state_quality = metrics.get("state_quality", None)
- progress = metrics.get("progress", 0)
-
- # Format values
- loss_str = f"{loss:.4f}" if loss is not None else "N/A"
- reward_str = f"{reward:.4f}" if reward is not None else "N/A"
- state_quality_str = f"{state_quality:.1%}" if state_quality is not None else "N/A"
- progress_str = f"{progress:.1%}"
-
- # Determine status
- if training_active:
- status = "Training Active"
- status_class = "text-success"
- else:
- status = "Training Inactive"
- status_class = "text-warning"
-
- # Return updates
- return {
- "model-training-status": html.Span(status, className=status_class),
- "training-progress-bar": progress * 100,
- "training-progress-text": progress_str,
- "training-loss-value": loss_str,
- "training-reward-value": reward_str,
- "training-state-quality": state_quality_str
- }
\ No newline at end of file
diff --git a/web/tensorboard_integration.py b/web/tensorboard_integration.py
deleted file mode 100644
index b95a745..0000000
--- a/web/tensorboard_integration.py
+++ /dev/null
@@ -1,203 +0,0 @@
-#!/usr/bin/env python3
-"""
-TensorBoard Integration for Dashboard
-
-This module provides integration between the trading dashboard and TensorBoard,
-allowing training metrics to be visualized in real-time.
-"""
-
-import os
-import sys
-import subprocess
-import threading
-import time
-import logging
-import webbrowser
-from pathlib import Path
-from typing import Optional, Dict, Any
-
-logger = logging.getLogger(__name__)
-
-class TensorBoardIntegration:
- """
- TensorBoard integration for dashboard
-
- Provides methods to start TensorBoard server and access training metrics
- """
-
- def __init__(self, log_dir: str = "runs", port: int = 6006):
- """
- Initialize TensorBoard integration
-
- Args:
- log_dir: Directory containing TensorBoard logs
- port: Port to run TensorBoard on
- """
- self.log_dir = log_dir
- self.port = port
- self.process = None
- self.url = f"http://localhost:{port}"
- self.is_running = False
- self.latest_metrics = {}
-
- # Create log directory if it doesn't exist
- os.makedirs(log_dir, exist_ok=True)
-
- def start_tensorboard(self, open_browser: bool = False) -> bool:
- """
- Start TensorBoard server in a separate process
-
- Args:
- open_browser: Whether to open browser automatically
-
- Returns:
- bool: True if TensorBoard was started successfully
- """
- if self.is_running:
- logger.info("TensorBoard is already running")
- return True
-
- try:
- # Check if TensorBoard is available
- try:
- import tensorboard
- logger.info(f"TensorBoard version {tensorboard.__version__} available")
- except ImportError:
- logger.warning("TensorBoard not installed. Install with: pip install tensorboard")
- return False
-
- # Check if log directory exists and has content
- log_dir_path = Path(self.log_dir)
- if not log_dir_path.exists():
- logger.warning(f"Log directory {self.log_dir} does not exist")
- os.makedirs(self.log_dir, exist_ok=True)
- logger.info(f"Created log directory {self.log_dir}")
-
- # Start TensorBoard process
- cmd = [
- sys.executable,
- "-m",
- "tensorboard.main",
- "--logdir", self.log_dir,
- "--port", str(self.port),
- "--reload_interval", "5", # Reload data every 5 seconds
- "--reload_multifile", "true" # Better handling of multiple log files
- ]
-
- logger.info(f"Starting TensorBoard: {' '.join(cmd)}")
-
- # Start process without capturing output
- self.process = subprocess.Popen(
- cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- text=True
- )
-
- # Wait a moment for TensorBoard to start
- time.sleep(2)
-
- # Check if process is running
- if self.process.poll() is None:
- self.is_running = True
- logger.info(f"TensorBoard started at {self.url}")
-
- # Open browser if requested
- if open_browser:
- try:
- webbrowser.open(self.url)
- logger.info("Browser opened automatically")
- except Exception as e:
- logger.warning(f"Could not open browser: {e}")
-
- # Start monitoring thread
- threading.Thread(target=self._monitor_process, daemon=True).start()
-
- return True
- else:
- stdout, stderr = self.process.communicate()
- logger.error(f"TensorBoard failed to start: {stderr}")
- return False
-
- except Exception as e:
- logger.error(f"Error starting TensorBoard: {e}")
- return False
-
- def _monitor_process(self):
- """Monitor TensorBoard process and capture output"""
- try:
- while self.process and self.process.poll() is None:
- # Read output line by line
- for line in iter(self.process.stdout.readline, ''):
- if line:
- line = line.strip()
- if line:
- logger.debug(f"TensorBoard: {line}")
-
- time.sleep(0.1)
-
- # Process has ended
- self.is_running = False
- logger.info("TensorBoard process has ended")
-
- except Exception as e:
- logger.error(f"Error monitoring TensorBoard process: {e}")
-
- def stop_tensorboard(self):
- """Stop TensorBoard server"""
- if self.process and self.process.poll() is None:
- try:
- self.process.terminate()
- self.process.wait(timeout=5)
- logger.info("TensorBoard stopped")
- except subprocess.TimeoutExpired:
- self.process.kill()
- logger.warning("TensorBoard process killed after timeout")
- except Exception as e:
- logger.error(f"Error stopping TensorBoard: {e}")
-
- self.is_running = False
-
- def get_tensorboard_url(self) -> str:
- """Get TensorBoard URL"""
- return self.url
-
- def is_tensorboard_running(self) -> bool:
- """Check if TensorBoard is running"""
- if self.process:
- return self.process.poll() is None
- return False
-
- def get_latest_metrics(self) -> Dict[str, Any]:
- """
- Get latest training metrics from TensorBoard
-
- This is a placeholder - in a real implementation, you would
- parse TensorBoard event files to extract metrics
- """
- # In a real implementation, you would parse TensorBoard event files
- # For now, return placeholder data
- return {
- "training_active": self.is_running,
- "tensorboard_url": self.url,
- "metrics_available": self.is_running
- }
-
-# Singleton instance
-_tensorboard_integration = None
-
-def get_tensorboard_integration(log_dir: str = "runs", port: int = 6006) -> TensorBoardIntegration:
- """
- Get TensorBoard integration singleton instance
-
- Args:
- log_dir: Directory containing TensorBoard logs
- port: Port to run TensorBoard on
-
- Returns:
- TensorBoardIntegration: Singleton instance
- """
- global _tensorboard_integration
- if _tensorboard_integration is None:
- _tensorboard_integration = TensorBoardIntegration(log_dir, port)
- return _tensorboard_integration
\ No newline at end of file