BIG CLEANUP
This commit is contained in:
56
CLEANUP_TODO.md
Normal file
56
CLEANUP_TODO.md
Normal file
@@ -0,0 +1,56 @@
|
||||
Cleanup run summary:
|
||||
- Deleted files: 183
|
||||
- NN\__init__.py
|
||||
- NN\models\__init__.py
|
||||
- NN\models\cnn_model.py
|
||||
- NN\models\transformer_model.py
|
||||
- NN\start_tensorboard.py
|
||||
- NN\training\enhanced_rl_training_integration.py
|
||||
- NN\training\example_checkpoint_usage.py
|
||||
- NN\training\integrate_checkpoint_management.py
|
||||
- NN\utils\__init__.py
|
||||
- NN\utils\data_interface.py
|
||||
- NN\utils\multi_data_interface.py
|
||||
- NN\utils\realtime_analyzer.py
|
||||
- NN\utils\signal_interpreter.py
|
||||
- NN\utils\trading_env.py
|
||||
- _dev\cleanup_models_now.py
|
||||
- _tools\build_keep_set.py
|
||||
- apply_trading_fixes.py
|
||||
- apply_trading_fixes_to_main.py
|
||||
- audit_training_system.py
|
||||
- balance_trading_signals.py
|
||||
- check_live_trading.py
|
||||
- check_mexc_symbols.py
|
||||
- cleanup_checkpoint_db.py
|
||||
- cleanup_checkpoints.py
|
||||
- core\__init__.py
|
||||
- core\api_rate_limiter.py
|
||||
- core\async_handler.py
|
||||
- core\bookmap_data_provider.py
|
||||
- core\bookmap_integration.py
|
||||
- core\cnn_monitor.py
|
||||
- core\cnn_training_pipeline.py
|
||||
- core\config_sync.py
|
||||
- core\enhanced_cnn_adapter.py
|
||||
- core\enhanced_cob_websocket.py
|
||||
- core\enhanced_orchestrator.py
|
||||
- core\enhanced_training_integration.py
|
||||
- core\exchanges\__init__.py
|
||||
- core\exchanges\binance_interface.py
|
||||
- core\exchanges\bybit\debug\test_bybit_balance.py
|
||||
- core\exchanges\bybit_interface.py
|
||||
- core\exchanges\bybit_rest_client.py
|
||||
- core\exchanges\deribit_interface.py
|
||||
- core\exchanges\mexc\debug\final_mexc_order_test.py
|
||||
- core\exchanges\mexc\debug\fix_mexc_orders.py
|
||||
- core\exchanges\mexc\debug\fix_mexc_orders_v2.py
|
||||
- core\exchanges\mexc\debug\fix_mexc_orders_v3.py
|
||||
- core\exchanges\mexc\debug\test_mexc_interface_debug.py
|
||||
- core\exchanges\mexc\debug\test_mexc_order_signature.py
|
||||
- core\exchanges\mexc\debug\test_mexc_order_signature_v2.py
|
||||
- core\exchanges\mexc\debug\test_mexc_signature_debug.py
|
||||
... and 133 more
|
||||
- Removed test directories: 1
|
||||
- tests
|
||||
- Kept (excluded): 1
|
||||
184
DELETE_CANDIDATES.txt
Normal file
184
DELETE_CANDIDATES.txt
Normal file
@@ -0,0 +1,184 @@
|
||||
NN\__init__.py
|
||||
NN\models\__init__.py
|
||||
NN\models\cnn_model.py
|
||||
NN\models\transformer_model.py
|
||||
NN\start_tensorboard.py
|
||||
NN\training\enhanced_realtime_training.py
|
||||
NN\training\enhanced_rl_training_integration.py
|
||||
NN\training\example_checkpoint_usage.py
|
||||
NN\training\integrate_checkpoint_management.py
|
||||
NN\utils\__init__.py
|
||||
NN\utils\data_interface.py
|
||||
NN\utils\multi_data_interface.py
|
||||
NN\utils\realtime_analyzer.py
|
||||
NN\utils\signal_interpreter.py
|
||||
NN\utils\trading_env.py
|
||||
_dev\cleanup_models_now.py
|
||||
_tools\build_keep_set.py
|
||||
apply_trading_fixes.py
|
||||
apply_trading_fixes_to_main.py
|
||||
audit_training_system.py
|
||||
balance_trading_signals.py
|
||||
check_live_trading.py
|
||||
check_mexc_symbols.py
|
||||
cleanup_checkpoint_db.py
|
||||
cleanup_checkpoints.py
|
||||
core\__init__.py
|
||||
core\api_rate_limiter.py
|
||||
core\async_handler.py
|
||||
core\bookmap_data_provider.py
|
||||
core\bookmap_integration.py
|
||||
core\cnn_monitor.py
|
||||
core\cnn_training_pipeline.py
|
||||
core\config_sync.py
|
||||
core\enhanced_cnn_adapter.py
|
||||
core\enhanced_cob_websocket.py
|
||||
core\enhanced_orchestrator.py
|
||||
core\enhanced_training_integration.py
|
||||
core\exchanges\__init__.py
|
||||
core\exchanges\binance_interface.py
|
||||
core\exchanges\bybit\debug\test_bybit_balance.py
|
||||
core\exchanges\bybit_interface.py
|
||||
core\exchanges\bybit_rest_client.py
|
||||
core\exchanges\deribit_interface.py
|
||||
core\exchanges\mexc\debug\final_mexc_order_test.py
|
||||
core\exchanges\mexc\debug\fix_mexc_orders.py
|
||||
core\exchanges\mexc\debug\fix_mexc_orders_v2.py
|
||||
core\exchanges\mexc\debug\fix_mexc_orders_v3.py
|
||||
core\exchanges\mexc\debug\test_mexc_interface_debug.py
|
||||
core\exchanges\mexc\debug\test_mexc_order_signature.py
|
||||
core\exchanges\mexc\debug\test_mexc_order_signature_v2.py
|
||||
core\exchanges\mexc\debug\test_mexc_signature_debug.py
|
||||
core\exchanges\mexc\debug\test_small_mexc_order.py
|
||||
core\exchanges\mexc\test_live_trading.py
|
||||
core\exchanges\mexc_interface.py
|
||||
core\exchanges\trading_agent_test.py
|
||||
core\mexc_webclient\__init__.py
|
||||
core\mexc_webclient\auto_browser.py
|
||||
core\mexc_webclient\browser_automation.py
|
||||
core\mexc_webclient\mexc_futures_client.py
|
||||
core\mexc_webclient\session_manager.py
|
||||
core\mexc_webclient\test_mexc_futures_webclient.py
|
||||
core\model_output_manager.py
|
||||
core\negative_case_trainer.py
|
||||
core\nn_decision_fusion.py
|
||||
core\prediction_tracker.py
|
||||
core\realtime_tick_processor.py
|
||||
core\retrospective_trainer.py
|
||||
core\rl_training_pipeline.py
|
||||
core\robust_cob_provider.py
|
||||
core\shared_cob_service.py
|
||||
core\shared_data_manager.py
|
||||
core\tick_aggregator.py
|
||||
core\trading_action.py
|
||||
core\trading_executor_fix.py
|
||||
core\training_data_collector.py
|
||||
core\williams_market_structure.py
|
||||
dataprovider_realtime.py
|
||||
debug\test_fixed_issues.py
|
||||
debug\test_trading_fixes.py
|
||||
debug\trade_audit.py
|
||||
debug_training_methods.py
|
||||
docs\exchanges\bybit\examples.py
|
||||
example_usage_simplified_data_provider.py
|
||||
kill_stale_processes.py
|
||||
launch_training.py
|
||||
main.py
|
||||
main_clean.py
|
||||
migrate_existing_models.py
|
||||
model_manager.py
|
||||
position_sync_enhancement.py
|
||||
read_logs.py
|
||||
reset_db_manager.py
|
||||
reset_models_and_fix_mapping.py
|
||||
run_clean_dashboard.py
|
||||
run_continuous_training.py
|
||||
run_crash_safe_dashboard.py
|
||||
run_enhanced_rl_training.py
|
||||
run_enhanced_training_dashboard.py
|
||||
run_integrated_rl_cob_dashboard.py
|
||||
run_mexc_browser.py
|
||||
run_optimized_cob_system.py
|
||||
run_realtime_rl_cob_trader.py
|
||||
run_simple_dashboard.py
|
||||
run_stable_dashboard.py
|
||||
run_templated_dashboard.py
|
||||
run_tensorboard.py
|
||||
run_tests.py
|
||||
scripts\kill_stale_processes.py
|
||||
scripts\restart_dashboard_with_learning.py
|
||||
scripts\restart_main_overnight.py
|
||||
setup_mexc_browser.py
|
||||
start_monitoring.py
|
||||
start_overnight_training.py
|
||||
system_stability_audit.py
|
||||
test_build_base_data_performance.py
|
||||
test_bybit_eth_futures.py
|
||||
test_bybit_eth_futures_fixed.py
|
||||
test_bybit_eth_live.py
|
||||
test_bybit_public_api.py
|
||||
test_cache_fix.py
|
||||
test_cnn_integration.py
|
||||
test_cob_dashboard.py
|
||||
test_cob_data_quality.py
|
||||
test_cob_websocket_only.py
|
||||
test_continuous_cnn_training.py
|
||||
test_dashboard_data_flow.py
|
||||
test_dashboard_performance.py
|
||||
test_data_integration.py
|
||||
test_data_provider_integration.py
|
||||
test_db_migration.py
|
||||
test_deribit_integration.py
|
||||
test_device_fix.py
|
||||
test_device_training_fix.py
|
||||
test_enhanced_cnn_adapter.py
|
||||
test_enhanced_cob_websocket.py
|
||||
test_enhanced_data_provider_websocket.py
|
||||
test_enhanced_inference_logging.py
|
||||
test_enhanced_training_integration.py
|
||||
test_enhanced_training_simple.py
|
||||
test_fifo_queues.py
|
||||
test_hold_position_fix.py
|
||||
test_imbalance_calculation.py
|
||||
test_improved_data_integration.py
|
||||
test_integrated_standardized_provider.py
|
||||
test_leverage_fix.py
|
||||
test_massive_dqn.py
|
||||
test_mexc_order_fix.py
|
||||
test_model_output_manager.py
|
||||
test_model_registry.py
|
||||
test_model_statistics.py
|
||||
test_model_stats.py
|
||||
test_model_training.py
|
||||
test_orchestrator_fix.py
|
||||
test_order_sync_and_fees.py
|
||||
test_position_based_rewards.py
|
||||
test_profitability_reward_system.py
|
||||
test_training_data_collection.py
|
||||
test_training_fixes.py
|
||||
test_websocket_cob_data.py
|
||||
tests\cob\test_cob_comparison.py
|
||||
tests\cob\test_cob_data_stability.py
|
||||
tests\test_training.py
|
||||
tests\test_training_integration.py
|
||||
tests\test_training_status.py
|
||||
tests\test_universal_data_format.py
|
||||
tests\test_universal_stream_integration.py
|
||||
trading_main.py
|
||||
utils\__init__.py
|
||||
utils\async_task_manager.py
|
||||
utils\launch_tensorboard.py
|
||||
utils\model_utils.py
|
||||
utils\port_manager.py
|
||||
utils\process_supervisor.py
|
||||
utils\reward_calculator.py
|
||||
utils\system_monitor.py
|
||||
utils\tensorboard_logger.py
|
||||
utils\text_logger.py
|
||||
verify_checkpoint_system.py
|
||||
web\__init__.py
|
||||
web\dashboard_fix.py
|
||||
web\dashboard_model.py
|
||||
web\layout_manager_with_tensorboard.py
|
||||
web\tensorboard_component.py
|
||||
web\tensorboard_integration.py
|
||||
84
DEPENDENCY_TREE.md
Normal file
84
DEPENDENCY_TREE.md
Normal file
@@ -0,0 +1,84 @@
|
||||
Dependency tree from dashboards (module -> deps):
|
||||
- NN\models\advanced_transformer_trading.py
|
||||
- NN\models\cob_rl_model.py
|
||||
- models\__init__.py
|
||||
- NN\models\dqn_agent.py
|
||||
- utils\checkpoint_manager.py
|
||||
- utils\training_integration.py
|
||||
- NN\models\enhanced_cnn.py
|
||||
- NN\models\model_interfaces.py
|
||||
- NN\models\standardized_cnn.py
|
||||
- core\data_models.py
|
||||
- core\cob_integration.py
|
||||
- core\config.py
|
||||
- safe_logging.py
|
||||
- core\data_models.py
|
||||
- core\data_provider.py
|
||||
- utils\cache_manager.py
|
||||
- utils\timezone_utils.py
|
||||
- core\exchanges\exchange_factory.py
|
||||
- core\exchanges\exchange_interface.py
|
||||
- core\extrema_trainer.py
|
||||
- utils\checkpoint_manager.py
|
||||
- utils\training_integration.py
|
||||
- core\multi_exchange_cob_provider.py
|
||||
- core\orchestrator.py
|
||||
- NN\models\advanced_transformer_trading.py
|
||||
- NN\models\cob_rl_model.py
|
||||
- NN\models\dqn_agent.py
|
||||
- NN\models\enhanced_cnn.py
|
||||
- NN\models\model_interfaces.py
|
||||
- NN\models\standardized_cnn.py
|
||||
- core\data_models.py
|
||||
- core\extrema_trainer.py
|
||||
- enhanced_realtime_training.py
|
||||
- models\__init__.py
|
||||
- utils\checkpoint_manager.py
|
||||
- utils\database_manager.py
|
||||
- utils\inference_logger.py
|
||||
- core\overnight_training_coordinator.py
|
||||
- core\realtime_rl_cob_trader.py
|
||||
- NN\models\cob_rl_model.py
|
||||
- core\trading_executor.py
|
||||
- utils\checkpoint_manager.py
|
||||
- core\standardized_data_provider.py
|
||||
- core\trade_data_manager.py
|
||||
- core\trading_executor.py
|
||||
- core\data_provider.py
|
||||
- core\exchanges\exchange_factory.py
|
||||
- core\exchanges\exchange_interface.py
|
||||
- core\training_integration.py
|
||||
- core\universal_data_adapter.py
|
||||
- enhanced_realtime_training.py
|
||||
- models\__init__.py
|
||||
- safe_logging.py
|
||||
- utils\cache_manager.py
|
||||
- utils\checkpoint_manager.py
|
||||
- utils\database_manager.py
|
||||
- utils\inference_logger.py
|
||||
- utils\timezone_utils.py
|
||||
- utils\training_integration.py
|
||||
- web\clean_dashboard.py
|
||||
- NN\models\advanced_transformer_trading.py
|
||||
- NN\models\standardized_cnn.py
|
||||
- core\cob_integration.py
|
||||
- core\config.py
|
||||
- core\data_models.py
|
||||
- core\data_provider.py
|
||||
- core\multi_exchange_cob_provider.py
|
||||
- core\orchestrator.py
|
||||
- core\overnight_training_coordinator.py
|
||||
- core\realtime_rl_cob_trader.py
|
||||
- core\standardized_data_provider.py
|
||||
- core\trade_data_manager.py
|
||||
- core\trading_executor.py
|
||||
- core\training_integration.py
|
||||
- core\universal_data_adapter.py
|
||||
- utils\checkpoint_manager.py
|
||||
- utils\timezone_utils.py
|
||||
- web\component_manager.py
|
||||
- web\layout_manager.py
|
||||
- web\cob_realtime_dashboard.py
|
||||
- core\cob_integration.py
|
||||
- web\component_manager.py
|
||||
- web\layout_manager.py
|
||||
35
KEEP_SET.txt
Normal file
35
KEEP_SET.txt
Normal file
@@ -0,0 +1,35 @@
|
||||
NN\models\advanced_transformer_trading.py
|
||||
NN\models\cob_rl_model.py
|
||||
NN\models\dqn_agent.py
|
||||
NN\models\enhanced_cnn.py
|
||||
NN\models\model_interfaces.py
|
||||
NN\models\standardized_cnn.py
|
||||
core\cob_integration.py
|
||||
core\config.py
|
||||
core\data_models.py
|
||||
core\data_provider.py
|
||||
core\exchanges\exchange_factory.py
|
||||
core\exchanges\exchange_interface.py
|
||||
core\extrema_trainer.py
|
||||
core\multi_exchange_cob_provider.py
|
||||
core\orchestrator.py
|
||||
core\overnight_training_coordinator.py
|
||||
core\realtime_rl_cob_trader.py
|
||||
core\standardized_data_provider.py
|
||||
core\trade_data_manager.py
|
||||
core\trading_executor.py
|
||||
core\training_integration.py
|
||||
core\universal_data_adapter.py
|
||||
enhanced_realtime_training.py
|
||||
models\__init__.py
|
||||
safe_logging.py
|
||||
utils\cache_manager.py
|
||||
utils\checkpoint_manager.py
|
||||
utils\database_manager.py
|
||||
utils\inference_logger.py
|
||||
utils\timezone_utils.py
|
||||
utils\training_integration.py
|
||||
web\clean_dashboard.py
|
||||
web\cob_realtime_dashboard.py
|
||||
web\component_manager.py
|
||||
web\layout_manager.py
|
||||
@@ -1,16 +0,0 @@
|
||||
"""
|
||||
Neural Network Trading System
|
||||
============================
|
||||
|
||||
A comprehensive neural network trading system that uses deep learning models
|
||||
to analyze cryptocurrency price data and generate trading signals.
|
||||
|
||||
The system consists of:
|
||||
1. Data Interface: Connects to realtime trading data
|
||||
2. CNN Model: Deep convolutional neural network for feature extraction
|
||||
3. Transformer Model: Processes high-level features for improved pattern recognition
|
||||
4. MoE: Mixture of Experts model that combines multiple neural networks
|
||||
"""
|
||||
|
||||
__version__ = '0.1.0'
|
||||
__author__ = 'Gogo2 Project'
|
||||
@@ -1,27 +0,0 @@
|
||||
"""
|
||||
Neural Network Models
|
||||
====================
|
||||
|
||||
This package contains the neural network models used in the trading system:
|
||||
- CNN Model: Deep convolutional neural network for feature extraction
|
||||
- DQN Agent: Deep Q-Network for reinforcement learning
|
||||
- COB RL Model: Specialized RL model for order book data
|
||||
- Advanced Transformer: High-performance transformer for trading
|
||||
|
||||
PyTorch implementation only.
|
||||
"""
|
||||
|
||||
# Import core models
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig
|
||||
from NN.models.standardized_cnn import StandardizedCNN # Use the unified CNN model
|
||||
|
||||
# Import model interfaces
|
||||
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
|
||||
|
||||
# Export the unified StandardizedCNN as CNNModel for compatibility
|
||||
CNNModel = StandardizedCNN
|
||||
|
||||
__all__ = ['CNNModel', 'StandardizedCNN', 'DQNAgent', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
|
||||
'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']
|
||||
@@ -1,201 +0,0 @@
|
||||
# """
|
||||
# Legacy CNN Model Compatibility Layer
|
||||
|
||||
# This module provides compatibility redirects to the unified StandardizedCNN model.
|
||||
# All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired
|
||||
# in favor of the StandardizedCNN architecture.
|
||||
# """
|
||||
|
||||
# import logging
|
||||
# import warnings
|
||||
# from typing import Tuple, Dict, Any, Optional
|
||||
# import torch
|
||||
# import numpy as np
|
||||
|
||||
# # Import the standardized CNN model
|
||||
# from .standardized_cnn import StandardizedCNN
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# # Compatibility aliases and wrappers
|
||||
# class EnhancedCNNModel:
|
||||
# """Legacy compatibility wrapper - redirects to StandardizedCNN"""
|
||||
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# warnings.warn(
|
||||
# "EnhancedCNNModel is deprecated. Use StandardizedCNN instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# # Create StandardizedCNN with default parameters
|
||||
# self.standardized_cnn = StandardizedCNN()
|
||||
# logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
||||
|
||||
# def __getattr__(self, name):
|
||||
# """Delegate all method calls to StandardizedCNN"""
|
||||
# return getattr(self.standardized_cnn, name)
|
||||
|
||||
|
||||
# class CNNModelTrainer:
|
||||
# """Legacy compatibility wrapper for CNN training"""
|
||||
|
||||
# def __init__(self, model=None, *args, **kwargs):
|
||||
# warnings.warn(
|
||||
# "CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# if isinstance(model, EnhancedCNNModel):
|
||||
# self.model = model.standardized_cnn
|
||||
# else:
|
||||
# self.model = StandardizedCNN()
|
||||
# logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()")
|
||||
|
||||
# def train_step(self, x, y, *args, **kwargs):
|
||||
# """Legacy train step wrapper"""
|
||||
# try:
|
||||
# # Convert to BaseDataInput format if needed
|
||||
# if hasattr(x, 'get_feature_vector'):
|
||||
# # Already BaseDataInput
|
||||
# base_input = x
|
||||
# else:
|
||||
# # Create mock BaseDataInput for legacy compatibility
|
||||
# from core.data_models import BaseDataInput
|
||||
# base_input = BaseDataInput()
|
||||
# # Set mock feature vector
|
||||
# if isinstance(x, torch.Tensor):
|
||||
# feature_vector = x.flatten().cpu().numpy()
|
||||
# else:
|
||||
# feature_vector = np.array(x).flatten()
|
||||
|
||||
# # Pad or truncate to expected size
|
||||
# expected_size = self.model.expected_feature_dim
|
||||
# if len(feature_vector) < expected_size:
|
||||
# padding = np.zeros(expected_size - len(feature_vector))
|
||||
# feature_vector = np.concatenate([feature_vector, padding])
|
||||
# else:
|
||||
# feature_vector = feature_vector[:expected_size]
|
||||
|
||||
# base_input._feature_vector = feature_vector
|
||||
|
||||
# # Convert target to string format
|
||||
# if isinstance(y, torch.Tensor):
|
||||
# y_val = y.item() if y.numel() == 1 else y.argmax().item()
|
||||
# else:
|
||||
# y_val = int(y) if np.isscalar(y) else int(np.argmax(y))
|
||||
|
||||
# target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'}
|
||||
# target = target_map.get(y_val, 'HOLD')
|
||||
|
||||
# # Use StandardizedCNN training
|
||||
# optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
||||
# loss = self.model.train_step([base_input], [target], optimizer)
|
||||
|
||||
# return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5}
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Legacy train_step error: {e}")
|
||||
# return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5}
|
||||
|
||||
|
||||
# # class CNNModel:
|
||||
# # """Legacy compatibility wrapper for CNN model interface"""
|
||||
|
||||
# # def __init__(self, input_shape=(900, 50), output_size=3, model_path=None):
|
||||
# # warnings.warn(
|
||||
# # "CNNModel is deprecated. Use StandardizedCNN directly.",
|
||||
# # DeprecationWarning,
|
||||
# # stacklevel=2
|
||||
# # )
|
||||
# # self.input_shape = input_shape
|
||||
# # self.output_size = output_size
|
||||
# # self.standardized_cnn = StandardizedCNN()
|
||||
# # self.trainer = CNNModelTrainer(self.standardized_cnn)
|
||||
# # logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
||||
|
||||
# # def build_model(self, **kwargs):
|
||||
# # """Legacy build method - no-op for StandardizedCNN"""
|
||||
# # return self
|
||||
|
||||
# # def predict(self, X):
|
||||
# # """Legacy predict method"""
|
||||
# # try:
|
||||
# # # Convert input to BaseDataInput
|
||||
# # from core.data_models import BaseDataInput
|
||||
# # base_input = BaseDataInput()
|
||||
|
||||
# # if isinstance(X, np.ndarray):
|
||||
# # feature_vector = X.flatten()
|
||||
# # else:
|
||||
# # feature_vector = np.array(X).flatten()
|
||||
|
||||
# # # Pad or truncate to expected size
|
||||
# # expected_size = self.standardized_cnn.expected_feature_dim
|
||||
# # if len(feature_vector) < expected_size:
|
||||
# # padding = np.zeros(expected_size - len(feature_vector))
|
||||
# # feature_vector = np.concatenate([feature_vector, padding])
|
||||
# # else:
|
||||
# # feature_vector = feature_vector[:expected_size]
|
||||
|
||||
# # base_input._feature_vector = feature_vector
|
||||
|
||||
# # # Get prediction from StandardizedCNN
|
||||
# # result = self.standardized_cnn.predict_from_base_input(base_input)
|
||||
|
||||
# # # Convert to legacy format
|
||||
# # action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
||||
# # pred_class = np.array([action_map.get(result.predictions['action'], 2)])
|
||||
# # pred_proba = np.array([result.predictions['action_probabilities']])
|
||||
|
||||
# # return pred_class, pred_proba
|
||||
|
||||
# # except Exception as e:
|
||||
# # logger.error(f"Legacy predict error: {e}")
|
||||
# # # Return safe defaults
|
||||
# # pred_class = np.array([2]) # HOLD
|
||||
# # pred_proba = np.array([[0.33, 0.33, 0.34]])
|
||||
# # return pred_class, pred_proba
|
||||
|
||||
# # def fit(self, X, y, **kwargs):
|
||||
# # """Legacy fit method"""
|
||||
# # try:
|
||||
# # return self.trainer.train_step(X, y)
|
||||
# # except Exception as e:
|
||||
# # logger.error(f"Legacy fit error: {e}")
|
||||
# # return self
|
||||
|
||||
# # def save(self, filepath: str):
|
||||
# # """Legacy save method"""
|
||||
# # try:
|
||||
# # torch.save(self.standardized_cnn.state_dict(), filepath)
|
||||
# # logger.info(f"StandardizedCNN saved to {filepath}")
|
||||
# # except Exception as e:
|
||||
# # logger.error(f"Error saving model: {e}")
|
||||
|
||||
|
||||
# def create_enhanced_cnn_model(input_size: int = 60,
|
||||
# feature_dim: int = 50,
|
||||
# output_size: int = 3,
|
||||
# base_channels: int = 256,
|
||||
# device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]:
|
||||
# """Legacy compatibility function - returns StandardizedCNN"""
|
||||
# warnings.warn(
|
||||
# "create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# model = StandardizedCNN()
|
||||
# trainer = CNNModelTrainer(model)
|
||||
|
||||
# logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly")
|
||||
# return model, trainer
|
||||
|
||||
|
||||
# # Export compatibility symbols
|
||||
# __all__ = [
|
||||
# 'EnhancedCNNModel',
|
||||
# 'CNNModelTrainer',
|
||||
# # 'CNNModel',
|
||||
# 'create_enhanced_cnn_model'
|
||||
# ]
|
||||
@@ -1,821 +0,0 @@
|
||||
"""
|
||||
Transformer Neural Network for timeseries analysis
|
||||
|
||||
This module implements a Transformer model with attention mechanisms for cryptocurrency price analysis.
|
||||
It also includes a Mixture of Experts model that combines predictions from multiple models.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Model, load_model
|
||||
from tensorflow.keras.layers import (
|
||||
Input, Dense, Dropout, BatchNormalization,
|
||||
Concatenate, Layer, LayerNormalization, MultiHeadAttention,
|
||||
Add, GlobalAveragePooling1D, Conv1D, Reshape
|
||||
)
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
|
||||
import datetime
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TransformerBlock(Layer):
|
||||
"""
|
||||
Transformer block implementation with multi-head attention and feed-forward networks.
|
||||
"""
|
||||
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
|
||||
super(TransformerBlock, self).__init__()
|
||||
self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
|
||||
self.ffn = tf.keras.Sequential([
|
||||
Dense(ff_dim, activation="relu"),
|
||||
Dense(embed_dim),
|
||||
])
|
||||
self.layernorm1 = LayerNormalization(epsilon=1e-6)
|
||||
self.layernorm2 = LayerNormalization(epsilon=1e-6)
|
||||
self.dropout1 = Dropout(rate)
|
||||
self.dropout2 = Dropout(rate)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
attn_output = self.att(inputs, inputs)
|
||||
attn_output = self.dropout1(attn_output, training=training)
|
||||
out1 = self.layernorm1(inputs + attn_output)
|
||||
ffn_output = self.ffn(out1)
|
||||
ffn_output = self.dropout2(ffn_output, training=training)
|
||||
return self.layernorm2(out1 + ffn_output)
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update({
|
||||
'att': self.att,
|
||||
'ffn': self.ffn,
|
||||
'layernorm1': self.layernorm1,
|
||||
'layernorm2': self.layernorm2,
|
||||
'dropout1': self.dropout1,
|
||||
'dropout2': self.dropout2
|
||||
})
|
||||
return config
|
||||
|
||||
class PositionalEncoding(Layer):
|
||||
"""
|
||||
Positional encoding layer to add position information to input embeddings.
|
||||
"""
|
||||
def __init__(self, position, d_model):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.position = position
|
||||
self.d_model = d_model
|
||||
self.pos_encoding = self.positional_encoding(position, d_model)
|
||||
|
||||
def get_angles(self, position, i, d_model):
|
||||
angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
|
||||
return position * angles
|
||||
|
||||
def positional_encoding(self, position, d_model):
|
||||
angle_rads = self.get_angles(
|
||||
position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
|
||||
i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
|
||||
d_model=d_model
|
||||
)
|
||||
|
||||
# Apply sin to even indices in the array
|
||||
sines = tf.math.sin(angle_rads[:, 0::2])
|
||||
|
||||
# Apply cos to odd indices in the array
|
||||
cosines = tf.math.cos(angle_rads[:, 1::2])
|
||||
|
||||
pos_encoding = tf.concat([sines, cosines], axis=-1)
|
||||
pos_encoding = pos_encoding[tf.newaxis, ...]
|
||||
|
||||
return tf.cast(pos_encoding, tf.float32)
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update({
|
||||
'position': self.position,
|
||||
'd_model': self.d_model,
|
||||
'pos_encoding': self.pos_encoding
|
||||
})
|
||||
return config
|
||||
|
||||
class TransformerModel:
|
||||
"""
|
||||
Transformer Neural Network for time series analysis.
|
||||
|
||||
This model uses self-attention mechanisms to capture relationships between
|
||||
different time points in the input data.
|
||||
"""
|
||||
|
||||
def __init__(self, ts_input_shape=(20, 5), feature_input_shape=64, output_size=1, model_dir="NN/models/saved"):
|
||||
"""
|
||||
Initialize the Transformer model.
|
||||
|
||||
Args:
|
||||
ts_input_shape (tuple): Shape of time series input data (sequence_length, features)
|
||||
feature_input_shape (int): Shape of additional feature input (e.g., from CNN)
|
||||
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
|
||||
model_dir (str): Directory to save trained models
|
||||
"""
|
||||
self.ts_input_shape = ts_input_shape
|
||||
self.feature_input_shape = feature_input_shape
|
||||
self.output_size = output_size
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
self.history = None
|
||||
|
||||
# Create model directory if it doesn't exist
|
||||
os.makedirs(self.model_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"Initialized Transformer model with TS input shape {ts_input_shape}, "
|
||||
f"feature input shape {feature_input_shape}, and output size {output_size}")
|
||||
|
||||
def build_model(self, embed_dim=32, num_heads=4, ff_dim=64, num_transformer_blocks=2, dropout_rate=0.1, learning_rate=0.001):
|
||||
"""
|
||||
Build the Transformer model architecture.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Embedding dimension for transformer
|
||||
num_heads (int): Number of attention heads
|
||||
ff_dim (int): Hidden dimension of the feed forward network
|
||||
num_transformer_blocks (int): Number of transformer blocks
|
||||
dropout_rate (float): Dropout rate for regularization
|
||||
learning_rate (float): Learning rate for Adam optimizer
|
||||
|
||||
Returns:
|
||||
The compiled model
|
||||
"""
|
||||
# Time series input
|
||||
ts_inputs = Input(shape=self.ts_input_shape, name="ts_input")
|
||||
|
||||
# Additional feature input (e.g., from CNN)
|
||||
feature_inputs = Input(shape=(self.feature_input_shape,), name="feature_input")
|
||||
|
||||
# Process time series with transformer
|
||||
# First, project the input to the embedding dimension
|
||||
x = Conv1D(embed_dim, 1, activation="relu")(ts_inputs)
|
||||
|
||||
# Add positional encoding
|
||||
x = PositionalEncoding(self.ts_input_shape[0], embed_dim)(x)
|
||||
|
||||
# Add transformer blocks
|
||||
for _ in range(num_transformer_blocks):
|
||||
x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate)(x)
|
||||
|
||||
# Global pooling to get a single vector representation
|
||||
x = GlobalAveragePooling1D()(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# Combine with additional features
|
||||
combined = Concatenate()([x, feature_inputs])
|
||||
|
||||
# Dense layers for final classification/regression
|
||||
x = Dense(64, activation="relu")(combined)
|
||||
x = BatchNormalization()(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# Output layer
|
||||
if self.output_size == 1:
|
||||
# Binary classification (up/down)
|
||||
outputs = Dense(1, activation='sigmoid', name='output')(x)
|
||||
loss = 'binary_crossentropy'
|
||||
metrics = ['accuracy']
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification (buy/hold/sell)
|
||||
outputs = Dense(3, activation='softmax', name='output')(x)
|
||||
loss = 'categorical_crossentropy'
|
||||
metrics = ['accuracy']
|
||||
else:
|
||||
# Regression
|
||||
outputs = Dense(self.output_size, activation='linear', name='output')(x)
|
||||
loss = 'mse'
|
||||
metrics = ['mae']
|
||||
|
||||
# Create and compile model
|
||||
self.model = Model(inputs=[ts_inputs, feature_inputs], outputs=outputs)
|
||||
|
||||
# Compile with Adam optimizer
|
||||
self.model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss=loss,
|
||||
metrics=metrics
|
||||
)
|
||||
|
||||
# Log model summary
|
||||
self.model.summary(print_fn=lambda x: logger.info(x))
|
||||
|
||||
return self.model
|
||||
|
||||
def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
|
||||
callbacks=None, class_weights=None):
|
||||
"""
|
||||
Train the Transformer model on the provided data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
y (numpy.ndarray): Target labels
|
||||
batch_size (int): Batch size
|
||||
epochs (int): Number of epochs
|
||||
validation_split (float): Fraction of data to use for validation
|
||||
callbacks (list): List of Keras callbacks
|
||||
class_weights (dict): Class weights for imbalanced datasets
|
||||
|
||||
Returns:
|
||||
History object containing training metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
self.build_model()
|
||||
|
||||
# Default callbacks if none provided
|
||||
if callbacks is None:
|
||||
# Create a timestamp for model checkpoints
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
callbacks = [
|
||||
EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
restore_best_weights=True
|
||||
),
|
||||
ReduceLROnPlateau(
|
||||
monitor='val_loss',
|
||||
factor=0.5,
|
||||
patience=5,
|
||||
min_lr=1e-6
|
||||
),
|
||||
ModelCheckpoint(
|
||||
filepath=os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5"),
|
||||
monitor='val_loss',
|
||||
save_best_only=True
|
||||
)
|
||||
]
|
||||
|
||||
# Check if y needs to be one-hot encoded for multi-class
|
||||
if self.output_size == 3 and len(y.shape) == 1:
|
||||
y = tf.keras.utils.to_categorical(y, num_classes=3)
|
||||
|
||||
# Train the model
|
||||
logger.info(f"Training Transformer model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
|
||||
self.history = self.model.fit(
|
||||
[X_ts, X_features], y,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
validation_split=validation_split,
|
||||
callbacks=callbacks,
|
||||
class_weight=class_weights,
|
||||
verbose=2
|
||||
)
|
||||
|
||||
# Save the trained model
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_path = os.path.join(self.model_dir, f"transformer_model_final_{timestamp}.h5")
|
||||
self.model.save(model_path)
|
||||
logger.info(f"Model saved to {model_path}")
|
||||
|
||||
# Save training history
|
||||
history_path = os.path.join(self.model_dir, f"transformer_model_history_{timestamp}.json")
|
||||
with open(history_path, 'w') as f:
|
||||
# Convert numpy values to Python native types for JSON serialization
|
||||
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
|
||||
json.dump(history_dict, f, indent=2)
|
||||
|
||||
return self.history
|
||||
|
||||
def evaluate(self, X_ts, X_features, y):
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
y (numpy.ndarray): Target labels
|
||||
|
||||
Returns:
|
||||
dict: Evaluation metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Convert y to one-hot encoding for multi-class
|
||||
if self.output_size == 3 and len(y.shape) == 1:
|
||||
y = tf.keras.utils.to_categorical(y, num_classes=3)
|
||||
|
||||
# Evaluate model
|
||||
logger.info(f"Evaluating Transformer model on {len(X_ts)} samples")
|
||||
eval_results = self.model.evaluate([X_ts, X_features], y, verbose=0)
|
||||
|
||||
metrics = {}
|
||||
for metric, value in zip(self.model.metrics_names, eval_results):
|
||||
metrics[metric] = value
|
||||
logger.info(f"{metric}: {value:.4f}")
|
||||
|
||||
return metrics
|
||||
|
||||
def predict(self, X_ts, X_features=None):
|
||||
"""
|
||||
Make predictions on new data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
|
||||
Returns:
|
||||
tuple: (y_pred, y_proba) where:
|
||||
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
|
||||
y_proba is the class probability
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Ensure X_ts has the right shape
|
||||
if len(X_ts.shape) == 2:
|
||||
# Single sample, add batch dimension
|
||||
X_ts = np.expand_dims(X_ts, axis=0)
|
||||
|
||||
# Ensure X_features has the right shape
|
||||
if X_features is None:
|
||||
# Extract features from time series data if no external features provided
|
||||
X_features = self._extract_features_from_timeseries(X_ts)
|
||||
elif len(X_features.shape) == 1:
|
||||
# Single sample, add batch dimension
|
||||
X_features = np.expand_dims(X_features, axis=0)
|
||||
|
||||
def _extract_features_from_timeseries(self, X_ts: np.ndarray) -> np.ndarray:
|
||||
"""Extract meaningful features from time series data instead of using dummy zeros"""
|
||||
try:
|
||||
batch_size = X_ts.shape[0]
|
||||
features = []
|
||||
|
||||
for i in range(batch_size):
|
||||
sample = X_ts[i] # Shape: (timesteps, features)
|
||||
|
||||
# Extract statistical features from each feature dimension
|
||||
sample_features = []
|
||||
|
||||
for feature_idx in range(sample.shape[1]):
|
||||
feature_data = sample[:, feature_idx]
|
||||
|
||||
# Basic statistical features
|
||||
sample_features.extend([
|
||||
np.mean(feature_data), # Mean
|
||||
np.std(feature_data), # Standard deviation
|
||||
np.min(feature_data), # Minimum
|
||||
np.max(feature_data), # Maximum
|
||||
np.percentile(feature_data, 25), # 25th percentile
|
||||
np.percentile(feature_data, 75), # 75th percentile
|
||||
])
|
||||
|
||||
# Trend features
|
||||
if len(feature_data) > 1:
|
||||
# Linear trend (slope)
|
||||
x = np.arange(len(feature_data))
|
||||
slope = np.polyfit(x, feature_data, 1)[0]
|
||||
sample_features.append(slope)
|
||||
|
||||
# Rate of change
|
||||
rate_of_change = (feature_data[-1] - feature_data[0]) / feature_data[0] if feature_data[0] != 0 else 0
|
||||
sample_features.append(rate_of_change)
|
||||
else:
|
||||
sample_features.extend([0.0, 0.0])
|
||||
|
||||
# Pad or truncate to expected feature size
|
||||
while len(sample_features) < self.feature_input_shape:
|
||||
sample_features.append(0.0)
|
||||
sample_features = sample_features[:self.feature_input_shape]
|
||||
|
||||
features.append(sample_features)
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting features from time series: {e}")
|
||||
# Fallback to zeros if extraction fails
|
||||
return np.zeros((X_ts.shape[0], self.feature_input_shape), dtype=np.float32)
|
||||
|
||||
# Get predictions
|
||||
y_proba = self.model.predict([X_ts, X_features])
|
||||
|
||||
# Process based on output type
|
||||
if self.output_size == 1:
|
||||
# Binary classification
|
||||
y_pred = (y_proba > 0.5).astype(int).flatten()
|
||||
return y_pred, y_proba.flatten()
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification
|
||||
y_pred = np.argmax(y_proba, axis=1)
|
||||
return y_pred, y_proba
|
||||
else:
|
||||
# Regression
|
||||
return y_proba, y_proba
|
||||
|
||||
def save(self, filepath=None):
|
||||
"""
|
||||
Save the model to disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to save the model
|
||||
|
||||
Returns:
|
||||
str: Path where the model was saved
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built yet")
|
||||
|
||||
if filepath is None:
|
||||
# Create a default filepath with timestamp
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filepath = os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5")
|
||||
|
||||
self.model.save(filepath)
|
||||
logger.info(f"Model saved to {filepath}")
|
||||
return filepath
|
||||
|
||||
def load(self, filepath):
|
||||
"""
|
||||
Load a saved model from disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the saved model
|
||||
|
||||
Returns:
|
||||
The loaded model
|
||||
"""
|
||||
# Register custom layers
|
||||
custom_objects = {
|
||||
'TransformerBlock': TransformerBlock,
|
||||
'PositionalEncoding': PositionalEncoding
|
||||
}
|
||||
|
||||
self.model = load_model(filepath, custom_objects=custom_objects)
|
||||
logger.info(f"Model loaded from {filepath}")
|
||||
return self.model
|
||||
|
||||
def plot_training_history(self):
|
||||
"""
|
||||
Plot training history (loss and metrics).
|
||||
|
||||
Returns:
|
||||
str: Path to the saved plot
|
||||
"""
|
||||
if self.history is None:
|
||||
raise ValueError("Model has not been trained yet")
|
||||
|
||||
plt.figure(figsize=(12, 5))
|
||||
|
||||
# Plot loss
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(self.history.history['loss'], label='Training Loss')
|
||||
if 'val_loss' in self.history.history:
|
||||
plt.plot(self.history.history['val_loss'], label='Validation Loss')
|
||||
plt.title('Model Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
|
||||
# Plot accuracy
|
||||
plt.subplot(1, 2, 2)
|
||||
|
||||
if 'accuracy' in self.history.history:
|
||||
plt.plot(self.history.history['accuracy'], label='Training Accuracy')
|
||||
if 'val_accuracy' in self.history.history:
|
||||
plt.plot(self.history.history['val_accuracy'], label='Validation Accuracy')
|
||||
plt.title('Model Accuracy')
|
||||
plt.ylabel('Accuracy')
|
||||
elif 'mae' in self.history.history:
|
||||
plt.plot(self.history.history['mae'], label='Training MAE')
|
||||
if 'val_mae' in self.history.history:
|
||||
plt.plot(self.history.history['val_mae'], label='Validation MAE')
|
||||
plt.title('Model MAE')
|
||||
plt.ylabel('MAE')
|
||||
|
||||
plt.xlabel('Epoch')
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
fig_path = os.path.join(self.model_dir, f"transformer_training_history_{timestamp}.png")
|
||||
plt.savefig(fig_path)
|
||||
plt.close()
|
||||
|
||||
logger.info(f"Training history plot saved to {fig_path}")
|
||||
return fig_path
|
||||
|
||||
|
||||
class MixtureOfExpertsModel:
|
||||
"""
|
||||
Mixture of Experts (MoE) model.
|
||||
|
||||
This model combines predictions from multiple expert models (such as CNN and Transformer)
|
||||
using a weighted ensemble approach.
|
||||
"""
|
||||
|
||||
def __init__(self, output_size=1, model_dir="NN/models/saved"):
|
||||
"""
|
||||
Initialize the MoE model.
|
||||
|
||||
Args:
|
||||
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
|
||||
model_dir (str): Directory to save trained models
|
||||
"""
|
||||
self.output_size = output_size
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
self.history = None
|
||||
self.experts = {}
|
||||
|
||||
# Create model directory if it doesn't exist
|
||||
os.makedirs(self.model_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"Initialized Mixture of Experts model with output size {output_size}")
|
||||
|
||||
def add_expert(self, name, model):
|
||||
"""
|
||||
Add an expert model to the MoE.
|
||||
|
||||
Args:
|
||||
name (str): Name of the expert model
|
||||
model: The expert model instance
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.experts[name] = model
|
||||
logger.info(f"Added expert model '{name}' to MoE")
|
||||
|
||||
def build_model(self, ts_input_shape=(20, 5), expert_weights=None, learning_rate=0.001):
|
||||
"""
|
||||
Build the MoE model by combining expert models.
|
||||
|
||||
Args:
|
||||
ts_input_shape (tuple): Shape of time series input data
|
||||
expert_weights (dict): Weights for each expert model
|
||||
learning_rate (float): Learning rate for Adam optimizer
|
||||
|
||||
Returns:
|
||||
The compiled model
|
||||
"""
|
||||
# Time series input
|
||||
ts_inputs = Input(shape=ts_input_shape, name="ts_input")
|
||||
|
||||
# Additional feature input (from CNN)
|
||||
feature_inputs = Input(shape=(64,), name="feature_input") # Default size for features
|
||||
|
||||
# Process with each expert model
|
||||
expert_outputs = []
|
||||
expert_names = []
|
||||
|
||||
for name, expert in self.experts.items():
|
||||
# Skip if expert model is not valid or doesn't have a call/predict method
|
||||
if expert is None:
|
||||
logger.warning(f"Expert model '{name}' is None, skipping")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Different handling based on model type
|
||||
if name == 'cnn':
|
||||
# CNN model takes only time series input
|
||||
expert_output = expert(ts_inputs)
|
||||
expert_outputs.append(expert_output)
|
||||
expert_names.append(name)
|
||||
elif name == 'transformer':
|
||||
# Transformer model takes both time series and feature inputs
|
||||
expert_output = expert([ts_inputs, feature_inputs])
|
||||
expert_outputs.append(expert_output)
|
||||
expert_names.append(name)
|
||||
else:
|
||||
logger.warning(f"Unknown expert model type: {name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding expert '{name}': {str(e)}")
|
||||
|
||||
if not expert_outputs:
|
||||
logger.error("No valid expert models found")
|
||||
return None
|
||||
|
||||
# Use expert weighting
|
||||
if expert_weights is None:
|
||||
# Equal weighting
|
||||
weights = [1.0 / len(expert_outputs)] * len(expert_outputs)
|
||||
else:
|
||||
# User-provided weights
|
||||
weights = [expert_weights.get(name, 1.0 / len(expert_outputs)) for name in expert_names]
|
||||
# Normalize weights
|
||||
weights = [w / sum(weights) for w in weights]
|
||||
|
||||
# Combine expert outputs using weighted average
|
||||
if len(expert_outputs) == 1:
|
||||
# Only one expert, use its output directly
|
||||
combined_output = expert_outputs[0]
|
||||
else:
|
||||
# Multiple experts, compute weighted average
|
||||
weighted_outputs = [output * weight for output, weight in zip(expert_outputs, weights)]
|
||||
combined_output = Add()(weighted_outputs)
|
||||
|
||||
# Create the MoE model
|
||||
moe_model = Model(inputs=[ts_inputs, feature_inputs], outputs=combined_output)
|
||||
|
||||
# Compile the model
|
||||
if self.output_size == 1:
|
||||
# Binary classification
|
||||
moe_model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss='binary_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification for BUY/HOLD/SELL
|
||||
moe_model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
else:
|
||||
# Regression
|
||||
moe_model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss='mse',
|
||||
metrics=['mae']
|
||||
)
|
||||
|
||||
self.model = moe_model
|
||||
|
||||
# Log model summary
|
||||
self.model.summary(print_fn=lambda x: logger.info(x))
|
||||
|
||||
logger.info(f"Built MoE model with weights: {weights}")
|
||||
return self.model
|
||||
|
||||
def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
|
||||
callbacks=None, class_weights=None):
|
||||
"""
|
||||
Train the MoE model on the provided data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
y (numpy.ndarray): Target labels
|
||||
batch_size (int): Batch size
|
||||
epochs (int): Number of epochs
|
||||
validation_split (float): Fraction of data to use for validation
|
||||
callbacks (list): List of Keras callbacks
|
||||
class_weights (dict): Class weights for imbalanced datasets
|
||||
|
||||
Returns:
|
||||
History object containing training metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
logger.error("MoE model has not been built yet")
|
||||
return None
|
||||
|
||||
# Default callbacks if none provided
|
||||
if callbacks is None:
|
||||
# Create a timestamp for model checkpoints
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
callbacks = [
|
||||
EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
restore_best_weights=True
|
||||
),
|
||||
ReduceLROnPlateau(
|
||||
monitor='val_loss',
|
||||
factor=0.5,
|
||||
patience=5,
|
||||
min_lr=1e-6
|
||||
),
|
||||
ModelCheckpoint(
|
||||
filepath=os.path.join(self.model_dir, f"moe_model_{timestamp}.h5"),
|
||||
monitor='val_loss',
|
||||
save_best_only=True
|
||||
)
|
||||
]
|
||||
|
||||
# Check if y needs to be one-hot encoded for multi-class
|
||||
if self.output_size == 3 and len(y.shape) == 1:
|
||||
y = tf.keras.utils.to_categorical(y, num_classes=3)
|
||||
|
||||
# Train the model
|
||||
logger.info(f"Training MoE model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
|
||||
self.history = self.model.fit(
|
||||
[X_ts, X_features], y,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
validation_split=validation_split,
|
||||
callbacks=callbacks,
|
||||
class_weight=class_weights,
|
||||
verbose=2
|
||||
)
|
||||
|
||||
# Save the trained model
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_path = os.path.join(self.model_dir, f"moe_model_final_{timestamp}.h5")
|
||||
self.model.save(model_path)
|
||||
logger.info(f"Model saved to {model_path}")
|
||||
|
||||
# Save training history
|
||||
history_path = os.path.join(self.model_dir, f"moe_model_history_{timestamp}.json")
|
||||
with open(history_path, 'w') as f:
|
||||
# Convert numpy values to Python native types for JSON serialization
|
||||
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
|
||||
json.dump(history_dict, f, indent=2)
|
||||
|
||||
return self.history
|
||||
|
||||
def predict(self, X_ts, X_features=None):
|
||||
"""
|
||||
Make predictions on new data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
|
||||
Returns:
|
||||
tuple: (y_pred, y_proba) where:
|
||||
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
|
||||
y_proba is the class probability
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Ensure X_ts has the right shape
|
||||
if len(X_ts.shape) == 2:
|
||||
# Single sample, add batch dimension
|
||||
X_ts = np.expand_dims(X_ts, axis=0)
|
||||
|
||||
# Ensure X_features has the right shape
|
||||
if X_features is None:
|
||||
# Create dummy features with zeros
|
||||
X_features = np.zeros((X_ts.shape[0], 64)) # Default size
|
||||
elif len(X_features.shape) == 1:
|
||||
# Single sample, add batch dimension
|
||||
X_features = np.expand_dims(X_features, axis=0)
|
||||
|
||||
# Get predictions
|
||||
y_proba = self.model.predict([X_ts, X_features])
|
||||
|
||||
# Process based on output type
|
||||
if self.output_size == 1:
|
||||
# Binary classification
|
||||
y_pred = (y_proba > 0.5).astype(int).flatten()
|
||||
return y_pred, y_proba.flatten()
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification
|
||||
y_pred = np.argmax(y_proba, axis=1)
|
||||
return y_pred, y_proba
|
||||
else:
|
||||
# Regression
|
||||
return y_proba, y_proba
|
||||
|
||||
def save(self, filepath=None):
|
||||
"""
|
||||
Save the model to disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to save the model
|
||||
|
||||
Returns:
|
||||
str: Path where the model was saved
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built yet")
|
||||
|
||||
if filepath is None:
|
||||
# Create a default filepath with timestamp
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filepath = os.path.join(self.model_dir, f"moe_model_{timestamp}.h5")
|
||||
|
||||
self.model.save(filepath)
|
||||
logger.info(f"Model saved to {filepath}")
|
||||
return filepath
|
||||
|
||||
def load(self, filepath):
|
||||
"""
|
||||
Load a saved model from disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the saved model
|
||||
|
||||
Returns:
|
||||
The loaded model
|
||||
"""
|
||||
# Register custom layers
|
||||
custom_objects = {
|
||||
'TransformerBlock': TransformerBlock,
|
||||
'PositionalEncoding': PositionalEncoding
|
||||
}
|
||||
|
||||
self.model = load_model(filepath, custom_objects=custom_objects)
|
||||
logger.info(f"Model loaded from {filepath}")
|
||||
return self.model
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# This would be a complete implementation in a real system
|
||||
print("Transformer and MoE models defined, but not implemented here.")
|
||||
@@ -1,88 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Start TensorBoard for monitoring neural network training
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import webbrowser
|
||||
from time import sleep
|
||||
|
||||
def start_tensorboard(logdir="NN/models/saved/logs", port=6006, open_browser=True):
|
||||
"""
|
||||
Start TensorBoard in a subprocess
|
||||
|
||||
Args:
|
||||
logdir: Directory containing TensorBoard logs
|
||||
port: Port to run TensorBoard on
|
||||
open_browser: Whether to open a browser automatically
|
||||
"""
|
||||
# Make sure the log directory exists
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
|
||||
# Create command
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tensorboard.main",
|
||||
f"--logdir={logdir}",
|
||||
f"--port={port}",
|
||||
"--bind_all"
|
||||
]
|
||||
|
||||
print(f"Starting TensorBoard with logs from {logdir} on port {port}")
|
||||
print(f"Command: {' '.join(cmd)}")
|
||||
|
||||
# Start TensorBoard in a subprocess
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True
|
||||
)
|
||||
|
||||
# Wait for TensorBoard to start up
|
||||
for line in process.stdout:
|
||||
print(line.strip())
|
||||
if "TensorBoard" in line and "http://" in line:
|
||||
# TensorBoard is running, extract the URL
|
||||
url = None
|
||||
for part in line.split():
|
||||
if part.startswith(("http://", "https://")):
|
||||
url = part
|
||||
break
|
||||
|
||||
# Open browser if requested and URL found
|
||||
if open_browser and url:
|
||||
print(f"Opening TensorBoard in browser: {url}")
|
||||
webbrowser.open(url)
|
||||
|
||||
break
|
||||
|
||||
# Return the process for the caller to manage
|
||||
return process
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description="Start TensorBoard for NN training visualization")
|
||||
parser.add_argument("--logdir", default="NN/models/saved/logs", help="Directory containing TensorBoard logs")
|
||||
parser.add_argument("--port", type=int, default=6006, help="Port to run TensorBoard on")
|
||||
parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Start TensorBoard
|
||||
process = start_tensorboard(args.logdir, args.port, not args.no_browser)
|
||||
|
||||
try:
|
||||
# Keep the script running until Ctrl+C
|
||||
print("TensorBoard is running. Press Ctrl+C to stop.")
|
||||
while True:
|
||||
sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("Stopping TensorBoard...")
|
||||
process.terminate()
|
||||
process.wait()
|
||||
@@ -1,490 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Enhanced RL Training Integration - Comprehensive Fix
|
||||
|
||||
This script addresses the critical RL training audit issues:
|
||||
1. MASSIVE INPUT DATA GAP (99.25% Missing) - Implements full 13,400 feature state
|
||||
2. Disconnected Training Pipeline - Provides proper data flow integration
|
||||
3. Missing Enhanced State Builder - Connects orchestrator to dashboard
|
||||
4. Reward Calculation Issues - Ensures enhanced pivot-based rewards
|
||||
5. Williams Market Structure Integration - Proper feature extraction
|
||||
6. Real-time Data Integration - Live market data to RL
|
||||
|
||||
Usage:
|
||||
python enhanced_rl_training_integration.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging, get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
from utils.tensorboard_logger import TensorBoardLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedRLTrainingIntegrator:
|
||||
"""
|
||||
Comprehensive RL Training Integrator
|
||||
|
||||
Fixes all audit issues by ensuring proper data flow and feature completeness.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the enhanced RL training integrator"""
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
logger.info("=" * 70)
|
||||
logger.info("ENHANCED RL TRAINING INTEGRATION - COMPREHENSIVE FIX")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Get configuration
|
||||
self.config = get_config()
|
||||
|
||||
# Initialize core components
|
||||
self.data_provider = DataProvider()
|
||||
self.enhanced_orchestrator = None
|
||||
self.trading_executor = TradingExecutor()
|
||||
self.dashboard = None
|
||||
|
||||
# Training metrics
|
||||
self.training_stats = {
|
||||
'total_episodes': 0,
|
||||
'successful_state_builds': 0,
|
||||
'enhanced_reward_calculations': 0,
|
||||
'comprehensive_features_used': 0,
|
||||
'pivot_features_extracted': 0,
|
||||
'cob_features_available': 0
|
||||
}
|
||||
|
||||
# Initialize TensorBoard logger
|
||||
experiment_name = f"enhanced_rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self.tb_logger = TensorBoardLogger(
|
||||
log_dir="runs",
|
||||
experiment_name=experiment_name,
|
||||
enabled=True
|
||||
)
|
||||
logger.info(f"TensorBoard logging enabled for experiment: {experiment_name}")
|
||||
|
||||
logger.info("Enhanced RL Training Integrator initialized")
|
||||
|
||||
async def start_integration(self):
|
||||
"""Start the comprehensive RL training integration"""
|
||||
try:
|
||||
logger.info("Starting comprehensive RL training integration...")
|
||||
|
||||
# 1. Initialize Enhanced Orchestrator with comprehensive features
|
||||
await self._initialize_enhanced_orchestrator()
|
||||
|
||||
# 2. Create enhanced dashboard with proper connections
|
||||
await self._create_enhanced_dashboard()
|
||||
|
||||
# 3. Verify comprehensive state building
|
||||
await self._verify_comprehensive_state_building()
|
||||
|
||||
# 4. Test enhanced reward calculation
|
||||
await self._test_enhanced_reward_calculation()
|
||||
|
||||
# 5. Validate Williams market structure integration
|
||||
await self._validate_williams_integration()
|
||||
|
||||
# 6. Start live training with comprehensive features
|
||||
await self._start_live_comprehensive_training()
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("COMPREHENSIVE RL TRAINING INTEGRATION COMPLETE")
|
||||
logger.info("=" * 70)
|
||||
self._log_integration_stats()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RL training integration: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _initialize_enhanced_orchestrator(self):
|
||||
"""Initialize enhanced orchestrator with comprehensive RL capabilities"""
|
||||
try:
|
||||
logger.info("[STEP 1] Initializing Enhanced Orchestrator...")
|
||||
|
||||
# Create enhanced orchestrator with RL training enabled
|
||||
self.enhanced_orchestrator = EnhancedTradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
enhanced_rl_training=True,
|
||||
model_registry={} # Will be populated as needed
|
||||
)
|
||||
|
||||
# Start COB integration for real-time market microstructure
|
||||
await self.enhanced_orchestrator.start_cob_integration()
|
||||
|
||||
# Start real-time processing
|
||||
await self.enhanced_orchestrator.start_realtime_processing()
|
||||
|
||||
logger.info("[SUCCESS] Enhanced Orchestrator initialized with:")
|
||||
logger.info(" - Comprehensive RL state building: ENABLED")
|
||||
logger.info(" - Enhanced pivot-based rewards: ENABLED")
|
||||
logger.info(" - COB integration: ENABLED")
|
||||
logger.info(" - Williams market structure: ENABLED")
|
||||
logger.info(" - Real-time tick processing: ENABLED")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing enhanced orchestrator: {e}")
|
||||
raise
|
||||
|
||||
async def _create_enhanced_dashboard(self):
|
||||
"""Create dashboard with enhanced orchestrator connections"""
|
||||
try:
|
||||
logger.info("[STEP 2] Creating Enhanced Dashboard...")
|
||||
|
||||
# Create trading dashboard with enhanced orchestrator
|
||||
self.dashboard = TradingDashboard(
|
||||
data_provider=self.data_provider,
|
||||
orchestrator=self.enhanced_orchestrator, # Use enhanced orchestrator
|
||||
trading_executor=self.trading_executor
|
||||
)
|
||||
|
||||
# Verify enhanced connections
|
||||
has_comprehensive_state_builder = hasattr(self.dashboard.orchestrator, 'build_comprehensive_rl_state')
|
||||
has_enhanced_reward_calc = hasattr(self.dashboard.orchestrator, 'calculate_enhanced_pivot_reward')
|
||||
has_symbol_correlation = hasattr(self.dashboard.orchestrator, '_get_symbol_correlation')
|
||||
|
||||
logger.info("[SUCCESS] Enhanced Dashboard created with:")
|
||||
logger.info(f" - Comprehensive state builder: {'AVAILABLE' if has_comprehensive_state_builder else 'MISSING'}")
|
||||
logger.info(f" - Enhanced reward calculation: {'AVAILABLE' if has_enhanced_reward_calc else 'MISSING'}")
|
||||
logger.info(f" - Symbol correlation analysis: {'AVAILABLE' if has_symbol_correlation else 'MISSING'}")
|
||||
|
||||
if not all([has_comprehensive_state_builder, has_enhanced_reward_calc, has_symbol_correlation]):
|
||||
logger.warning("Some enhanced features are missing - this will cause fallbacks to basic training")
|
||||
else:
|
||||
logger.info(" - ALL ENHANCED FEATURES AVAILABLE!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating enhanced dashboard: {e}")
|
||||
raise
|
||||
|
||||
async def _verify_comprehensive_state_building(self):
|
||||
"""Verify that comprehensive RL state building works correctly"""
|
||||
try:
|
||||
logger.info("[STEP 3] Verifying Comprehensive State Building...")
|
||||
|
||||
# Test comprehensive state building for ETH
|
||||
eth_state = self.enhanced_orchestrator.build_comprehensive_rl_state('ETH/USDT')
|
||||
|
||||
if eth_state is not None:
|
||||
logger.info(f"[SUCCESS] ETH comprehensive state built: {len(eth_state)} features")
|
||||
|
||||
# Verify feature count
|
||||
if len(eth_state) == 13400:
|
||||
logger.info(" - PERFECT: Exactly 13,400 features as required!")
|
||||
self.training_stats['comprehensive_features_used'] += 1
|
||||
else:
|
||||
logger.warning(f" - MISMATCH: Expected 13,400 features, got {len(eth_state)}")
|
||||
|
||||
# Analyze feature distribution
|
||||
self._analyze_state_features(eth_state)
|
||||
self.training_stats['successful_state_builds'] += 1
|
||||
|
||||
else:
|
||||
logger.error(" - FAILED: Comprehensive state building returned None")
|
||||
|
||||
# Test for BTC reference
|
||||
btc_state = self.enhanced_orchestrator.build_comprehensive_rl_state('BTC/USDT')
|
||||
if btc_state is not None:
|
||||
logger.info(f"[SUCCESS] BTC reference state built: {len(btc_state)} features")
|
||||
self.training_stats['successful_state_builds'] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying comprehensive state building: {e}")
|
||||
|
||||
def _analyze_state_features(self, state_vector: np.ndarray):
|
||||
"""Analyze the comprehensive state feature distribution"""
|
||||
try:
|
||||
# Calculate feature statistics
|
||||
non_zero_features = np.count_nonzero(state_vector)
|
||||
zero_features = len(state_vector) - non_zero_features
|
||||
feature_mean = np.mean(state_vector)
|
||||
feature_std = np.std(state_vector)
|
||||
feature_min = np.min(state_vector)
|
||||
feature_max = np.max(state_vector)
|
||||
|
||||
logger.info(" - Feature Analysis:")
|
||||
logger.info(f" * Non-zero features: {non_zero_features:,} ({non_zero_features/len(state_vector)*100:.1f}%)")
|
||||
logger.info(f" * Zero features: {zero_features:,} ({zero_features/len(state_vector)*100:.1f}%)")
|
||||
logger.info(f" * Mean: {feature_mean:.6f}")
|
||||
logger.info(f" * Std: {feature_std:.6f}")
|
||||
logger.info(f" * Range: [{feature_min:.6f}, {feature_max:.6f}]")
|
||||
|
||||
# Log feature statistics to TensorBoard
|
||||
step = self.training_stats['total_episodes']
|
||||
self.tb_logger.log_scalars('Features/Distribution', {
|
||||
'non_zero_percentage': non_zero_features/len(state_vector)*100,
|
||||
'mean': feature_mean,
|
||||
'std': feature_std,
|
||||
'min': feature_min,
|
||||
'max': feature_max
|
||||
}, step)
|
||||
|
||||
# Log feature histogram to TensorBoard
|
||||
self.tb_logger.log_histogram('Features/Values', state_vector, step)
|
||||
|
||||
# Check if features are properly distributed
|
||||
if non_zero_features > len(state_vector) * 0.1: # At least 10% non-zero
|
||||
logger.info(" * GOOD: Features are well distributed")
|
||||
else:
|
||||
logger.warning(" * WARNING: Too many zero features - data may be incomplete")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error analyzing state features: {e}")
|
||||
|
||||
async def _test_enhanced_reward_calculation(self):
|
||||
"""Test enhanced pivot-based reward calculation"""
|
||||
try:
|
||||
logger.info("[STEP 4] Testing Enhanced Reward Calculation...")
|
||||
|
||||
# Create mock trade data for testing
|
||||
trade_decision = {
|
||||
'action': 'BUY',
|
||||
'confidence': 0.75,
|
||||
'price': 2500.0,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
trade_outcome = {
|
||||
'net_pnl': 50.0,
|
||||
'exit_price': 2550.0,
|
||||
'duration': timedelta(minutes=15)
|
||||
}
|
||||
|
||||
# Get market data for reward calculation
|
||||
market_data = {
|
||||
'volatility': 0.03,
|
||||
'order_flow_direction': 'bullish',
|
||||
'order_flow_strength': 0.8
|
||||
}
|
||||
|
||||
# Test enhanced reward calculation
|
||||
if hasattr(self.enhanced_orchestrator, 'calculate_enhanced_pivot_reward'):
|
||||
enhanced_reward = self.enhanced_orchestrator.calculate_enhanced_pivot_reward(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
logger.info(f"[SUCCESS] Enhanced reward calculated: {enhanced_reward:.3f}")
|
||||
logger.info(" - Enhanced pivot-based reward system: WORKING")
|
||||
self.training_stats['enhanced_reward_calculations'] += 1
|
||||
|
||||
# Log reward metrics to TensorBoard
|
||||
step = self.training_stats['enhanced_reward_calculations']
|
||||
self.tb_logger.log_scalar('Rewards/Enhanced', enhanced_reward, step)
|
||||
|
||||
# Log reward components to TensorBoard
|
||||
self.tb_logger.log_scalars('Rewards/Components', {
|
||||
'pnl_component': trade_outcome['net_pnl'],
|
||||
'confidence': trade_decision['confidence'],
|
||||
'volatility': market_data['volatility'],
|
||||
'order_flow_strength': market_data['order_flow_strength']
|
||||
}, step)
|
||||
|
||||
else:
|
||||
logger.error(" - FAILED: Enhanced reward calculation method not available")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing enhanced reward calculation: {e}")
|
||||
|
||||
async def _validate_williams_integration(self):
|
||||
"""Validate Williams market structure integration"""
|
||||
try:
|
||||
logger.info("[STEP 5] Validating Williams Market Structure Integration...")
|
||||
|
||||
# Test Williams pivot feature extraction
|
||||
try:
|
||||
from training.williams_market_structure import extract_pivot_features, analyze_pivot_context
|
||||
|
||||
# Get test market data
|
||||
df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=100)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Test pivot feature extraction
|
||||
pivot_features = extract_pivot_features(df)
|
||||
|
||||
if pivot_features is not None:
|
||||
logger.info(f"[SUCCESS] Williams pivot features extracted: {len(pivot_features)} features")
|
||||
self.training_stats['pivot_features_extracted'] += 1
|
||||
|
||||
# Test pivot context analysis
|
||||
market_data = {'ohlcv_data': df}
|
||||
pivot_context = analyze_pivot_context(
|
||||
market_data, datetime.now(), 'BUY'
|
||||
)
|
||||
|
||||
if pivot_context is not None:
|
||||
logger.info("[SUCCESS] Williams pivot context analysis: WORKING")
|
||||
logger.info(f" - Near pivot: {pivot_context.get('near_pivot', False)}")
|
||||
logger.info(f" - Pivot strength: {pivot_context.get('pivot_strength', 0):.3f}")
|
||||
else:
|
||||
logger.warning(" - Williams pivot context analysis returned None")
|
||||
else:
|
||||
logger.warning(" - Williams pivot feature extraction returned None")
|
||||
else:
|
||||
logger.warning(" - No market data available for Williams testing")
|
||||
|
||||
except ImportError:
|
||||
logger.error(" - Williams market structure module not available")
|
||||
except Exception as e:
|
||||
logger.error(f" - Error in Williams integration: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating Williams integration: {e}")
|
||||
|
||||
async def _start_live_comprehensive_training(self):
|
||||
"""Start live training with comprehensive feature integration"""
|
||||
try:
|
||||
logger.info("[STEP 6] Starting Live Comprehensive Training...")
|
||||
|
||||
# Run a few training iterations to verify integration
|
||||
for iteration in range(5):
|
||||
logger.info(f"Training iteration {iteration + 1}/5")
|
||||
|
||||
# Make coordinated decisions using enhanced orchestrator
|
||||
decisions = await self.enhanced_orchestrator.make_coordinated_decisions()
|
||||
|
||||
# Track iteration metrics for TensorBoard
|
||||
iteration_metrics = {
|
||||
'decisions_count': len(decisions),
|
||||
'confidence_avg': 0.0,
|
||||
'state_size_avg': 0.0,
|
||||
'successful_states': 0
|
||||
}
|
||||
|
||||
# Process each decision
|
||||
for symbol, decision in decisions.items():
|
||||
if decision:
|
||||
logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
|
||||
|
||||
# Track confidence for TensorBoard
|
||||
iteration_metrics['confidence_avg'] += decision.confidence
|
||||
|
||||
# Build comprehensive state for this decision
|
||||
comprehensive_state = self.enhanced_orchestrator.build_comprehensive_rl_state(symbol)
|
||||
|
||||
if comprehensive_state is not None:
|
||||
state_size = len(comprehensive_state)
|
||||
logger.info(f" - Comprehensive state: {state_size} features")
|
||||
self.training_stats['total_episodes'] += 1
|
||||
|
||||
# Track state size for TensorBoard
|
||||
iteration_metrics['state_size_avg'] += state_size
|
||||
iteration_metrics['successful_states'] += 1
|
||||
|
||||
# Log individual state metrics to TensorBoard
|
||||
self.tb_logger.log_state_metrics(
|
||||
symbol=symbol,
|
||||
state_info={
|
||||
'size': state_size,
|
||||
'quality': 1.0 if state_size == 13400 else 0.8,
|
||||
'feature_counts': {
|
||||
'total': state_size,
|
||||
'non_zero': np.count_nonzero(comprehensive_state)
|
||||
}
|
||||
},
|
||||
step=self.training_stats['total_episodes']
|
||||
)
|
||||
else:
|
||||
logger.warning(f" - Failed to build comprehensive state for {symbol}")
|
||||
|
||||
# Calculate averages for TensorBoard
|
||||
if decisions:
|
||||
iteration_metrics['confidence_avg'] /= len(decisions)
|
||||
|
||||
if iteration_metrics['successful_states'] > 0:
|
||||
iteration_metrics['state_size_avg'] /= iteration_metrics['successful_states']
|
||||
|
||||
# Log iteration metrics to TensorBoard
|
||||
self.tb_logger.log_scalars('Training/Iteration', {
|
||||
'iteration': iteration + 1,
|
||||
'decisions_count': iteration_metrics['decisions_count'],
|
||||
'confidence_avg': iteration_metrics['confidence_avg'],
|
||||
'state_size_avg': iteration_metrics['state_size_avg'],
|
||||
'successful_states': iteration_metrics['successful_states']
|
||||
}, iteration + 1)
|
||||
|
||||
# Wait between iterations
|
||||
await asyncio.sleep(2)
|
||||
|
||||
logger.info("[SUCCESS] Live comprehensive training demonstration complete")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in live comprehensive training: {e}")
|
||||
|
||||
def _log_integration_stats(self):
|
||||
"""Log comprehensive integration statistics"""
|
||||
logger.info("INTEGRATION STATISTICS:")
|
||||
logger.info(f" - Total training episodes: {self.training_stats['total_episodes']}")
|
||||
logger.info(f" - Successful state builds: {self.training_stats['successful_state_builds']}")
|
||||
logger.info(f" - Enhanced reward calculations: {self.training_stats['enhanced_reward_calculations']}")
|
||||
logger.info(f" - Comprehensive features used: {self.training_stats['comprehensive_features_used']}")
|
||||
logger.info(f" - Pivot features extracted: {self.training_stats['pivot_features_extracted']}")
|
||||
|
||||
# Calculate success rates
|
||||
state_success_rate = 0
|
||||
if self.training_stats['total_episodes'] > 0:
|
||||
state_success_rate = self.training_stats['successful_state_builds'] / self.training_stats['total_episodes'] * 100
|
||||
logger.info(f" - State building success rate: {state_success_rate:.1f}%")
|
||||
|
||||
# Log final statistics to TensorBoard
|
||||
self.tb_logger.log_scalars('Integration/Statistics', {
|
||||
'total_episodes': self.training_stats['total_episodes'],
|
||||
'successful_state_builds': self.training_stats['successful_state_builds'],
|
||||
'enhanced_reward_calculations': self.training_stats['enhanced_reward_calculations'],
|
||||
'comprehensive_features_used': self.training_stats['comprehensive_features_used'],
|
||||
'pivot_features_extracted': self.training_stats['pivot_features_extracted'],
|
||||
'state_success_rate': state_success_rate
|
||||
}, 0) # Use step 0 for final summary stats
|
||||
|
||||
# Integration status
|
||||
if self.training_stats['comprehensive_features_used'] > 0:
|
||||
logger.info("STATUS: COMPREHENSIVE RL TRAINING INTEGRATION SUCCESSFUL! ✅")
|
||||
logger.info("The system is now using the full 13,400 feature comprehensive state.")
|
||||
|
||||
# Log success status to TensorBoard
|
||||
self.tb_logger.log_scalar('Integration/Success', 1.0, 0)
|
||||
else:
|
||||
logger.warning("STATUS: Integration partially successful - some fallbacks may occur")
|
||||
|
||||
# Log partial success status to TensorBoard
|
||||
self.tb_logger.log_scalar('Integration/Success', 0.5, 0)
|
||||
|
||||
async def main():
|
||||
"""Main entry point"""
|
||||
try:
|
||||
# Create and run the enhanced RL training integrator
|
||||
integrator = EnhancedRLTrainingIntegrator()
|
||||
await integrator.start_integration()
|
||||
|
||||
logger.info("Enhanced RL training integration completed successfully!")
|
||||
return 0
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Integration interrupted by user")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in integration: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code)
|
||||
@@ -1,148 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example: Using the Checkpoint Management System
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint, get_checkpoint_manager
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ExampleCNN(nn.Module):
|
||||
def __init__(self, input_channels=5, num_classes=3):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
|
||||
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(64, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.relu(self.conv1(x))
|
||||
x = torch.relu(self.conv2(x))
|
||||
x = self.pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.fc(x)
|
||||
|
||||
def example_cnn_training():
|
||||
logger.info("=== CNN Training Example ===")
|
||||
|
||||
model = ExampleCNN()
|
||||
training_integration = get_training_integration()
|
||||
|
||||
for epoch in range(5): # Simulate 5 epochs
|
||||
# Simulate training metrics
|
||||
train_loss = 2.0 - (epoch * 0.15) + np.random.normal(0, 0.1)
|
||||
train_acc = 0.3 + (epoch * 0.06) + np.random.normal(0, 0.02)
|
||||
val_loss = train_loss + np.random.normal(0, 0.05)
|
||||
val_acc = train_acc - 0.05 + np.random.normal(0, 0.02)
|
||||
|
||||
# Clamp values to realistic ranges
|
||||
train_acc = max(0.0, min(1.0, train_acc))
|
||||
val_acc = max(0.0, min(1.0, val_acc))
|
||||
train_loss = max(0.1, train_loss)
|
||||
val_loss = max(0.1, val_loss)
|
||||
|
||||
logger.info(f"Epoch {epoch+1}: train_acc={train_acc:.3f}, val_acc={val_acc:.3f}")
|
||||
|
||||
# Save checkpoint
|
||||
saved = training_integration.save_cnn_checkpoint(
|
||||
cnn_model=model,
|
||||
model_name="example_cnn",
|
||||
epoch=epoch + 1,
|
||||
train_accuracy=train_acc,
|
||||
val_accuracy=val_acc,
|
||||
train_loss=train_loss,
|
||||
val_loss=val_loss,
|
||||
training_time_hours=0.1 * (epoch + 1)
|
||||
)
|
||||
|
||||
if saved:
|
||||
logger.info(f" Checkpoint saved for epoch {epoch+1}")
|
||||
else:
|
||||
logger.info(f" Checkpoint not saved (performance not improved)")
|
||||
|
||||
# Load the best checkpoint
|
||||
logger.info("\\nLoading best checkpoint...")
|
||||
best_result = load_best_checkpoint("example_cnn")
|
||||
if best_result:
|
||||
file_path, metadata = best_result
|
||||
logger.info(f"Best checkpoint: {metadata.checkpoint_id}")
|
||||
logger.info(f"Performance score: {metadata.performance_score:.4f}")
|
||||
|
||||
def example_manual_checkpoint():
|
||||
logger.info("\\n=== Manual Checkpoint Example ===")
|
||||
|
||||
model = nn.Linear(10, 3)
|
||||
|
||||
performance_metrics = {
|
||||
'accuracy': 0.85,
|
||||
'val_accuracy': 0.82,
|
||||
'loss': 0.45,
|
||||
'val_loss': 0.48
|
||||
}
|
||||
|
||||
training_metadata = {
|
||||
'epoch': 25,
|
||||
'training_time_hours': 2.5,
|
||||
'total_parameters': sum(p.numel() for p in model.parameters())
|
||||
}
|
||||
|
||||
logger.info("Saving checkpoint manually...")
|
||||
metadata = save_checkpoint(
|
||||
model=model,
|
||||
model_name="example_manual",
|
||||
model_type="cnn",
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata,
|
||||
force_save=True
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f" Manual checkpoint saved: {metadata.checkpoint_id}")
|
||||
logger.info(f" Performance score: {metadata.performance_score:.4f}")
|
||||
|
||||
def show_checkpoint_stats():
|
||||
logger.info("\\n=== Checkpoint Statistics ===")
|
||||
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
stats = checkpoint_manager.get_checkpoint_stats()
|
||||
|
||||
logger.info(f"Total models: {stats['total_models']}")
|
||||
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
|
||||
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
|
||||
|
||||
for model_name, model_stats in stats['models'].items():
|
||||
logger.info(f"\\n{model_name}:")
|
||||
logger.info(f" Checkpoints: {model_stats['checkpoint_count']}")
|
||||
logger.info(f" Size: {model_stats['total_size_mb']:.2f} MB")
|
||||
logger.info(f" Best performance: {model_stats['best_performance']:.4f}")
|
||||
|
||||
def main():
|
||||
logger.info(" Checkpoint Management System Examples")
|
||||
logger.info("=" * 50)
|
||||
|
||||
try:
|
||||
example_cnn_training()
|
||||
example_manual_checkpoint()
|
||||
show_checkpoint_stats()
|
||||
|
||||
logger.info("\\n All examples completed successfully!")
|
||||
logger.info("\\nTo use in your training:")
|
||||
logger.info("1. Import: from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint")
|
||||
logger.info("2. Or use: from utils.training_integration import get_training_integration")
|
||||
logger.info("3. Save checkpoints during training with performance metrics")
|
||||
logger.info("4. Load best checkpoints for inference or continued training")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in examples: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,517 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Checkpoint Management Integration
|
||||
|
||||
This script demonstrates how to integrate the checkpoint management system
|
||||
across all training pipelines in the gogo2 project.
|
||||
|
||||
Features:
|
||||
- DQN Agent training with automatic checkpointing
|
||||
- CNN Model training with checkpoint management
|
||||
- ExtremaTrainer with checkpoint persistence
|
||||
- NegativeCaseTrainer with checkpoint integration
|
||||
- Unified training orchestration with checkpoint coordination
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('logs/checkpoint_integration.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, get_checkpoint_stats
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
# Import training components
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
from core.extrema_trainer import ExtremaTrainer
|
||||
from core.negative_case_trainer import NegativeCaseTrainer
|
||||
from core.data_provider import DataProvider
|
||||
from core.config import get_config
|
||||
|
||||
class CheckpointIntegratedTrainingSystem:
|
||||
"""Unified training system with comprehensive checkpoint management"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the checkpoint-integrated training system"""
|
||||
self.config = get_config()
|
||||
self.running = False
|
||||
|
||||
# Checkpoint management
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
self.training_integration = get_training_integration()
|
||||
|
||||
# Data provider
|
||||
self.data_provider = DataProvider(
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
timeframes=['1s', '1m', '1h', '1d']
|
||||
)
|
||||
|
||||
# Training components with checkpoint management
|
||||
self.dqn_agent = None
|
||||
self.cnn_trainer = None
|
||||
self.extrema_trainer = None
|
||||
self.negative_case_trainer = None
|
||||
|
||||
# Training statistics
|
||||
self.training_stats = {
|
||||
'start_time': None,
|
||||
'total_training_sessions': 0,
|
||||
'checkpoints_saved': 0,
|
||||
'models_loaded': 0,
|
||||
'best_performances': {}
|
||||
}
|
||||
|
||||
logger.info("Checkpoint-Integrated Training System initialized")
|
||||
|
||||
async def initialize_components(self):
|
||||
"""Initialize all training components with checkpoint management"""
|
||||
try:
|
||||
logger.info("Initializing training components with checkpoint management...")
|
||||
|
||||
# Initialize data provider
|
||||
await self.data_provider.start_real_time_streaming()
|
||||
logger.info("Data provider streaming started")
|
||||
|
||||
# Initialize DQN Agent with checkpoint management
|
||||
logger.info("Initializing DQN Agent with checkpoints...")
|
||||
self.dqn_agent = DQNAgent(
|
||||
state_shape=(100,), # Example state shape
|
||||
n_actions=3,
|
||||
model_name="integrated_dqn_agent",
|
||||
enable_checkpoints=True
|
||||
)
|
||||
logger.info("✅ DQN Agent initialized with checkpoint management")
|
||||
|
||||
# Initialize StandardizedCNN Model with checkpoint management
|
||||
logger.info("Initializing StandardizedCNN Model with checkpoints...")
|
||||
self.cnn_model = StandardizedCNN(model_name="integrated_cnn_model")
|
||||
logger.info("✅ StandardizedCNN Model initialized with checkpoint management")
|
||||
|
||||
# Initialize ExtremaTrainer with checkpoint management
|
||||
logger.info("Initializing ExtremaTrainer with checkpoints...")
|
||||
self.extrema_trainer = ExtremaTrainer(
|
||||
data_provider=self.data_provider,
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
model_name="integrated_extrema_trainer",
|
||||
enable_checkpoints=True
|
||||
)
|
||||
await self.extrema_trainer.initialize_context_data()
|
||||
logger.info("✅ ExtremaTrainer initialized with checkpoint management")
|
||||
|
||||
# Initialize NegativeCaseTrainer with checkpoint management
|
||||
logger.info("Initializing NegativeCaseTrainer with checkpoints...")
|
||||
self.negative_case_trainer = NegativeCaseTrainer(
|
||||
model_name="integrated_negative_case_trainer",
|
||||
enable_checkpoints=True
|
||||
)
|
||||
logger.info("✅ NegativeCaseTrainer initialized with checkpoint management")
|
||||
|
||||
# Load existing checkpoints for all components
|
||||
self.training_stats['models_loaded'] = await self._load_all_checkpoints()
|
||||
|
||||
logger.info("All training components initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing components: {e}")
|
||||
raise
|
||||
|
||||
async def _load_all_checkpoints(self) -> int:
|
||||
"""Load checkpoints for all training components"""
|
||||
loaded_count = 0
|
||||
|
||||
try:
|
||||
# DQN Agent checkpoint loading is handled in __init__
|
||||
if hasattr(self.dqn_agent, 'episode_count') and self.dqn_agent.episode_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"DQN Agent resumed from episode {self.dqn_agent.episode_count}")
|
||||
|
||||
# CNN Trainer checkpoint loading is handled in __init__
|
||||
if hasattr(self.cnn_trainer, 'epoch_count') and self.cnn_trainer.epoch_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"CNN Trainer resumed from epoch {self.cnn_trainer.epoch_count}")
|
||||
|
||||
# ExtremaTrainer checkpoint loading is handled in __init__
|
||||
if hasattr(self.extrema_trainer, 'training_session_count') and self.extrema_trainer.training_session_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"ExtremaTrainer resumed from session {self.extrema_trainer.training_session_count}")
|
||||
|
||||
# NegativeCaseTrainer checkpoint loading is handled in __init__
|
||||
if hasattr(self.negative_case_trainer, 'training_session_count') and self.negative_case_trainer.training_session_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"NegativeCaseTrainer resumed from session {self.negative_case_trainer.training_session_count}")
|
||||
|
||||
return loaded_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoints: {e}")
|
||||
return 0
|
||||
|
||||
async def run_integrated_training_loop(self):
|
||||
"""Run the integrated training loop with checkpoint coordination"""
|
||||
logger.info("Starting integrated training loop with checkpoint management...")
|
||||
|
||||
self.running = True
|
||||
self.training_stats['start_time'] = datetime.now()
|
||||
|
||||
training_cycle = 0
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
training_cycle += 1
|
||||
cycle_start = time.time()
|
||||
|
||||
logger.info(f"=== Training Cycle {training_cycle} ===")
|
||||
|
||||
# DQN Training
|
||||
dqn_results = await self._train_dqn_agent()
|
||||
|
||||
# CNN Training
|
||||
cnn_results = await self._train_cnn_model()
|
||||
|
||||
# Extrema Detection Training
|
||||
extrema_results = await self._train_extrema_detector()
|
||||
|
||||
# Negative Case Training (runs in background)
|
||||
negative_results = await self._process_negative_cases()
|
||||
|
||||
# Coordinate checkpoint saving
|
||||
await self._coordinate_checkpoint_saving(
|
||||
dqn_results, cnn_results, extrema_results, negative_results
|
||||
)
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['total_training_sessions'] += 1
|
||||
|
||||
# Log cycle summary
|
||||
cycle_duration = time.time() - cycle_start
|
||||
logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
|
||||
|
||||
# Wait before next cycle
|
||||
await asyncio.sleep(60) # 1-minute cycles
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training loop: {e}")
|
||||
finally:
|
||||
await self.shutdown()
|
||||
|
||||
async def _train_dqn_agent(self) -> Dict[str, Any]:
|
||||
"""Train DQN agent with automatic checkpointing"""
|
||||
try:
|
||||
if not self.dqn_agent:
|
||||
return {'status': 'skipped', 'reason': 'no_agent'}
|
||||
|
||||
# Simulate DQN training episode
|
||||
episode_reward = 0.0
|
||||
|
||||
# Add some training experiences (simulate real training)
|
||||
for _ in range(10): # Simulate 10 training steps
|
||||
state = np.random.randn(100).astype(np.float32)
|
||||
action = np.random.randint(0, 3)
|
||||
reward = np.random.randn() * 0.1
|
||||
next_state = np.random.randn(100).astype(np.float32)
|
||||
done = np.random.random() < 0.1
|
||||
|
||||
self.dqn_agent.remember(state, action, reward, next_state, done)
|
||||
episode_reward += reward
|
||||
|
||||
# Train if enough experiences
|
||||
loss = 0.0
|
||||
if len(self.dqn_agent.memory) >= self.dqn_agent.batch_size:
|
||||
loss = self.dqn_agent.replay()
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.dqn_agent.save_checkpoint(episode_reward)
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'episode_reward': episode_reward,
|
||||
'loss': loss,
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'episode': self.dqn_agent.episode_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training DQN agent: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _train_cnn_model(self) -> Dict[str, Any]:
|
||||
"""Train CNN model with automatic checkpointing"""
|
||||
try:
|
||||
if not self.cnn_trainer:
|
||||
return {'status': 'skipped', 'reason': 'no_trainer'}
|
||||
|
||||
# Simulate CNN training step
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
batch_size = 32
|
||||
input_size = 60
|
||||
feature_dim = 50
|
||||
|
||||
# Generate synthetic training data
|
||||
x = torch.randn(batch_size, input_size, feature_dim)
|
||||
y = torch.randint(0, 3, (batch_size,))
|
||||
|
||||
# Training step
|
||||
results = self.cnn_trainer.train_step(x, y)
|
||||
|
||||
# Simulate validation
|
||||
val_x = torch.randn(16, input_size, feature_dim)
|
||||
val_y = torch.randint(0, 3, (16,))
|
||||
val_results = self.cnn_trainer.train_step(val_x, val_y)
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.cnn_trainer.save_checkpoint(
|
||||
train_accuracy=results.get('accuracy', 0.5),
|
||||
val_accuracy=val_results.get('accuracy', 0.5),
|
||||
train_loss=results.get('total_loss', 1.0),
|
||||
val_loss=val_results.get('total_loss', 1.0)
|
||||
)
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'train_accuracy': results.get('accuracy', 0.5),
|
||||
'val_accuracy': val_results.get('accuracy', 0.5),
|
||||
'train_loss': results.get('total_loss', 1.0),
|
||||
'val_loss': val_results.get('total_loss', 1.0),
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'epoch': self.cnn_trainer.epoch_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN model: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _train_extrema_detector(self) -> Dict[str, Any]:
|
||||
"""Train extrema detector with automatic checkpointing"""
|
||||
try:
|
||||
if not self.extrema_trainer:
|
||||
return {'status': 'skipped', 'reason': 'no_trainer'}
|
||||
|
||||
# Update context data and detect extrema
|
||||
update_results = self.extrema_trainer.update_context_data()
|
||||
|
||||
# Get training data
|
||||
extrema_data = self.extrema_trainer.get_extrema_training_data(count=10)
|
||||
|
||||
# Simulate training accuracy improvement
|
||||
if extrema_data:
|
||||
self.extrema_trainer.training_stats['total_extrema_detected'] += len(extrema_data)
|
||||
self.extrema_trainer.training_stats['successful_predictions'] += len(extrema_data) // 2
|
||||
self.extrema_trainer.training_stats['failed_predictions'] += len(extrema_data) // 2
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.extrema_trainer.save_checkpoint()
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'extrema_detected': len(extrema_data),
|
||||
'context_updates': sum(1 for success in update_results.values() if success),
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'session': self.extrema_trainer.training_session_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training extrema detector: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _process_negative_cases(self) -> Dict[str, Any]:
|
||||
"""Process negative cases with automatic checkpointing"""
|
||||
try:
|
||||
if not self.negative_case_trainer:
|
||||
return {'status': 'skipped', 'reason': 'no_trainer'}
|
||||
|
||||
# Simulate adding a negative case
|
||||
if np.random.random() < 0.1: # 10% chance of negative case
|
||||
trade_info = {
|
||||
'symbol': 'ETH/USDT',
|
||||
'action': 'BUY',
|
||||
'price': 2000.0,
|
||||
'pnl': -50.0, # Loss
|
||||
'value': 1000.0,
|
||||
'confidence': 0.7,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
market_data = {
|
||||
'exit_price': 1950.0,
|
||||
'state_before': {},
|
||||
'state_after': {},
|
||||
'tick_data': [],
|
||||
'technical_indicators': {}
|
||||
}
|
||||
|
||||
case_id = self.negative_case_trainer.add_losing_trade(trade_info, market_data)
|
||||
|
||||
# Simulate loss improvement
|
||||
loss_improvement = np.random.random() * 0.1
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.negative_case_trainer.save_checkpoint(loss_improvement)
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'case_added': case_id,
|
||||
'loss_improvement': loss_improvement,
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'session': self.negative_case_trainer.training_session_count
|
||||
}
|
||||
else:
|
||||
return {'status': 'no_cases'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing negative cases: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _coordinate_checkpoint_saving(self, dqn_results: Dict, cnn_results: Dict,
|
||||
extrema_results: Dict, negative_results: Dict):
|
||||
"""Coordinate checkpoint saving across all components"""
|
||||
try:
|
||||
# Count successful checkpoints
|
||||
checkpoints_saved = sum([
|
||||
dqn_results.get('checkpoint_saved', False),
|
||||
cnn_results.get('checkpoint_saved', False),
|
||||
extrema_results.get('checkpoint_saved', False),
|
||||
negative_results.get('checkpoint_saved', False)
|
||||
])
|
||||
|
||||
if checkpoints_saved > 0:
|
||||
logger.info(f"Saved {checkpoints_saved} checkpoints this cycle")
|
||||
|
||||
# Update best performances
|
||||
if 'episode_reward' in dqn_results:
|
||||
current_best = self.training_stats['best_performances'].get('dqn_reward', float('-inf'))
|
||||
if dqn_results['episode_reward'] > current_best:
|
||||
self.training_stats['best_performances']['dqn_reward'] = dqn_results['episode_reward']
|
||||
|
||||
if 'val_accuracy' in cnn_results:
|
||||
current_best = self.training_stats['best_performances'].get('cnn_accuracy', 0.0)
|
||||
if cnn_results['val_accuracy'] > current_best:
|
||||
self.training_stats['best_performances']['cnn_accuracy'] = cnn_results['val_accuracy']
|
||||
|
||||
# Log checkpoint statistics every 10 cycles
|
||||
if self.training_stats['total_training_sessions'] % 10 == 0:
|
||||
await self._log_checkpoint_statistics()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error coordinating checkpoint saving: {e}")
|
||||
|
||||
async def _log_checkpoint_statistics(self):
|
||||
"""Log comprehensive checkpoint statistics"""
|
||||
try:
|
||||
stats = get_checkpoint_stats()
|
||||
|
||||
logger.info("=== Checkpoint Statistics ===")
|
||||
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
|
||||
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
|
||||
logger.info(f"Models managed: {len(stats['models'])}")
|
||||
|
||||
for model_name, model_stats in stats['models'].items():
|
||||
logger.info(f" {model_name}: {model_stats['checkpoint_count']} checkpoints, "
|
||||
f"{model_stats['total_size_mb']:.2f} MB, "
|
||||
f"best: {model_stats['best_performance']:.4f}")
|
||||
|
||||
logger.info(f"Training sessions: {self.training_stats['total_training_sessions']}")
|
||||
logger.info(f"Checkpoints saved: {self.training_stats['checkpoints_saved']}")
|
||||
logger.info(f"Best performances: {self.training_stats['best_performances']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging checkpoint statistics: {e}")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the training system and save final checkpoints"""
|
||||
logger.info("Shutting down checkpoint-integrated training system...")
|
||||
|
||||
self.running = False
|
||||
|
||||
try:
|
||||
# Force save checkpoints for all components
|
||||
if self.dqn_agent:
|
||||
self.dqn_agent.save_checkpoint(0.0, force_save=True)
|
||||
|
||||
if self.cnn_trainer:
|
||||
self.cnn_trainer.save_checkpoint(0.0, 0.0, 0.0, 0.0, force_save=True)
|
||||
|
||||
if self.extrema_trainer:
|
||||
self.extrema_trainer.save_checkpoint(force_save=True)
|
||||
|
||||
if self.negative_case_trainer:
|
||||
self.negative_case_trainer.save_checkpoint(force_save=True)
|
||||
|
||||
# Final statistics
|
||||
await self._log_checkpoint_statistics()
|
||||
|
||||
logger.info("Checkpoint-integrated training system shutdown complete")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
|
||||
async def main():
|
||||
"""Main function to run the checkpoint-integrated training system"""
|
||||
logger.info("🚀 Starting Checkpoint-Integrated Training System")
|
||||
|
||||
# Create and initialize the training system
|
||||
training_system = CheckpointIntegratedTrainingSystem()
|
||||
|
||||
# Setup signal handlers for graceful shutdown
|
||||
def signal_handler(signum, frame):
|
||||
logger.info("Received shutdown signal")
|
||||
asyncio.create_task(training_system.shutdown())
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
# Initialize components
|
||||
await training_system.initialize_components()
|
||||
|
||||
# Run the integrated training loop
|
||||
await training_system.run_integrated_training_loop()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main: {e}")
|
||||
raise
|
||||
finally:
|
||||
await training_system.shutdown()
|
||||
|
||||
logger.info("✅ Checkpoint management integration complete!")
|
||||
logger.info("All training pipelines now support automatic checkpointing")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Ensure logs directory exists
|
||||
Path("logs").mkdir(exist_ok=True)
|
||||
|
||||
# Run the checkpoint-integrated training system
|
||||
asyncio.run(main())
|
||||
@@ -1,13 +0,0 @@
|
||||
"""
|
||||
Neural Network Utilities
|
||||
======================
|
||||
|
||||
This package contains utility functions and classes used in the neural network trading system:
|
||||
- Data Interface: Connects to realtime trading data and processes it for the neural network models
|
||||
"""
|
||||
|
||||
from .data_interface import DataInterface
|
||||
from .trading_env import TradingEnvironment
|
||||
from .signal_interpreter import SignalInterpreter
|
||||
|
||||
__all__ = ['DataInterface', 'TradingEnvironment', 'SignalInterpreter']
|
||||
@@ -1,123 +0,0 @@
|
||||
"""
|
||||
Enhanced Data Interface with additional NN trading parameters
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from .data_interface import DataInterface
|
||||
|
||||
class MultiDataInterface(DataInterface):
|
||||
"""
|
||||
Enhanced data interface that supports window_size and output_size parameters
|
||||
for neural network trading models.
|
||||
"""
|
||||
|
||||
def __init__(self, symbol: str,
|
||||
timeframes: List[str],
|
||||
window_size: int = 20,
|
||||
output_size: int = 3,
|
||||
data_dir: str = "NN/data"):
|
||||
"""
|
||||
Initialize with window_size and output_size for NN predictions.
|
||||
"""
|
||||
super().__init__(symbol, timeframes, data_dir)
|
||||
self.window_size = window_size
|
||||
self.output_size = output_size
|
||||
self.scalers = {} # Store scalers for each timeframe
|
||||
self.min_window_threshold = 100 # Minimum candles needed for training
|
||||
|
||||
def get_feature_count(self) -> int:
|
||||
"""
|
||||
Get number of features (OHLCV) for NN input.
|
||||
"""
|
||||
return 5 # open, high, low, close, volume
|
||||
|
||||
def prepare_training_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Prepare training data with windowed sequences"""
|
||||
# Get historical data for primary timeframe
|
||||
primary_tf = self.timeframes[0]
|
||||
df = self.get_historical_data(timeframe=primary_tf,
|
||||
n_candles=self.min_window_threshold + 1000)
|
||||
|
||||
if df is None or len(df) < self.min_window_threshold:
|
||||
raise ValueError(f"Insufficient data for training. Need at least {self.min_window_threshold} candles")
|
||||
|
||||
# Prepare OHLCV sequences
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Create sequences and labels
|
||||
X = []
|
||||
y = []
|
||||
|
||||
for i in range(len(ohlcv) - self.window_size - self.output_size):
|
||||
# Input sequence
|
||||
seq = ohlcv[i:i+self.window_size]
|
||||
X.append(seq)
|
||||
|
||||
# Output target (price movement direction)
|
||||
close_prices = ohlcv[i+self.window_size:i+self.window_size+self.output_size, 3] # Close prices
|
||||
price_changes = np.diff(close_prices)
|
||||
|
||||
if self.output_size == 1:
|
||||
# Binary classification (up/down)
|
||||
label = 1 if price_changes[0] > 0 else 0
|
||||
elif self.output_size == 3:
|
||||
# 3-class classification (buy/hold/sell)
|
||||
if price_changes[0] > 0.002: # Significant rise
|
||||
label = 0 # Buy
|
||||
elif price_changes[0] < -0.002: # Significant drop
|
||||
label = 2 # Sell
|
||||
else:
|
||||
label = 1 # Hold
|
||||
else:
|
||||
raise ValueError(f"Unsupported output_size: {self.output_size}")
|
||||
|
||||
y.append(label)
|
||||
|
||||
# Convert to numpy arrays
|
||||
X = np.array(X)
|
||||
y = np.array(y)
|
||||
|
||||
# Split into train/validation (80/20)
|
||||
split_idx = int(0.8 * len(X))
|
||||
X_train, y_train = X[:split_idx], y[:split_idx]
|
||||
X_val, y_val = X[split_idx:], y[split_idx:]
|
||||
|
||||
return X_train, y_train, X_val, y_val
|
||||
|
||||
def prepare_prediction_data(self) -> np.ndarray:
|
||||
"""Prepare most recent window for predictions"""
|
||||
primary_tf = self.timeframes[0]
|
||||
df = self.get_historical_data(timeframe=primary_tf,
|
||||
n_candles=self.window_size,
|
||||
use_cache=False)
|
||||
|
||||
if df is None or len(df) < self.window_size:
|
||||
raise ValueError(f"Need at least {self.window_size} candles for prediction")
|
||||
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values[-self.window_size:]
|
||||
return np.array([ohlcv]) # Add batch dimension
|
||||
|
||||
def process_predictions(self, predictions: np.ndarray):
|
||||
"""Convert prediction probabilities to trading signals"""
|
||||
signals = []
|
||||
for pred in predictions:
|
||||
if self.output_size == 1:
|
||||
signal = "BUY" if pred[0] > 0.5 else "SELL"
|
||||
confidence = np.abs(pred[0] - 0.5) * 2 # Convert to 0-1 scale
|
||||
elif self.output_size == 3:
|
||||
action_idx = np.argmax(pred)
|
||||
signal = ["BUY", "HOLD", "SELL"][action_idx]
|
||||
confidence = pred[action_idx]
|
||||
else:
|
||||
signal = "HOLD"
|
||||
confidence = 0.0
|
||||
|
||||
signals.append({
|
||||
'action': signal,
|
||||
'confidence': confidence,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return signals
|
||||
@@ -1,364 +0,0 @@
|
||||
"""
|
||||
Realtime Analyzer for Neural Network Trading System
|
||||
|
||||
This module implements real-time analysis of market data using trained neural network models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import os
|
||||
import pandas as pd
|
||||
from collections import deque
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RealtimeAnalyzer:
|
||||
"""
|
||||
Handles real-time analysis of market data using trained neural network models.
|
||||
|
||||
Features:
|
||||
- Connects to real-time data sources (websockets)
|
||||
- Processes tick data into multiple timeframes (1s, 1m, 1h, 1d)
|
||||
- Uses trained models to analyze all timeframes
|
||||
- Generates trading signals
|
||||
- Manages risk and position sizing
|
||||
- Logs all trading decisions
|
||||
"""
|
||||
|
||||
def __init__(self, data_interface, model, symbol="BTC/USDT", timeframes=None):
|
||||
"""
|
||||
Initialize the realtime analyzer.
|
||||
|
||||
Args:
|
||||
data_interface (DataInterface): Preconfigured data interface
|
||||
model: Trained neural network model
|
||||
symbol (str): Trading pair symbol
|
||||
timeframes (list): List of timeframes to monitor (default: ['1s', '1m', '1h', '1d'])
|
||||
"""
|
||||
self.data_interface = data_interface
|
||||
self.model = model
|
||||
self.symbol = symbol
|
||||
self.timeframes = timeframes or ['1s', '1m', '1h', '1d']
|
||||
self.running = False
|
||||
self.data_queue = Queue()
|
||||
self.prediction_interval = 10 # Seconds between predictions
|
||||
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
|
||||
self.ws = None
|
||||
self.tick_storage = deque(maxlen=10000) # Store up to 10,000 ticks
|
||||
self.candle_cache = {
|
||||
'1s': deque(maxlen=5000),
|
||||
'1m': deque(maxlen=5000),
|
||||
'1h': deque(maxlen=5000),
|
||||
'1d': deque(maxlen=5000)
|
||||
}
|
||||
|
||||
logger.info(f"RealtimeAnalyzer initialized for {symbol} with timeframes: {self.timeframes}")
|
||||
|
||||
def start(self):
|
||||
"""Start the realtime analysis process."""
|
||||
if self.running:
|
||||
logger.warning("Realtime analyzer already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
# Start WebSocket connection thread
|
||||
self.ws_thread = Thread(target=self._run_websocket, daemon=True)
|
||||
self.ws_thread.start()
|
||||
|
||||
# Start data processing thread
|
||||
self.processing_thread = Thread(target=self._process_data, daemon=True)
|
||||
self.processing_thread.start()
|
||||
|
||||
# Start analysis thread
|
||||
self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
|
||||
self.analysis_thread.start()
|
||||
|
||||
logger.info("Realtime analysis started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the realtime analysis process."""
|
||||
self.running = False
|
||||
if self.ws:
|
||||
asyncio.run(self.ws.close())
|
||||
if hasattr(self, 'ws_thread'):
|
||||
self.ws_thread.join(timeout=1)
|
||||
if hasattr(self, 'processing_thread'):
|
||||
self.processing_thread.join(timeout=1)
|
||||
if hasattr(self, 'analysis_thread'):
|
||||
self.analysis_thread.join(timeout=1)
|
||||
logger.info("Realtime analysis stopped")
|
||||
|
||||
def _run_websocket(self):
|
||||
"""Thread function for running WebSocket connection."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._connect_websocket())
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to WebSocket and receive data."""
|
||||
while self.running:
|
||||
try:
|
||||
logger.info(f"Connecting to WebSocket: {self.ws_url}")
|
||||
async with websockets.connect(self.ws_url) as ws:
|
||||
self.ws = ws
|
||||
logger.info("WebSocket connected")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
message = await ws.recv()
|
||||
data = json.loads(message)
|
||||
|
||||
if 'e' in data and data['e'] == 'trade':
|
||||
tick = {
|
||||
'timestamp': data['T'],
|
||||
'price': float(data['p']),
|
||||
'volume': float(data['q']),
|
||||
'symbol': self.symbol
|
||||
}
|
||||
self.tick_storage.append(tick)
|
||||
self.data_queue.put(tick)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning("WebSocket connection closed")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving WebSocket message: {str(e)}")
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket connection error: {str(e)}")
|
||||
time.sleep(5) # Wait before reconnecting
|
||||
|
||||
def _process_data(self):
|
||||
"""Process incoming tick data into candles for all timeframes."""
|
||||
logger.info("Starting data processing thread")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Process any new ticks
|
||||
while not self.data_queue.empty():
|
||||
tick = self.data_queue.get()
|
||||
|
||||
# Convert timestamp to datetime
|
||||
timestamp = datetime.fromtimestamp(tick['timestamp'] / 1000)
|
||||
|
||||
# Process for each timeframe
|
||||
for timeframe in self.timeframes:
|
||||
interval = self._get_interval_seconds(timeframe)
|
||||
if interval is None:
|
||||
continue
|
||||
|
||||
# Round timestamp to nearest candle interval
|
||||
candle_ts = int(tick['timestamp'] // (interval * 1000)) * (interval * 1000)
|
||||
|
||||
# Get or create candle for this timeframe
|
||||
if not self.candle_cache[timeframe]:
|
||||
# First candle for this timeframe
|
||||
candle = {
|
||||
'timestamp': candle_ts,
|
||||
'open': tick['price'],
|
||||
'high': tick['price'],
|
||||
'low': tick['price'],
|
||||
'close': tick['price'],
|
||||
'volume': tick['volume']
|
||||
}
|
||||
self.candle_cache[timeframe].append(candle)
|
||||
else:
|
||||
# Update existing candle
|
||||
last_candle = self.candle_cache[timeframe][-1]
|
||||
|
||||
if last_candle['timestamp'] == candle_ts:
|
||||
# Update current candle
|
||||
last_candle['high'] = max(last_candle['high'], tick['price'])
|
||||
last_candle['low'] = min(last_candle['low'], tick['price'])
|
||||
last_candle['close'] = tick['price']
|
||||
last_candle['volume'] += tick['volume']
|
||||
else:
|
||||
# New candle
|
||||
candle = {
|
||||
'timestamp': candle_ts,
|
||||
'open': tick['price'],
|
||||
'high': tick['price'],
|
||||
'low': tick['price'],
|
||||
'close': tick['price'],
|
||||
'volume': tick['volume']
|
||||
}
|
||||
self.candle_cache[timeframe].append(candle)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in data processing: {str(e)}")
|
||||
time.sleep(1)
|
||||
|
||||
def _get_interval_seconds(self, timeframe):
|
||||
"""Convert timeframe string to seconds."""
|
||||
intervals = {
|
||||
'1s': 1,
|
||||
'1m': 60,
|
||||
'1h': 3600,
|
||||
'1d': 86400
|
||||
}
|
||||
return intervals.get(timeframe)
|
||||
|
||||
def _analyze_data(self):
|
||||
"""Thread function for analyzing data and generating signals."""
|
||||
logger.info("Starting analysis thread")
|
||||
|
||||
last_prediction_time = 0
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Only make predictions at the specified interval
|
||||
if current_time - last_prediction_time < self.prediction_interval:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Prepare input data from all timeframes
|
||||
input_data = {}
|
||||
valid = True
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
if not self.candle_cache[timeframe]:
|
||||
logger.warning(f"No data available for timeframe {timeframe}")
|
||||
valid = False
|
||||
break
|
||||
|
||||
# Get last N candles for this timeframe
|
||||
candles = list(self.candle_cache[timeframe])[-self.data_interface.window_size:]
|
||||
|
||||
# Convert to numpy array
|
||||
ohlcv = np.array([
|
||||
[c['open'], c['high'], c['low'], c['close'], c['volume']]
|
||||
for c in candles
|
||||
])
|
||||
|
||||
# Normalize data
|
||||
ohlcv_normalized = (ohlcv - ohlcv.mean(axis=0)) / (ohlcv.std(axis=0) + 1e-8)
|
||||
input_data[timeframe] = ohlcv_normalized
|
||||
|
||||
if not valid:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Make prediction using the model
|
||||
try:
|
||||
prediction = self.model.predict(input_data)
|
||||
|
||||
# Get latest timestamp from 1s timeframe
|
||||
latest_ts = self.candle_cache['1s'][-1]['timestamp'] if self.candle_cache['1s'] else int(time.time() * 1000)
|
||||
|
||||
# Process prediction
|
||||
self._process_prediction(
|
||||
prediction=prediction,
|
||||
timeframe='multi',
|
||||
timestamp=latest_ts
|
||||
)
|
||||
|
||||
last_prediction_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Error making prediction: {str(e)}")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in analysis: {str(e)}")
|
||||
time.sleep(1)
|
||||
|
||||
def _process_prediction(self, prediction, timeframe, timestamp):
|
||||
"""
|
||||
Process model prediction and generate trading signals.
|
||||
|
||||
Args:
|
||||
prediction: Model prediction output
|
||||
timeframe (str): Timeframe the prediction is for ('multi' for combined)
|
||||
timestamp: Timestamp of the prediction (ms)
|
||||
"""
|
||||
# Convert prediction to trading signal
|
||||
signal, confidence = self._prediction_to_signal(prediction)
|
||||
|
||||
# Convert timestamp to datetime
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp / 1000)
|
||||
except:
|
||||
dt = datetime.now()
|
||||
|
||||
# Log the signal with all timeframes
|
||||
logger.info(
|
||||
f"Signal generated - Timeframes: {', '.join(self.timeframes)}, "
|
||||
f"Timestamp: {dt}, "
|
||||
f"Signal: {signal} (Confidence: {confidence:.2f})"
|
||||
)
|
||||
|
||||
# In a real implementation, we would execute trades here
|
||||
# For now, we'll just log the signals
|
||||
|
||||
def _prediction_to_signal(self, prediction):
|
||||
"""
|
||||
Convert model prediction to trading signal and confidence.
|
||||
|
||||
Args:
|
||||
prediction: Model prediction output (can be dict for multi-timeframe)
|
||||
|
||||
Returns:
|
||||
tuple: (signal, confidence) where signal is BUY/SELL/HOLD,
|
||||
confidence is probability (0-1)
|
||||
"""
|
||||
if isinstance(prediction, dict):
|
||||
# Multi-timeframe prediction - combine signals
|
||||
signals = []
|
||||
confidences = []
|
||||
|
||||
for tf, pred in prediction.items():
|
||||
if len(pred.shape) == 1:
|
||||
# Binary classification
|
||||
signal = "BUY" if pred[0] > 0.5 else "SELL"
|
||||
confidence = pred[0] if signal == "BUY" else 1 - pred[0]
|
||||
else:
|
||||
# Multi-class
|
||||
class_idx = np.argmax(pred)
|
||||
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
||||
confidence = pred[class_idx]
|
||||
|
||||
signals.append(signal)
|
||||
confidences.append(confidence)
|
||||
|
||||
# Simple voting system - count BUY/SELL signals
|
||||
buy_count = signals.count("BUY")
|
||||
sell_count = signals.count("SELL")
|
||||
|
||||
if buy_count > sell_count:
|
||||
final_signal = "BUY"
|
||||
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "BUY"])
|
||||
elif sell_count > buy_count:
|
||||
final_signal = "SELL"
|
||||
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "SELL"])
|
||||
else:
|
||||
final_signal = "HOLD"
|
||||
final_confidence = np.mean(confidences)
|
||||
|
||||
return final_signal, final_confidence
|
||||
|
||||
else:
|
||||
# Single prediction
|
||||
if len(prediction.shape) == 1:
|
||||
# Binary classification
|
||||
signal = "BUY" if prediction[0] > 0.5 else "SELL"
|
||||
confidence = prediction[0] if signal == "BUY" else 1 - prediction[0]
|
||||
else:
|
||||
# Multi-class
|
||||
class_idx = np.argmax(prediction)
|
||||
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
||||
confidence = prediction[class_idx]
|
||||
|
||||
return signal, confidence
|
||||
@@ -1,98 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Immediate Model Cleanup Script
|
||||
|
||||
This script will clean up all existing model files and prepare the system
|
||||
for fresh training with the new model management system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from model_manager import ModelManager
|
||||
|
||||
def main():
|
||||
"""Run the model cleanup"""
|
||||
|
||||
# Configure logging for better output
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
print("GOGO2 MODEL CLEANUP SYSTEM")
|
||||
print("=" * 60)
|
||||
print()
|
||||
print("This script will:")
|
||||
print("1. Delete ALL existing model files (.pt, .pth)")
|
||||
print("2. Remove ALL checkpoint directories")
|
||||
print("3. Clear model backup directories")
|
||||
print("4. Reset the model registry")
|
||||
print("5. Create clean directory structure")
|
||||
print()
|
||||
print("WARNING: This action cannot be undone!")
|
||||
print()
|
||||
|
||||
# Calculate current space usage first
|
||||
try:
|
||||
manager = ModelManager()
|
||||
storage_stats = manager.get_storage_stats()
|
||||
print(f"Current storage usage:")
|
||||
print(f"- Models: {storage_stats['total_models']}")
|
||||
print(f"- Size: {storage_stats['actual_size_mb']:.1f}MB")
|
||||
print()
|
||||
except Exception as e:
|
||||
print(f"Error checking current storage: {e}")
|
||||
print()
|
||||
|
||||
# Ask for confirmation
|
||||
print("Type 'CLEANUP' to proceed with the cleanup:")
|
||||
user_input = input("> ").strip()
|
||||
|
||||
if user_input != "CLEANUP":
|
||||
print("Cleanup cancelled. No changes made.")
|
||||
return
|
||||
|
||||
print()
|
||||
print("Starting cleanup...")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
# Create manager and run cleanup
|
||||
manager = ModelManager()
|
||||
cleanup_result = manager.cleanup_all_existing_models(confirm=True)
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("CLEANUP COMPLETE")
|
||||
print("=" * 60)
|
||||
print(f"Files deleted: {cleanup_result['deleted_files']}")
|
||||
print(f"Space freed: {cleanup_result['freed_space_mb']:.1f} MB")
|
||||
print(f"Directories cleaned: {len(cleanup_result['deleted_directories'])}")
|
||||
|
||||
if cleanup_result['errors']:
|
||||
print(f"Errors encountered: {len(cleanup_result['errors'])}")
|
||||
print("Errors:")
|
||||
for error in cleanup_result['errors'][:5]: # Show first 5 errors
|
||||
print(f" - {error}")
|
||||
if len(cleanup_result['errors']) > 5:
|
||||
print(f" ... and {len(cleanup_result['errors']) - 5} more")
|
||||
|
||||
print()
|
||||
print("System is now ready for fresh model training!")
|
||||
print("The following directories have been created:")
|
||||
print("- models/best_models/")
|
||||
print("- models/cnn/")
|
||||
print("- models/rl/")
|
||||
print("- models/checkpoints/")
|
||||
print("- NN/models/saved/")
|
||||
print()
|
||||
print("New models will be automatically managed by the ModelManager.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during cleanup: {e}")
|
||||
logging.exception("Cleanup failed")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
74
_tools/delete_candidates.py
Normal file
74
_tools/delete_candidates.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
EXCLUDE_PREFIXES = (
|
||||
'COBY' + os.sep, # Do not touch COBY subsystem
|
||||
)
|
||||
EXCLUDE_FILES = {
|
||||
'NN' + os.sep + 'training' + os.sep + 'enhanced_realtime_training.py',
|
||||
}
|
||||
|
||||
delete_list_path = ROOT / 'DELETE_CANDIDATES.txt'
|
||||
|
||||
deleted_files: list[str] = []
|
||||
kept_files: list[str] = []
|
||||
|
||||
if delete_list_path.exists():
|
||||
for line in delete_list_path.read_text(encoding='utf-8').splitlines():
|
||||
rel = line.strip()
|
||||
if not rel:
|
||||
continue
|
||||
# Skip excluded prefixes
|
||||
if any(rel.startswith(p) for p in EXCLUDE_PREFIXES):
|
||||
kept_files.append(rel)
|
||||
continue
|
||||
# Skip explicitly excluded files
|
||||
if rel in EXCLUDE_FILES:
|
||||
kept_files.append(rel)
|
||||
continue
|
||||
fp = ROOT / rel
|
||||
if fp.exists() and fp.is_file():
|
||||
try:
|
||||
fp.unlink()
|
||||
deleted_files.append(rel)
|
||||
except Exception:
|
||||
kept_files.append(rel)
|
||||
|
||||
# Remove tests directories outside COBY
|
||||
removed_dirs: list[str] = []
|
||||
for d in ROOT.rglob('tests'):
|
||||
try:
|
||||
rel = str(d.relative_to(ROOT))
|
||||
except Exception:
|
||||
continue
|
||||
if any(rel.startswith(p) for p in EXCLUDE_PREFIXES):
|
||||
continue
|
||||
if d.is_dir():
|
||||
try:
|
||||
shutil.rmtree(d)
|
||||
removed_dirs.append(rel)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Write cleanup log / todo
|
||||
log_lines = []
|
||||
log_lines.append('Cleanup run summary:')
|
||||
log_lines.append(f'- Deleted files: {len(deleted_files)}')
|
||||
for x in deleted_files[:50]:
|
||||
log_lines.append(f' - {x}')
|
||||
if len(deleted_files) > 50:
|
||||
log_lines.append(f' ... and {len(deleted_files)-50} more')
|
||||
log_lines.append(f'- Removed test directories: {len(removed_dirs)}')
|
||||
for x in removed_dirs[:50]:
|
||||
log_lines.append(f' - {x}')
|
||||
log_lines.append(f'- Kept (excluded): {len(kept_files)}')
|
||||
|
||||
(ROOT / 'CLEANUP_TODO.md').write_text('\n'.join(log_lines), encoding='utf-8')
|
||||
|
||||
print(f'Deleted files: {len(deleted_files)}')
|
||||
print(f'Removed test dirs: {len(removed_dirs)}')
|
||||
print(f'Kept (excluded): {len(kept_files)}')
|
||||
|
||||
@@ -1,193 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apply Trading System Fixes
|
||||
|
||||
This script applies fixes to the trading system to address:
|
||||
1. Duplicate entry prices
|
||||
2. P&L calculation issues
|
||||
3. Position tracking problems
|
||||
4. Trade display issues
|
||||
|
||||
Usage:
|
||||
python apply_trading_fixes.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('logs/trading_fixes.log')
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def apply_fixes():
|
||||
"""Apply all fixes to the trading system"""
|
||||
logger.info("=" * 70)
|
||||
logger.info("APPLYING TRADING SYSTEM FIXES")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Import fixes
|
||||
try:
|
||||
from core.trading_executor_fix import TradingExecutorFix
|
||||
from web.dashboard_fix import DashboardFix
|
||||
|
||||
logger.info("Fix modules imported successfully")
|
||||
except ImportError as e:
|
||||
logger.error(f"Error importing fix modules: {e}")
|
||||
return False
|
||||
|
||||
# Apply fixes to trading executor
|
||||
try:
|
||||
# Import trading executor
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Create a test instance to apply fixes
|
||||
test_executor = TradingExecutor()
|
||||
|
||||
# Apply fixes
|
||||
TradingExecutorFix.apply_fixes(test_executor)
|
||||
|
||||
logger.info("Trading executor fixes applied successfully to test instance")
|
||||
|
||||
# Verify fixes
|
||||
if hasattr(test_executor, 'price_cache_timestamp'):
|
||||
logger.info("✅ Price caching fix verified")
|
||||
else:
|
||||
logger.warning("❌ Price caching fix not verified")
|
||||
|
||||
if hasattr(test_executor, 'trade_cooldown_seconds'):
|
||||
logger.info("✅ Trade cooldown fix verified")
|
||||
else:
|
||||
logger.warning("❌ Trade cooldown fix not verified")
|
||||
|
||||
if hasattr(test_executor, '_check_trade_cooldown'):
|
||||
logger.info("✅ Trade cooldown check method verified")
|
||||
else:
|
||||
logger.warning("❌ Trade cooldown check method not verified")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying trading executor fixes: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Create patch for main.py
|
||||
try:
|
||||
main_patch = """
|
||||
# Apply trading system fixes
|
||||
try:
|
||||
from core.trading_executor_fix import TradingExecutorFix
|
||||
from web.dashboard_fix import DashboardFix
|
||||
|
||||
# Apply fixes to trading executor
|
||||
if trading_executor:
|
||||
TradingExecutorFix.apply_fixes(trading_executor)
|
||||
logger.info("✅ Trading executor fixes applied")
|
||||
|
||||
# Apply fixes to dashboard
|
||||
if 'dashboard' in locals() and dashboard:
|
||||
DashboardFix.apply_fixes(dashboard)
|
||||
logger.info("✅ Dashboard fixes applied")
|
||||
|
||||
logger.info("Trading system fixes applied successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error applying trading system fixes: {e}")
|
||||
"""
|
||||
|
||||
# Write patch instructions
|
||||
with open('patch_instructions.txt', 'w') as f:
|
||||
f.write("""
|
||||
TRADING SYSTEM FIX INSTRUCTIONS
|
||||
==============================
|
||||
|
||||
To apply the fixes to your trading system, follow these steps:
|
||||
|
||||
1. Add the following code to main.py just before the dashboard.run_server() call:
|
||||
|
||||
```python
|
||||
# Apply trading system fixes
|
||||
try:
|
||||
from core.trading_executor_fix import TradingExecutorFix
|
||||
from web.dashboard_fix import DashboardFix
|
||||
|
||||
# Apply fixes to trading executor
|
||||
if trading_executor:
|
||||
TradingExecutorFix.apply_fixes(trading_executor)
|
||||
logger.info("✅ Trading executor fixes applied")
|
||||
|
||||
# Apply fixes to dashboard
|
||||
if 'dashboard' in locals() and dashboard:
|
||||
DashboardFix.apply_fixes(dashboard)
|
||||
logger.info("✅ Dashboard fixes applied")
|
||||
|
||||
logger.info("Trading system fixes applied successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error applying trading system fixes: {e}")
|
||||
```
|
||||
|
||||
2. Add the following code to web/clean_dashboard.py in the __init__ method, just before the run_server method:
|
||||
|
||||
```python
|
||||
# Apply dashboard fixes if available
|
||||
try:
|
||||
from web.dashboard_fix import DashboardFix
|
||||
DashboardFix.apply_fixes(self)
|
||||
logger.info("✅ Dashboard fixes applied during initialization")
|
||||
except ImportError:
|
||||
logger.warning("Dashboard fixes not available")
|
||||
```
|
||||
|
||||
3. Run the system with the fixes applied:
|
||||
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
4. Monitor the logs for any issues with the fixes.
|
||||
|
||||
These fixes address:
|
||||
- Duplicate entry prices
|
||||
- P&L calculation issues
|
||||
- Position tracking problems
|
||||
- Trade display issues
|
||||
- Rapid consecutive trades
|
||||
""")
|
||||
|
||||
logger.info("Patch instructions written to patch_instructions.txt")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating patch: {e}")
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("TRADING SYSTEM FIXES READY TO APPLY")
|
||||
logger.info("See patch_instructions.txt for instructions")
|
||||
logger.info("=" * 70)
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
# Apply fixes
|
||||
success = apply_fixes()
|
||||
|
||||
if success:
|
||||
print("\nTrading system fixes ready to apply!")
|
||||
print("See patch_instructions.txt for instructions")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\nError preparing trading system fixes")
|
||||
sys.exit(1)
|
||||
@@ -1,218 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apply Trading System Fixes to Main.py
|
||||
|
||||
This script applies the trading system fixes directly to main.py
|
||||
to address the issues with duplicate entry prices and P&L calculation.
|
||||
|
||||
Usage:
|
||||
python apply_trading_fixes_to_main.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('logs/apply_fixes.log')
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def backup_file(file_path):
|
||||
"""Create a backup of a file"""
|
||||
try:
|
||||
backup_path = f"{file_path}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
shutil.copy2(file_path, backup_path)
|
||||
logger.info(f"Created backup: {backup_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating backup of {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def apply_fixes_to_main():
|
||||
"""Apply fixes to main.py"""
|
||||
main_py_path = "main.py"
|
||||
|
||||
if not os.path.exists(main_py_path):
|
||||
logger.error(f"File {main_py_path} not found")
|
||||
return False
|
||||
|
||||
# Create backup
|
||||
if not backup_file(main_py_path):
|
||||
logger.error("Failed to create backup, aborting")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Read main.py
|
||||
with open(main_py_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Find the position to insert the fixes
|
||||
# Look for the line before dashboard.run_server()
|
||||
run_server_pattern = r"dashboard\.run_server\("
|
||||
match = re.search(run_server_pattern, content)
|
||||
|
||||
if not match:
|
||||
logger.error("Could not find dashboard.run_server() call in main.py")
|
||||
return False
|
||||
|
||||
# Find the position to insert the fixes (before the run_server call)
|
||||
insert_pos = content.rfind("\n", 0, match.start())
|
||||
|
||||
if insert_pos == -1:
|
||||
logger.error("Could not find insertion point in main.py")
|
||||
return False
|
||||
|
||||
# Prepare the fixes to insert
|
||||
fixes_code = """
|
||||
# Apply trading system fixes
|
||||
try:
|
||||
from core.trading_executor_fix import TradingExecutorFix
|
||||
from web.dashboard_fix import DashboardFix
|
||||
|
||||
# Apply fixes to trading executor
|
||||
if trading_executor:
|
||||
TradingExecutorFix.apply_fixes(trading_executor)
|
||||
logger.info("✅ Trading executor fixes applied")
|
||||
|
||||
# Apply fixes to dashboard
|
||||
if 'dashboard' in locals() and dashboard:
|
||||
DashboardFix.apply_fixes(dashboard)
|
||||
logger.info("✅ Dashboard fixes applied")
|
||||
|
||||
logger.info("Trading system fixes applied successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error applying trading system fixes: {e}")
|
||||
|
||||
"""
|
||||
|
||||
# Insert the fixes
|
||||
new_content = content[:insert_pos] + fixes_code + content[insert_pos:]
|
||||
|
||||
# Write the modified content back to main.py
|
||||
with open(main_py_path, 'w') as f:
|
||||
f.write(new_content)
|
||||
|
||||
logger.info(f"Successfully applied fixes to {main_py_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying fixes to {main_py_path}: {e}")
|
||||
return False
|
||||
|
||||
def apply_fixes_to_dashboard():
|
||||
"""Apply fixes to web/clean_dashboard.py"""
|
||||
dashboard_py_path = "web/clean_dashboard.py"
|
||||
|
||||
if not os.path.exists(dashboard_py_path):
|
||||
logger.error(f"File {dashboard_py_path} not found")
|
||||
return False
|
||||
|
||||
# Create backup
|
||||
if not backup_file(dashboard_py_path):
|
||||
logger.error("Failed to create backup, aborting")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Read dashboard.py
|
||||
with open(dashboard_py_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Find the position to insert the fixes
|
||||
# Look for the __init__ method
|
||||
init_pattern = r"def __init__\(self,"
|
||||
match = re.search(init_pattern, content)
|
||||
|
||||
if not match:
|
||||
logger.error("Could not find __init__ method in dashboard.py")
|
||||
return False
|
||||
|
||||
# Find the end of the __init__ method
|
||||
init_end_pattern = r"logger\.debug\(.*\)"
|
||||
init_end_matches = list(re.finditer(init_end_pattern, content[match.end():]))
|
||||
|
||||
if not init_end_matches:
|
||||
logger.error("Could not find end of __init__ method in dashboard.py")
|
||||
return False
|
||||
|
||||
# Get the last logger.debug line in the __init__ method
|
||||
last_debug_match = init_end_matches[-1]
|
||||
insert_pos = match.end() + last_debug_match.end()
|
||||
|
||||
# Prepare the fixes to insert
|
||||
fixes_code = """
|
||||
|
||||
# Apply dashboard fixes if available
|
||||
try:
|
||||
from web.dashboard_fix import DashboardFix
|
||||
DashboardFix.apply_fixes(self)
|
||||
logger.info("✅ Dashboard fixes applied during initialization")
|
||||
except ImportError:
|
||||
logger.warning("Dashboard fixes not available")
|
||||
"""
|
||||
|
||||
# Insert the fixes
|
||||
new_content = content[:insert_pos] + fixes_code + content[insert_pos:]
|
||||
|
||||
# Write the modified content back to dashboard.py
|
||||
with open(dashboard_py_path, 'w') as f:
|
||||
f.write(new_content)
|
||||
|
||||
logger.info(f"Successfully applied fixes to {dashboard_py_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying fixes to {dashboard_py_path}: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
logger.info("=" * 70)
|
||||
logger.info("APPLYING TRADING SYSTEM FIXES TO MAIN.PY")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
# Apply fixes to main.py
|
||||
main_success = apply_fixes_to_main()
|
||||
|
||||
# Apply fixes to dashboard.py
|
||||
dashboard_success = apply_fixes_to_dashboard()
|
||||
|
||||
if main_success and dashboard_success:
|
||||
logger.info("=" * 70)
|
||||
logger.info("TRADING SYSTEM FIXES APPLIED SUCCESSFULLY")
|
||||
logger.info("=" * 70)
|
||||
logger.info("The following issues have been fixed:")
|
||||
logger.info("1. Duplicate entry prices")
|
||||
logger.info("2. P&L calculation issues")
|
||||
logger.info("3. Position tracking problems")
|
||||
logger.info("4. Trade display issues")
|
||||
logger.info("5. Rapid consecutive trades")
|
||||
logger.info("=" * 70)
|
||||
logger.info("You can now run the trading system with the fixes applied:")
|
||||
logger.info("python main.py")
|
||||
logger.info("=" * 70)
|
||||
return 0
|
||||
else:
|
||||
logger.error("=" * 70)
|
||||
logger.error("FAILED TO APPLY SOME FIXES")
|
||||
logger.error("=" * 70)
|
||||
logger.error("Please check the logs for details")
|
||||
logger.error("=" * 70)
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,189 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Balance Trading Signals - Analyze and fix SHORT signal bias
|
||||
|
||||
This script analyzes the trading signals from the orchestrator and adjusts
|
||||
the model weights to balance BUY and SELL signals.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import get_config, setup_logging
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def analyze_trading_signals():
|
||||
"""Analyze trading signals from the orchestrator"""
|
||||
logger.info("Analyzing trading signals...")
|
||||
|
||||
# Initialize components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
||||
|
||||
# Get recent decisions
|
||||
symbols = orchestrator.symbols
|
||||
all_decisions = {}
|
||||
|
||||
for symbol in symbols:
|
||||
decisions = orchestrator.get_recent_decisions(symbol)
|
||||
all_decisions[symbol] = decisions
|
||||
|
||||
# Count actions
|
||||
action_counts = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||
for decision in decisions:
|
||||
action_counts[decision.action] += 1
|
||||
|
||||
total_decisions = sum(action_counts.values())
|
||||
if total_decisions > 0:
|
||||
buy_percent = action_counts['BUY'] / total_decisions * 100
|
||||
sell_percent = action_counts['SELL'] / total_decisions * 100
|
||||
hold_percent = action_counts['HOLD'] / total_decisions * 100
|
||||
|
||||
logger.info(f"Symbol: {symbol}")
|
||||
logger.info(f" Total decisions: {total_decisions}")
|
||||
logger.info(f" BUY: {action_counts['BUY']} ({buy_percent:.1f}%)")
|
||||
logger.info(f" SELL: {action_counts['SELL']} ({sell_percent:.1f}%)")
|
||||
logger.info(f" HOLD: {action_counts['HOLD']} ({hold_percent:.1f}%)")
|
||||
|
||||
# Check for bias
|
||||
if sell_percent > buy_percent * 2: # If SELL signals are more than twice BUY signals
|
||||
logger.warning(f" SELL bias detected: {sell_percent:.1f}% vs {buy_percent:.1f}%")
|
||||
|
||||
# Adjust model weights to balance signals
|
||||
logger.info(" Adjusting model weights to balance signals...")
|
||||
|
||||
# Get current model weights
|
||||
model_weights = orchestrator.model_weights
|
||||
logger.info(f" Current model weights: {model_weights}")
|
||||
|
||||
# Identify models with SELL bias
|
||||
model_predictions = {}
|
||||
for model_name in model_weights:
|
||||
model_predictions[model_name] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||
|
||||
# Analyze recent decisions to identify biased models
|
||||
for decision in decisions:
|
||||
reasoning = decision.reasoning
|
||||
if 'models_used' in reasoning:
|
||||
for model_name in reasoning['models_used']:
|
||||
if model_name in model_predictions:
|
||||
model_predictions[model_name][decision.action] += 1
|
||||
|
||||
# Calculate bias for each model
|
||||
model_bias = {}
|
||||
for model_name, actions in model_predictions.items():
|
||||
total = sum(actions.values())
|
||||
if total > 0:
|
||||
buy_pct = actions['BUY'] / total * 100
|
||||
sell_pct = actions['SELL'] / total * 100
|
||||
|
||||
# Calculate bias score (-100 to 100, negative = SELL bias, positive = BUY bias)
|
||||
bias_score = buy_pct - sell_pct
|
||||
model_bias[model_name] = bias_score
|
||||
|
||||
logger.info(f" Model {model_name}: Bias score = {bias_score:.1f} (BUY: {buy_pct:.1f}%, SELL: {sell_pct:.1f}%)")
|
||||
|
||||
# Adjust weights based on bias
|
||||
adjusted_weights = {}
|
||||
for model_name, weight in model_weights.items():
|
||||
if model_name in model_bias:
|
||||
bias = model_bias[model_name]
|
||||
|
||||
# If model has strong SELL bias, reduce its weight
|
||||
if bias < -30: # Strong SELL bias
|
||||
adjusted_weights[model_name] = max(0.05, weight * 0.7) # Reduce weight by 30%
|
||||
logger.info(f" Reducing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} due to SELL bias")
|
||||
# If model has BUY bias, increase its weight to balance
|
||||
elif bias > 10: # BUY bias
|
||||
adjusted_weights[model_name] = min(0.5, weight * 1.3) # Increase weight by 30%
|
||||
logger.info(f" Increasing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} to balance SELL bias")
|
||||
else:
|
||||
adjusted_weights[model_name] = weight
|
||||
else:
|
||||
adjusted_weights[model_name] = weight
|
||||
|
||||
# Save adjusted weights
|
||||
save_adjusted_weights(adjusted_weights)
|
||||
|
||||
logger.info(f" Adjusted weights: {adjusted_weights}")
|
||||
logger.info(" Weights saved to 'adjusted_model_weights.json'")
|
||||
|
||||
# Recommend next steps
|
||||
logger.info("\nRecommended actions:")
|
||||
logger.info("1. Update the model weights in the orchestrator")
|
||||
logger.info("2. Monitor trading signals for balance")
|
||||
logger.info("3. Consider retraining models with balanced data")
|
||||
|
||||
def save_adjusted_weights(weights):
|
||||
"""Save adjusted weights to a file"""
|
||||
output = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'weights': weights,
|
||||
'notes': 'Adjusted to balance BUY/SELL signals'
|
||||
}
|
||||
|
||||
with open('adjusted_model_weights.json', 'w') as f:
|
||||
json.dump(output, f, indent=2)
|
||||
|
||||
def apply_balanced_weights():
|
||||
"""Apply balanced weights to the orchestrator"""
|
||||
try:
|
||||
# Check if weights file exists
|
||||
if not os.path.exists('adjusted_model_weights.json'):
|
||||
logger.error("Adjusted weights file not found. Run analyze_trading_signals() first.")
|
||||
return False
|
||||
|
||||
# Load adjusted weights
|
||||
with open('adjusted_model_weights.json', 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
weights = data.get('weights', {})
|
||||
if not weights:
|
||||
logger.error("No weights found in the file.")
|
||||
return False
|
||||
|
||||
logger.info(f"Loaded adjusted weights: {weights}")
|
||||
|
||||
# Initialize components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
||||
|
||||
# Apply weights
|
||||
for model_name, weight in weights.items():
|
||||
if model_name in orchestrator.model_weights:
|
||||
orchestrator.model_weights[model_name] = weight
|
||||
|
||||
# Save updated weights
|
||||
orchestrator._save_orchestrator_state()
|
||||
|
||||
logger.info("Applied balanced weights to orchestrator.")
|
||||
logger.info("Restart the trading system for changes to take effect.")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying balanced weights: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("=" * 70)
|
||||
logger.info("TRADING SIGNAL BALANCE ANALYZER")
|
||||
logger.info("=" * 70)
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == 'apply':
|
||||
apply_balanced_weights()
|
||||
else:
|
||||
analyze_trading_signals()
|
||||
@@ -1,163 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import importlib
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
from safe_logging import setup_safe_logging
|
||||
|
||||
# Configure logging
|
||||
setup_safe_logging()
|
||||
logger = logging.getLogger("check_live_trading")
|
||||
|
||||
def check_dependencies():
|
||||
"""Check if all required dependencies are installed"""
|
||||
required_packages = [
|
||||
"numpy", "pandas", "matplotlib", "mplfinance", "torch",
|
||||
"dotenv", "ccxt", "websockets", "tensorboard",
|
||||
"sklearn", "PIL", "asyncio"
|
||||
]
|
||||
|
||||
missing_packages = []
|
||||
|
||||
for package in required_packages:
|
||||
try:
|
||||
if package == "dotenv":
|
||||
importlib.import_module("dotenv")
|
||||
elif package == "PIL":
|
||||
importlib.import_module("PIL")
|
||||
else:
|
||||
importlib.import_module(package)
|
||||
logger.info(f"✅ {package} is installed")
|
||||
except ImportError:
|
||||
missing_packages.append(package)
|
||||
logger.error(f"❌ {package} is NOT installed")
|
||||
|
||||
if missing_packages:
|
||||
logger.error(f"Missing packages: {', '.join(missing_packages)}")
|
||||
logger.info("Install missing packages with: pip install -r requirements.txt")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def check_api_keys():
|
||||
"""Check if API keys are configured"""
|
||||
load_dotenv()
|
||||
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
secret_key = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
if not api_key or api_key == "your_api_key_here" or not secret_key or secret_key == "your_secret_key_here":
|
||||
logger.error("❌ API keys are not properly configured in .env file")
|
||||
logger.info("Please update your .env file with valid MEXC API keys")
|
||||
return False
|
||||
|
||||
logger.info("✅ API keys are configured")
|
||||
return True
|
||||
|
||||
def check_model_files():
|
||||
"""Check if trained model files exist"""
|
||||
model_files = [
|
||||
"models/trading_agent_best_pnl.pt",
|
||||
"models/trading_agent_best_reward.pt",
|
||||
"models/trading_agent_final.pt"
|
||||
]
|
||||
|
||||
missing_models = []
|
||||
|
||||
for model_file in model_files:
|
||||
if os.path.exists(model_file):
|
||||
logger.info(f"✅ Model file exists: {model_file}")
|
||||
else:
|
||||
missing_models.append(model_file)
|
||||
logger.error(f"❌ Model file missing: {model_file}")
|
||||
|
||||
if missing_models:
|
||||
logger.warning("Some model files are missing. You need to train the model first.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def check_exchange_connection():
|
||||
"""Test connection to MEXC exchange"""
|
||||
try:
|
||||
import ccxt
|
||||
|
||||
# Load API keys
|
||||
load_dotenv()
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
secret_key = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
if api_key == "your_api_key_here" or secret_key == "your_secret_key_here":
|
||||
logger.warning("⚠️ Using placeholder API keys, skipping exchange connection test")
|
||||
return False
|
||||
|
||||
# Initialize exchange
|
||||
exchange = ccxt.mexc({
|
||||
'apiKey': api_key,
|
||||
'secret': secret_key,
|
||||
'enableRateLimit': True
|
||||
})
|
||||
|
||||
# Test connection by fetching markets
|
||||
markets = exchange.fetch_markets()
|
||||
logger.info(f"✅ Successfully connected to MEXC exchange")
|
||||
logger.info(f"✅ Found {len(markets)} markets")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to connect to MEXC exchange: {str(e)}")
|
||||
return False
|
||||
|
||||
def check_directories():
|
||||
"""Check if required directories exist"""
|
||||
required_dirs = ["models", "runs", "trade_logs"]
|
||||
|
||||
for directory in required_dirs:
|
||||
if not os.path.exists(directory):
|
||||
logger.info(f"Creating directory: {directory}")
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
logger.info("✅ All required directories exist")
|
||||
return True
|
||||
|
||||
async def main():
|
||||
"""Run all checks"""
|
||||
logger.info("Running pre-flight checks for live trading...")
|
||||
|
||||
checks = [
|
||||
("Dependencies", check_dependencies()),
|
||||
("API Keys", check_api_keys()),
|
||||
("Model Files", check_model_files()),
|
||||
("Directories", check_directories()),
|
||||
("Exchange Connection", await check_exchange_connection())
|
||||
]
|
||||
|
||||
# Count failed checks
|
||||
failed_checks = sum(1 for _, result in checks if not result)
|
||||
|
||||
# Print summary
|
||||
logger.info("\n" + "="*50)
|
||||
logger.info("LIVE TRADING PRE-FLIGHT CHECK SUMMARY")
|
||||
logger.info("="*50)
|
||||
|
||||
for check_name, result in checks:
|
||||
status = "✅ PASS" if result else "❌ FAIL"
|
||||
logger.info(f"{check_name}: {status}")
|
||||
|
||||
logger.info("="*50)
|
||||
|
||||
if failed_checks == 0:
|
||||
logger.info("🚀 All checks passed! You're ready for live trading.")
|
||||
logger.info("\nRun live trading with:")
|
||||
logger.info("python main.py --mode live --demo true --symbol ETH/USDT --timeframe 1m")
|
||||
logger.info("\nFor real trading (after updating API keys):")
|
||||
logger.info("python main.py --mode live --demo false --symbol ETH/USDT --timeframe 1m --leverage 50")
|
||||
return 0
|
||||
else:
|
||||
logger.error(f"❌ {failed_checks} check(s) failed. Please fix the issues before running live trading.")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code)
|
||||
@@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Check MEXC Available Trading Symbols
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def check_mexc_symbols():
|
||||
"""Check available trading symbols on MEXC"""
|
||||
try:
|
||||
logger.info("=== MEXC SYMBOL AVAILABILITY CHECK ===")
|
||||
|
||||
# Initialize trading executor
|
||||
executor = TradingExecutor("config.yaml")
|
||||
|
||||
if not executor.exchange:
|
||||
logger.error("Failed to initialize exchange")
|
||||
return
|
||||
|
||||
# Get all supported symbols
|
||||
logger.info("Fetching all supported symbols from MEXC...")
|
||||
supported_symbols = executor.exchange.get_api_symbols()
|
||||
|
||||
logger.info(f"Total supported symbols: {len(supported_symbols)}")
|
||||
|
||||
# Filter ETH-related symbols
|
||||
eth_symbols = [s for s in supported_symbols if 'ETH' in s]
|
||||
logger.info(f"ETH-related symbols ({len(eth_symbols)}):")
|
||||
for symbol in sorted(eth_symbols):
|
||||
logger.info(f" {symbol}")
|
||||
|
||||
# Filter USDT pairs
|
||||
usdt_symbols = [s for s in supported_symbols if s.endswith('USDT')]
|
||||
logger.info(f"USDT pairs ({len(usdt_symbols)}):")
|
||||
for symbol in sorted(usdt_symbols)[:20]: # Show first 20
|
||||
logger.info(f" {symbol}")
|
||||
if len(usdt_symbols) > 20:
|
||||
logger.info(f" ... and {len(usdt_symbols) - 20} more")
|
||||
|
||||
# Filter USDC pairs
|
||||
usdc_symbols = [s for s in supported_symbols if s.endswith('USDC')]
|
||||
logger.info(f"USDC pairs ({len(usdc_symbols)}):")
|
||||
for symbol in sorted(usdc_symbols):
|
||||
logger.info(f" {symbol}")
|
||||
|
||||
# Check specific symbols we're interested in
|
||||
test_symbols = ['ETHUSDT', 'ETHUSDC', 'BTCUSDT', 'BTCUSDC']
|
||||
logger.info("Checking specific symbols:")
|
||||
for symbol in test_symbols:
|
||||
if symbol in supported_symbols:
|
||||
logger.info(f" ✅ {symbol} - SUPPORTED")
|
||||
else:
|
||||
logger.info(f" ❌ {symbol} - NOT SUPPORTED")
|
||||
|
||||
# Show a sample of all available symbols
|
||||
logger.info("Sample of all available symbols:")
|
||||
for symbol in sorted(supported_symbols)[:30]:
|
||||
logger.info(f" {symbol}")
|
||||
if len(supported_symbols) > 30:
|
||||
logger.info(f" ... and {len(supported_symbols) - 30} more")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking MEXC symbols: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_mexc_symbols()
|
||||
@@ -1,108 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Cleanup Checkpoint Database
|
||||
|
||||
Remove invalid database entries and ensure consistency
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from utils.database_manager import get_database_manager
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def cleanup_invalid_checkpoints():
|
||||
"""Remove database entries for non-existent checkpoint files"""
|
||||
print("=== Cleaning Up Invalid Checkpoint Entries ===")
|
||||
|
||||
db_manager = get_database_manager()
|
||||
|
||||
# Get all checkpoints from database
|
||||
all_models = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target', 'cob_rl', 'extrema_trainer', 'decision']
|
||||
|
||||
removed_count = 0
|
||||
|
||||
for model_name in all_models:
|
||||
checkpoints = db_manager.list_checkpoints(model_name)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
file_path = Path(checkpoint.file_path)
|
||||
|
||||
if not file_path.exists():
|
||||
print(f"Removing invalid entry: {checkpoint.checkpoint_id} -> {checkpoint.file_path}")
|
||||
|
||||
# Remove from database by setting as inactive and creating a new active one if needed
|
||||
try:
|
||||
# For now, we'll just report - the system will handle missing files gracefully
|
||||
logger.warning(f"Invalid checkpoint file: {checkpoint.file_path}")
|
||||
removed_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove invalid checkpoint: {e}")
|
||||
else:
|
||||
print(f"Valid checkpoint: {checkpoint.checkpoint_id} -> {checkpoint.file_path}")
|
||||
|
||||
print(f"Found {removed_count} invalid checkpoint entries")
|
||||
|
||||
def verify_checkpoint_loading():
|
||||
"""Test that checkpoint loading works correctly"""
|
||||
print("\n=== Verifying Checkpoint Loading ===")
|
||||
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
|
||||
models_to_test = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target']
|
||||
|
||||
for model_name in models_to_test:
|
||||
try:
|
||||
result = load_best_checkpoint(model_name)
|
||||
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
file_exists = Path(file_path).exists()
|
||||
|
||||
print(f"{model_name}:")
|
||||
print(f" ✅ Checkpoint found: {metadata.checkpoint_id}")
|
||||
print(f" 📁 File exists: {file_exists}")
|
||||
print(f" 📊 Loss: {getattr(metadata, 'loss', 'N/A')}")
|
||||
print(f" 💾 Size: {Path(file_path).stat().st_size / (1024*1024):.1f}MB" if file_exists else " 💾 Size: N/A")
|
||||
else:
|
||||
print(f"{model_name}: ❌ No valid checkpoint found")
|
||||
|
||||
except Exception as e:
|
||||
print(f"{model_name}: ❌ Error loading checkpoint: {e}")
|
||||
|
||||
def test_checkpoint_system_integration():
|
||||
"""Test integration with the orchestrator"""
|
||||
print("\n=== Testing Orchestrator Integration ===")
|
||||
|
||||
try:
|
||||
# Test database manager integration
|
||||
from utils.database_manager import get_database_manager
|
||||
db_manager = get_database_manager()
|
||||
|
||||
# Test fast metadata access
|
||||
for model_name in ['dqn_agent', 'enhanced_cnn']:
|
||||
metadata = db_manager.get_best_checkpoint_metadata(model_name)
|
||||
if metadata:
|
||||
print(f"{model_name}: ✅ Fast metadata access works")
|
||||
print(f" ID: {metadata.checkpoint_id}")
|
||||
print(f" Loss: {metadata.performance_metrics.get('loss', 'N/A')}")
|
||||
else:
|
||||
print(f"{model_name}: ❌ No metadata found")
|
||||
|
||||
print("\n✅ Checkpoint system is ready for use!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Integration test failed: {e}")
|
||||
|
||||
def main():
|
||||
"""Main cleanup process"""
|
||||
cleanup_invalid_checkpoints()
|
||||
verify_checkpoint_loading()
|
||||
test_checkpoint_system_integration()
|
||||
|
||||
print("\n=== Cleanup Complete ===")
|
||||
print("The checkpoint system should now work without 'file not found' errors!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,186 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Checkpoint Cleanup and Migration Script
|
||||
|
||||
This script helps clean up existing checkpoints and migrate to the new
|
||||
checkpoint management system with W&B integration.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
import torch
|
||||
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointCleanup:
|
||||
def __init__(self):
|
||||
self.saved_models_dir = Path("NN/models/saved")
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
|
||||
def analyze_existing_checkpoints(self) -> Dict[str, Any]:
|
||||
logger.info("Analyzing existing checkpoint files...")
|
||||
|
||||
analysis = {
|
||||
'total_files': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'model_types': {},
|
||||
'file_patterns': {},
|
||||
'potential_duplicates': []
|
||||
}
|
||||
|
||||
if not self.saved_models_dir.exists():
|
||||
logger.warning(f"Saved models directory not found: {self.saved_models_dir}")
|
||||
return analysis
|
||||
|
||||
for pt_file in self.saved_models_dir.rglob("*.pt"):
|
||||
try:
|
||||
file_size_mb = pt_file.stat().st_size / (1024 * 1024)
|
||||
analysis['total_files'] += 1
|
||||
analysis['total_size_mb'] += file_size_mb
|
||||
|
||||
filename = pt_file.name
|
||||
|
||||
if 'cnn' in filename.lower():
|
||||
model_type = 'cnn'
|
||||
elif 'dqn' in filename.lower() or 'rl' in filename.lower():
|
||||
model_type = 'rl'
|
||||
elif 'agent' in filename.lower():
|
||||
model_type = 'rl'
|
||||
else:
|
||||
model_type = 'unknown'
|
||||
|
||||
if model_type not in analysis['model_types']:
|
||||
analysis['model_types'][model_type] = {'count': 0, 'size_mb': 0.0}
|
||||
|
||||
analysis['model_types'][model_type]['count'] += 1
|
||||
analysis['model_types'][model_type]['size_mb'] += file_size_mb
|
||||
|
||||
base_name = filename.split('_')[0] if '_' in filename else filename.replace('.pt', '')
|
||||
if base_name not in analysis['file_patterns']:
|
||||
analysis['file_patterns'][base_name] = []
|
||||
|
||||
analysis['file_patterns'][base_name].append({
|
||||
'path': str(pt_file),
|
||||
'size_mb': file_size_mb,
|
||||
'modified': datetime.fromtimestamp(pt_file.stat().st_mtime)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing {pt_file}: {e}")
|
||||
|
||||
for base_name, files in analysis['file_patterns'].items():
|
||||
if len(files) > 5: # More than 5 files with same base name
|
||||
analysis['potential_duplicates'].append({
|
||||
'base_name': base_name,
|
||||
'count': len(files),
|
||||
'total_size_mb': sum(f['size_mb'] for f in files),
|
||||
'files': files
|
||||
})
|
||||
|
||||
logger.info(f"Analysis complete:")
|
||||
logger.info(f" Total files: {analysis['total_files']}")
|
||||
logger.info(f" Total size: {analysis['total_size_mb']:.2f} MB")
|
||||
logger.info(f" Model types: {analysis['model_types']}")
|
||||
logger.info(f" Potential duplicates: {len(analysis['potential_duplicates'])}")
|
||||
|
||||
return analysis
|
||||
|
||||
def cleanup_duplicates(self, dry_run: bool = True) -> Dict[str, Any]:
|
||||
logger.info(f"Starting duplicate cleanup (dry_run={dry_run})...")
|
||||
|
||||
cleanup_results = {
|
||||
'removed': 0,
|
||||
'kept': 0,
|
||||
'space_saved_mb': 0.0,
|
||||
'details': []
|
||||
}
|
||||
|
||||
analysis = self.analyze_existing_checkpoints()
|
||||
|
||||
for duplicate_group in analysis['potential_duplicates']:
|
||||
base_name = duplicate_group['base_name']
|
||||
files = duplicate_group['files']
|
||||
|
||||
# Sort by modification time (newest first)
|
||||
files.sort(key=lambda x: x['modified'], reverse=True)
|
||||
|
||||
logger.info(f"Processing {base_name}: {len(files)} files")
|
||||
|
||||
# Keep only the 5 newest files
|
||||
for i, file_info in enumerate(files):
|
||||
if i < 5: # Keep first 5 (newest)
|
||||
cleanup_results['kept'] += 1
|
||||
cleanup_results['details'].append({
|
||||
'action': 'kept',
|
||||
'file': file_info['path']
|
||||
})
|
||||
else: # Remove the rest
|
||||
if not dry_run:
|
||||
try:
|
||||
Path(file_info['path']).unlink()
|
||||
logger.info(f"Removed: {file_info['path']}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing {file_info['path']}: {e}")
|
||||
continue
|
||||
|
||||
cleanup_results['removed'] += 1
|
||||
cleanup_results['space_saved_mb'] += file_info['size_mb']
|
||||
cleanup_results['details'].append({
|
||||
'action': 'removed',
|
||||
'file': file_info['path'],
|
||||
'size_mb': file_info['size_mb']
|
||||
})
|
||||
|
||||
logger.info(f"Cleanup {'simulation' if dry_run else 'complete'}:")
|
||||
logger.info(f" Kept: {cleanup_results['kept']}")
|
||||
logger.info(f" Removed: {cleanup_results['removed']}")
|
||||
logger.info(f" Space saved: {cleanup_results['space_saved_mb']:.2f} MB")
|
||||
|
||||
return cleanup_results
|
||||
|
||||
def main():
|
||||
logger.info("=== Checkpoint Cleanup Tool ===")
|
||||
|
||||
cleanup = CheckpointCleanup()
|
||||
|
||||
# Analyze existing checkpoints
|
||||
logger.info("\\n1. Analyzing existing checkpoints...")
|
||||
analysis = cleanup.analyze_existing_checkpoints()
|
||||
|
||||
if analysis['total_files'] == 0:
|
||||
logger.info("No checkpoint files found.")
|
||||
return
|
||||
|
||||
# Show potential space savings
|
||||
total_duplicates = sum(len(group['files']) - 5 for group in analysis['potential_duplicates'] if len(group['files']) > 5)
|
||||
if total_duplicates > 0:
|
||||
logger.info(f"\\nFound {total_duplicates} files that could be cleaned up")
|
||||
|
||||
# Dry run first
|
||||
logger.info("\\n2. Simulating cleanup...")
|
||||
dry_run_results = cleanup.cleanup_duplicates(dry_run=True)
|
||||
|
||||
if dry_run_results['removed'] > 0:
|
||||
proceed = input(f"\\nProceed with cleanup? Will remove {dry_run_results['removed']} files "
|
||||
f"and save {dry_run_results['space_saved_mb']:.2f} MB. (y/n): ").lower().strip() == 'y'
|
||||
|
||||
if proceed:
|
||||
logger.info("\\n3. Performing actual cleanup...")
|
||||
cleanup_results = cleanup.cleanup_duplicates(dry_run=False)
|
||||
logger.info("\\n=== Cleanup Complete ===")
|
||||
else:
|
||||
logger.info("Cleanup cancelled.")
|
||||
else:
|
||||
logger.info("No files to remove.")
|
||||
else:
|
||||
logger.info("No duplicate files found that need cleanup.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,402 +0,0 @@
|
||||
"""
|
||||
API Rate Limiter and Error Handler
|
||||
|
||||
This module provides robust rate limiting and error handling for API requests,
|
||||
specifically designed to handle Binance's aggressive rate limiting (HTTP 418 errors)
|
||||
and other exchange API limitations.
|
||||
|
||||
Features:
|
||||
- Exponential backoff for rate limiting
|
||||
- IP rotation and proxy support
|
||||
- Request queuing and throttling
|
||||
- Error recovery strategies
|
||||
- Thread-safe operations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""Configuration for rate limiting"""
|
||||
requests_per_second: float = 0.5 # Very conservative for Binance
|
||||
requests_per_minute: int = 20
|
||||
requests_per_hour: int = 1000
|
||||
|
||||
# Backoff configuration
|
||||
initial_backoff: float = 1.0
|
||||
max_backoff: float = 300.0 # 5 minutes max
|
||||
backoff_multiplier: float = 2.0
|
||||
|
||||
# Error handling
|
||||
max_retries: int = 3
|
||||
retry_delay: float = 5.0
|
||||
|
||||
# IP blocking detection
|
||||
block_detection_threshold: int = 3 # 3 consecutive 418s = blocked
|
||||
block_recovery_time: int = 3600 # 1 hour recovery time
|
||||
|
||||
@dataclass
|
||||
class APIEndpoint:
|
||||
"""API endpoint configuration"""
|
||||
name: str
|
||||
base_url: str
|
||||
rate_limit: RateLimitConfig
|
||||
last_request_time: float = 0.0
|
||||
request_count_minute: int = 0
|
||||
request_count_hour: int = 0
|
||||
consecutive_errors: int = 0
|
||||
blocked_until: Optional[datetime] = None
|
||||
|
||||
# Request history for rate limiting
|
||||
request_history: deque = field(default_factory=lambda: deque(maxlen=3600)) # 1 hour history
|
||||
|
||||
class APIRateLimiter:
|
||||
"""Thread-safe API rate limiter with error handling"""
|
||||
|
||||
def __init__(self, config: RateLimitConfig = None):
|
||||
self.config = config or RateLimitConfig()
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Endpoint tracking
|
||||
self.endpoints: Dict[str, APIEndpoint] = {}
|
||||
|
||||
# Global rate limiting
|
||||
self.global_request_history = deque(maxlen=3600)
|
||||
self.global_blocked_until: Optional[datetime] = None
|
||||
|
||||
# Request session with retry strategy
|
||||
self.session = self._create_session()
|
||||
|
||||
# Background cleanup thread
|
||||
self.cleanup_thread = None
|
||||
self.is_running = False
|
||||
|
||||
logger.info("API Rate Limiter initialized")
|
||||
logger.info(f"Rate limits: {self.config.requests_per_second}/s, {self.config.requests_per_minute}/m")
|
||||
|
||||
def _create_session(self) -> requests.Session:
|
||||
"""Create requests session with retry strategy"""
|
||||
session = requests.Session()
|
||||
|
||||
# Retry strategy
|
||||
retry_strategy = Retry(
|
||||
total=self.config.max_retries,
|
||||
backoff_factor=1,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
allowed_methods=["HEAD", "GET", "OPTIONS"]
|
||||
)
|
||||
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
# Headers to appear more legitimate
|
||||
session.headers.update({
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Accept': 'application/json',
|
||||
'Accept-Language': 'en-US,en;q=0.9',
|
||||
'Accept-Encoding': 'gzip, deflate, br',
|
||||
'Connection': 'keep-alive',
|
||||
'Upgrade-Insecure-Requests': '1',
|
||||
})
|
||||
|
||||
return session
|
||||
|
||||
def register_endpoint(self, name: str, base_url: str, rate_limit: RateLimitConfig = None):
|
||||
"""Register an API endpoint for rate limiting"""
|
||||
with self.lock:
|
||||
self.endpoints[name] = APIEndpoint(
|
||||
name=name,
|
||||
base_url=base_url,
|
||||
rate_limit=rate_limit or self.config
|
||||
)
|
||||
logger.info(f"Registered endpoint: {name} -> {base_url}")
|
||||
|
||||
def start_background_cleanup(self):
|
||||
"""Start background cleanup thread"""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_worker, daemon=True)
|
||||
self.cleanup_thread.start()
|
||||
logger.info("Started background cleanup thread")
|
||||
|
||||
def stop_background_cleanup(self):
|
||||
"""Stop background cleanup thread"""
|
||||
self.is_running = False
|
||||
if self.cleanup_thread:
|
||||
self.cleanup_thread.join(timeout=5)
|
||||
logger.info("Stopped background cleanup thread")
|
||||
|
||||
def _cleanup_worker(self):
|
||||
"""Background worker to clean up old request history"""
|
||||
while self.is_running:
|
||||
try:
|
||||
current_time = time.time()
|
||||
cutoff_time = current_time - 3600 # 1 hour ago
|
||||
|
||||
with self.lock:
|
||||
# Clean global history
|
||||
while (self.global_request_history and
|
||||
self.global_request_history[0] < cutoff_time):
|
||||
self.global_request_history.popleft()
|
||||
|
||||
# Clean endpoint histories
|
||||
for endpoint in self.endpoints.values():
|
||||
while (endpoint.request_history and
|
||||
endpoint.request_history[0] < cutoff_time):
|
||||
endpoint.request_history.popleft()
|
||||
|
||||
# Reset counters
|
||||
endpoint.request_count_minute = len([
|
||||
t for t in endpoint.request_history
|
||||
if t > current_time - 60
|
||||
])
|
||||
endpoint.request_count_hour = len(endpoint.request_history)
|
||||
|
||||
time.sleep(60) # Clean every minute
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup worker: {e}")
|
||||
time.sleep(30)
|
||||
|
||||
def can_make_request(self, endpoint_name: str) -> tuple[bool, float]:
|
||||
"""
|
||||
Check if we can make a request to the endpoint
|
||||
|
||||
Returns:
|
||||
(can_make_request, wait_time_seconds)
|
||||
"""
|
||||
with self.lock:
|
||||
current_time = time.time()
|
||||
|
||||
# Check global blocking
|
||||
if self.global_blocked_until and datetime.now() < self.global_blocked_until:
|
||||
wait_time = (self.global_blocked_until - datetime.now()).total_seconds()
|
||||
return False, wait_time
|
||||
|
||||
# Get endpoint
|
||||
endpoint = self.endpoints.get(endpoint_name)
|
||||
if not endpoint:
|
||||
logger.warning(f"Unknown endpoint: {endpoint_name}")
|
||||
return False, 60.0
|
||||
|
||||
# Check endpoint blocking
|
||||
if endpoint.blocked_until and datetime.now() < endpoint.blocked_until:
|
||||
wait_time = (endpoint.blocked_until - datetime.now()).total_seconds()
|
||||
return False, wait_time
|
||||
|
||||
# Check rate limits
|
||||
config = endpoint.rate_limit
|
||||
|
||||
# Per-second rate limit
|
||||
time_since_last = current_time - endpoint.last_request_time
|
||||
if time_since_last < (1.0 / config.requests_per_second):
|
||||
wait_time = (1.0 / config.requests_per_second) - time_since_last
|
||||
return False, wait_time
|
||||
|
||||
# Per-minute rate limit
|
||||
minute_requests = len([
|
||||
t for t in endpoint.request_history
|
||||
if t > current_time - 60
|
||||
])
|
||||
if minute_requests >= config.requests_per_minute:
|
||||
return False, 60.0
|
||||
|
||||
# Per-hour rate limit
|
||||
if len(endpoint.request_history) >= config.requests_per_hour:
|
||||
return False, 3600.0
|
||||
|
||||
return True, 0.0
|
||||
|
||||
def make_request(self, endpoint_name: str, url: str, method: str = 'GET',
|
||||
**kwargs) -> Optional[requests.Response]:
|
||||
"""
|
||||
Make a rate-limited request with error handling
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the registered endpoint
|
||||
url: Full URL to request
|
||||
method: HTTP method
|
||||
**kwargs: Additional arguments for requests
|
||||
|
||||
Returns:
|
||||
Response object or None if failed
|
||||
"""
|
||||
with self.lock:
|
||||
endpoint = self.endpoints.get(endpoint_name)
|
||||
if not endpoint:
|
||||
logger.error(f"Unknown endpoint: {endpoint_name}")
|
||||
return None
|
||||
|
||||
# Check if we can make the request
|
||||
can_request, wait_time = self.can_make_request(endpoint_name)
|
||||
if not can_request:
|
||||
logger.debug(f"Rate limited for {endpoint_name}, waiting {wait_time:.2f}s")
|
||||
time.sleep(min(wait_time, 30)) # Cap wait time
|
||||
return None
|
||||
|
||||
# Record request attempt
|
||||
current_time = time.time()
|
||||
endpoint.last_request_time = current_time
|
||||
endpoint.request_history.append(current_time)
|
||||
self.global_request_history.append(current_time)
|
||||
|
||||
# Add jitter to avoid thundering herd
|
||||
jitter = random.uniform(0.1, 0.5)
|
||||
time.sleep(jitter)
|
||||
|
||||
# Make the request (outside of lock to avoid blocking other threads)
|
||||
try:
|
||||
# Set timeout
|
||||
kwargs.setdefault('timeout', 10)
|
||||
|
||||
# Make request
|
||||
response = self.session.request(method, url, **kwargs)
|
||||
|
||||
# Handle response
|
||||
with self.lock:
|
||||
if response.status_code == 200:
|
||||
# Success - reset error counter
|
||||
endpoint.consecutive_errors = 0
|
||||
return response
|
||||
|
||||
elif response.status_code == 418:
|
||||
# Binance "I'm a teapot" - rate limited/blocked
|
||||
endpoint.consecutive_errors += 1
|
||||
logger.warning(f"HTTP 418 (rate limited) for {endpoint_name}, consecutive errors: {endpoint.consecutive_errors}")
|
||||
|
||||
if endpoint.consecutive_errors >= endpoint.rate_limit.block_detection_threshold:
|
||||
# We're likely IP blocked
|
||||
block_time = datetime.now() + timedelta(seconds=endpoint.rate_limit.block_recovery_time)
|
||||
endpoint.blocked_until = block_time
|
||||
logger.error(f"Endpoint {endpoint_name} blocked until {block_time}")
|
||||
|
||||
return None
|
||||
|
||||
elif response.status_code == 429:
|
||||
# Too many requests
|
||||
endpoint.consecutive_errors += 1
|
||||
logger.warning(f"HTTP 429 (too many requests) for {endpoint_name}")
|
||||
|
||||
# Implement exponential backoff
|
||||
backoff_time = min(
|
||||
endpoint.rate_limit.initial_backoff * (endpoint.rate_limit.backoff_multiplier ** endpoint.consecutive_errors),
|
||||
endpoint.rate_limit.max_backoff
|
||||
)
|
||||
|
||||
block_time = datetime.now() + timedelta(seconds=backoff_time)
|
||||
endpoint.blocked_until = block_time
|
||||
logger.warning(f"Backing off {endpoint_name} for {backoff_time:.2f}s")
|
||||
|
||||
return None
|
||||
|
||||
else:
|
||||
# Other error
|
||||
endpoint.consecutive_errors += 1
|
||||
logger.warning(f"HTTP {response.status_code} for {endpoint_name}: {response.text[:200]}")
|
||||
return None
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
with self.lock:
|
||||
endpoint.consecutive_errors += 1
|
||||
logger.error(f"Request exception for {endpoint_name}: {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
with self.lock:
|
||||
endpoint.consecutive_errors += 1
|
||||
logger.error(f"Unexpected error for {endpoint_name}: {e}")
|
||||
return None
|
||||
|
||||
def get_endpoint_status(self, endpoint_name: str) -> Dict[str, Any]:
|
||||
"""Get status information for an endpoint"""
|
||||
with self.lock:
|
||||
endpoint = self.endpoints.get(endpoint_name)
|
||||
if not endpoint:
|
||||
return {'error': 'Unknown endpoint'}
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
return {
|
||||
'name': endpoint.name,
|
||||
'base_url': endpoint.base_url,
|
||||
'consecutive_errors': endpoint.consecutive_errors,
|
||||
'blocked_until': endpoint.blocked_until.isoformat() if endpoint.blocked_until else None,
|
||||
'requests_last_minute': len([t for t in endpoint.request_history if t > current_time - 60]),
|
||||
'requests_last_hour': len(endpoint.request_history),
|
||||
'last_request_time': endpoint.last_request_time,
|
||||
'can_make_request': self.can_make_request(endpoint_name)[0]
|
||||
}
|
||||
|
||||
def get_all_endpoint_status(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get status for all endpoints"""
|
||||
return {name: self.get_endpoint_status(name) for name in self.endpoints.keys()}
|
||||
|
||||
def reset_endpoint(self, endpoint_name: str):
|
||||
"""Reset an endpoint's error state"""
|
||||
with self.lock:
|
||||
endpoint = self.endpoints.get(endpoint_name)
|
||||
if endpoint:
|
||||
endpoint.consecutive_errors = 0
|
||||
endpoint.blocked_until = None
|
||||
logger.info(f"Reset endpoint: {endpoint_name}")
|
||||
|
||||
def reset_all_endpoints(self):
|
||||
"""Reset all endpoints' error states"""
|
||||
with self.lock:
|
||||
for endpoint in self.endpoints.values():
|
||||
endpoint.consecutive_errors = 0
|
||||
endpoint.blocked_until = None
|
||||
self.global_blocked_until = None
|
||||
logger.info("Reset all endpoints")
|
||||
|
||||
# Global rate limiter instance
|
||||
_global_rate_limiter = None
|
||||
|
||||
def get_rate_limiter() -> APIRateLimiter:
|
||||
"""Get global rate limiter instance"""
|
||||
global _global_rate_limiter
|
||||
if _global_rate_limiter is None:
|
||||
_global_rate_limiter = APIRateLimiter()
|
||||
_global_rate_limiter.start_background_cleanup()
|
||||
|
||||
# Register common endpoints
|
||||
_global_rate_limiter.register_endpoint(
|
||||
'binance_api',
|
||||
'https://api.binance.com',
|
||||
RateLimitConfig(
|
||||
requests_per_second=0.2, # Very conservative
|
||||
requests_per_minute=10,
|
||||
requests_per_hour=500
|
||||
)
|
||||
)
|
||||
|
||||
_global_rate_limiter.register_endpoint(
|
||||
'mexc_api',
|
||||
'https://api.mexc.com',
|
||||
RateLimitConfig(
|
||||
requests_per_second=0.5,
|
||||
requests_per_minute=20,
|
||||
requests_per_hour=1000
|
||||
)
|
||||
)
|
||||
|
||||
return _global_rate_limiter
|
||||
@@ -1,442 +0,0 @@
|
||||
"""
|
||||
Async Handler for UI Stability Fix
|
||||
|
||||
Properly handles all async operations in the dashboard with single event loop management,
|
||||
proper exception handling, and timeout support to prevent async/await errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Coroutine, Dict, Optional, Union
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import functools
|
||||
import weakref
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncOperationError(Exception):
|
||||
"""Exception raised for async operation errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncHandler:
|
||||
"""
|
||||
Centralized async operation handler with single event loop management
|
||||
and proper exception handling for async operations.
|
||||
"""
|
||||
|
||||
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
|
||||
"""
|
||||
Initialize the async handler
|
||||
|
||||
Args:
|
||||
loop: Optional event loop to use. If None, creates a new one.
|
||||
"""
|
||||
self._loop = loop
|
||||
self._thread = None
|
||||
self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="AsyncHandler")
|
||||
self._running = False
|
||||
self._callbacks = weakref.WeakSet()
|
||||
self._timeout_default = 30.0 # Default timeout for operations
|
||||
|
||||
# Start the event loop in a separate thread if not provided
|
||||
if self._loop is None:
|
||||
self._start_event_loop_thread()
|
||||
|
||||
logger.info("AsyncHandler initialized with event loop management")
|
||||
|
||||
def _start_event_loop_thread(self):
|
||||
"""Start the event loop in a separate thread"""
|
||||
def run_event_loop():
|
||||
"""Run the event loop in a separate thread"""
|
||||
try:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._running = True
|
||||
logger.debug("Event loop started in separate thread")
|
||||
self._loop.run_forever()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event loop thread: {e}")
|
||||
finally:
|
||||
self._running = False
|
||||
logger.debug("Event loop thread stopped")
|
||||
|
||||
self._thread = threading.Thread(target=run_event_loop, daemon=True, name="AsyncHandler-EventLoop")
|
||||
self._thread.start()
|
||||
|
||||
# Wait for the loop to be ready
|
||||
timeout = 5.0
|
||||
start_time = time.time()
|
||||
while not self._running and (time.time() - start_time) < timeout:
|
||||
time.sleep(0.1)
|
||||
|
||||
if not self._running:
|
||||
raise AsyncOperationError("Failed to start event loop within timeout")
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the async handler is running"""
|
||||
return self._running and self._loop is not None and not self._loop.is_closed()
|
||||
|
||||
def run_async_safely(self, coro: Coroutine, timeout: Optional[float] = None) -> Any:
|
||||
"""
|
||||
Run an async coroutine safely with proper error handling and timeout
|
||||
|
||||
Args:
|
||||
coro: The coroutine to run
|
||||
timeout: Timeout in seconds (uses default if None)
|
||||
|
||||
Returns:
|
||||
The result of the coroutine
|
||||
|
||||
Raises:
|
||||
AsyncOperationError: If the operation fails or times out
|
||||
"""
|
||||
if not self.is_running():
|
||||
raise AsyncOperationError("AsyncHandler is not running")
|
||||
|
||||
timeout = timeout or self._timeout_default
|
||||
|
||||
try:
|
||||
# Schedule the coroutine on the event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
asyncio.wait_for(coro, timeout=timeout),
|
||||
self._loop
|
||||
)
|
||||
|
||||
# Wait for the result with timeout
|
||||
result = future.result(timeout=timeout + 1.0) # Add buffer to future timeout
|
||||
logger.debug("Async operation completed successfully")
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Async operation timed out after {timeout} seconds")
|
||||
raise AsyncOperationError(f"Operation timed out after {timeout} seconds")
|
||||
except Exception as e:
|
||||
logger.error(f"Async operation failed: {e}")
|
||||
raise AsyncOperationError(f"Async operation failed: {e}")
|
||||
|
||||
def schedule_coroutine(self, coro: Coroutine, callback: Optional[Callable] = None) -> None:
|
||||
"""
|
||||
Schedule a coroutine to run asynchronously without waiting for result
|
||||
|
||||
Args:
|
||||
coro: The coroutine to schedule
|
||||
callback: Optional callback to call with the result
|
||||
"""
|
||||
if not self.is_running():
|
||||
logger.warning("Cannot schedule coroutine: AsyncHandler is not running")
|
||||
return
|
||||
|
||||
async def wrapped_coro():
|
||||
"""Wrapper to handle exceptions and callbacks"""
|
||||
try:
|
||||
result = await coro
|
||||
if callback:
|
||||
try:
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in coroutine callback: {e}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scheduled coroutine: {e}")
|
||||
if callback:
|
||||
try:
|
||||
callback(None) # Call callback with None on error
|
||||
except Exception as cb_e:
|
||||
logger.error(f"Error in error callback: {cb_e}")
|
||||
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(wrapped_coro(), self._loop)
|
||||
logger.debug("Coroutine scheduled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to schedule coroutine: {e}")
|
||||
|
||||
def create_task_safely(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]:
|
||||
"""
|
||||
Create an asyncio task safely with proper error handling
|
||||
|
||||
Args:
|
||||
coro: The coroutine to create a task for
|
||||
name: Optional name for the task
|
||||
|
||||
Returns:
|
||||
The created task or None if failed
|
||||
"""
|
||||
if not self.is_running():
|
||||
logger.warning("Cannot create task: AsyncHandler is not running")
|
||||
return None
|
||||
|
||||
async def create_task():
|
||||
"""Create the task in the event loop"""
|
||||
try:
|
||||
task = asyncio.create_task(coro, name=name)
|
||||
logger.debug(f"Task created: {name or 'unnamed'}")
|
||||
return task
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create task {name}: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(create_task(), self._loop)
|
||||
return future.result(timeout=5.0)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create task {name}: {e}")
|
||||
return None
|
||||
|
||||
async def handle_orchestrator_connection(self, orchestrator) -> bool:
|
||||
"""
|
||||
Handle orchestrator connection with proper async patterns
|
||||
|
||||
Args:
|
||||
orchestrator: The orchestrator instance to connect to
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
logger.info("Connecting to orchestrator...")
|
||||
|
||||
# Add decision callback if orchestrator supports it
|
||||
if hasattr(orchestrator, 'add_decision_callback'):
|
||||
await orchestrator.add_decision_callback(self._handle_trading_decision)
|
||||
logger.info("Decision callback added to orchestrator")
|
||||
|
||||
# Start COB integration if available
|
||||
if hasattr(orchestrator, 'start_cob_integration'):
|
||||
await orchestrator.start_cob_integration()
|
||||
logger.info("COB integration started")
|
||||
|
||||
# Start continuous trading if available
|
||||
if hasattr(orchestrator, 'start_continuous_trading'):
|
||||
await orchestrator.start_continuous_trading()
|
||||
logger.info("Continuous trading started")
|
||||
|
||||
logger.info("Successfully connected to orchestrator")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to orchestrator: {e}")
|
||||
return False
|
||||
|
||||
async def handle_cob_integration(self, cob_integration) -> bool:
|
||||
"""
|
||||
Handle COB integration startup with proper async patterns
|
||||
|
||||
Args:
|
||||
cob_integration: The COB integration instance
|
||||
|
||||
Returns:
|
||||
True if startup successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
logger.info("Starting COB integration...")
|
||||
|
||||
if hasattr(cob_integration, 'start'):
|
||||
await cob_integration.start()
|
||||
logger.info("COB integration started successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning("COB integration does not have start method")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start COB integration: {e}")
|
||||
return False
|
||||
|
||||
async def _handle_trading_decision(self, decision: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Handle trading decision with proper async patterns
|
||||
|
||||
Args:
|
||||
decision: The trading decision dictionary
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Handling trading decision: {decision.get('action', 'UNKNOWN')}")
|
||||
|
||||
# Process the decision (this would be customized based on needs)
|
||||
# For now, just log it
|
||||
symbol = decision.get('symbol', 'UNKNOWN')
|
||||
action = decision.get('action', 'HOLD')
|
||||
confidence = decision.get('confidence', 0.0)
|
||||
|
||||
logger.info(f"Trading decision processed: {action} {symbol} (confidence: {confidence:.2f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling trading decision: {e}")
|
||||
|
||||
def run_in_executor(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Run a blocking function in the thread pool executor
|
||||
|
||||
Args:
|
||||
func: The function to run
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The result of the function
|
||||
"""
|
||||
if not self.is_running():
|
||||
raise AsyncOperationError("AsyncHandler is not running")
|
||||
|
||||
try:
|
||||
# Create a partial function with the arguments
|
||||
partial_func = functools.partial(func, *args, **kwargs)
|
||||
|
||||
# Create a coroutine that runs the function in executor
|
||||
async def run_in_executor_coro():
|
||||
return await self._loop.run_in_executor(self._executor, partial_func)
|
||||
|
||||
# Run the coroutine
|
||||
future = asyncio.run_coroutine_threadsafe(run_in_executor_coro(), self._loop)
|
||||
|
||||
result = future.result(timeout=self._timeout_default)
|
||||
logger.debug("Executor function completed successfully")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running function in executor: {e}")
|
||||
raise AsyncOperationError(f"Executor function failed: {e}")
|
||||
|
||||
def add_periodic_task(self, coro_func: Callable[[], Coroutine], interval: float, name: Optional[str] = None) -> Optional[asyncio.Task]:
|
||||
"""
|
||||
Add a periodic task that runs at specified intervals
|
||||
|
||||
Args:
|
||||
coro_func: Function that returns a coroutine to run periodically
|
||||
interval: Interval in seconds between runs
|
||||
name: Optional name for the task
|
||||
|
||||
Returns:
|
||||
The created task or None if failed
|
||||
"""
|
||||
async def periodic_runner():
|
||||
"""Run the coroutine periodically"""
|
||||
task_name = name or "periodic_task"
|
||||
logger.info(f"Starting periodic task: {task_name} (interval: {interval}s)")
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
coro = coro_func()
|
||||
await coro
|
||||
logger.debug(f"Periodic task {task_name} completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in periodic task {task_name}: {e}")
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Periodic task {task_name} cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in periodic task {task_name}: {e}")
|
||||
|
||||
return self.create_task_safely(periodic_runner(), name=f"periodic_{name}")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the async handler and clean up resources"""
|
||||
try:
|
||||
logger.info("Stopping AsyncHandler...")
|
||||
|
||||
if self._loop and not self._loop.is_closed():
|
||||
# Cancel all tasks
|
||||
if self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self._cancel_all_tasks(), self._loop)
|
||||
|
||||
# Stop the event loop
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
|
||||
# Shutdown executor
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
# Wait for thread to finish
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=5.0)
|
||||
|
||||
self._running = False
|
||||
logger.info("AsyncHandler stopped successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping AsyncHandler: {e}")
|
||||
|
||||
async def _cancel_all_tasks(self) -> None:
|
||||
"""Cancel all running tasks"""
|
||||
try:
|
||||
tasks = [task for task in asyncio.all_tasks(self._loop) if not task.done()]
|
||||
if tasks:
|
||||
logger.info(f"Cancelling {len(tasks)} running tasks")
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to be cancelled
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.debug("All tasks cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling tasks: {e}")
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit"""
|
||||
self.stop()
|
||||
|
||||
|
||||
class AsyncContextManager:
|
||||
"""
|
||||
Context manager for async operations that ensures proper cleanup
|
||||
"""
|
||||
|
||||
def __init__(self, async_handler: AsyncHandler):
|
||||
self.async_handler = async_handler
|
||||
self.active_tasks = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Cancel any active tasks
|
||||
for task in self.active_tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
def create_task(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]:
|
||||
"""Create a task and track it for cleanup"""
|
||||
task = self.async_handler.create_task_safely(coro, name)
|
||||
if task:
|
||||
self.active_tasks.append(task)
|
||||
return task
|
||||
|
||||
|
||||
def create_async_handler(loop: Optional[asyncio.AbstractEventLoop] = None) -> AsyncHandler:
|
||||
"""
|
||||
Factory function to create an AsyncHandler instance
|
||||
|
||||
Args:
|
||||
loop: Optional event loop to use
|
||||
|
||||
Returns:
|
||||
AsyncHandler instance
|
||||
"""
|
||||
return AsyncHandler(loop=loop)
|
||||
|
||||
|
||||
def run_async_safely(coro: Coroutine, timeout: Optional[float] = None) -> Any:
|
||||
"""
|
||||
Convenience function to run a coroutine safely with a temporary AsyncHandler
|
||||
|
||||
Args:
|
||||
coro: The coroutine to run
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
The result of the coroutine
|
||||
"""
|
||||
with AsyncHandler() as handler:
|
||||
return handler.run_async_safely(coro, timeout=timeout)
|
||||
@@ -1,952 +0,0 @@
|
||||
"""
|
||||
Bookmap Order Book Data Provider
|
||||
|
||||
This module integrates with Bookmap to gather:
|
||||
- Current Order Book (COB) data
|
||||
- Session Volume Profile (SVP) data
|
||||
- Order book sweeps and momentum trades detection
|
||||
- Real-time order size heatmap matrix (last 10 minutes)
|
||||
- Level 2 market depth analysis
|
||||
|
||||
The data is processed and fed to CNN and DQN networks for enhanced trading decisions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import websockets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from collections import deque, defaultdict
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread, Lock
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class OrderBookLevel:
|
||||
"""Represents a single order book level"""
|
||||
price: float
|
||||
size: float
|
||||
orders: int
|
||||
side: str # 'bid' or 'ask'
|
||||
timestamp: datetime
|
||||
|
||||
@dataclass
|
||||
class OrderBookSnapshot:
|
||||
"""Complete order book snapshot"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
bids: List[OrderBookLevel]
|
||||
asks: List[OrderBookLevel]
|
||||
spread: float
|
||||
mid_price: float
|
||||
|
||||
@dataclass
|
||||
class VolumeProfileLevel:
|
||||
"""Volume profile level data"""
|
||||
price: float
|
||||
volume: float
|
||||
buy_volume: float
|
||||
sell_volume: float
|
||||
trades_count: int
|
||||
vwap: float
|
||||
|
||||
@dataclass
|
||||
class OrderFlowSignal:
|
||||
"""Order flow signal detection"""
|
||||
timestamp: datetime
|
||||
signal_type: str # 'sweep', 'absorption', 'iceberg', 'momentum'
|
||||
price: float
|
||||
volume: float
|
||||
confidence: float
|
||||
description: str
|
||||
|
||||
class BookmapDataProvider:
|
||||
"""
|
||||
Real-time order book data provider using Bookmap-style analysis
|
||||
|
||||
Features:
|
||||
- Level 2 order book monitoring
|
||||
- Order flow detection (sweeps, absorptions)
|
||||
- Volume profile analysis
|
||||
- Order size heatmap generation
|
||||
- Market microstructure analysis
|
||||
"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None, depth_levels: int = 20):
|
||||
"""
|
||||
Initialize Bookmap data provider
|
||||
|
||||
Args:
|
||||
symbols: List of symbols to monitor
|
||||
depth_levels: Number of order book levels to track
|
||||
"""
|
||||
self.symbols = symbols or ['ETHUSDT', 'BTCUSDT']
|
||||
self.depth_levels = depth_levels
|
||||
self.is_streaming = False
|
||||
|
||||
# Order book data storage
|
||||
self.order_books: Dict[str, OrderBookSnapshot] = {}
|
||||
self.order_book_history: Dict[str, deque] = {}
|
||||
self.volume_profiles: Dict[str, List[VolumeProfileLevel]] = {}
|
||||
|
||||
# Heatmap data (10-minute rolling window)
|
||||
self.heatmap_window = timedelta(minutes=10)
|
||||
self.order_heatmaps: Dict[str, deque] = {}
|
||||
self.price_levels: Dict[str, List[float]] = {}
|
||||
|
||||
# Order flow detection
|
||||
self.flow_signals: Dict[str, deque] = {}
|
||||
self.sweep_threshold = 0.8 # Minimum confidence for sweep detection
|
||||
self.absorption_threshold = 0.7 # Minimum confidence for absorption
|
||||
|
||||
# Market microstructure metrics
|
||||
self.bid_ask_spreads: Dict[str, deque] = {}
|
||||
self.order_book_imbalances: Dict[str, deque] = {}
|
||||
self.liquidity_metrics: Dict[str, Dict] = {}
|
||||
|
||||
# WebSocket connections
|
||||
self.websocket_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.data_lock = Lock()
|
||||
|
||||
# Callbacks for CNN/DQN integration
|
||||
self.cnn_callbacks: List[Callable] = []
|
||||
self.dqn_callbacks: List[Callable] = []
|
||||
|
||||
# Performance tracking
|
||||
self.update_counts = defaultdict(int)
|
||||
self.last_update_times = {}
|
||||
|
||||
# Initialize data structures
|
||||
for symbol in self.symbols:
|
||||
self.order_book_history[symbol] = deque(maxlen=1000)
|
||||
self.order_heatmaps[symbol] = deque(maxlen=600) # 10 min at 1s intervals
|
||||
self.flow_signals[symbol] = deque(maxlen=500)
|
||||
self.bid_ask_spreads[symbol] = deque(maxlen=1000)
|
||||
self.order_book_imbalances[symbol] = deque(maxlen=1000)
|
||||
self.liquidity_metrics[symbol] = {
|
||||
'total_bid_size': 0.0,
|
||||
'total_ask_size': 0.0,
|
||||
'weighted_mid': 0.0,
|
||||
'liquidity_ratio': 1.0
|
||||
}
|
||||
|
||||
logger.info(f"BookmapDataProvider initialized for {len(self.symbols)} symbols")
|
||||
logger.info(f"Tracking {depth_levels} order book levels per side")
|
||||
|
||||
def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
|
||||
"""Add callback for CNN model updates"""
|
||||
self.cnn_callbacks.append(callback)
|
||||
logger.info(f"Added CNN callback: {len(self.cnn_callbacks)} total")
|
||||
|
||||
def add_dqn_callback(self, callback: Callable[[str, Dict], None]):
|
||||
"""Add callback for DQN model updates"""
|
||||
self.dqn_callbacks.append(callback)
|
||||
logger.info(f"Added DQN callback: {len(self.dqn_callbacks)} total")
|
||||
|
||||
async def start_streaming(self):
|
||||
"""Start real-time order book streaming"""
|
||||
if self.is_streaming:
|
||||
logger.warning("Bookmap streaming already active")
|
||||
return
|
||||
|
||||
self.is_streaming = True
|
||||
logger.info("Starting Bookmap order book streaming")
|
||||
|
||||
# Start order book streams for each symbol
|
||||
for symbol in self.symbols:
|
||||
# Order book depth stream
|
||||
depth_task = asyncio.create_task(self._stream_order_book_depth(symbol))
|
||||
self.websocket_tasks[f"{symbol}_depth"] = depth_task
|
||||
|
||||
# Trade stream for order flow analysis
|
||||
trade_task = asyncio.create_task(self._stream_trades(symbol))
|
||||
self.websocket_tasks[f"{symbol}_trades"] = trade_task
|
||||
|
||||
# Start analysis threads
|
||||
analysis_task = asyncio.create_task(self._continuous_analysis())
|
||||
self.websocket_tasks["analysis"] = analysis_task
|
||||
|
||||
logger.info(f"Started streaming for {len(self.symbols)} symbols")
|
||||
|
||||
async def stop_streaming(self):
|
||||
"""Stop order book streaming"""
|
||||
if not self.is_streaming:
|
||||
return
|
||||
|
||||
logger.info("Stopping Bookmap streaming")
|
||||
self.is_streaming = False
|
||||
|
||||
# Cancel all tasks
|
||||
for name, task in self.websocket_tasks.items():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.websocket_tasks.clear()
|
||||
logger.info("Bookmap streaming stopped")
|
||||
|
||||
async def _stream_order_book_depth(self, symbol: str):
|
||||
"""Stream order book depth data"""
|
||||
binance_symbol = symbol.lower()
|
||||
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@depth20@100ms"
|
||||
|
||||
while self.is_streaming:
|
||||
try:
|
||||
async with websockets.connect(url) as websocket:
|
||||
logger.info(f"Order book depth WebSocket connected for {symbol}")
|
||||
|
||||
async for message in websocket:
|
||||
if not self.is_streaming:
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_depth_update(symbol, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing depth for {symbol}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Depth WebSocket error for {symbol}: {e}")
|
||||
if self.is_streaming:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def _stream_trades(self, symbol: str):
|
||||
"""Stream trade data for order flow analysis"""
|
||||
binance_symbol = symbol.lower()
|
||||
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade"
|
||||
|
||||
while self.is_streaming:
|
||||
try:
|
||||
async with websockets.connect(url) as websocket:
|
||||
logger.info(f"Trade WebSocket connected for {symbol}")
|
||||
|
||||
async for message in websocket:
|
||||
if not self.is_streaming:
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_trade_update(symbol, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing trade for {symbol}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Trade WebSocket error for {symbol}: {e}")
|
||||
if self.is_streaming:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def _process_depth_update(self, symbol: str, data: Dict):
|
||||
"""Process order book depth update"""
|
||||
try:
|
||||
timestamp = datetime.now()
|
||||
|
||||
# Parse bids and asks
|
||||
bids = []
|
||||
asks = []
|
||||
|
||||
for bid_data in data.get('bids', []):
|
||||
price = float(bid_data[0])
|
||||
size = float(bid_data[1])
|
||||
bids.append(OrderBookLevel(
|
||||
price=price,
|
||||
size=size,
|
||||
orders=1, # Binance doesn't provide order count
|
||||
side='bid',
|
||||
timestamp=timestamp
|
||||
))
|
||||
|
||||
for ask_data in data.get('asks', []):
|
||||
price = float(ask_data[0])
|
||||
size = float(ask_data[1])
|
||||
asks.append(OrderBookLevel(
|
||||
price=price,
|
||||
size=size,
|
||||
orders=1,
|
||||
side='ask',
|
||||
timestamp=timestamp
|
||||
))
|
||||
|
||||
# Sort order book levels
|
||||
bids.sort(key=lambda x: x.price, reverse=True)
|
||||
asks.sort(key=lambda x: x.price)
|
||||
|
||||
# Calculate spread and mid price
|
||||
if bids and asks:
|
||||
best_bid = bids[0].price
|
||||
best_ask = asks[0].price
|
||||
spread = best_ask - best_bid
|
||||
mid_price = (best_bid + best_ask) / 2
|
||||
else:
|
||||
spread = 0.0
|
||||
mid_price = 0.0
|
||||
|
||||
# Create order book snapshot
|
||||
snapshot = OrderBookSnapshot(
|
||||
symbol=symbol,
|
||||
timestamp=timestamp,
|
||||
bids=bids,
|
||||
asks=asks,
|
||||
spread=spread,
|
||||
mid_price=mid_price
|
||||
)
|
||||
|
||||
with self.data_lock:
|
||||
self.order_books[symbol] = snapshot
|
||||
self.order_book_history[symbol].append(snapshot)
|
||||
|
||||
# Update liquidity metrics
|
||||
self._update_liquidity_metrics(symbol, snapshot)
|
||||
|
||||
# Update order book imbalance
|
||||
self._calculate_order_book_imbalance(symbol, snapshot)
|
||||
|
||||
# Update heatmap data
|
||||
self._update_order_heatmap(symbol, snapshot)
|
||||
|
||||
# Update counters
|
||||
self.update_counts[f"{symbol}_depth"] += 1
|
||||
self.last_update_times[f"{symbol}_depth"] = timestamp
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing depth update for {symbol}: {e}")
|
||||
|
||||
async def _process_trade_update(self, symbol: str, data: Dict):
|
||||
"""Process trade data for order flow analysis"""
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(int(data['T']) / 1000)
|
||||
price = float(data['p'])
|
||||
quantity = float(data['q'])
|
||||
is_buyer_maker = data['m']
|
||||
|
||||
# Analyze for order flow signals
|
||||
await self._analyze_order_flow(symbol, timestamp, price, quantity, is_buyer_maker)
|
||||
|
||||
# Update volume profile
|
||||
self._update_volume_profile(symbol, price, quantity, is_buyer_maker)
|
||||
|
||||
self.update_counts[f"{symbol}_trades"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing trade for {symbol}: {e}")
|
||||
|
||||
def _update_liquidity_metrics(self, symbol: str, snapshot: OrderBookSnapshot):
|
||||
"""Update liquidity metrics from order book snapshot"""
|
||||
try:
|
||||
total_bid_size = sum(level.size for level in snapshot.bids)
|
||||
total_ask_size = sum(level.size for level in snapshot.asks)
|
||||
|
||||
# Calculate weighted mid price
|
||||
if snapshot.bids and snapshot.asks:
|
||||
bid_weight = total_bid_size / (total_bid_size + total_ask_size)
|
||||
ask_weight = total_ask_size / (total_bid_size + total_ask_size)
|
||||
weighted_mid = (snapshot.bids[0].price * ask_weight +
|
||||
snapshot.asks[0].price * bid_weight)
|
||||
else:
|
||||
weighted_mid = snapshot.mid_price
|
||||
|
||||
# Liquidity ratio (bid/ask balance)
|
||||
if total_ask_size > 0:
|
||||
liquidity_ratio = total_bid_size / total_ask_size
|
||||
else:
|
||||
liquidity_ratio = 1.0
|
||||
|
||||
self.liquidity_metrics[symbol] = {
|
||||
'total_bid_size': total_bid_size,
|
||||
'total_ask_size': total_ask_size,
|
||||
'weighted_mid': weighted_mid,
|
||||
'liquidity_ratio': liquidity_ratio,
|
||||
'spread_bps': (snapshot.spread / snapshot.mid_price) * 10000 if snapshot.mid_price > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating liquidity metrics for {symbol}: {e}")
|
||||
|
||||
def _calculate_order_book_imbalance(self, symbol: str, snapshot: OrderBookSnapshot):
|
||||
"""Calculate order book imbalance ratio"""
|
||||
try:
|
||||
if not snapshot.bids or not snapshot.asks:
|
||||
return
|
||||
|
||||
# Calculate imbalance for top N levels
|
||||
n_levels = min(5, len(snapshot.bids), len(snapshot.asks))
|
||||
|
||||
total_bid_size = sum(snapshot.bids[i].size for i in range(n_levels))
|
||||
total_ask_size = sum(snapshot.asks[i].size for i in range(n_levels))
|
||||
|
||||
if total_bid_size + total_ask_size > 0:
|
||||
imbalance = (total_bid_size - total_ask_size) / (total_bid_size + total_ask_size)
|
||||
else:
|
||||
imbalance = 0.0
|
||||
|
||||
self.order_book_imbalances[symbol].append({
|
||||
'timestamp': snapshot.timestamp,
|
||||
'imbalance': imbalance,
|
||||
'bid_size': total_bid_size,
|
||||
'ask_size': total_ask_size
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating imbalance for {symbol}: {e}")
|
||||
|
||||
def _update_order_heatmap(self, symbol: str, snapshot: OrderBookSnapshot):
|
||||
"""Update order size heatmap matrix"""
|
||||
try:
|
||||
# Create heatmap entry
|
||||
heatmap_entry = {
|
||||
'timestamp': snapshot.timestamp,
|
||||
'mid_price': snapshot.mid_price,
|
||||
'levels': {}
|
||||
}
|
||||
|
||||
# Add bid levels
|
||||
for level in snapshot.bids:
|
||||
price_offset = level.price - snapshot.mid_price
|
||||
heatmap_entry['levels'][price_offset] = {
|
||||
'side': 'bid',
|
||||
'size': level.size,
|
||||
'price': level.price
|
||||
}
|
||||
|
||||
# Add ask levels
|
||||
for level in snapshot.asks:
|
||||
price_offset = level.price - snapshot.mid_price
|
||||
heatmap_entry['levels'][price_offset] = {
|
||||
'side': 'ask',
|
||||
'size': level.size,
|
||||
'price': level.price
|
||||
}
|
||||
|
||||
self.order_heatmaps[symbol].append(heatmap_entry)
|
||||
|
||||
# Clean old entries (keep 10 minutes)
|
||||
cutoff_time = snapshot.timestamp - self.heatmap_window
|
||||
while (self.order_heatmaps[symbol] and
|
||||
self.order_heatmaps[symbol][0]['timestamp'] < cutoff_time):
|
||||
self.order_heatmaps[symbol].popleft()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating heatmap for {symbol}: {e}")
|
||||
|
||||
def _update_volume_profile(self, symbol: str, price: float, quantity: float, is_buyer_maker: bool):
|
||||
"""Update volume profile with new trade"""
|
||||
try:
|
||||
# Initialize if not exists
|
||||
if symbol not in self.volume_profiles:
|
||||
self.volume_profiles[symbol] = []
|
||||
|
||||
# Find or create price level
|
||||
price_level = None
|
||||
for level in self.volume_profiles[symbol]:
|
||||
if abs(level.price - price) < 0.01: # Price tolerance
|
||||
price_level = level
|
||||
break
|
||||
|
||||
if not price_level:
|
||||
price_level = VolumeProfileLevel(
|
||||
price=price,
|
||||
volume=0.0,
|
||||
buy_volume=0.0,
|
||||
sell_volume=0.0,
|
||||
trades_count=0,
|
||||
vwap=price
|
||||
)
|
||||
self.volume_profiles[symbol].append(price_level)
|
||||
|
||||
# Update volume profile
|
||||
volume = price * quantity
|
||||
old_total = price_level.volume
|
||||
|
||||
price_level.volume += volume
|
||||
price_level.trades_count += 1
|
||||
|
||||
if is_buyer_maker:
|
||||
price_level.sell_volume += volume
|
||||
else:
|
||||
price_level.buy_volume += volume
|
||||
|
||||
# Update VWAP
|
||||
if price_level.volume > 0:
|
||||
price_level.vwap = ((price_level.vwap * old_total) + (price * volume)) / price_level.volume
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating volume profile for {symbol}: {e}")
|
||||
|
||||
async def _analyze_order_flow(self, symbol: str, timestamp: datetime, price: float,
|
||||
quantity: float, is_buyer_maker: bool):
|
||||
"""Analyze order flow for sweep and absorption patterns"""
|
||||
try:
|
||||
# Get recent order book data
|
||||
if symbol not in self.order_book_history or not self.order_book_history[symbol]:
|
||||
return
|
||||
|
||||
recent_snapshots = list(self.order_book_history[symbol])[-10:] # Last 10 snapshots
|
||||
|
||||
# Check for order book sweeps
|
||||
sweep_signal = self._detect_order_sweep(symbol, recent_snapshots, price, quantity, is_buyer_maker)
|
||||
if sweep_signal:
|
||||
self.flow_signals[symbol].append(sweep_signal)
|
||||
await self._notify_flow_signal(symbol, sweep_signal)
|
||||
|
||||
# Check for absorption patterns
|
||||
absorption_signal = self._detect_absorption(symbol, recent_snapshots, price, quantity)
|
||||
if absorption_signal:
|
||||
self.flow_signals[symbol].append(absorption_signal)
|
||||
await self._notify_flow_signal(symbol, absorption_signal)
|
||||
|
||||
# Check for momentum trades
|
||||
momentum_signal = self._detect_momentum_trade(symbol, price, quantity, is_buyer_maker)
|
||||
if momentum_signal:
|
||||
self.flow_signals[symbol].append(momentum_signal)
|
||||
await self._notify_flow_signal(symbol, momentum_signal)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing order flow for {symbol}: {e}")
|
||||
|
||||
def _detect_order_sweep(self, symbol: str, snapshots: List[OrderBookSnapshot],
|
||||
price: float, quantity: float, is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
|
||||
"""Detect order book sweep patterns"""
|
||||
try:
|
||||
if len(snapshots) < 2:
|
||||
return None
|
||||
|
||||
before_snapshot = snapshots[-2]
|
||||
after_snapshot = snapshots[-1]
|
||||
|
||||
# Check if multiple levels were consumed
|
||||
if is_buyer_maker: # Sell order, check ask side
|
||||
levels_consumed = 0
|
||||
total_consumed_size = 0
|
||||
|
||||
for level in before_snapshot.asks[:5]: # Check top 5 levels
|
||||
if level.price <= price:
|
||||
levels_consumed += 1
|
||||
total_consumed_size += level.size
|
||||
|
||||
if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
|
||||
confidence = min(0.9, levels_consumed / 5.0 + 0.3)
|
||||
|
||||
return OrderFlowSignal(
|
||||
timestamp=datetime.now(),
|
||||
signal_type='sweep',
|
||||
price=price,
|
||||
volume=quantity * price,
|
||||
confidence=confidence,
|
||||
description=f"Sell sweep: {levels_consumed} levels, {total_consumed_size:.2f} size"
|
||||
)
|
||||
else: # Buy order, check bid side
|
||||
levels_consumed = 0
|
||||
total_consumed_size = 0
|
||||
|
||||
for level in before_snapshot.bids[:5]:
|
||||
if level.price >= price:
|
||||
levels_consumed += 1
|
||||
total_consumed_size += level.size
|
||||
|
||||
if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
|
||||
confidence = min(0.9, levels_consumed / 5.0 + 0.3)
|
||||
|
||||
return OrderFlowSignal(
|
||||
timestamp=datetime.now(),
|
||||
signal_type='sweep',
|
||||
price=price,
|
||||
volume=quantity * price,
|
||||
confidence=confidence,
|
||||
description=f"Buy sweep: {levels_consumed} levels, {total_consumed_size:.2f} size"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting sweep for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _detect_absorption(self, symbol: str, snapshots: List[OrderBookSnapshot],
|
||||
price: float, quantity: float) -> Optional[OrderFlowSignal]:
|
||||
"""Detect absorption patterns where large orders are absorbed without price movement"""
|
||||
try:
|
||||
if len(snapshots) < 3:
|
||||
return None
|
||||
|
||||
# Check if large order was absorbed with minimal price impact
|
||||
volume_threshold = 10000 # $10K minimum for absorption
|
||||
price_impact_threshold = 0.001 # 0.1% max price impact
|
||||
|
||||
trade_value = price * quantity
|
||||
if trade_value < volume_threshold:
|
||||
return None
|
||||
|
||||
# Calculate price impact
|
||||
price_before = snapshots[-3].mid_price
|
||||
price_after = snapshots[-1].mid_price
|
||||
price_impact = abs(price_after - price_before) / price_before
|
||||
|
||||
if price_impact < price_impact_threshold:
|
||||
confidence = min(0.8, (trade_value / 50000) * 0.5 + 0.3) # Scale with size
|
||||
|
||||
return OrderFlowSignal(
|
||||
timestamp=datetime.now(),
|
||||
signal_type='absorption',
|
||||
price=price,
|
||||
volume=trade_value,
|
||||
confidence=confidence,
|
||||
description=f"Absorption: ${trade_value:.0f} with {price_impact*100:.3f}% impact"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting absorption for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _detect_momentum_trade(self, symbol: str, price: float, quantity: float,
|
||||
is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
|
||||
"""Detect momentum trades based on size and direction"""
|
||||
try:
|
||||
trade_value = price * quantity
|
||||
momentum_threshold = 25000 # $25K minimum for momentum classification
|
||||
|
||||
if trade_value < momentum_threshold:
|
||||
return None
|
||||
|
||||
# Calculate confidence based on trade size
|
||||
confidence = min(0.9, trade_value / 100000 * 0.6 + 0.3)
|
||||
|
||||
direction = "sell" if is_buyer_maker else "buy"
|
||||
|
||||
return OrderFlowSignal(
|
||||
timestamp=datetime.now(),
|
||||
signal_type='momentum',
|
||||
price=price,
|
||||
volume=trade_value,
|
||||
confidence=confidence,
|
||||
description=f"Large {direction}: ${trade_value:.0f}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting momentum for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
async def _notify_flow_signal(self, symbol: str, signal: OrderFlowSignal):
|
||||
"""Notify CNN and DQN models of order flow signals"""
|
||||
try:
|
||||
signal_data = {
|
||||
'signal_type': signal.signal_type,
|
||||
'price': signal.price,
|
||||
'volume': signal.volume,
|
||||
'confidence': signal.confidence,
|
||||
'timestamp': signal.timestamp,
|
||||
'description': signal.description
|
||||
}
|
||||
|
||||
# Notify CNN callbacks
|
||||
for callback in self.cnn_callbacks:
|
||||
try:
|
||||
callback(symbol, signal_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in CNN callback: {e}")
|
||||
|
||||
# Notify DQN callbacks
|
||||
for callback in self.dqn_callbacks:
|
||||
try:
|
||||
callback(symbol, signal_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in DQN callback: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying flow signal: {e}")
|
||||
|
||||
async def _continuous_analysis(self):
|
||||
"""Continuous analysis of market microstructure"""
|
||||
while self.is_streaming:
|
||||
try:
|
||||
await asyncio.sleep(1) # Analyze every second
|
||||
|
||||
for symbol in self.symbols:
|
||||
# Generate CNN features
|
||||
cnn_features = self.get_cnn_features(symbol)
|
||||
if cnn_features is not None:
|
||||
for callback in self.cnn_callbacks:
|
||||
try:
|
||||
callback(symbol, {'features': cnn_features, 'type': 'orderbook'})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in CNN feature callback: {e}")
|
||||
|
||||
# Generate DQN state features
|
||||
dqn_features = self.get_dqn_state_features(symbol)
|
||||
if dqn_features is not None:
|
||||
for callback in self.dqn_callbacks:
|
||||
try:
|
||||
callback(symbol, {'state': dqn_features, 'type': 'orderbook'})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in DQN state callback: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous analysis: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
def get_cnn_features(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Generate CNN input features from order book data"""
|
||||
try:
|
||||
if symbol not in self.order_books:
|
||||
return None
|
||||
|
||||
snapshot = self.order_books[symbol]
|
||||
features = []
|
||||
|
||||
# Order book features (40 features: 20 levels x 2 sides)
|
||||
for i in range(min(20, len(snapshot.bids))):
|
||||
bid = snapshot.bids[i]
|
||||
features.append(bid.size)
|
||||
features.append(bid.price - snapshot.mid_price) # Price offset
|
||||
|
||||
# Pad if not enough bid levels
|
||||
while len(features) < 40:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
for i in range(min(20, len(snapshot.asks))):
|
||||
ask = snapshot.asks[i]
|
||||
features.append(ask.size)
|
||||
features.append(ask.price - snapshot.mid_price) # Price offset
|
||||
|
||||
# Pad if not enough ask levels
|
||||
while len(features) < 80:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
# Liquidity metrics (10 features)
|
||||
metrics = self.liquidity_metrics.get(symbol, {})
|
||||
features.extend([
|
||||
metrics.get('total_bid_size', 0.0),
|
||||
metrics.get('total_ask_size', 0.0),
|
||||
metrics.get('liquidity_ratio', 1.0),
|
||||
metrics.get('spread_bps', 0.0),
|
||||
snapshot.spread,
|
||||
metrics.get('weighted_mid', snapshot.mid_price) - snapshot.mid_price,
|
||||
len(snapshot.bids),
|
||||
len(snapshot.asks),
|
||||
snapshot.mid_price,
|
||||
time.time() % 86400 # Time of day
|
||||
])
|
||||
|
||||
# Order book imbalance features (5 features)
|
||||
if self.order_book_imbalances[symbol]:
|
||||
latest_imbalance = self.order_book_imbalances[symbol][-1]
|
||||
features.extend([
|
||||
latest_imbalance['imbalance'],
|
||||
latest_imbalance['bid_size'],
|
||||
latest_imbalance['ask_size'],
|
||||
latest_imbalance['bid_size'] + latest_imbalance['ask_size'],
|
||||
abs(latest_imbalance['imbalance'])
|
||||
])
|
||||
else:
|
||||
features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
# Flow signal features (5 features)
|
||||
recent_signals = [s for s in self.flow_signals[symbol]
|
||||
if (datetime.now() - s.timestamp).seconds < 60]
|
||||
|
||||
sweep_count = sum(1 for s in recent_signals if s.signal_type == 'sweep')
|
||||
absorption_count = sum(1 for s in recent_signals if s.signal_type == 'absorption')
|
||||
momentum_count = sum(1 for s in recent_signals if s.signal_type == 'momentum')
|
||||
|
||||
max_confidence = max([s.confidence for s in recent_signals], default=0.0)
|
||||
total_flow_volume = sum(s.volume for s in recent_signals)
|
||||
|
||||
features.extend([
|
||||
sweep_count,
|
||||
absorption_count,
|
||||
momentum_count,
|
||||
max_confidence,
|
||||
total_flow_volume
|
||||
])
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating CNN features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_dqn_state_features(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Generate DQN state features from order book data"""
|
||||
try:
|
||||
if symbol not in self.order_books:
|
||||
return None
|
||||
|
||||
snapshot = self.order_books[symbol]
|
||||
state_features = []
|
||||
|
||||
# Normalized order book state (20 features)
|
||||
total_bid_size = sum(level.size for level in snapshot.bids[:10])
|
||||
total_ask_size = sum(level.size for level in snapshot.asks[:10])
|
||||
total_size = total_bid_size + total_ask_size
|
||||
|
||||
if total_size > 0:
|
||||
for i in range(min(10, len(snapshot.bids))):
|
||||
state_features.append(snapshot.bids[i].size / total_size)
|
||||
|
||||
# Pad bids
|
||||
while len(state_features) < 10:
|
||||
state_features.append(0.0)
|
||||
|
||||
for i in range(min(10, len(snapshot.asks))):
|
||||
state_features.append(snapshot.asks[i].size / total_size)
|
||||
|
||||
# Pad asks
|
||||
while len(state_features) < 20:
|
||||
state_features.append(0.0)
|
||||
else:
|
||||
state_features.extend([0.0] * 20)
|
||||
|
||||
# Market state indicators (10 features)
|
||||
metrics = self.liquidity_metrics.get(symbol, {})
|
||||
|
||||
# Normalize spread as percentage
|
||||
spread_pct = (snapshot.spread / snapshot.mid_price) if snapshot.mid_price > 0 else 0
|
||||
|
||||
# Liquidity imbalance
|
||||
liquidity_ratio = metrics.get('liquidity_ratio', 1.0)
|
||||
liquidity_imbalance = (liquidity_ratio - 1) / (liquidity_ratio + 1)
|
||||
|
||||
# Recent flow signals strength
|
||||
recent_signals = [s for s in self.flow_signals[symbol]
|
||||
if (datetime.now() - s.timestamp).seconds < 30]
|
||||
flow_strength = sum(s.confidence for s in recent_signals) / max(len(recent_signals), 1)
|
||||
|
||||
# Price volatility (from recent snapshots)
|
||||
if len(self.order_book_history[symbol]) >= 10:
|
||||
recent_prices = [s.mid_price for s in list(self.order_book_history[symbol])[-10:]]
|
||||
price_volatility = np.std(recent_prices) / np.mean(recent_prices) if recent_prices else 0
|
||||
else:
|
||||
price_volatility = 0
|
||||
|
||||
state_features.extend([
|
||||
spread_pct * 10000, # Spread in basis points
|
||||
liquidity_imbalance,
|
||||
flow_strength,
|
||||
price_volatility * 100, # Volatility as percentage
|
||||
min(len(snapshot.bids), 20) / 20, # Book depth ratio
|
||||
min(len(snapshot.asks), 20) / 20,
|
||||
sweep_count / 10 if 'sweep_count' in locals() else 0, # From CNN features
|
||||
absorption_count / 5 if 'absorption_count' in locals() else 0,
|
||||
momentum_count / 5 if 'momentum_count' in locals() else 0,
|
||||
(datetime.now().hour * 60 + datetime.now().minute) / 1440 # Time of day normalized
|
||||
])
|
||||
|
||||
return np.array(state_features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating DQN features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_order_heatmap_matrix(self, symbol: str, levels: int = 40) -> Optional[np.ndarray]:
|
||||
"""Generate order size heatmap matrix for dashboard visualization"""
|
||||
try:
|
||||
if symbol not in self.order_heatmaps or not self.order_heatmaps[symbol]:
|
||||
return None
|
||||
|
||||
# Create price levels around current mid price
|
||||
current_snapshot = self.order_books.get(symbol)
|
||||
if not current_snapshot:
|
||||
return None
|
||||
|
||||
mid_price = current_snapshot.mid_price
|
||||
price_step = mid_price * 0.0001 # 1 basis point steps
|
||||
|
||||
# Create matrix: time x price levels
|
||||
time_window = min(600, len(self.order_heatmaps[symbol])) # 10 minutes max
|
||||
heatmap_matrix = np.zeros((time_window, levels))
|
||||
|
||||
# Fill matrix with order sizes
|
||||
for t, entry in enumerate(list(self.order_heatmaps[symbol])[-time_window:]):
|
||||
for price_offset, level_data in entry['levels'].items():
|
||||
# Convert price offset to matrix index
|
||||
level_idx = int((price_offset + (levels/2) * price_step) / price_step)
|
||||
|
||||
if 0 <= level_idx < levels:
|
||||
size_weight = 1.0 if level_data['side'] == 'bid' else -1.0
|
||||
heatmap_matrix[t, level_idx] = level_data['size'] * size_weight
|
||||
|
||||
return heatmap_matrix
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating heatmap matrix for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_volume_profile_data(self, symbol: str) -> Optional[List[Dict]]:
|
||||
"""Get session volume profile data"""
|
||||
try:
|
||||
if symbol not in self.volume_profiles:
|
||||
return None
|
||||
|
||||
profile_data = []
|
||||
for level in sorted(self.volume_profiles[symbol], key=lambda x: x.price):
|
||||
profile_data.append({
|
||||
'price': level.price,
|
||||
'volume': level.volume,
|
||||
'buy_volume': level.buy_volume,
|
||||
'sell_volume': level.sell_volume,
|
||||
'trades_count': level.trades_count,
|
||||
'vwap': level.vwap,
|
||||
'net_volume': level.buy_volume - level.sell_volume
|
||||
})
|
||||
|
||||
return profile_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting volume profile for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_current_order_book(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get current order book snapshot"""
|
||||
try:
|
||||
if symbol not in self.order_books:
|
||||
return None
|
||||
|
||||
snapshot = self.order_books[symbol]
|
||||
|
||||
return {
|
||||
'timestamp': snapshot.timestamp.isoformat(),
|
||||
'symbol': symbol,
|
||||
'mid_price': snapshot.mid_price,
|
||||
'spread': snapshot.spread,
|
||||
'bids': [{'price': l.price, 'size': l.size} for l in snapshot.bids[:20]],
|
||||
'asks': [{'price': l.price, 'size': l.size} for l in snapshot.asks[:20]],
|
||||
'liquidity_metrics': self.liquidity_metrics.get(symbol, {}),
|
||||
'recent_signals': [
|
||||
{
|
||||
'type': s.signal_type,
|
||||
'price': s.price,
|
||||
'volume': s.volume,
|
||||
'confidence': s.confidence,
|
||||
'timestamp': s.timestamp.isoformat()
|
||||
}
|
||||
for s in list(self.flow_signals[symbol])[-5:] # Last 5 signals
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting order book for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get provider statistics"""
|
||||
return {
|
||||
'symbols': self.symbols,
|
||||
'is_streaming': self.is_streaming,
|
||||
'update_counts': dict(self.update_counts),
|
||||
'last_update_times': {k: v.isoformat() if isinstance(v, datetime) else v
|
||||
for k, v in self.last_update_times.items()},
|
||||
'order_books_active': len(self.order_books),
|
||||
'flow_signals_total': sum(len(signals) for signals in self.flow_signals.values()),
|
||||
'cnn_callbacks': len(self.cnn_callbacks),
|
||||
'dqn_callbacks': len(self.dqn_callbacks),
|
||||
'websocket_tasks': len(self.websocket_tasks)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,785 +0,0 @@
|
||||
"""
|
||||
CNN Training Pipeline with Comprehensive Data Storage and Replay
|
||||
|
||||
This module implements a robust CNN training pipeline that:
|
||||
1. Integrates with the comprehensive training data collection system
|
||||
2. Stores all backpropagation data for gradient replay
|
||||
3. Enables retraining on most profitable setups
|
||||
4. Maintains training episode profitability tracking
|
||||
5. Supports both real-time and batch training modes
|
||||
|
||||
Key Features:
|
||||
- Integration with TrainingDataCollector for data validation
|
||||
- Gradient and loss storage for each training step
|
||||
- Profitable episode prioritization and replay
|
||||
- Comprehensive training metrics and validation
|
||||
- Real-time pivot point prediction with outcome tracking
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import pickle
|
||||
from collections import deque, defaultdict
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from .training_data_collector import (
|
||||
TrainingDataCollector,
|
||||
TrainingEpisode,
|
||||
ModelInputPackage,
|
||||
get_training_data_collector
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class CNNTrainingStep:
|
||||
"""Single CNN training step with complete backpropagation data"""
|
||||
step_id: str
|
||||
timestamp: datetime
|
||||
episode_id: str
|
||||
|
||||
# Input data
|
||||
input_features: torch.Tensor
|
||||
target_labels: torch.Tensor
|
||||
|
||||
# Forward pass results
|
||||
model_outputs: Dict[str, torch.Tensor]
|
||||
predictions: Dict[str, Any]
|
||||
confidence_scores: torch.Tensor
|
||||
|
||||
# Loss components
|
||||
total_loss: float
|
||||
pivot_prediction_loss: float
|
||||
confidence_loss: float
|
||||
regularization_loss: float
|
||||
|
||||
# Backpropagation data
|
||||
gradients: Dict[str, torch.Tensor] # Gradients for each parameter
|
||||
gradient_norms: Dict[str, float] # Gradient norms for monitoring
|
||||
|
||||
# Model state
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None
|
||||
optimizer_state: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Training metadata
|
||||
learning_rate: float = 0.001
|
||||
batch_size: int = 32
|
||||
epoch: int = 0
|
||||
|
||||
# Profitability tracking
|
||||
actual_profitability: Optional[float] = None
|
||||
prediction_accuracy: Optional[float] = None
|
||||
training_value: float = 0.0 # Value of this training step for replay
|
||||
|
||||
@dataclass
|
||||
class CNNTrainingSession:
|
||||
"""Complete CNN training session with multiple steps"""
|
||||
session_id: str
|
||||
start_timestamp: datetime
|
||||
end_timestamp: Optional[datetime] = None
|
||||
|
||||
# Session configuration
|
||||
training_mode: str = 'real_time' # 'real_time', 'batch', 'replay'
|
||||
symbol: str = ''
|
||||
|
||||
# Training steps
|
||||
training_steps: List[CNNTrainingStep] = field(default_factory=list)
|
||||
|
||||
# Session metrics
|
||||
total_steps: int = 0
|
||||
average_loss: float = 0.0
|
||||
best_loss: float = float('inf')
|
||||
convergence_achieved: bool = False
|
||||
|
||||
# Profitability metrics
|
||||
profitable_predictions: int = 0
|
||||
total_predictions: int = 0
|
||||
profitability_rate: float = 0.0
|
||||
|
||||
# Session value for replay prioritization
|
||||
session_value: float = 0.0
|
||||
|
||||
class CNNPivotPredictor(nn.Module):
|
||||
"""CNN model for pivot point prediction with comprehensive output"""
|
||||
|
||||
def __init__(self,
|
||||
input_channels: int = 10, # Multiple timeframes
|
||||
sequence_length: int = 300, # 300 bars
|
||||
hidden_dim: int = 256,
|
||||
num_pivot_classes: int = 3, # high, low, none
|
||||
dropout_rate: float = 0.2):
|
||||
|
||||
super(CNNPivotPredictor, self).__init__()
|
||||
|
||||
self.input_channels = input_channels
|
||||
self.sequence_length = sequence_length
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
# Convolutional layers for pattern extraction
|
||||
self.conv_layers = nn.Sequential(
|
||||
# First conv block
|
||||
nn.Conv1d(input_channels, 64, kernel_size=7, padding=3),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
# Second conv block
|
||||
nn.Conv1d(64, 128, kernel_size=5, padding=2),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
# Third conv block
|
||||
nn.Conv1d(128, 256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
)
|
||||
|
||||
# LSTM for temporal dependencies
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=256,
|
||||
hidden_size=hidden_dim,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
dropout=dropout_rate,
|
||||
bidirectional=True
|
||||
)
|
||||
|
||||
# Attention mechanism
|
||||
self.attention = nn.MultiheadAttention(
|
||||
embed_dim=hidden_dim * 2, # Bidirectional LSTM
|
||||
num_heads=8,
|
||||
dropout=dropout_rate,
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
# Output heads
|
||||
self.pivot_classifier = nn.Sequential(
|
||||
nn.Linear(hidden_dim * 2, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(hidden_dim, num_pivot_classes)
|
||||
)
|
||||
|
||||
self.pivot_price_regressor = nn.Sequential(
|
||||
nn.Linear(hidden_dim * 2, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(hidden_dim, 1)
|
||||
)
|
||||
|
||||
self.confidence_head = nn.Sequential(
|
||||
nn.Linear(hidden_dim * 2, hidden_dim // 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim // 2, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize weights with proper scaling"""
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
torch.nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through CNN pivot predictor
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, input_channels, sequence_length]
|
||||
|
||||
Returns:
|
||||
Dict containing predictions and hidden states
|
||||
"""
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Convolutional feature extraction
|
||||
conv_features = self.conv_layers(x) # [batch, 256, sequence_length]
|
||||
|
||||
# Prepare for LSTM (transpose to [batch, sequence, features])
|
||||
lstm_input = conv_features.transpose(1, 2) # [batch, sequence_length, 256]
|
||||
|
||||
# LSTM processing
|
||||
lstm_output, (hidden, cell) = self.lstm(lstm_input) # [batch, sequence_length, hidden_dim*2]
|
||||
|
||||
# Attention mechanism
|
||||
attended_output, attention_weights = self.attention(
|
||||
lstm_output, lstm_output, lstm_output
|
||||
)
|
||||
|
||||
# Use the last timestep for predictions
|
||||
final_features = attended_output[:, -1, :] # [batch, hidden_dim*2]
|
||||
|
||||
# Generate predictions
|
||||
pivot_logits = self.pivot_classifier(final_features)
|
||||
pivot_price = self.pivot_price_regressor(final_features)
|
||||
confidence = self.confidence_head(final_features)
|
||||
|
||||
return {
|
||||
'pivot_logits': pivot_logits,
|
||||
'pivot_price': pivot_price,
|
||||
'confidence': confidence,
|
||||
'hidden_states': final_features,
|
||||
'attention_weights': attention_weights,
|
||||
'conv_features': conv_features,
|
||||
'lstm_output': lstm_output
|
||||
}
|
||||
|
||||
class CNNTrainingDataset(Dataset):
|
||||
"""Dataset for CNN training with training episodes"""
|
||||
|
||||
def __init__(self, training_episodes: List[TrainingEpisode]):
|
||||
self.episodes = training_episodes
|
||||
self.valid_episodes = self._validate_episodes()
|
||||
|
||||
def _validate_episodes(self) -> List[TrainingEpisode]:
|
||||
"""Validate and filter episodes for training"""
|
||||
valid = []
|
||||
for episode in self.episodes:
|
||||
try:
|
||||
# Check if episode has required data
|
||||
if (episode.input_package.cnn_features is not None and
|
||||
episode.actual_outcome.outcome_validated):
|
||||
valid.append(episode)
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid episode {episode.episode_id}: {e}")
|
||||
|
||||
logger.info(f"Validated {len(valid)}/{len(self.episodes)} episodes for training")
|
||||
return valid
|
||||
|
||||
def __len__(self):
|
||||
return len(self.valid_episodes)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
episode = self.valid_episodes[idx]
|
||||
|
||||
# Extract features
|
||||
features = torch.from_numpy(episode.input_package.cnn_features).float()
|
||||
|
||||
# Create labels from actual outcomes
|
||||
pivot_class = self._determine_pivot_class(episode.actual_outcome)
|
||||
pivot_price = episode.actual_outcome.optimal_exit_price
|
||||
confidence_target = episode.actual_outcome.profitability_score
|
||||
|
||||
return {
|
||||
'features': features,
|
||||
'pivot_class': torch.tensor(pivot_class, dtype=torch.long),
|
||||
'pivot_price': torch.tensor(pivot_price, dtype=torch.float),
|
||||
'confidence_target': torch.tensor(confidence_target, dtype=torch.float),
|
||||
'episode_id': episode.episode_id,
|
||||
'profitability': episode.actual_outcome.profitability_score
|
||||
}
|
||||
|
||||
def _determine_pivot_class(self, outcome) -> int:
|
||||
"""Determine pivot class from outcome"""
|
||||
if outcome.price_change_15m > 0.5: # Significant upward movement
|
||||
return 0 # High pivot
|
||||
elif outcome.price_change_15m < -0.5: # Significant downward movement
|
||||
return 1 # Low pivot
|
||||
else:
|
||||
return 2 # No significant pivot
|
||||
|
||||
class CNNTrainer:
|
||||
"""CNN trainer with comprehensive data storage and replay capabilities"""
|
||||
|
||||
def __init__(self,
|
||||
model: CNNPivotPredictor,
|
||||
device: str = 'cuda',
|
||||
learning_rate: float = 0.001,
|
||||
storage_dir: str = "cnn_training_storage"):
|
||||
|
||||
self.model = model.to(device)
|
||||
self.device = device
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
# Storage
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Optimizer
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=learning_rate,
|
||||
weight_decay=1e-5
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer, mode='min', patience=10, factor=0.5
|
||||
)
|
||||
|
||||
# Training data collector
|
||||
self.data_collector = get_training_data_collector()
|
||||
|
||||
# Training sessions storage
|
||||
self.training_sessions: List[CNNTrainingSession] = []
|
||||
self.current_session: Optional[CNNTrainingSession] = None
|
||||
|
||||
# Training statistics
|
||||
self.training_stats = {
|
||||
'total_sessions': 0,
|
||||
'total_steps': 0,
|
||||
'best_validation_loss': float('inf'),
|
||||
'profitable_predictions': 0,
|
||||
'total_predictions': 0,
|
||||
'replay_sessions': 0
|
||||
}
|
||||
|
||||
# Background training
|
||||
self.is_training = False
|
||||
self.training_thread = None
|
||||
|
||||
logger.info(f"CNN Trainer initialized")
|
||||
logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
||||
logger.info(f"Storage directory: {self.storage_dir}")
|
||||
|
||||
def start_real_time_training(self, symbol: str):
|
||||
"""Start real-time training for a symbol"""
|
||||
if self.is_training:
|
||||
logger.warning("CNN training already running")
|
||||
return
|
||||
|
||||
self.is_training = True
|
||||
self.training_thread = threading.Thread(
|
||||
target=self._real_time_training_worker,
|
||||
args=(symbol,),
|
||||
daemon=True
|
||||
)
|
||||
self.training_thread.start()
|
||||
|
||||
logger.info(f"Started real-time CNN training for {symbol}")
|
||||
|
||||
def stop_training(self):
|
||||
"""Stop training"""
|
||||
self.is_training = False
|
||||
if self.training_thread:
|
||||
self.training_thread.join(timeout=10)
|
||||
|
||||
if self.current_session:
|
||||
self._finalize_training_session()
|
||||
|
||||
logger.info("CNN training stopped")
|
||||
|
||||
def _real_time_training_worker(self, symbol: str):
|
||||
"""Real-time training worker"""
|
||||
logger.info(f"Real-time CNN training worker started for {symbol}")
|
||||
|
||||
while self.is_training:
|
||||
try:
|
||||
# Get high-priority episodes for training
|
||||
episodes = self.data_collector.get_high_priority_episodes(
|
||||
symbol=symbol,
|
||||
limit=100,
|
||||
min_priority=0.3
|
||||
)
|
||||
|
||||
if len(episodes) >= 32: # Minimum batch size
|
||||
self._train_on_episodes(episodes, training_mode='real_time')
|
||||
|
||||
# Wait before next training cycle
|
||||
threading.Event().wait(300) # Train every 5 minutes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real-time training worker: {e}")
|
||||
threading.Event().wait(60) # Wait before retrying
|
||||
|
||||
logger.info(f"Real-time CNN training worker stopped for {symbol}")
|
||||
|
||||
def train_on_profitable_episodes(self,
|
||||
symbol: str,
|
||||
min_profitability: float = 0.7,
|
||||
max_episodes: int = 500) -> Dict[str, Any]:
|
||||
"""Train specifically on most profitable episodes"""
|
||||
try:
|
||||
# Get all episodes for symbol
|
||||
all_episodes = self.data_collector.training_episodes.get(symbol, [])
|
||||
|
||||
# Filter for profitable episodes
|
||||
profitable_episodes = [
|
||||
ep for ep in all_episodes
|
||||
if (ep.actual_outcome.is_profitable and
|
||||
ep.actual_outcome.profitability_score >= min_profitability)
|
||||
]
|
||||
|
||||
# Sort by profitability and limit
|
||||
profitable_episodes.sort(
|
||||
key=lambda x: x.actual_outcome.profitability_score,
|
||||
reverse=True
|
||||
)
|
||||
profitable_episodes = profitable_episodes[:max_episodes]
|
||||
|
||||
if len(profitable_episodes) < 10:
|
||||
logger.warning(f"Insufficient profitable episodes for {symbol}: {len(profitable_episodes)}")
|
||||
return {'status': 'insufficient_data', 'episodes_found': len(profitable_episodes)}
|
||||
|
||||
# Train on profitable episodes
|
||||
results = self._train_on_episodes(
|
||||
profitable_episodes,
|
||||
training_mode='profitable_replay'
|
||||
)
|
||||
|
||||
logger.info(f"Trained on {len(profitable_episodes)} profitable episodes for {symbol}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on profitable episodes: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
def _train_on_episodes(self,
|
||||
episodes: List[TrainingEpisode],
|
||||
training_mode: str = 'batch') -> Dict[str, Any]:
|
||||
"""Train on a batch of episodes with comprehensive data storage"""
|
||||
try:
|
||||
# Start new training session
|
||||
session = CNNTrainingSession(
|
||||
session_id=f"{training_mode}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
start_timestamp=datetime.now(),
|
||||
training_mode=training_mode,
|
||||
symbol=episodes[0].input_package.symbol if episodes else 'unknown'
|
||||
)
|
||||
self.current_session = session
|
||||
|
||||
# Create dataset and dataloader
|
||||
dataset = CNNTrainingDataset(episodes)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
num_workers=2
|
||||
)
|
||||
|
||||
# Training loop
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
# Move to device
|
||||
features = batch['features'].to(self.device)
|
||||
pivot_class = batch['pivot_class'].to(self.device)
|
||||
pivot_price = batch['pivot_price'].to(self.device)
|
||||
confidence_target = batch['confidence_target'].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
outputs = self.model(features)
|
||||
|
||||
# Calculate losses
|
||||
classification_loss = F.cross_entropy(outputs['pivot_logits'], pivot_class)
|
||||
regression_loss = F.mse_loss(outputs['pivot_price'].squeeze(), pivot_price)
|
||||
confidence_loss = F.binary_cross_entropy(
|
||||
outputs['confidence'].squeeze(),
|
||||
confidence_target
|
||||
)
|
||||
|
||||
# Combined loss
|
||||
total_batch_loss = classification_loss + 0.5 * regression_loss + 0.3 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_batch_loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||||
|
||||
# Store gradients before optimizer step
|
||||
gradients = {}
|
||||
gradient_norms = {}
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.grad is not None:
|
||||
gradients[name] = param.grad.clone().detach()
|
||||
gradient_norms[name] = param.grad.norm().item()
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
|
||||
# Create training step record
|
||||
step = CNNTrainingStep(
|
||||
step_id=f"{session.session_id}_step_{batch_idx}",
|
||||
timestamp=datetime.now(),
|
||||
episode_id=f"batch_{batch_idx}",
|
||||
input_features=features.detach().cpu(),
|
||||
target_labels=pivot_class.detach().cpu(),
|
||||
model_outputs={k: v.detach().cpu() for k, v in outputs.items()},
|
||||
predictions=self._extract_predictions(outputs),
|
||||
confidence_scores=outputs['confidence'].detach().cpu(),
|
||||
total_loss=total_batch_loss.item(),
|
||||
pivot_prediction_loss=classification_loss.item(),
|
||||
confidence_loss=confidence_loss.item(),
|
||||
regularization_loss=0.0,
|
||||
gradients=gradients,
|
||||
gradient_norms=gradient_norms,
|
||||
learning_rate=self.optimizer.param_groups[0]['lr'],
|
||||
batch_size=features.size(0)
|
||||
)
|
||||
|
||||
# Calculate training value for this step
|
||||
step.training_value = self._calculate_step_training_value(step, batch)
|
||||
|
||||
# Add to session
|
||||
session.training_steps.append(step)
|
||||
|
||||
total_loss += total_batch_loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Log progress
|
||||
if batch_idx % 10 == 0:
|
||||
logger.debug(f"Batch {batch_idx}: Loss = {total_batch_loss.item():.4f}")
|
||||
|
||||
# Finalize session
|
||||
session.end_timestamp = datetime.now()
|
||||
session.total_steps = num_batches
|
||||
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
session.best_loss = min(step.total_loss for step in session.training_steps)
|
||||
|
||||
# Calculate session value
|
||||
session.session_value = self._calculate_session_value(session)
|
||||
|
||||
# Update scheduler
|
||||
self.scheduler.step(session.average_loss)
|
||||
|
||||
# Save session
|
||||
self._save_training_session(session)
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['total_sessions'] += 1
|
||||
self.training_stats['total_steps'] += session.total_steps
|
||||
if training_mode == 'profitable_replay':
|
||||
self.training_stats['replay_sessions'] += 1
|
||||
|
||||
logger.info(f"Training session completed: {session.session_id}")
|
||||
logger.info(f"Average loss: {session.average_loss:.4f}")
|
||||
logger.info(f"Session value: {session.session_value:.3f}")
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'session_id': session.session_id,
|
||||
'average_loss': session.average_loss,
|
||||
'total_steps': session.total_steps,
|
||||
'session_value': session.session_value
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training session: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
finally:
|
||||
self.current_session = None
|
||||
|
||||
def _extract_predictions(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
|
||||
"""Extract human-readable predictions from model outputs"""
|
||||
try:
|
||||
pivot_probs = F.softmax(outputs['pivot_logits'], dim=1)
|
||||
predicted_class = torch.argmax(pivot_probs, dim=1)
|
||||
|
||||
return {
|
||||
'pivot_class': predicted_class.cpu().numpy().tolist(),
|
||||
'pivot_probabilities': pivot_probs.cpu().numpy().tolist(),
|
||||
'pivot_price': outputs['pivot_price'].cpu().numpy().tolist(),
|
||||
'confidence': outputs['confidence'].cpu().numpy().tolist()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting predictions: {e}")
|
||||
return {}
|
||||
|
||||
def _calculate_step_training_value(self,
|
||||
step: CNNTrainingStep,
|
||||
batch: Dict[str, Any]) -> float:
|
||||
"""Calculate the training value of a step for replay prioritization"""
|
||||
try:
|
||||
value = 0.0
|
||||
|
||||
# Base value from loss (lower loss = higher value)
|
||||
if step.total_loss > 0:
|
||||
value += 1.0 / (1.0 + step.total_loss)
|
||||
|
||||
# Bonus for high profitability episodes in batch
|
||||
avg_profitability = torch.mean(batch['profitability']).item()
|
||||
value += avg_profitability * 0.3
|
||||
|
||||
# Bonus for gradient magnitude (indicates learning)
|
||||
avg_grad_norm = np.mean(list(step.gradient_norms.values()))
|
||||
value += min(avg_grad_norm / 10.0, 0.2) # Cap at 0.2
|
||||
|
||||
return min(value, 1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating step training value: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_session_value(self, session: CNNTrainingSession) -> float:
|
||||
"""Calculate overall session value for replay prioritization"""
|
||||
try:
|
||||
if not session.training_steps:
|
||||
return 0.0
|
||||
|
||||
# Average step values
|
||||
avg_step_value = np.mean([step.training_value for step in session.training_steps])
|
||||
|
||||
# Bonus for convergence
|
||||
convergence_bonus = 0.0
|
||||
if len(session.training_steps) > 10:
|
||||
early_loss = np.mean([s.total_loss for s in session.training_steps[:5]])
|
||||
late_loss = np.mean([s.total_loss for s in session.training_steps[-5:]])
|
||||
if early_loss > late_loss:
|
||||
convergence_bonus = min((early_loss - late_loss) / early_loss, 0.3)
|
||||
|
||||
# Bonus for profitable replay sessions
|
||||
mode_bonus = 0.2 if session.training_mode == 'profitable_replay' else 0.0
|
||||
|
||||
return min(avg_step_value + convergence_bonus + mode_bonus, 1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating session value: {e}")
|
||||
return 0.0
|
||||
|
||||
def _save_training_session(self, session: CNNTrainingSession):
|
||||
"""Save training session to disk"""
|
||||
try:
|
||||
session_dir = self.storage_dir / session.symbol / 'sessions'
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save full session data
|
||||
session_file = session_dir / f"{session.session_id}.pkl"
|
||||
with open(session_file, 'wb') as f:
|
||||
pickle.dump(session, f)
|
||||
|
||||
# Save session metadata
|
||||
metadata = {
|
||||
'session_id': session.session_id,
|
||||
'start_timestamp': session.start_timestamp.isoformat(),
|
||||
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
|
||||
'training_mode': session.training_mode,
|
||||
'symbol': session.symbol,
|
||||
'total_steps': session.total_steps,
|
||||
'average_loss': session.average_loss,
|
||||
'best_loss': session.best_loss,
|
||||
'session_value': session.session_value
|
||||
}
|
||||
|
||||
metadata_file = session_dir / f"{session.session_id}_metadata.json"
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
logger.debug(f"Saved training session: {session.session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving training session: {e}")
|
||||
|
||||
def _finalize_training_session(self):
|
||||
"""Finalize current training session"""
|
||||
if self.current_session:
|
||||
self.current_session.end_timestamp = datetime.now()
|
||||
self._save_training_session(self.current_session)
|
||||
self.training_sessions.append(self.current_session)
|
||||
self.current_session = None
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive training statistics"""
|
||||
stats = self.training_stats.copy()
|
||||
|
||||
# Add recent session information
|
||||
if self.training_sessions:
|
||||
recent_sessions = sorted(
|
||||
self.training_sessions,
|
||||
key=lambda x: x.start_timestamp,
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
stats['recent_sessions'] = [
|
||||
{
|
||||
'session_id': s.session_id,
|
||||
'timestamp': s.start_timestamp.isoformat(),
|
||||
'mode': s.training_mode,
|
||||
'average_loss': s.average_loss,
|
||||
'session_value': s.session_value
|
||||
}
|
||||
for s in recent_sessions
|
||||
]
|
||||
|
||||
# Calculate profitability rate
|
||||
if stats['total_predictions'] > 0:
|
||||
stats['profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
|
||||
else:
|
||||
stats['profitability_rate'] = 0.0
|
||||
|
||||
return stats
|
||||
|
||||
def replay_high_value_sessions(self,
|
||||
symbol: str,
|
||||
min_session_value: float = 0.7,
|
||||
max_sessions: int = 10) -> Dict[str, Any]:
|
||||
"""Replay high-value training sessions"""
|
||||
try:
|
||||
# Find high-value sessions
|
||||
high_value_sessions = [
|
||||
s for s in self.training_sessions
|
||||
if (s.symbol == symbol and
|
||||
s.session_value >= min_session_value)
|
||||
]
|
||||
|
||||
# Sort by value and limit
|
||||
high_value_sessions.sort(key=lambda x: x.session_value, reverse=True)
|
||||
high_value_sessions = high_value_sessions[:max_sessions]
|
||||
|
||||
if not high_value_sessions:
|
||||
return {'status': 'no_high_value_sessions', 'sessions_found': 0}
|
||||
|
||||
# Replay sessions
|
||||
total_replayed = 0
|
||||
for session in high_value_sessions:
|
||||
# Extract episodes from session steps
|
||||
episode_ids = list(set(step.episode_id for step in session.training_steps))
|
||||
|
||||
# Get corresponding episodes
|
||||
episodes = []
|
||||
for episode_id in episode_ids:
|
||||
# Find episode in data collector
|
||||
for ep in self.data_collector.training_episodes.get(symbol, []):
|
||||
if ep.episode_id == episode_id:
|
||||
episodes.append(ep)
|
||||
break
|
||||
|
||||
if episodes:
|
||||
self._train_on_episodes(episodes, training_mode='high_value_replay')
|
||||
total_replayed += 1
|
||||
|
||||
logger.info(f"Replayed {total_replayed} high-value sessions for {symbol}")
|
||||
return {
|
||||
'status': 'success',
|
||||
'sessions_replayed': total_replayed,
|
||||
'sessions_found': len(high_value_sessions)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error replaying high-value sessions: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
# Global instance
|
||||
cnn_trainer = None
|
||||
|
||||
def get_cnn_trainer(model: CNNPivotPredictor = None) -> CNNTrainer:
|
||||
"""Get global CNN trainer instance"""
|
||||
global cnn_trainer
|
||||
if cnn_trainer is None:
|
||||
if model is None:
|
||||
model = CNNPivotPredictor()
|
||||
cnn_trainer = CNNTrainer(model)
|
||||
return cnn_trainer
|
||||
@@ -1,864 +0,0 @@
|
||||
# """
|
||||
# Enhanced CNN Adapter for Standardized Input Format
|
||||
|
||||
# This module provides an adapter for the EnhancedCNN model to work with the standardized
|
||||
# BaseDataInput format, enabling seamless integration with the multi-modal trading system.
|
||||
# """
|
||||
|
||||
# import torch
|
||||
# import numpy as np
|
||||
# import logging
|
||||
# import os
|
||||
# import random
|
||||
# from datetime import datetime, timedelta
|
||||
# from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
# from threading import Lock
|
||||
|
||||
# from .data_models import BaseDataInput, ModelOutput, create_model_output
|
||||
# from NN.models.enhanced_cnn import EnhancedCNN
|
||||
# from utils.inference_logger import log_model_inference
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# class EnhancedCNNAdapter:
|
||||
# """
|
||||
# Adapter for EnhancedCNN model to work with standardized BaseDataInput format
|
||||
|
||||
# This adapter:
|
||||
# 1. Converts BaseDataInput to the format expected by EnhancedCNN
|
||||
# 2. Processes model outputs to create standardized ModelOutput
|
||||
# 3. Manages model training with collected data
|
||||
# 4. Handles checkpoint management
|
||||
# """
|
||||
|
||||
# def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"):
|
||||
# """
|
||||
# Initialize the EnhancedCNN adapter
|
||||
|
||||
# Args:
|
||||
# model_path: Path to load model from, if None a new model is created
|
||||
# checkpoint_dir: Directory to save checkpoints to
|
||||
# """
|
||||
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# self.model = None
|
||||
# self.model_path = model_path
|
||||
# self.checkpoint_dir = checkpoint_dir
|
||||
# self.training_lock = Lock()
|
||||
# self.training_data = []
|
||||
# self.max_training_samples = 10000
|
||||
# self.batch_size = 32
|
||||
# self.learning_rate = 0.0001
|
||||
# self.model_name = "enhanced_cnn"
|
||||
|
||||
# # Enhanced metrics tracking
|
||||
# self.last_inference_time = None
|
||||
# self.last_inference_duration = 0.0
|
||||
# self.last_prediction_output = None
|
||||
# self.last_training_time = None
|
||||
# self.last_training_duration = 0.0
|
||||
# self.last_training_loss = 0.0
|
||||
# self.inference_count = 0
|
||||
# self.training_count = 0
|
||||
|
||||
# # Create checkpoint directory if it doesn't exist
|
||||
# os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# # Initialize the model
|
||||
# self._initialize_model()
|
||||
|
||||
# # Load checkpoint if available
|
||||
# if model_path and os.path.exists(model_path):
|
||||
# self._load_checkpoint(model_path)
|
||||
# else:
|
||||
# self._load_best_checkpoint()
|
||||
|
||||
# # Final device check and move
|
||||
# self._ensure_model_on_device()
|
||||
|
||||
# logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
|
||||
|
||||
# def _create_realistic_synthetic_features(self, symbol: str) -> torch.Tensor:
|
||||
# """Create realistic synthetic features instead of random data"""
|
||||
# try:
|
||||
# # Create realistic market-like features
|
||||
# features = torch.zeros(7850, dtype=torch.float32, device=self.device)
|
||||
|
||||
# # OHLCV features (6000 features: 300 frames x 4 timeframes x 5 features)
|
||||
# ohlcv_start = 0
|
||||
# for timeframe_idx in range(4): # 1s, 1m, 1h, 1d
|
||||
# base_price = 3500.0 + timeframe_idx * 10 # Slight variation per timeframe
|
||||
# for frame_idx in range(300):
|
||||
# # Create realistic price movement
|
||||
# price_change = torch.sin(torch.tensor(frame_idx * 0.1)) * 0.01 # Cyclical movement
|
||||
# current_price = base_price * (1 + price_change)
|
||||
|
||||
# # Realistic OHLCV values
|
||||
# open_price = current_price
|
||||
# high_price = current_price * torch.uniform(1.0, 1.005)
|
||||
# low_price = current_price * torch.uniform(0.995, 1.0)
|
||||
# close_price = current_price * torch.uniform(0.998, 1.002)
|
||||
# volume = torch.uniform(500.0, 2000.0)
|
||||
|
||||
# # Set features
|
||||
# feature_idx = ohlcv_start + frame_idx * 5 + timeframe_idx * 1500
|
||||
# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
|
||||
|
||||
# # BTC OHLCV features (1500 features: 300 frames x 5 features)
|
||||
# btc_start = 6000
|
||||
# btc_base_price = 50000.0
|
||||
# for frame_idx in range(300):
|
||||
# price_change = torch.sin(torch.tensor(frame_idx * 0.05)) * 0.02
|
||||
# current_price = btc_base_price * (1 + price_change)
|
||||
|
||||
# open_price = current_price
|
||||
# high_price = current_price * torch.uniform(1.0, 1.01)
|
||||
# low_price = current_price * torch.uniform(0.99, 1.0)
|
||||
# close_price = current_price * torch.uniform(0.995, 1.005)
|
||||
# volume = torch.uniform(100.0, 500.0)
|
||||
|
||||
# feature_idx = btc_start + frame_idx * 5
|
||||
# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
|
||||
|
||||
# # COB features (200 features) - realistic order book data
|
||||
# cob_start = 7500
|
||||
# for i in range(200):
|
||||
# features[cob_start + i] = torch.uniform(0.0, 1000.0) # Realistic COB values
|
||||
|
||||
# # Technical indicators (100 features)
|
||||
# indicator_start = 7700
|
||||
# for i in range(100):
|
||||
# features[indicator_start + i] = torch.uniform(-1.0, 1.0) # Normalized indicators
|
||||
|
||||
# # Last predictions (50 features)
|
||||
# prediction_start = 7800
|
||||
# for i in range(50):
|
||||
# features[prediction_start + i] = torch.uniform(0.0, 1.0) # Probability values
|
||||
|
||||
# return features
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error creating realistic synthetic features: {e}")
|
||||
# # Fallback to small random variation
|
||||
# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
|
||||
# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
|
||||
# return base_features + noise
|
||||
|
||||
# def _create_realistic_features(self, symbol: str) -> torch.Tensor:
|
||||
# """Create features from real market data if available"""
|
||||
# try:
|
||||
# # This would need to be implemented to use actual market data
|
||||
# # For now, fall back to synthetic features
|
||||
# return self._create_realistic_synthetic_features(symbol)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error creating realistic features: {e}")
|
||||
# return self._create_realistic_synthetic_features(symbol)
|
||||
|
||||
# def _initialize_model(self):
|
||||
# """Initialize the EnhancedCNN model"""
|
||||
# try:
|
||||
# # Calculate input shape based on BaseDataInput structure
|
||||
# # OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
|
||||
# # BTC OHLCV: 300 frames x 5 features = 1500 features
|
||||
# # COB: ±20 buckets x 4 metrics = 160 features
|
||||
# # MA: 4 timeframes x 10 buckets = 40 features
|
||||
# # Technical indicators: 100 features
|
||||
# # Last predictions: 50 features
|
||||
# # Total: 7850 features
|
||||
# input_shape = 7850
|
||||
# n_actions = 3 # BUY, SELL, HOLD
|
||||
|
||||
# # Create model
|
||||
# self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
|
||||
# # Ensure model is moved to the correct device
|
||||
# self.model.to(self.device)
|
||||
|
||||
# logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error initializing EnhancedCNN model: {e}")
|
||||
# raise
|
||||
|
||||
# def _load_checkpoint(self, checkpoint_path: str) -> bool:
|
||||
# """Load model from checkpoint path"""
|
||||
# try:
|
||||
# if self.model and os.path.exists(checkpoint_path):
|
||||
# success = self.model.load(checkpoint_path)
|
||||
# if success:
|
||||
# # Ensure model is moved to the correct device after loading
|
||||
# self.model.to(self.device)
|
||||
# logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
|
||||
# return True
|
||||
# else:
|
||||
# logger.warning(f"Failed to load model from {checkpoint_path}")
|
||||
# return False
|
||||
# else:
|
||||
# logger.warning(f"Checkpoint path does not exist: {checkpoint_path}")
|
||||
# return False
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error loading checkpoint: {e}")
|
||||
# return False
|
||||
|
||||
# def _load_best_checkpoint(self) -> bool:
|
||||
# """Load the best available checkpoint"""
|
||||
# try:
|
||||
# return self.load_best_checkpoint()
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error loading best checkpoint: {e}")
|
||||
# return False
|
||||
|
||||
# def load_best_checkpoint(self) -> bool:
|
||||
# """Load the best checkpoint based on accuracy"""
|
||||
# try:
|
||||
# # Import checkpoint manager
|
||||
# from utils.checkpoint_manager import CheckpointManager
|
||||
|
||||
# # Create checkpoint manager
|
||||
# checkpoint_manager = CheckpointManager(
|
||||
# checkpoint_dir=self.checkpoint_dir,
|
||||
# max_checkpoints=10,
|
||||
# metric_name="accuracy"
|
||||
# )
|
||||
|
||||
# # Load best checkpoint
|
||||
# best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
|
||||
|
||||
# if not best_checkpoint_path:
|
||||
# logger.info(f"No checkpoints found for {self.model_name} - starting in COLD START mode")
|
||||
# return False
|
||||
|
||||
# # Load model
|
||||
# success = self.model.load(best_checkpoint_path)
|
||||
|
||||
# if success:
|
||||
# # Ensure model is moved to the correct device after loading
|
||||
# self.model.to(self.device)
|
||||
# logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
|
||||
|
||||
# # Log metrics
|
||||
# metrics = best_checkpoint_metadata.get('metrics', {})
|
||||
# logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
|
||||
|
||||
# return True
|
||||
# else:
|
||||
# logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
|
||||
# return False
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error loading best checkpoint: {e}")
|
||||
# return False
|
||||
|
||||
# def _ensure_model_on_device(self):
|
||||
# """Ensure model and all its components are on the correct device"""
|
||||
# try:
|
||||
# if self.model:
|
||||
# self.model.to(self.device)
|
||||
# # Also ensure the model's internal device is set correctly
|
||||
# if hasattr(self.model, 'device'):
|
||||
# self.model.device = self.device
|
||||
# logger.debug(f"Model ensured on device {self.device}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error ensuring model on device: {e}")
|
||||
|
||||
# def _create_default_output(self, symbol: str) -> ModelOutput:
|
||||
# """Create default output when prediction fails"""
|
||||
# return create_model_output(
|
||||
# model_type='cnn',
|
||||
# model_name=self.model_name,
|
||||
# symbol=symbol,
|
||||
# action='HOLD',
|
||||
# confidence=0.0,
|
||||
# metadata={'error': 'Prediction failed, using default output'}
|
||||
# )
|
||||
|
||||
# def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# """Process hidden states for cross-model feeding"""
|
||||
# processed_states = {}
|
||||
|
||||
# for key, value in hidden_states.items():
|
||||
# if isinstance(value, torch.Tensor):
|
||||
# # Convert tensor to numpy array
|
||||
# processed_states[key] = value.cpu().numpy().tolist()
|
||||
# else:
|
||||
# processed_states[key] = value
|
||||
|
||||
# return processed_states
|
||||
|
||||
|
||||
|
||||
# def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
|
||||
# """
|
||||
# Convert BaseDataInput to feature vector for EnhancedCNN
|
||||
|
||||
# Args:
|
||||
# base_data: Standardized input data
|
||||
|
||||
# Returns:
|
||||
# torch.Tensor: Feature vector for EnhancedCNN
|
||||
# """
|
||||
# try:
|
||||
# # Use the get_feature_vector method from BaseDataInput
|
||||
# features = base_data.get_feature_vector()
|
||||
|
||||
# # Validate feature quality before using
|
||||
# self._validate_feature_quality(features)
|
||||
|
||||
# # Convert to torch tensor
|
||||
# features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device)
|
||||
|
||||
# return features_tensor
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error converting BaseDataInput to features: {e}")
|
||||
# # Return empty tensor with correct shape
|
||||
# return torch.zeros(7850, dtype=torch.float32, device=self.device)
|
||||
|
||||
# def _validate_feature_quality(self, features: np.ndarray):
|
||||
# """Validate that features are realistic and not synthetic/placeholder data"""
|
||||
# try:
|
||||
# if len(features) != 7850:
|
||||
# logger.warning(f"Feature vector has wrong size: {len(features)} != 7850")
|
||||
# return
|
||||
|
||||
# # Check for all-zero or all-identical features (indicates placeholder data)
|
||||
# if np.all(features == 0):
|
||||
# logger.warning("Feature vector contains all zeros - likely placeholder data")
|
||||
# return
|
||||
|
||||
# # Check for repetitive patterns in OHLCV data (first 6000 features)
|
||||
# ohlcv_features = features[:6000]
|
||||
# if len(ohlcv_features) >= 20:
|
||||
# # Check if first 20 values are identical (indicates padding with same bar)
|
||||
# if np.allclose(ohlcv_features[:20], ohlcv_features[0], atol=1e-6):
|
||||
# logger.warning("OHLCV features show repetitive pattern - possible synthetic data")
|
||||
|
||||
# # Check for unrealistic values
|
||||
# if np.any(features > 1e6) or np.any(features < -1e6):
|
||||
# logger.warning("Feature vector contains unrealistic values")
|
||||
|
||||
# # Check for NaN or infinite values
|
||||
# if np.any(np.isnan(features)) or np.any(np.isinf(features)):
|
||||
# logger.warning("Feature vector contains NaN or infinite values")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error validating feature quality: {e}")
|
||||
|
||||
# def predict(self, base_data: BaseDataInput) -> ModelOutput:
|
||||
# """
|
||||
# Make a prediction using the EnhancedCNN model
|
||||
|
||||
# Args:
|
||||
# base_data: Standardized input data
|
||||
|
||||
# Returns:
|
||||
# ModelOutput: Standardized model output
|
||||
# """
|
||||
# try:
|
||||
# # Track inference timing
|
||||
# start_time = datetime.now()
|
||||
# inference_start = start_time.timestamp()
|
||||
|
||||
# # Convert BaseDataInput to features
|
||||
# features = self._convert_base_data_to_features(base_data)
|
||||
|
||||
# # Ensure features has batch dimension
|
||||
# if features.dim() == 1:
|
||||
# features = features.unsqueeze(0)
|
||||
|
||||
# # Ensure model is on correct device before prediction
|
||||
# self._ensure_model_on_device()
|
||||
|
||||
# # Set model to evaluation mode
|
||||
# self.model.eval()
|
||||
|
||||
# # Make prediction
|
||||
# with torch.no_grad():
|
||||
# q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features)
|
||||
|
||||
# # Get action and confidence
|
||||
# action_probs = torch.softmax(q_values, dim=1)
|
||||
# action_idx = torch.argmax(action_probs, dim=1).item()
|
||||
# raw_confidence = float(action_probs[0, action_idx].item())
|
||||
|
||||
# # Validate confidence - prevent 100% confidence which indicates overfitting
|
||||
# if raw_confidence >= 0.99:
|
||||
# logger.warning(f"CNN produced suspiciously high confidence: {raw_confidence:.4f} - possible overfitting")
|
||||
# # Cap confidence at 0.95 to prevent unrealistic predictions
|
||||
# confidence = min(raw_confidence, 0.95)
|
||||
# logger.info(f"Capped confidence from {raw_confidence:.4f} to {confidence:.4f}")
|
||||
# else:
|
||||
# confidence = raw_confidence
|
||||
|
||||
# # Map action index to action string
|
||||
# actions = ['BUY', 'SELL', 'HOLD']
|
||||
# action = actions[action_idx]
|
||||
|
||||
# # Extract pivot price prediction (simplified - take first value from price_pred)
|
||||
# pivot_price = None
|
||||
# if price_pred is not None and len(price_pred.squeeze()) > 0:
|
||||
# # Get current price from base_data for context
|
||||
# current_price = 0.0
|
||||
# if base_data.ohlcv_1s and len(base_data.ohlcv_1s) > 0:
|
||||
# current_price = base_data.ohlcv_1s[-1].close
|
||||
|
||||
# # Calculate pivot price as current price + predicted change
|
||||
# price_change_pct = float(price_pred.squeeze()[0].item()) # First prediction value
|
||||
# pivot_price = current_price * (1 + price_change_pct * 0.01) # Convert percentage to price
|
||||
|
||||
# # Create predictions dictionary
|
||||
# predictions = {
|
||||
# 'action': action,
|
||||
# 'buy_probability': float(action_probs[0, 0].item()),
|
||||
# 'sell_probability': float(action_probs[0, 1].item()),
|
||||
# 'hold_probability': float(action_probs[0, 2].item()),
|
||||
# 'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
|
||||
# 'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist(),
|
||||
# 'pivot_price': pivot_price
|
||||
# }
|
||||
|
||||
# # Create hidden states dictionary
|
||||
# hidden_states = {
|
||||
# 'features': features_refined.squeeze(0).cpu().numpy().tolist()
|
||||
# }
|
||||
|
||||
# # Calculate inference duration
|
||||
# end_time = datetime.now()
|
||||
# inference_duration = (end_time.timestamp() - inference_start) * 1000 # Convert to milliseconds
|
||||
|
||||
# # Update metrics
|
||||
# self.last_inference_time = start_time
|
||||
# self.last_inference_duration = inference_duration
|
||||
# self.inference_count += 1
|
||||
|
||||
# # Store last prediction output for dashboard
|
||||
# self.last_prediction_output = {
|
||||
# 'action': action,
|
||||
# 'confidence': confidence,
|
||||
# 'pivot_price': pivot_price,
|
||||
# 'timestamp': start_time,
|
||||
# 'symbol': base_data.symbol
|
||||
# }
|
||||
|
||||
# # Create metadata dictionary
|
||||
# metadata = {
|
||||
# 'model_version': '1.0',
|
||||
# 'timestamp': start_time.isoformat(),
|
||||
# 'input_shape': features.shape,
|
||||
# 'inference_duration_ms': inference_duration,
|
||||
# 'inference_count': self.inference_count
|
||||
# }
|
||||
|
||||
# # Create ModelOutput
|
||||
# model_output = ModelOutput(
|
||||
# model_type='cnn',
|
||||
# model_name=self.model_name,
|
||||
# symbol=base_data.symbol,
|
||||
# timestamp=start_time,
|
||||
# confidence=confidence,
|
||||
# predictions=predictions,
|
||||
# hidden_states=hidden_states,
|
||||
# metadata=metadata
|
||||
# )
|
||||
|
||||
# # Log inference with full input data for training feedback
|
||||
# log_model_inference(
|
||||
# model_name=self.model_name,
|
||||
# symbol=base_data.symbol,
|
||||
# action=action,
|
||||
# confidence=confidence,
|
||||
# probabilities={
|
||||
# 'BUY': predictions['buy_probability'],
|
||||
# 'SELL': predictions['sell_probability'],
|
||||
# 'HOLD': predictions['hold_probability']
|
||||
# },
|
||||
# input_features=features.cpu().numpy(), # Store full feature vector
|
||||
# processing_time_ms=inference_duration,
|
||||
# checkpoint_id=None, # Could be enhanced to track checkpoint
|
||||
# metadata={
|
||||
# 'base_data_input': {
|
||||
# 'symbol': base_data.symbol,
|
||||
# 'timestamp': base_data.timestamp.isoformat(),
|
||||
# 'ohlcv_1s_count': len(base_data.ohlcv_1s),
|
||||
# 'ohlcv_1m_count': len(base_data.ohlcv_1m),
|
||||
# 'ohlcv_1h_count': len(base_data.ohlcv_1h),
|
||||
# 'ohlcv_1d_count': len(base_data.ohlcv_1d),
|
||||
# 'btc_ohlcv_1s_count': len(base_data.btc_ohlcv_1s),
|
||||
# 'has_cob_data': base_data.cob_data is not None,
|
||||
# 'technical_indicators_count': len(base_data.technical_indicators),
|
||||
# 'pivot_points_count': len(base_data.pivot_points),
|
||||
# 'last_predictions_count': len(base_data.last_predictions)
|
||||
# },
|
||||
# 'model_predictions': {
|
||||
# 'pivot_price': pivot_price,
|
||||
# 'extrema_prediction': predictions['extrema'],
|
||||
# 'price_prediction': predictions['price_prediction']
|
||||
# }
|
||||
# }
|
||||
# )
|
||||
|
||||
# return model_output
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error making prediction with EnhancedCNN: {e}")
|
||||
# # Return default ModelOutput
|
||||
# return create_model_output(
|
||||
# model_type='cnn',
|
||||
# model_name=self.model_name,
|
||||
# symbol=base_data.symbol,
|
||||
# action='HOLD',
|
||||
# confidence=0.0
|
||||
# )
|
||||
|
||||
# def add_training_sample(self, symbol_or_base_data, actual_action: str, reward: float):
|
||||
# """
|
||||
# Add a training sample to the training data
|
||||
|
||||
# Args:
|
||||
# symbol_or_base_data: Either a symbol string or BaseDataInput object
|
||||
# actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
|
||||
# reward: Reward received for the action
|
||||
# """
|
||||
# try:
|
||||
# # Handle both symbol string and BaseDataInput object
|
||||
# if isinstance(symbol_or_base_data, str):
|
||||
# # For cold start mode - create a simple training sample with current features
|
||||
# # This is a simplified approach for rapid training
|
||||
# symbol = symbol_or_base_data
|
||||
|
||||
# # Create a realistic feature vector instead of random data
|
||||
# # Use actual market data if available, otherwise create realistic synthetic data
|
||||
# try:
|
||||
# # Try to get real market data first
|
||||
# if hasattr(self, 'data_provider') and self.data_provider:
|
||||
# # This would need to be implemented in the adapter
|
||||
# features = self._create_realistic_features(symbol)
|
||||
# else:
|
||||
# # Create realistic synthetic features (not random)
|
||||
# features = self._create_realistic_synthetic_features(symbol)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Could not create realistic features for {symbol}: {e}")
|
||||
# # Fallback to small random variation instead of pure random
|
||||
# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
|
||||
# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
|
||||
# features = base_features + noise
|
||||
|
||||
# logger.debug(f"Added realistic training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
|
||||
|
||||
# else:
|
||||
# # Full BaseDataInput object
|
||||
# base_data = symbol_or_base_data
|
||||
# features = self._convert_base_data_to_features(base_data)
|
||||
# symbol = base_data.symbol
|
||||
|
||||
# logger.debug(f"Added full training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
|
||||
|
||||
# # Convert action to index
|
||||
# actions = ['BUY', 'SELL', 'HOLD']
|
||||
# action_idx = actions.index(actual_action)
|
||||
|
||||
# # Add to training data
|
||||
# with self.training_lock:
|
||||
# self.training_data.append((features, action_idx, reward))
|
||||
|
||||
# # Limit training data size
|
||||
# if len(self.training_data) > self.max_training_samples:
|
||||
# # Sort by reward (highest first) and keep top samples
|
||||
# self.training_data.sort(key=lambda x: x[2], reverse=True)
|
||||
# self.training_data = self.training_data[:self.max_training_samples]
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error adding training sample: {e}")
|
||||
|
||||
# def train(self, epochs: int = 1) -> Dict[str, float]:
|
||||
# """
|
||||
# Train the model with collected data and inference history
|
||||
|
||||
# Args:
|
||||
# epochs: Number of epochs to train for
|
||||
|
||||
# Returns:
|
||||
# Dict[str, float]: Training metrics
|
||||
# """
|
||||
# try:
|
||||
# # Track training timing
|
||||
# training_start_time = datetime.now()
|
||||
# training_start = training_start_time.timestamp()
|
||||
|
||||
# with self.training_lock:
|
||||
# # Get additional training data from inference history
|
||||
# self._load_training_data_from_inference_history()
|
||||
|
||||
# # Check if we have enough data
|
||||
# if len(self.training_data) < self.batch_size:
|
||||
# logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
|
||||
# return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
|
||||
|
||||
# # Ensure model is on correct device before training
|
||||
# self._ensure_model_on_device()
|
||||
|
||||
# # Set model to training mode
|
||||
# self.model.train()
|
||||
|
||||
# # Create optimizer
|
||||
# optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
||||
|
||||
# # Training metrics
|
||||
# total_loss = 0.0
|
||||
# correct_predictions = 0
|
||||
# total_predictions = 0
|
||||
|
||||
# # Train for specified number of epochs
|
||||
# for epoch in range(epochs):
|
||||
# # Shuffle training data
|
||||
# np.random.shuffle(self.training_data)
|
||||
|
||||
# # Process in batches
|
||||
# for i in range(0, len(self.training_data), self.batch_size):
|
||||
# batch = self.training_data[i:i+self.batch_size]
|
||||
|
||||
# # Skip if batch is too small
|
||||
# if len(batch) < 2:
|
||||
# continue
|
||||
|
||||
# # Prepare batch - ensure all tensors are on the correct device
|
||||
# features = torch.stack([sample[0].to(self.device) for sample in batch])
|
||||
# actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
|
||||
# rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
|
||||
|
||||
# # Zero gradients
|
||||
# optimizer.zero_grad()
|
||||
|
||||
# # Forward pass
|
||||
# q_values, _, _, _, _ = self.model(features)
|
||||
|
||||
# # Calculate loss (CrossEntropyLoss with reward weighting)
|
||||
# # First, apply softmax to get probabilities
|
||||
# probs = torch.softmax(q_values, dim=1)
|
||||
|
||||
# # Get probability of chosen action
|
||||
# chosen_probs = probs[torch.arange(len(actions)), actions]
|
||||
|
||||
# # Calculate negative log likelihood loss
|
||||
# nll_loss = -torch.log(chosen_probs + 1e-10)
|
||||
|
||||
# # Weight by reward (higher reward = higher weight)
|
||||
# # Normalize rewards to [0, 1] range
|
||||
# min_reward = rewards.min()
|
||||
# max_reward = rewards.max()
|
||||
# if max_reward > min_reward:
|
||||
# normalized_rewards = (rewards - min_reward) / (max_reward - min_reward)
|
||||
# else:
|
||||
# normalized_rewards = torch.ones_like(rewards)
|
||||
|
||||
# # Apply reward weighting (higher reward = higher weight)
|
||||
# weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights
|
||||
|
||||
# # Mean loss
|
||||
# loss = weighted_loss.mean()
|
||||
|
||||
# # Backward pass
|
||||
# loss.backward()
|
||||
|
||||
# # Update weights
|
||||
# optimizer.step()
|
||||
|
||||
# # Update metrics
|
||||
# total_loss += loss.item()
|
||||
|
||||
# # Calculate accuracy
|
||||
# predicted_actions = torch.argmax(q_values, dim=1)
|
||||
# correct_predictions += (predicted_actions == actions).sum().item()
|
||||
# total_predictions += len(actions)
|
||||
|
||||
# # Validate training - detect overfitting
|
||||
# if total_predictions > 0:
|
||||
# current_accuracy = correct_predictions / total_predictions
|
||||
# if current_accuracy >= 0.99:
|
||||
# logger.warning(f"CNN training shows suspiciously high accuracy: {current_accuracy:.4f} - possible overfitting")
|
||||
# # Add regularization to prevent overfitting
|
||||
# l2_reg = 0.01 * sum(p.pow(2.0).sum() for p in self.model.parameters())
|
||||
# loss = loss + l2_reg
|
||||
# logger.info("Added L2 regularization to prevent overfitting")
|
||||
|
||||
# # Calculate final metrics
|
||||
# avg_loss = total_loss / (len(self.training_data) / self.batch_size)
|
||||
# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
# # Calculate training duration
|
||||
# training_end_time = datetime.now()
|
||||
# training_duration = (training_end_time.timestamp() - training_start) * 1000 # Convert to milliseconds
|
||||
|
||||
# # Update training metrics
|
||||
# self.last_training_time = training_start_time
|
||||
# self.last_training_duration = training_duration
|
||||
# self.last_training_loss = avg_loss
|
||||
# self.training_count += 1
|
||||
|
||||
# # Save checkpoint
|
||||
# self._save_checkpoint(avg_loss, accuracy)
|
||||
|
||||
# logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}, duration={training_duration:.1f}ms")
|
||||
|
||||
# return {
|
||||
# 'loss': avg_loss,
|
||||
# 'accuracy': accuracy,
|
||||
# 'samples': len(self.training_data),
|
||||
# 'duration_ms': training_duration,
|
||||
# 'training_count': self.training_count
|
||||
# }
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error training model: {e}")
|
||||
# return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)}
|
||||
|
||||
# def _save_checkpoint(self, loss: float, accuracy: float):
|
||||
# """
|
||||
# Save model checkpoint
|
||||
|
||||
# Args:
|
||||
# loss: Training loss
|
||||
# accuracy: Training accuracy
|
||||
# """
|
||||
# try:
|
||||
# # Import checkpoint manager
|
||||
# from utils.checkpoint_manager import CheckpointManager
|
||||
|
||||
# # Create checkpoint manager
|
||||
# checkpoint_manager = CheckpointManager(
|
||||
# checkpoint_dir=self.checkpoint_dir,
|
||||
# max_checkpoints=10,
|
||||
# metric_name="accuracy"
|
||||
# )
|
||||
|
||||
# # Create temporary model file
|
||||
# temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp")
|
||||
# self.model.save(temp_path)
|
||||
|
||||
# # Create metrics
|
||||
# metrics = {
|
||||
# 'loss': loss,
|
||||
# 'accuracy': accuracy,
|
||||
# 'samples': len(self.training_data)
|
||||
# }
|
||||
|
||||
# # Create metadata
|
||||
# metadata = {
|
||||
# 'timestamp': datetime.now().isoformat(),
|
||||
# 'model_name': self.model_name,
|
||||
# 'input_shape': self.model.input_shape,
|
||||
# 'n_actions': self.model.n_actions
|
||||
# }
|
||||
|
||||
# # Save checkpoint
|
||||
# checkpoint_path = checkpoint_manager.save_checkpoint(
|
||||
# model_name=self.model_name,
|
||||
# model_path=f"{temp_path}.pt",
|
||||
# metrics=metrics,
|
||||
# metadata=metadata
|
||||
# )
|
||||
|
||||
# # Delete temporary model file
|
||||
# if os.path.exists(f"{temp_path}.pt"):
|
||||
# os.remove(f"{temp_path}.pt")
|
||||
|
||||
# logger.info(f"Model checkpoint saved to {checkpoint_path}")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error saving checkpoint: {e}")
|
||||
|
||||
# def _load_training_data_from_inference_history(self):
|
||||
# """Load training data from inference history for continuous learning"""
|
||||
# try:
|
||||
# from utils.database_manager import get_database_manager
|
||||
|
||||
# db_manager = get_database_manager()
|
||||
|
||||
# # Get recent inference records with input features
|
||||
# inference_records = db_manager.get_inference_records_for_training(
|
||||
# model_name=self.model_name,
|
||||
# hours_back=24, # Last 24 hours
|
||||
# limit=1000
|
||||
# )
|
||||
|
||||
# if not inference_records:
|
||||
# logger.debug("No inference records found for training")
|
||||
# return
|
||||
|
||||
# # Convert inference records to training samples
|
||||
# # For now, use a simple approach: treat high-confidence predictions as ground truth
|
||||
# for record in inference_records:
|
||||
# if record.input_features is not None and record.confidence > 0.7:
|
||||
# # Convert action to index
|
||||
# actions = ['BUY', 'SELL', 'HOLD']
|
||||
# if record.action in actions:
|
||||
# action_idx = actions.index(record.action)
|
||||
|
||||
# # Use confidence as a proxy for reward (high confidence = good prediction)
|
||||
# reward = record.confidence * 2 - 1 # Scale to [-1, 1]
|
||||
|
||||
# # Convert features to tensor
|
||||
# features_tensor = torch.tensor(record.input_features, dtype=torch.float32, device=self.device)
|
||||
|
||||
# # Add to training data if not already present (avoid duplicates)
|
||||
# sample_exists = any(
|
||||
# torch.equal(features_tensor, existing[0])
|
||||
# for existing in self.training_data
|
||||
# )
|
||||
|
||||
# if not sample_exists:
|
||||
# self.training_data.append((features_tensor, action_idx, reward))
|
||||
|
||||
# logger.info(f"Loaded {len(inference_records)} inference records for training, total training samples: {len(self.training_data)}")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error loading training data from inference history: {e}")
|
||||
|
||||
# def evaluate_predictions_against_outcomes(self, hours_back: int = 1) -> Dict[str, float]:
|
||||
# """
|
||||
# Evaluate past predictions against actual market outcomes
|
||||
|
||||
# Args:
|
||||
# hours_back: How many hours back to evaluate
|
||||
|
||||
# Returns:
|
||||
# Dict with evaluation metrics
|
||||
# """
|
||||
# try:
|
||||
# from utils.database_manager import get_database_manager
|
||||
|
||||
# db_manager = get_database_manager()
|
||||
|
||||
# # Get inference records from the specified time period
|
||||
# inference_records = db_manager.get_inference_records_for_training(
|
||||
# model_name=self.model_name,
|
||||
# hours_back=hours_back,
|
||||
# limit=100
|
||||
# )
|
||||
|
||||
# if not inference_records:
|
||||
# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
|
||||
|
||||
# # For now, use a simple evaluation based on confidence
|
||||
# # In a real implementation, this would compare against actual price movements
|
||||
# correct_predictions = 0
|
||||
# total_predictions = len(inference_records)
|
||||
|
||||
# # Simple heuristic: high confidence predictions are more likely to be correct
|
||||
# for record in inference_records:
|
||||
# if record.confidence > 0.8: # High confidence threshold
|
||||
# correct_predictions += 1
|
||||
# elif record.confidence > 0.6: # Medium confidence
|
||||
# correct_predictions += 0.5
|
||||
|
||||
# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
# logger.info(f"Prediction evaluation: {correct_predictions:.1f}/{total_predictions} = {accuracy:.3f} accuracy")
|
||||
|
||||
# return {
|
||||
# 'accuracy': accuracy,
|
||||
# 'total_predictions': total_predictions,
|
||||
# 'correct_predictions': correct_predictions
|
||||
# }
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error evaluating predictions: {e}")
|
||||
# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
|
||||
@@ -1,464 +0,0 @@
|
||||
"""
|
||||
Enhanced Trading Orchestrator
|
||||
|
||||
Central coordination hub for the multi-modal trading system that manages:
|
||||
- Data subscription and management
|
||||
- Model inference coordination
|
||||
- Cross-model data feeding
|
||||
- Training pipeline orchestration
|
||||
- Decision making using Mixture of Experts
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from core.data_provider import DataProvider
|
||||
from core.trading_action import TradingAction
|
||||
from utils.tensorboard_logger import TensorBoardLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
"""Extensible model output format supporting all model types"""
|
||||
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
|
||||
model_name: str # Specific model identifier
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
confidence: float
|
||||
predictions: Dict[str, Any] # Model-specific predictions
|
||||
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
|
||||
|
||||
@dataclass
|
||||
class BaseDataInput:
|
||||
"""Unified base data input for all models"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
ohlcv_data: Dict[str, Any] = field(default_factory=dict) # Multi-timeframe OHLCV
|
||||
cob_data: Optional[Dict[str, Any]] = None # COB buckets for 1s timeframe
|
||||
technical_indicators: Dict[str, float] = field(default_factory=dict)
|
||||
pivot_points: List[Any] = field(default_factory=list)
|
||||
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict) # From all models
|
||||
market_microstructure: Dict[str, Any] = field(default_factory=dict) # Order flow, etc.
|
||||
|
||||
@dataclass
|
||||
class COBData:
|
||||
"""Cumulative Order Book data for price buckets"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
current_price: float
|
||||
bucket_size: float # $1 for ETH, $10 for BTC
|
||||
price_buckets: Dict[float, Dict[str, float]] = field(default_factory=dict) # price -> {bid_volume, ask_volume, etc.}
|
||||
bid_ask_imbalance: Dict[float, float] = field(default_factory=dict) # price -> imbalance ratio
|
||||
volume_weighted_prices: Dict[float, float] = field(default_factory=dict) # price -> VWAP within bucket
|
||||
order_flow_metrics: Dict[str, float] = field(default_factory=dict) # Various order flow indicators
|
||||
|
||||
class EnhancedTradingOrchestrator:
|
||||
"""
|
||||
Enhanced Trading Orchestrator implementing the design specification
|
||||
|
||||
Coordinates data flow, model inference, and decision making for the multi-modal trading system.
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider, symbols: List[str], enhanced_rl_training: bool = False, model_registry: Dict = None):
|
||||
"""Initialize the enhanced orchestrator"""
|
||||
self.data_provider = data_provider
|
||||
self.symbols = symbols
|
||||
self.enhanced_rl_training = enhanced_rl_training
|
||||
self.model_registry = model_registry or {}
|
||||
|
||||
# Data management
|
||||
self.data_buffers = {symbol: {} for symbol in symbols}
|
||||
self.last_update_times = {symbol: {} for symbol in symbols}
|
||||
|
||||
# Model output storage
|
||||
self.model_outputs = {symbol: {} for symbol in symbols}
|
||||
self.model_output_history = {symbol: {} for symbol in symbols}
|
||||
|
||||
# Training pipeline
|
||||
self.training_data = {symbol: [] for symbol in symbols}
|
||||
self.tensorboard_logger = TensorBoardLogger("runs", f"orchestrator_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||||
|
||||
# COB integration
|
||||
self.cob_data = {symbol: None for symbol in symbols}
|
||||
|
||||
# Performance tracking
|
||||
self.performance_metrics = {
|
||||
'inference_count': 0,
|
||||
'successful_states': 0,
|
||||
'total_episodes': 0
|
||||
}
|
||||
|
||||
logger.info("Enhanced Trading Orchestrator initialized")
|
||||
|
||||
async def start_cob_integration(self):
|
||||
"""Start COB data integration for real-time market microstructure"""
|
||||
try:
|
||||
# Subscribe to COB data updates
|
||||
self.data_provider.subscribe_to_cob_data(self._on_cob_data_update)
|
||||
logger.info("COB integration started")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting COB integration: {e}")
|
||||
|
||||
async def start_realtime_processing(self):
|
||||
"""Start real-time data processing"""
|
||||
try:
|
||||
# Subscribe to tick data for real-time processing
|
||||
for symbol in self.symbols:
|
||||
self.data_provider.subscribe_to_ticks(
|
||||
callback=self._on_tick_data,
|
||||
symbols=[symbol],
|
||||
subscriber_name=f"orchestrator_{symbol}"
|
||||
)
|
||||
|
||||
logger.info("Real-time processing started")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting real-time processing: {e}")
|
||||
|
||||
def _on_cob_data_update(self, symbol: str, cob_data: dict):
|
||||
"""Handle COB data updates"""
|
||||
try:
|
||||
# Process and store COB data
|
||||
self.cob_data[symbol] = self._process_cob_data(symbol, cob_data)
|
||||
logger.debug(f"COB data updated for {symbol}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing COB data for {symbol}: {e}")
|
||||
|
||||
def _process_cob_data(self, symbol: str, cob_data: dict) -> COBData:
|
||||
"""Process raw COB data into structured format"""
|
||||
try:
|
||||
# Determine bucket size based on symbol
|
||||
bucket_size = 1.0 if 'ETH' in symbol else 10.0
|
||||
|
||||
# Extract current price
|
||||
stats = cob_data.get('stats', {})
|
||||
current_price = stats.get('mid_price', 0)
|
||||
|
||||
# Create COB data structure
|
||||
cob = COBData(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
current_price=current_price,
|
||||
bucket_size=bucket_size
|
||||
)
|
||||
|
||||
# Process order book data into price buckets
|
||||
bids = cob_data.get('bids', [])
|
||||
asks = cob_data.get('asks', [])
|
||||
|
||||
# Create price buckets around current price
|
||||
bucket_count = 20 # ±20 buckets
|
||||
for i in range(-bucket_count, bucket_count + 1):
|
||||
bucket_price = current_price + (i * bucket_size)
|
||||
cob.price_buckets[bucket_price] = {
|
||||
'bid_volume': 0.0,
|
||||
'ask_volume': 0.0
|
||||
}
|
||||
|
||||
# Aggregate bid volumes into buckets
|
||||
for price, volume in bids:
|
||||
bucket_price = round(price / bucket_size) * bucket_size
|
||||
if bucket_price in cob.price_buckets:
|
||||
cob.price_buckets[bucket_price]['bid_volume'] += volume
|
||||
|
||||
# Aggregate ask volumes into buckets
|
||||
for price, volume in asks:
|
||||
bucket_price = round(price / bucket_size) * bucket_size
|
||||
if bucket_price in cob.price_buckets:
|
||||
cob.price_buckets[bucket_price]['ask_volume'] += volume
|
||||
|
||||
# Calculate bid/ask imbalances
|
||||
for price, volumes in cob.price_buckets.items():
|
||||
bid_vol = volumes['bid_volume']
|
||||
ask_vol = volumes['ask_volume']
|
||||
total_vol = bid_vol + ask_vol
|
||||
if total_vol > 0:
|
||||
cob.bid_ask_imbalance[price] = (bid_vol - ask_vol) / total_vol
|
||||
else:
|
||||
cob.bid_ask_imbalance[price] = 0.0
|
||||
|
||||
# Calculate volume-weighted prices
|
||||
for price, volumes in cob.price_buckets.items():
|
||||
bid_vol = volumes['bid_volume']
|
||||
ask_vol = volumes['ask_volume']
|
||||
total_vol = bid_vol + ask_vol
|
||||
if total_vol > 0:
|
||||
cob.volume_weighted_prices[price] = (
|
||||
(price * bid_vol) + (price * ask_vol)
|
||||
) / total_vol
|
||||
else:
|
||||
cob.volume_weighted_prices[price] = price
|
||||
|
||||
# Calculate order flow metrics
|
||||
cob.order_flow_metrics = {
|
||||
'total_bid_volume': sum(v['bid_volume'] for v in cob.price_buckets.values()),
|
||||
'total_ask_volume': sum(v['ask_volume'] for v in cob.price_buckets.values()),
|
||||
'bid_ask_ratio': 0.0 if cob.order_flow_metrics['total_ask_volume'] == 0 else
|
||||
cob.order_flow_metrics['total_bid_volume'] / cob.order_flow_metrics['total_ask_volume']
|
||||
}
|
||||
|
||||
return cob
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing COB data for {symbol}: {e}")
|
||||
return COBData(symbol=symbol, timestamp=datetime.now(), current_price=0, bucket_size=bucket_size)
|
||||
|
||||
def _on_tick_data(self, tick):
|
||||
"""Handle incoming tick data"""
|
||||
try:
|
||||
# Update data buffers
|
||||
symbol = tick.symbol
|
||||
if symbol not in self.data_buffers:
|
||||
self.data_buffers[symbol] = {}
|
||||
|
||||
# Store tick data
|
||||
if 'ticks' not in self.data_buffers[symbol]:
|
||||
self.data_buffers[symbol]['ticks'] = []
|
||||
self.data_buffers[symbol]['ticks'].append(tick)
|
||||
|
||||
# Keep only last 1000 ticks
|
||||
if len(self.data_buffers[symbol]['ticks']) > 1000:
|
||||
self.data_buffers[symbol]['ticks'] = self.data_buffers[symbol]['ticks'][-1000:]
|
||||
|
||||
# Update last update time
|
||||
self.last_update_times[symbol]['tick'] = datetime.now()
|
||||
|
||||
logger.debug(f"Tick data updated for {symbol}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tick data: {e}")
|
||||
|
||||
def build_comprehensive_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Build comprehensive RL state with 13,400 features as specified
|
||||
|
||||
Returns:
|
||||
np.ndarray: State vector with 13,400 features
|
||||
"""
|
||||
try:
|
||||
# Initialize state vector
|
||||
state_size = 13400
|
||||
state = np.zeros(state_size, dtype=np.float32)
|
||||
|
||||
# Get latest data
|
||||
ohlcv_data = self.data_provider.get_latest_candles(symbol, '1s', limit=100)
|
||||
cob_data = self.cob_data.get(symbol)
|
||||
|
||||
# Feature index tracking
|
||||
idx = 0
|
||||
|
||||
# 1. OHLCV features (4000 features)
|
||||
if ohlcv_data is not None and not ohlcv_data.empty:
|
||||
# Use last 100 1s candles (40 features each: O,H,L,C,V + 36 indicators)
|
||||
for i in range(min(100, len(ohlcv_data))):
|
||||
if idx + 40 <= state_size:
|
||||
row = ohlcv_data.iloc[-(i+1)]
|
||||
state[idx] = row.get('open', 0) / 100000 # Normalized
|
||||
state[idx+1] = row.get('high', 0) / 100000
|
||||
state[idx+2] = row.get('low', 0) / 100000
|
||||
state[idx+3] = row.get('close', 0) / 100000
|
||||
state[idx+4] = row.get('volume', 0) / 1000000
|
||||
|
||||
# Add technical indicators if available
|
||||
indicator_idx = 5
|
||||
for col in ['sma_10', 'sma_20', 'ema_12', 'ema_26', 'rsi_14',
|
||||
'macd', 'bb_upper', 'bb_lower', 'atr', 'adx']:
|
||||
if col in row and idx + indicator_idx < state_size:
|
||||
state[idx + indicator_idx] = row[col] / 100000
|
||||
indicator_idx += 1
|
||||
|
||||
idx += 40
|
||||
|
||||
# 2. COB features (8000 features)
|
||||
if cob_data and idx + 8000 <= state_size:
|
||||
# Use 200 price buckets (40 features each)
|
||||
bucket_prices = sorted(cob_data.price_buckets.keys())
|
||||
for i, price in enumerate(bucket_prices[:200]):
|
||||
if idx + 40 <= state_size:
|
||||
bucket = cob_data.price_buckets[price]
|
||||
state[idx] = bucket.get('bid_volume', 0) / 1000000 # Normalized
|
||||
state[idx+1] = bucket.get('ask_volume', 0) / 1000000
|
||||
state[idx+2] = cob_data.bid_ask_imbalance.get(price, 0)
|
||||
state[idx+3] = cob_data.volume_weighted_prices.get(price, price) / 100000
|
||||
|
||||
# Additional COB metrics
|
||||
state[idx+4] = cob_data.order_flow_metrics.get('total_bid_volume', 0) / 10000000
|
||||
state[idx+5] = cob_data.order_flow_metrics.get('total_ask_volume', 0) / 10000000
|
||||
state[idx+6] = cob_data.order_flow_metrics.get('bid_ask_ratio', 0)
|
||||
|
||||
idx += 40
|
||||
|
||||
# 3. Technical indicator features (1000 features)
|
||||
# Already included in OHLCV section above
|
||||
|
||||
# 4. Market microstructure features (400 features)
|
||||
if cob_data and idx + 400 <= state_size:
|
||||
# Add order flow metrics
|
||||
metrics = list(cob_data.order_flow_metrics.values())
|
||||
for i, metric in enumerate(metrics[:400]):
|
||||
if idx + i < state_size:
|
||||
state[idx + i] = metric
|
||||
|
||||
# Log state building success
|
||||
self.performance_metrics['successful_states'] += 1
|
||||
logger.debug(f"Comprehensive RL state built for {symbol}: {len(state)} features")
|
||||
|
||||
# Log to TensorBoard
|
||||
self.tensorboard_logger.log_state_metrics(
|
||||
symbol=symbol,
|
||||
state_info={
|
||||
'size': len(state),
|
||||
'quality': 1.0,
|
||||
'feature_counts': {
|
||||
'total': len(state),
|
||||
'non_zero': np.count_nonzero(state)
|
||||
}
|
||||
},
|
||||
step=self.performance_metrics['successful_states']
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building comprehensive RL state for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def calculate_enhanced_pivot_reward(self, trade_decision: Dict, market_data: Dict, trade_outcome: Dict) -> float:
|
||||
"""
|
||||
Calculate enhanced pivot-based reward
|
||||
|
||||
Args:
|
||||
trade_decision: Trading decision with action and confidence
|
||||
market_data: Market context data
|
||||
trade_outcome: Actual trade results
|
||||
|
||||
Returns:
|
||||
float: Enhanced reward value
|
||||
"""
|
||||
try:
|
||||
# Base reward from PnL
|
||||
pnl_reward = trade_outcome.get('net_pnl', 0) / 100 # Normalize
|
||||
|
||||
# Confidence weighting
|
||||
confidence = trade_decision.get('confidence', 0.5)
|
||||
confidence_reward = confidence * 0.2
|
||||
|
||||
# Volatility adjustment
|
||||
volatility = market_data.get('volatility', 0.01)
|
||||
volatility_reward = (1.0 - volatility * 10) * 0.1 # Prefer low volatility
|
||||
|
||||
# Order flow alignment
|
||||
order_flow = market_data.get('order_flow_strength', 0)
|
||||
order_flow_reward = order_flow * 0.2
|
||||
|
||||
# Pivot alignment bonus (if near pivot in favorable direction)
|
||||
pivot_bonus = 0.0
|
||||
if market_data.get('near_pivot', False):
|
||||
action = trade_decision.get('action', '').upper()
|
||||
pivot_type = market_data.get('pivot_type', '').upper()
|
||||
|
||||
# Bonus for buying near support or selling near resistance
|
||||
if (action == 'BUY' and pivot_type == 'LOW') or \
|
||||
(action == 'SELL' and pivot_type == 'HIGH'):
|
||||
pivot_bonus = 0.5
|
||||
|
||||
# Calculate final reward
|
||||
enhanced_reward = pnl_reward + confidence_reward + volatility_reward + order_flow_reward + pivot_bonus
|
||||
|
||||
# Log to TensorBoard
|
||||
self.tensorboard_logger.log_scalars('Rewards/Components', {
|
||||
'pnl_component': pnl_reward,
|
||||
'confidence': confidence_reward,
|
||||
'volatility': volatility_reward,
|
||||
'order_flow': order_flow_reward,
|
||||
'pivot_bonus': pivot_bonus
|
||||
}, self.performance_metrics['total_episodes'])
|
||||
|
||||
self.tensorboard_logger.log_scalar('Rewards/Enhanced', enhanced_reward, self.performance_metrics['total_episodes'])
|
||||
|
||||
logger.debug(f"Enhanced reward calculated: {enhanced_reward}")
|
||||
return enhanced_reward
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating enhanced pivot reward: {e}")
|
||||
return 0.0
|
||||
|
||||
async def make_coordinated_decisions(self) -> Dict[str, TradingAction]:
|
||||
"""
|
||||
Make coordinated trading decisions using all available models
|
||||
|
||||
Returns:
|
||||
Dict[str, TradingAction]: Trading actions for each symbol
|
||||
"""
|
||||
try:
|
||||
decisions = {}
|
||||
|
||||
# For each symbol, coordinate model inference
|
||||
for symbol in self.symbols:
|
||||
# Build comprehensive state for RL model
|
||||
rl_state = self.build_comprehensive_rl_state(symbol)
|
||||
|
||||
if rl_state is not None:
|
||||
# Store state for training
|
||||
self.performance_metrics['total_episodes'] += 1
|
||||
|
||||
# Create mock RL decision (in a real implementation, this would call the RL model)
|
||||
action = 'BUY' if np.mean(rl_state[:100]) > 0.5 else 'SELL'
|
||||
confidence = min(1.0, max(0.0, np.std(rl_state) * 10))
|
||||
|
||||
# Create trading action
|
||||
decisions[symbol] = TradingAction(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
source='rl_orchestrator'
|
||||
)
|
||||
|
||||
logger.info(f"Coordinated decision for {symbol}: {action} (confidence: {confidence:.3f})")
|
||||
else:
|
||||
logger.warning(f"Failed to build state for {symbol}, skipping decision")
|
||||
|
||||
self.performance_metrics['inference_count'] += 1
|
||||
return decisions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making coordinated decisions: {e}")
|
||||
return {}
|
||||
|
||||
def _get_symbol_correlation(self, symbol1: str, symbol2: str) -> float:
|
||||
"""
|
||||
Calculate correlation between two symbols
|
||||
|
||||
Args:
|
||||
symbol1: First symbol
|
||||
symbol2: Second symbol
|
||||
|
||||
Returns:
|
||||
float: Correlation coefficient (-1 to 1)
|
||||
"""
|
||||
try:
|
||||
# Get recent price data for both symbols
|
||||
data1 = self.data_provider.get_latest_candles(symbol1, '1m', limit=50)
|
||||
data2 = self.data_provider.get_latest_candles(symbol2, '1m', limit=50)
|
||||
|
||||
if data1 is None or data2 is None or data1.empty or data2.empty:
|
||||
return 0.0
|
||||
|
||||
# Align data by timestamp
|
||||
merged = data1[['close']].join(data2[['close']], lsuffix='_1', rsuffix='_2', how='inner')
|
||||
|
||||
if len(merged) < 10:
|
||||
return 0.0
|
||||
|
||||
# Calculate correlation
|
||||
correlation = merged['close_1'].corr(merged['close_2'])
|
||||
return correlation if not np.isnan(correlation) else 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating symbol correlation: {e}")
|
||||
return 0.0
|
||||
```
|
||||
@@ -1,775 +0,0 @@
|
||||
"""
|
||||
Enhanced Training Integration Module
|
||||
|
||||
This module provides comprehensive integration between the training data collection system,
|
||||
CNN training pipeline, RL training pipeline, and your existing infrastructure.
|
||||
|
||||
Key Features:
|
||||
- Real-time integration with existing DataProvider
|
||||
- Coordinated training across CNN and RL models
|
||||
- Automatic outcome validation and profitability tracking
|
||||
- Integration with existing COB RL model
|
||||
- Performance monitoring and optimization
|
||||
- Seamless connection to existing orchestrator and trading executor
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Import existing components
|
||||
from .data_provider import DataProvider
|
||||
from .orchestrator import Orchestrator
|
||||
from .trading_executor import TradingExecutor
|
||||
|
||||
# Import our training system components
|
||||
from .training_data_collector import (
|
||||
TrainingDataCollector,
|
||||
get_training_data_collector
|
||||
)
|
||||
from .cnn_training_pipeline import (
|
||||
CNNPivotPredictor,
|
||||
CNNTrainer,
|
||||
get_cnn_trainer
|
||||
)
|
||||
from .rl_training_pipeline import (
|
||||
RLTradingAgent,
|
||||
RLTrainer,
|
||||
get_rl_trainer
|
||||
)
|
||||
from .training_integration import TrainingIntegration
|
||||
|
||||
# Import existing RL model
|
||||
try:
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
except ImportError:
|
||||
logger.warning("Could not import COBRLModelInterface - using fallback")
|
||||
COBRLModelInterface = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class EnhancedTrainingConfig:
|
||||
"""Enhanced configuration for comprehensive training integration"""
|
||||
# Data collection
|
||||
collection_interval: float = 1.0
|
||||
min_data_completeness: float = 0.8
|
||||
|
||||
# Training triggers
|
||||
min_episodes_for_cnn_training: int = 100
|
||||
min_experiences_for_rl_training: int = 200
|
||||
training_frequency_minutes: int = 30
|
||||
|
||||
# Profitability thresholds
|
||||
min_profitability_for_replay: float = 0.1
|
||||
high_profitability_threshold: float = 0.5
|
||||
|
||||
# Model integration
|
||||
use_existing_cob_rl_model: bool = True
|
||||
enable_cross_model_learning: bool = True
|
||||
|
||||
# Performance optimization
|
||||
max_concurrent_training_sessions: int = 2
|
||||
enable_background_validation: bool = True
|
||||
|
||||
class EnhancedTrainingIntegration:
|
||||
"""Enhanced training integration with existing infrastructure"""
|
||||
|
||||
def __init__(self,
|
||||
data_provider: DataProvider,
|
||||
orchestrator: Orchestrator = None,
|
||||
trading_executor: TradingExecutor = None,
|
||||
config: EnhancedTrainingConfig = None):
|
||||
|
||||
self.data_provider = data_provider
|
||||
self.orchestrator = orchestrator
|
||||
self.trading_executor = trading_executor
|
||||
self.config = config or EnhancedTrainingConfig()
|
||||
|
||||
# Initialize training components
|
||||
self.data_collector = get_training_data_collector()
|
||||
|
||||
# Initialize CNN components
|
||||
self.cnn_model = CNNPivotPredictor()
|
||||
self.cnn_trainer = get_cnn_trainer(self.cnn_model)
|
||||
|
||||
# Initialize RL components
|
||||
if self.config.use_existing_cob_rl_model and COBRLModelInterface:
|
||||
self.existing_rl_model = COBRLModelInterface()
|
||||
logger.info("Using existing COB RL model")
|
||||
else:
|
||||
self.existing_rl_model = None
|
||||
|
||||
self.rl_agent = RLTradingAgent()
|
||||
self.rl_trainer = get_rl_trainer(self.rl_agent)
|
||||
|
||||
# Integration state
|
||||
self.is_running = False
|
||||
self.training_threads = {}
|
||||
self.validation_thread = None
|
||||
|
||||
# Performance tracking
|
||||
self.integration_stats = {
|
||||
'total_data_packages': 0,
|
||||
'cnn_training_sessions': 0,
|
||||
'rl_training_sessions': 0,
|
||||
'profitable_predictions': 0,
|
||||
'total_predictions': 0,
|
||||
'cross_model_improvements': 0,
|
||||
'last_update': datetime.now()
|
||||
}
|
||||
|
||||
# Model prediction tracking
|
||||
self.recent_predictions = {}
|
||||
self.prediction_outcomes = {}
|
||||
|
||||
# Cross-model learning
|
||||
self.model_performance_history = {
|
||||
'cnn': [],
|
||||
'rl': [],
|
||||
'orchestrator': []
|
||||
}
|
||||
|
||||
logger.info("Enhanced Training Integration initialized")
|
||||
logger.info(f"CNN model parameters: {sum(p.numel() for p in self.cnn_model.parameters()):,}")
|
||||
logger.info(f"RL agent parameters: {sum(p.numel() for p in self.rl_agent.parameters()):,}")
|
||||
logger.info(f"Using existing COB RL model: {self.existing_rl_model is not None}")
|
||||
|
||||
def start_enhanced_integration(self):
|
||||
"""Start the enhanced training integration system"""
|
||||
if self.is_running:
|
||||
logger.warning("Enhanced training integration already running")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start data collection
|
||||
self.data_collector.start_collection()
|
||||
|
||||
# Start CNN training
|
||||
if self.config.min_episodes_for_cnn_training > 0:
|
||||
for symbol in self.data_provider.symbols:
|
||||
self.cnn_trainer.start_real_time_training(symbol)
|
||||
|
||||
# Start coordinated training thread
|
||||
self.training_threads['coordinator'] = threading.Thread(
|
||||
target=self._training_coordinator_worker,
|
||||
daemon=True
|
||||
)
|
||||
self.training_threads['coordinator'].start()
|
||||
|
||||
# Start data collection and validation
|
||||
self.training_threads['data_collector'] = threading.Thread(
|
||||
target=self._enhanced_data_collection_worker,
|
||||
daemon=True
|
||||
)
|
||||
self.training_threads['data_collector'].start()
|
||||
|
||||
# Start outcome validation if enabled
|
||||
if self.config.enable_background_validation:
|
||||
self.validation_thread = threading.Thread(
|
||||
target=self._outcome_validation_worker,
|
||||
daemon=True
|
||||
)
|
||||
self.validation_thread.start()
|
||||
|
||||
logger.info("Enhanced training integration started")
|
||||
|
||||
def stop_enhanced_integration(self):
|
||||
"""Stop the enhanced training integration system"""
|
||||
self.is_running = False
|
||||
|
||||
# Stop data collection
|
||||
self.data_collector.stop_collection()
|
||||
|
||||
# Stop CNN training
|
||||
self.cnn_trainer.stop_training()
|
||||
|
||||
# Wait for threads to finish
|
||||
for thread_name, thread in self.training_threads.items():
|
||||
thread.join(timeout=10)
|
||||
logger.info(f"Stopped {thread_name} thread")
|
||||
|
||||
if self.validation_thread:
|
||||
self.validation_thread.join(timeout=5)
|
||||
|
||||
logger.info("Enhanced training integration stopped")
|
||||
|
||||
def _enhanced_data_collection_worker(self):
|
||||
"""Enhanced data collection with real-time model integration"""
|
||||
logger.info("Enhanced data collection worker started")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
for symbol in self.data_provider.symbols:
|
||||
self._collect_enhanced_training_data(symbol)
|
||||
|
||||
time.sleep(self.config.collection_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in enhanced data collection: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
logger.info("Enhanced data collection worker stopped")
|
||||
|
||||
def _collect_enhanced_training_data(self, symbol: str):
|
||||
"""Collect enhanced training data with model predictions"""
|
||||
try:
|
||||
# Get comprehensive market data
|
||||
market_data = self._get_comprehensive_market_data(symbol)
|
||||
|
||||
if not market_data or not self._validate_market_data(market_data):
|
||||
return
|
||||
|
||||
# Get current model predictions
|
||||
model_predictions = self._get_all_model_predictions(symbol, market_data)
|
||||
|
||||
# Create enhanced features
|
||||
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
|
||||
rl_state = self._create_enhanced_rl_state(symbol, market_data, model_predictions)
|
||||
|
||||
# Collect training data with predictions
|
||||
episode_id = self.data_collector.collect_training_data(
|
||||
symbol=symbol,
|
||||
ohlcv_data=market_data['ohlcv'],
|
||||
tick_data=market_data['ticks'],
|
||||
cob_data=market_data['cob'],
|
||||
technical_indicators=market_data['indicators'],
|
||||
pivot_points=market_data['pivots'],
|
||||
cnn_features=cnn_features,
|
||||
rl_state=rl_state,
|
||||
orchestrator_context=market_data['context'],
|
||||
model_predictions=model_predictions
|
||||
)
|
||||
|
||||
if episode_id:
|
||||
# Store predictions for outcome validation
|
||||
self.recent_predictions[episode_id] = {
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': symbol,
|
||||
'predictions': model_predictions,
|
||||
'market_data': market_data
|
||||
}
|
||||
|
||||
# Add RL experience if we have action
|
||||
if 'rl_action' in model_predictions:
|
||||
self._add_rl_experience(symbol, market_data, model_predictions, episode_id)
|
||||
|
||||
self.integration_stats['total_data_packages'] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting enhanced training data for {symbol}: {e}")
|
||||
|
||||
def _get_comprehensive_market_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get comprehensive market data from all sources"""
|
||||
try:
|
||||
market_data = {}
|
||||
|
||||
# OHLCV data
|
||||
ohlcv_data = {}
|
||||
for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']:
|
||||
df = self.data_provider.get_historical_data(symbol, timeframe, limit=300, refresh=True)
|
||||
if df is not None and not df.empty:
|
||||
ohlcv_data[timeframe] = df
|
||||
market_data['ohlcv'] = ohlcv_data
|
||||
|
||||
# Tick data
|
||||
market_data['ticks'] = self._get_recent_tick_data(symbol)
|
||||
|
||||
# COB data
|
||||
market_data['cob'] = self._get_cob_data(symbol)
|
||||
|
||||
# Technical indicators
|
||||
market_data['indicators'] = self._get_technical_indicators(symbol)
|
||||
|
||||
# Pivot points
|
||||
market_data['pivots'] = self._get_pivot_points(symbol)
|
||||
|
||||
# Market context
|
||||
market_data['context'] = self._get_market_context(symbol)
|
||||
|
||||
return market_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting comprehensive market data: {e}")
|
||||
return {}
|
||||
|
||||
def _get_all_model_predictions(self, symbol: str, market_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get predictions from all available models"""
|
||||
predictions = {}
|
||||
|
||||
try:
|
||||
# CNN predictions
|
||||
if self.cnn_model and market_data.get('ohlcv'):
|
||||
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
|
||||
if cnn_features is not None:
|
||||
cnn_input = torch.from_numpy(cnn_features).float().unsqueeze(0)
|
||||
|
||||
# Reshape for CNN (add channel dimension)
|
||||
cnn_input = cnn_input.view(1, 10, -1) # Assuming 10 channels
|
||||
|
||||
with torch.no_grad():
|
||||
cnn_outputs = self.cnn_model(cnn_input)
|
||||
predictions['cnn'] = {
|
||||
'pivot_logits': cnn_outputs['pivot_logits'].cpu().numpy(),
|
||||
'pivot_price': cnn_outputs['pivot_price'].cpu().numpy(),
|
||||
'confidence': cnn_outputs['confidence'].cpu().numpy(),
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
# RL predictions
|
||||
if self.rl_agent and market_data.get('cob'):
|
||||
rl_state = self._create_enhanced_rl_state(symbol, market_data, predictions)
|
||||
if rl_state is not None:
|
||||
action, confidence = self.rl_agent.select_action(rl_state, epsilon=0.1)
|
||||
predictions['rl'] = {
|
||||
'action': action,
|
||||
'confidence': confidence,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
predictions['rl_action'] = action
|
||||
|
||||
# Existing COB RL model predictions
|
||||
if self.existing_rl_model and market_data.get('cob'):
|
||||
cob_features = market_data['cob'].get('cob_features', [])
|
||||
if cob_features and len(cob_features) >= 2000:
|
||||
cob_array = np.array(cob_features[:2000], dtype=np.float32)
|
||||
cob_prediction = self.existing_rl_model.predict(cob_array)
|
||||
predictions['cob_rl'] = {
|
||||
'predicted_direction': cob_prediction.get('predicted_direction', 1),
|
||||
'confidence': cob_prediction.get('confidence', 0.5),
|
||||
'value': cob_prediction.get('value', 0.0),
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
# Orchestrator predictions (if available)
|
||||
if self.orchestrator:
|
||||
try:
|
||||
# This would integrate with your orchestrator's prediction method
|
||||
orchestrator_prediction = self._get_orchestrator_prediction(symbol, market_data, predictions)
|
||||
if orchestrator_prediction:
|
||||
predictions['orchestrator'] = orchestrator_prediction
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get orchestrator prediction: {e}")
|
||||
|
||||
return predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model predictions: {e}")
|
||||
return {}
|
||||
|
||||
def _add_rl_experience(self, symbol: str, market_data: Dict[str, Any],
|
||||
predictions: Dict[str, Any], episode_id: str):
|
||||
"""Add RL experience to the training buffer"""
|
||||
try:
|
||||
# Create RL state
|
||||
state = self._create_enhanced_rl_state(symbol, market_data, predictions)
|
||||
if state is None:
|
||||
return
|
||||
|
||||
# Get action from predictions
|
||||
action = predictions.get('rl_action', 1) # Default to HOLD
|
||||
|
||||
# Calculate immediate reward (placeholder - would be updated with actual outcome)
|
||||
reward = 0.0
|
||||
|
||||
# Create next state (same as current for now - would be updated)
|
||||
next_state = state.copy()
|
||||
|
||||
# Market context
|
||||
market_context = {
|
||||
'symbol': symbol,
|
||||
'episode_id': episode_id,
|
||||
'timestamp': datetime.now(),
|
||||
'market_session': market_data['context'].get('market_session', 'unknown'),
|
||||
'volatility_regime': market_data['context'].get('volatility_regime', 'unknown')
|
||||
}
|
||||
|
||||
# Add experience
|
||||
experience_id = self.rl_trainer.add_experience(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=False,
|
||||
market_context=market_context,
|
||||
cnn_predictions=predictions.get('cnn'),
|
||||
confidence_score=predictions.get('rl', {}).get('confidence', 0.0)
|
||||
)
|
||||
|
||||
if experience_id:
|
||||
logger.debug(f"Added RL experience: {experience_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding RL experience: {e}")
|
||||
|
||||
def _training_coordinator_worker(self):
|
||||
"""Coordinate training across all models"""
|
||||
logger.info("Training coordinator worker started")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# Check if we should trigger training
|
||||
for symbol in self.data_provider.symbols:
|
||||
self._check_and_trigger_training(symbol)
|
||||
|
||||
# Wait before next check
|
||||
time.sleep(self.config.training_frequency_minutes * 60)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training coordinator: {e}")
|
||||
time.sleep(60)
|
||||
|
||||
logger.info("Training coordinator worker stopped")
|
||||
|
||||
def _check_and_trigger_training(self, symbol: str):
|
||||
"""Check conditions and trigger training if needed"""
|
||||
try:
|
||||
# Get training episodes and experiences
|
||||
episodes = self.data_collector.get_high_priority_episodes(symbol, limit=1000)
|
||||
|
||||
# Check CNN training conditions
|
||||
if len(episodes) >= self.config.min_episodes_for_cnn_training:
|
||||
profitable_episodes = [ep for ep in episodes if ep.actual_outcome.is_profitable]
|
||||
|
||||
if len(profitable_episodes) >= 20: # Minimum profitable episodes
|
||||
logger.info(f"Triggering CNN training for {symbol} with {len(profitable_episodes)} profitable episodes")
|
||||
|
||||
results = self.cnn_trainer.train_on_profitable_episodes(
|
||||
symbol=symbol,
|
||||
min_profitability=self.config.min_profitability_for_replay,
|
||||
max_episodes=len(profitable_episodes)
|
||||
)
|
||||
|
||||
if results.get('status') == 'success':
|
||||
self.integration_stats['cnn_training_sessions'] += 1
|
||||
logger.info(f"CNN training completed for {symbol}")
|
||||
|
||||
# Check RL training conditions
|
||||
buffer_stats = self.rl_trainer.experience_buffer.get_buffer_statistics()
|
||||
total_experiences = buffer_stats.get('total_experiences', 0)
|
||||
|
||||
if total_experiences >= self.config.min_experiences_for_rl_training:
|
||||
profitable_experiences = buffer_stats.get('profitable_experiences', 0)
|
||||
|
||||
if profitable_experiences >= 50: # Minimum profitable experiences
|
||||
logger.info(f"Triggering RL training with {profitable_experiences} profitable experiences")
|
||||
|
||||
results = self.rl_trainer.train_on_profitable_experiences(
|
||||
min_profitability=self.config.min_profitability_for_replay,
|
||||
max_experiences=min(profitable_experiences, 500),
|
||||
batch_size=32
|
||||
)
|
||||
|
||||
if results.get('status') == 'success':
|
||||
self.integration_stats['rl_training_sessions'] += 1
|
||||
logger.info("RL training completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking training conditions for {symbol}: {e}")
|
||||
|
||||
def _outcome_validation_worker(self):
|
||||
"""Background worker for validating prediction outcomes"""
|
||||
logger.info("Outcome validation worker started")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
self._validate_recent_predictions()
|
||||
time.sleep(300) # Check every 5 minutes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in outcome validation: {e}")
|
||||
time.sleep(60)
|
||||
|
||||
logger.info("Outcome validation worker stopped")
|
||||
|
||||
def _validate_recent_predictions(self):
|
||||
"""Validate recent predictions against actual outcomes"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
validation_delay = timedelta(hours=1) # Wait 1 hour to validate
|
||||
|
||||
validated_predictions = []
|
||||
|
||||
for episode_id, prediction_data in self.recent_predictions.items():
|
||||
prediction_time = prediction_data['timestamp']
|
||||
|
||||
if current_time - prediction_time >= validation_delay:
|
||||
# Validate this prediction
|
||||
outcome = self._calculate_prediction_outcome(prediction_data)
|
||||
|
||||
if outcome:
|
||||
self.prediction_outcomes[episode_id] = outcome
|
||||
|
||||
# Update RL experience if exists
|
||||
if 'rl_action' in prediction_data['predictions']:
|
||||
self._update_rl_experience_outcome(episode_id, outcome)
|
||||
|
||||
# Update statistics
|
||||
if outcome['is_profitable']:
|
||||
self.integration_stats['profitable_predictions'] += 1
|
||||
self.integration_stats['total_predictions'] += 1
|
||||
|
||||
validated_predictions.append(episode_id)
|
||||
|
||||
# Remove validated predictions
|
||||
for episode_id in validated_predictions:
|
||||
del self.recent_predictions[episode_id]
|
||||
|
||||
if validated_predictions:
|
||||
logger.info(f"Validated {len(validated_predictions)} predictions")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating predictions: {e}")
|
||||
|
||||
def _calculate_prediction_outcome(self, prediction_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Calculate actual outcome for a prediction"""
|
||||
try:
|
||||
symbol = prediction_data['symbol']
|
||||
prediction_time = prediction_data['timestamp']
|
||||
|
||||
# Get price data after prediction
|
||||
current_df = self.data_provider.get_historical_data(symbol, '1m', limit=100, refresh=True)
|
||||
|
||||
if current_df is None or current_df.empty:
|
||||
return None
|
||||
|
||||
# Find price at prediction time and current price
|
||||
prediction_price = prediction_data['market_data']['ohlcv'].get('1m', pd.DataFrame())
|
||||
if prediction_price.empty:
|
||||
return None
|
||||
|
||||
base_price = float(prediction_price['close'].iloc[-1])
|
||||
current_price = float(current_df['close'].iloc[-1])
|
||||
|
||||
# Calculate outcome
|
||||
price_change = (current_price - base_price) / base_price
|
||||
is_profitable = abs(price_change) > 0.005 # 0.5% threshold
|
||||
|
||||
return {
|
||||
'episode_id': prediction_data.get('episode_id'),
|
||||
'base_price': base_price,
|
||||
'current_price': current_price,
|
||||
'price_change': price_change,
|
||||
'is_profitable': is_profitable,
|
||||
'profitability_score': abs(price_change) * 10, # Scale to 0-1 range
|
||||
'validation_time': datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating prediction outcome: {e}")
|
||||
return None
|
||||
|
||||
def _update_rl_experience_outcome(self, episode_id: str, outcome: Dict[str, Any]):
|
||||
"""Update RL experience with actual outcome"""
|
||||
try:
|
||||
# Find the experience ID associated with this episode
|
||||
# This is a simplified approach - in practice you'd maintain better mapping
|
||||
actual_profit = outcome['price_change']
|
||||
|
||||
# Determine optimal action based on outcome
|
||||
if outcome['price_change'] > 0.01:
|
||||
optimal_action = 2 # BUY
|
||||
elif outcome['price_change'] < -0.01:
|
||||
optimal_action = 0 # SELL
|
||||
else:
|
||||
optimal_action = 1 # HOLD
|
||||
|
||||
# Update experience (this would need proper experience ID mapping)
|
||||
# For now, we'll update the most recent experience
|
||||
# In practice, you'd maintain a mapping between episodes and experiences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating RL experience outcome: {e}")
|
||||
|
||||
def get_integration_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive integration statistics"""
|
||||
stats = self.integration_stats.copy()
|
||||
|
||||
# Add component statistics
|
||||
stats['data_collector'] = self.data_collector.get_collection_statistics()
|
||||
stats['cnn_trainer'] = self.cnn_trainer.get_training_statistics()
|
||||
stats['rl_trainer'] = self.rl_trainer.get_training_statistics()
|
||||
|
||||
# Add performance metrics
|
||||
stats['is_running'] = self.is_running
|
||||
stats['active_symbols'] = len(self.data_provider.symbols)
|
||||
stats['recent_predictions_count'] = len(self.recent_predictions)
|
||||
stats['validated_outcomes_count'] = len(self.prediction_outcomes)
|
||||
|
||||
# Calculate profitability rate
|
||||
if stats['total_predictions'] > 0:
|
||||
stats['overall_profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
|
||||
else:
|
||||
stats['overall_profitability_rate'] = 0.0
|
||||
|
||||
return stats
|
||||
|
||||
def trigger_manual_training(self, training_type: str = 'all', symbol: str = None) -> Dict[str, Any]:
|
||||
"""Manually trigger training"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
if training_type in ['all', 'cnn']:
|
||||
symbols = [symbol] if symbol else self.data_provider.symbols
|
||||
for sym in symbols:
|
||||
cnn_results = self.cnn_trainer.train_on_profitable_episodes(
|
||||
symbol=sym,
|
||||
min_profitability=0.1,
|
||||
max_episodes=200
|
||||
)
|
||||
results[f'cnn_{sym}'] = cnn_results
|
||||
|
||||
if training_type in ['all', 'rl']:
|
||||
rl_results = self.rl_trainer.train_on_profitable_experiences(
|
||||
min_profitability=0.1,
|
||||
max_experiences=500,
|
||||
batch_size=32
|
||||
)
|
||||
results['rl'] = rl_results
|
||||
|
||||
return {'status': 'success', 'results': results}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in manual training trigger: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
# Helper methods (simplified implementations)
|
||||
def _get_recent_tick_data(self, symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Get recent tick data"""
|
||||
# Implementation would get tick data from data provider
|
||||
return []
|
||||
|
||||
def _get_cob_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get COB data"""
|
||||
# Implementation would get COB data from data provider
|
||||
return {}
|
||||
|
||||
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
|
||||
"""Get technical indicators"""
|
||||
# Implementation would get indicators from data provider
|
||||
return {}
|
||||
|
||||
def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Get pivot points"""
|
||||
# Implementation would get pivot points from data provider
|
||||
return []
|
||||
|
||||
def _get_market_context(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get market context"""
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now(),
|
||||
'market_session': 'unknown',
|
||||
'volatility_regime': 'unknown'
|
||||
}
|
||||
|
||||
def _validate_market_data(self, market_data: Dict[str, Any]) -> bool:
|
||||
"""Validate market data completeness"""
|
||||
required_fields = ['ohlcv', 'indicators']
|
||||
return all(field in market_data for field in required_fields)
|
||||
|
||||
def _create_enhanced_cnn_features(self, symbol: str, market_data: Dict[str, Any]) -> Optional[np.ndarray]:
|
||||
"""Create enhanced CNN features"""
|
||||
try:
|
||||
# Simplified feature creation
|
||||
features = []
|
||||
|
||||
# Add OHLCV features
|
||||
for timeframe in ['1m', '5m', '15m', '1h']:
|
||||
if timeframe in market_data.get('ohlcv', {}):
|
||||
df = market_data['ohlcv'][timeframe]
|
||||
if not df.empty:
|
||||
ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
if len(ohlcv_values) > 0:
|
||||
recent_values = ohlcv_values[-60:].flatten()
|
||||
features.extend(recent_values)
|
||||
|
||||
# Pad to target size
|
||||
target_size = 3000 # 10 channels * 300 sequence length
|
||||
if len(features) < target_size:
|
||||
features.extend([0.0] * (target_size - len(features)))
|
||||
else:
|
||||
features = features[:target_size]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating CNN features: {e}")
|
||||
return None
|
||||
|
||||
def _create_enhanced_rl_state(self, symbol: str, market_data: Dict[str, Any],
|
||||
predictions: Dict[str, Any] = None) -> Optional[np.ndarray]:
|
||||
"""Create enhanced RL state"""
|
||||
try:
|
||||
state_features = []
|
||||
|
||||
# Add market features
|
||||
if '1m' in market_data.get('ohlcv', {}):
|
||||
df = market_data['ohlcv']['1m']
|
||||
if not df.empty:
|
||||
latest = df.iloc[-1]
|
||||
state_features.extend([
|
||||
latest['open'], latest['high'],
|
||||
latest['low'], latest['close'], latest['volume']
|
||||
])
|
||||
|
||||
# Add technical indicators
|
||||
indicators = market_data.get('indicators', {})
|
||||
for value in indicators.values():
|
||||
state_features.append(value)
|
||||
|
||||
# Add model predictions as features
|
||||
if predictions:
|
||||
if 'cnn' in predictions:
|
||||
cnn_pred = predictions['cnn']
|
||||
state_features.extend(cnn_pred.get('pivot_logits', [0, 0, 0]))
|
||||
state_features.append(cnn_pred.get('confidence', [0.0])[0])
|
||||
|
||||
if 'cob_rl' in predictions:
|
||||
cob_pred = predictions['cob_rl']
|
||||
state_features.append(cob_pred.get('predicted_direction', 1))
|
||||
state_features.append(cob_pred.get('confidence', 0.5))
|
||||
|
||||
# Pad to target size
|
||||
target_size = 2000
|
||||
if len(state_features) < target_size:
|
||||
state_features.extend([0.0] * (target_size - len(state_features)))
|
||||
else:
|
||||
state_features = state_features[:target_size]
|
||||
|
||||
return np.array(state_features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating RL state: {e}")
|
||||
return None
|
||||
|
||||
def _get_orchestrator_prediction(self, symbol: str, market_data: Dict[str, Any],
|
||||
predictions: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Get orchestrator prediction"""
|
||||
# This would integrate with your orchestrator
|
||||
return None
|
||||
|
||||
# Global instance
|
||||
enhanced_training_integration = None
|
||||
|
||||
def get_enhanced_training_integration(data_provider: DataProvider = None,
|
||||
orchestrator: Orchestrator = None,
|
||||
trading_executor: TradingExecutor = None) -> EnhancedTrainingIntegration:
|
||||
"""Get global enhanced training integration instance"""
|
||||
global enhanced_training_integration
|
||||
if enhanced_training_integration is None:
|
||||
if data_provider is None:
|
||||
raise ValueError("DataProvider required for first initialization")
|
||||
enhanced_training_integration = EnhancedTrainingIntegration(
|
||||
data_provider, orchestrator, trading_executor
|
||||
)
|
||||
return enhanced_training_integration
|
||||
@@ -1,8 +0,0 @@
|
||||
# MEXC Web Client Module
|
||||
#
|
||||
# This module provides web-based trading capabilities for MEXC futures trading
|
||||
# which is not supported by their official API.
|
||||
|
||||
from .mexc_futures_client import MEXCFuturesWebClient
|
||||
|
||||
__all__ = ['MEXCFuturesWebClient']
|
||||
@@ -1,555 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MEXC Auto Browser with Request Interception
|
||||
|
||||
This script automatically spawns a ChromeDriver instance and captures
|
||||
all MEXC futures trading requests in real-time, including full request
|
||||
and response data needed for reverse engineering.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import queue
|
||||
|
||||
# Selenium imports
|
||||
try:
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
from selenium.webdriver.chrome.service import Service
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.common.exceptions import TimeoutException, WebDriverException
|
||||
from webdriver_manager.chrome import ChromeDriverManager
|
||||
except ImportError:
|
||||
print("Please install selenium and webdriver-manager:")
|
||||
print("pip install selenium webdriver-manager")
|
||||
sys.exit(1)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MEXCRequestInterceptor:
|
||||
"""
|
||||
Automatically spawns ChromeDriver and intercepts all MEXC API requests
|
||||
"""
|
||||
|
||||
def __init__(self, headless: bool = False, save_to_file: bool = True):
|
||||
"""
|
||||
Initialize the request interceptor
|
||||
|
||||
Args:
|
||||
headless: Run browser in headless mode
|
||||
save_to_file: Save captured requests to JSON file
|
||||
"""
|
||||
self.driver = None
|
||||
self.headless = headless
|
||||
self.save_to_file = save_to_file
|
||||
self.captured_requests = []
|
||||
self.captured_responses = []
|
||||
self.session_cookies = {}
|
||||
self.monitoring = False
|
||||
self.request_queue = queue.Queue()
|
||||
|
||||
# File paths for saving data
|
||||
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.requests_file = f"mexc_requests_{self.timestamp}.json"
|
||||
self.cookies_file = f"mexc_cookies_{self.timestamp}.json"
|
||||
|
||||
def setup_browser(self):
|
||||
"""Setup Chrome browser with necessary options"""
|
||||
chrome_options = webdriver.ChromeOptions()
|
||||
# Enable headless mode if needed
|
||||
if self.headless:
|
||||
chrome_options.add_argument('--headless')
|
||||
chrome_options.add_argument('--disable-gpu')
|
||||
chrome_options.add_argument('--window-size=1920,1080')
|
||||
chrome_options.add_argument('--disable-extensions')
|
||||
|
||||
# Set up Chrome options with a user data directory to persist session
|
||||
user_data_base_dir = os.path.join(os.getcwd(), 'chrome_user_data')
|
||||
os.makedirs(user_data_base_dir, exist_ok=True)
|
||||
|
||||
# Check for existing session directories
|
||||
session_dirs = [d for d in os.listdir(user_data_base_dir) if d.startswith('session_')]
|
||||
session_dirs.sort(reverse=True) # Sort descending to get the most recent first
|
||||
|
||||
user_data_dir = None
|
||||
if session_dirs:
|
||||
use_existing = input(f"Found {len(session_dirs)} existing sessions. Use an existing session? (y/n): ").lower().strip() == 'y'
|
||||
if use_existing:
|
||||
print("Available sessions:")
|
||||
for i, session in enumerate(session_dirs[:5], 1): # Show up to 5 most recent
|
||||
print(f"{i}. {session}")
|
||||
choice = input("Enter session number (default 1) or any other key for most recent: ")
|
||||
if choice.isdigit() and 1 <= int(choice) <= len(session_dirs):
|
||||
selected_session = session_dirs[int(choice) - 1]
|
||||
else:
|
||||
selected_session = session_dirs[0]
|
||||
user_data_dir = os.path.join(user_data_base_dir, selected_session)
|
||||
print(f"Using session: {selected_session}")
|
||||
|
||||
if user_data_dir is None:
|
||||
user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}')
|
||||
os.makedirs(user_data_dir, exist_ok=True)
|
||||
print(f"Creating new session: session_{self.timestamp}")
|
||||
|
||||
chrome_options.add_argument(f'--user-data-dir={user_data_dir}')
|
||||
|
||||
# Enable logging to capture JS console output and network activity
|
||||
chrome_options.set_capability('goog:loggingPrefs', {
|
||||
'browser': 'ALL',
|
||||
'performance': 'ALL'
|
||||
})
|
||||
|
||||
try:
|
||||
self.driver = webdriver.Chrome(options=chrome_options)
|
||||
except Exception as e:
|
||||
print(f"Failed to start browser with session: {e}")
|
||||
print("Falling back to a new session...")
|
||||
user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}_fallback')
|
||||
os.makedirs(user_data_dir, exist_ok=True)
|
||||
print(f"Creating fallback session: session_{self.timestamp}_fallback")
|
||||
chrome_options = webdriver.ChromeOptions()
|
||||
if self.headless:
|
||||
chrome_options.add_argument('--headless')
|
||||
chrome_options.add_argument('--disable-gpu')
|
||||
chrome_options.add_argument('--window-size=1920,1080')
|
||||
chrome_options.add_argument('--disable-extensions')
|
||||
chrome_options.add_argument(f'--user-data-dir={user_data_dir}')
|
||||
chrome_options.set_capability('goog:loggingPrefs', {
|
||||
'browser': 'ALL',
|
||||
'performance': 'ALL'
|
||||
})
|
||||
self.driver = webdriver.Chrome(options=chrome_options)
|
||||
|
||||
return self.driver
|
||||
|
||||
def start_monitoring(self):
|
||||
"""Start the browser and begin monitoring"""
|
||||
logger.info("Starting MEXC Request Interceptor...")
|
||||
|
||||
try:
|
||||
# Setup ChromeDriver
|
||||
self.driver = self.setup_browser()
|
||||
|
||||
# Navigate to MEXC futures
|
||||
mexc_url = "https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap"
|
||||
logger.info(f"Navigating to: {mexc_url}")
|
||||
self.driver.get(mexc_url)
|
||||
|
||||
# Wait for page load
|
||||
WebDriverWait(self.driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
logger.info("✅ MEXC page loaded successfully!")
|
||||
logger.info("📝 Please log in manually in the browser window")
|
||||
logger.info("🔍 Request monitoring is now active...")
|
||||
|
||||
# Start monitoring in background thread
|
||||
self.monitoring = True
|
||||
monitor_thread = threading.Thread(target=self._monitor_requests, daemon=True)
|
||||
monitor_thread.start()
|
||||
|
||||
# Wait for manual login
|
||||
self._wait_for_login()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start monitoring: {e}")
|
||||
return False
|
||||
|
||||
def _wait_for_login(self):
|
||||
"""Wait for user to log in and show interactive menu"""
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("MEXC REQUEST INTERCEPTOR - INTERACTIVE MODE")
|
||||
logger.info("="*60)
|
||||
|
||||
while True:
|
||||
print("\nOptions:")
|
||||
print("1. Check login status")
|
||||
print("2. Extract current cookies")
|
||||
print("3. Show captured requests summary")
|
||||
print("4. Save captured data to files")
|
||||
print("5. Perform test trade (manual)")
|
||||
print("6. Monitor for 60 seconds")
|
||||
print("0. Stop and exit")
|
||||
|
||||
choice = input("\nEnter choice (0-6): ").strip()
|
||||
|
||||
if choice == "1":
|
||||
self._check_login_status()
|
||||
elif choice == "2":
|
||||
self._extract_cookies()
|
||||
elif choice == "3":
|
||||
self._show_requests_summary()
|
||||
elif choice == "4":
|
||||
self._save_all_data()
|
||||
elif choice == "5":
|
||||
self._guide_test_trade()
|
||||
elif choice == "6":
|
||||
self._monitor_for_duration(60)
|
||||
elif choice == "0":
|
||||
break
|
||||
else:
|
||||
print("Invalid choice. Please try again.")
|
||||
|
||||
self.stop_monitoring()
|
||||
|
||||
def _check_login_status(self):
|
||||
"""Check if user is logged into MEXC"""
|
||||
try:
|
||||
cookies = self.driver.get_cookies()
|
||||
auth_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint']
|
||||
found_auth = []
|
||||
|
||||
for cookie in cookies:
|
||||
if cookie['name'] in auth_cookies and cookie['value']:
|
||||
found_auth.append(cookie['name'])
|
||||
|
||||
if len(found_auth) >= 2:
|
||||
print("✅ LOGIN DETECTED - You appear to be logged in!")
|
||||
print(f" Found auth cookies: {', '.join(found_auth)}")
|
||||
return True
|
||||
else:
|
||||
print("❌ NOT LOGGED IN - Please log in to MEXC in the browser")
|
||||
print(" Missing required authentication cookies")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking login: {e}")
|
||||
return False
|
||||
|
||||
def _extract_cookies(self):
|
||||
"""Extract and display current session cookies"""
|
||||
try:
|
||||
cookies = self.driver.get_cookies()
|
||||
cookie_dict = {}
|
||||
|
||||
for cookie in cookies:
|
||||
cookie_dict[cookie['name']] = cookie['value']
|
||||
|
||||
self.session_cookies = cookie_dict
|
||||
|
||||
print(f"\n📊 Extracted {len(cookie_dict)} cookies:")
|
||||
|
||||
# Show important cookies
|
||||
important = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId']
|
||||
for name in important:
|
||||
if name in cookie_dict:
|
||||
value = cookie_dict[name]
|
||||
display_value = value[:20] + "..." if len(value) > 20 else value
|
||||
print(f" ✅ {name}: {display_value}")
|
||||
else:
|
||||
print(f" ❌ {name}: Missing")
|
||||
|
||||
# Save cookies to file
|
||||
if self.save_to_file:
|
||||
with open(self.cookies_file, 'w') as f:
|
||||
json.dump(cookie_dict, f, indent=2)
|
||||
print(f"\n💾 Cookies saved to: {self.cookies_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error extracting cookies: {e}")
|
||||
|
||||
def _monitor_requests(self):
|
||||
"""Background thread to monitor network requests"""
|
||||
last_log_count = 0
|
||||
|
||||
while self.monitoring:
|
||||
try:
|
||||
# Get performance logs
|
||||
logs = self.driver.get_log('performance')
|
||||
|
||||
for log in logs:
|
||||
try:
|
||||
message = json.loads(log['message'])
|
||||
method = message.get('message', {}).get('method', '')
|
||||
|
||||
# Capture network requests
|
||||
if method == 'Network.requestWillBeSent':
|
||||
self._process_request(message['message']['params'])
|
||||
elif method == 'Network.responseReceived':
|
||||
self._process_response(message['message']['params'])
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
continue
|
||||
|
||||
# Show progress every 10 new requests
|
||||
if len(self.captured_requests) >= last_log_count + 10:
|
||||
last_log_count = len(self.captured_requests)
|
||||
logger.info(f"📈 Captured {len(self.captured_requests)} requests, {len(self.captured_responses)} responses")
|
||||
|
||||
except Exception as e:
|
||||
if self.monitoring: # Only log if we're still supposed to be monitoring
|
||||
logger.debug(f"Monitor error: {e}")
|
||||
|
||||
time.sleep(0.5) # Check every 500ms
|
||||
|
||||
def _process_request(self, request_data):
|
||||
"""Process a captured network request"""
|
||||
try:
|
||||
url = request_data.get('request', {}).get('url', '')
|
||||
|
||||
# Filter for MEXC API requests
|
||||
if self._is_mexc_request(url):
|
||||
request_info = {
|
||||
'type': 'request',
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'url': url,
|
||||
'method': request_data.get('request', {}).get('method', ''),
|
||||
'headers': request_data.get('request', {}).get('headers', {}),
|
||||
'postData': request_data.get('request', {}).get('postData', ''),
|
||||
'requestId': request_data.get('requestId', '')
|
||||
}
|
||||
|
||||
self.captured_requests.append(request_info)
|
||||
|
||||
# Show important requests immediately
|
||||
if ('futures.mexc.com' in url or 'captcha' in url):
|
||||
print(f"\n🚀 CAPTURED REQUEST: {request_info['method']} {url}")
|
||||
if request_info['postData']:
|
||||
print(f" 📄 POST Data: {request_info['postData'][:100]}...")
|
||||
|
||||
# Enhanced captcha detection and detailed logging
|
||||
if 'captcha' in url.lower() or 'robot' in url.lower():
|
||||
logger.info(f"CAPTCHA REQUEST DETECTED: {request_data.get('request', {}).get('method', 'UNKNOWN')} {url}")
|
||||
logger.info(f" Headers: {request_data.get('request', {}).get('headers', {})}")
|
||||
if request_data.get('request', {}).get('postData', ''):
|
||||
logger.info(f" Data: {request_data.get('request', {}).get('postData', '')}")
|
||||
# Attempt to capture related JavaScript or DOM elements (if possible)
|
||||
if self.driver is not None:
|
||||
try:
|
||||
js_snippet = self.driver.execute_script("return document.querySelector('script[src*=\"captcha\"]') ? document.querySelector('script[src*=\"captcha\"]').outerHTML : 'No captcha script found';")
|
||||
logger.info(f" Related JS Snippet: {js_snippet}")
|
||||
except Exception as e:
|
||||
logger.warning(f" Could not capture JS snippet: {e}")
|
||||
try:
|
||||
dom_element = self.driver.execute_script("return document.querySelector('div[id*=\"captcha\"]') ? document.querySelector('div[id*=\"captcha\"]').outerHTML : 'No captcha element found';")
|
||||
logger.info(f" Related DOM Element: {dom_element}")
|
||||
except Exception as e:
|
||||
logger.warning(f" Could not capture DOM element: {e}")
|
||||
else:
|
||||
logger.warning(" Driver not initialized, cannot capture JS or DOM elements")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing request: {e}")
|
||||
|
||||
def _process_response(self, response_data):
|
||||
"""Process a captured network response"""
|
||||
try:
|
||||
url = response_data.get('response', {}).get('url', '')
|
||||
|
||||
# Filter for MEXC API responses
|
||||
if self._is_mexc_request(url):
|
||||
response_info = {
|
||||
'type': 'response',
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'url': url,
|
||||
'status': response_data.get('response', {}).get('status', 0),
|
||||
'headers': response_data.get('response', {}).get('headers', {}),
|
||||
'requestId': response_data.get('requestId', '')
|
||||
}
|
||||
|
||||
self.captured_responses.append(response_info)
|
||||
|
||||
# Show important responses immediately
|
||||
if ('futures.mexc.com' in url or 'captcha' in url):
|
||||
status = response_info['status']
|
||||
status_emoji = "✅" if status == 200 else "❌"
|
||||
print(f" {status_emoji} RESPONSE: {status} for {url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing response: {e}")
|
||||
|
||||
def _is_mexc_request(self, url: str) -> bool:
|
||||
"""Check if URL is a relevant MEXC API request"""
|
||||
mexc_indicators = [
|
||||
'futures.mexc.com',
|
||||
'ucgateway/captcha_api',
|
||||
'api/v1/private',
|
||||
'api/v3/order',
|
||||
'mexc.com/api'
|
||||
]
|
||||
|
||||
return any(indicator in url for indicator in mexc_indicators)
|
||||
|
||||
def _show_requests_summary(self):
|
||||
"""Show summary of captured requests"""
|
||||
print(f"\n📊 CAPTURE SUMMARY:")
|
||||
print(f" Total Requests: {len(self.captured_requests)}")
|
||||
print(f" Total Responses: {len(self.captured_responses)}")
|
||||
|
||||
# Group by URL pattern
|
||||
url_counts = {}
|
||||
for req in self.captured_requests:
|
||||
base_url = req['url'].split('?')[0] # Remove query params
|
||||
url_counts[base_url] = url_counts.get(base_url, 0) + 1
|
||||
|
||||
print("\n🔗 Top URLs:")
|
||||
for url, count in sorted(url_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
|
||||
print(f" {count}x {url}")
|
||||
|
||||
# Show recent futures API calls
|
||||
futures_requests = [r for r in self.captured_requests if 'futures.mexc.com' in r['url']]
|
||||
if futures_requests:
|
||||
print(f"\n🚀 Futures API Calls: {len(futures_requests)}")
|
||||
for req in futures_requests[-3:]: # Show last 3
|
||||
print(f" {req['method']} {req['url']}")
|
||||
|
||||
def _save_all_data(self):
|
||||
"""Save all captured data to files"""
|
||||
if not self.save_to_file:
|
||||
print("File saving is disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
# Save requests
|
||||
with open(self.requests_file, 'w') as f:
|
||||
json.dump({
|
||||
'requests': self.captured_requests,
|
||||
'responses': self.captured_responses,
|
||||
'summary': {
|
||||
'total_requests': len(self.captured_requests),
|
||||
'total_responses': len(self.captured_responses),
|
||||
'capture_session': self.timestamp
|
||||
}
|
||||
}, f, indent=2)
|
||||
|
||||
# Save cookies if we have them
|
||||
if self.session_cookies:
|
||||
with open(self.cookies_file, 'w') as f:
|
||||
json.dump(self.session_cookies, f, indent=2)
|
||||
|
||||
print(f"\n💾 Data saved to:")
|
||||
print(f" 📋 Requests: {self.requests_file}")
|
||||
if self.session_cookies:
|
||||
print(f" 🍪 Cookies: {self.cookies_file}")
|
||||
|
||||
# Extract and save CAPTCHA tokens from captured requests
|
||||
captcha_tokens = self.extract_captcha_tokens()
|
||||
if captcha_tokens:
|
||||
captcha_file = f"mexc_captcha_tokens_{self.timestamp}.json"
|
||||
with open(captcha_file, 'w') as f:
|
||||
json.dump(captcha_tokens, f, indent=2)
|
||||
logger.info(f"Saved CAPTCHA tokens to {captcha_file}")
|
||||
else:
|
||||
logger.warning("No CAPTCHA tokens found in captured requests")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error saving data: {e}")
|
||||
|
||||
def _guide_test_trade(self):
|
||||
"""Guide user through performing a test trade"""
|
||||
print("\n🧪 TEST TRADE GUIDE:")
|
||||
print("1. Make sure you're logged into MEXC")
|
||||
print("2. Go to the trading interface")
|
||||
print("3. Try to place a SMALL test trade (it may fail, but we'll capture the requests)")
|
||||
print("4. Watch the console for captured API calls")
|
||||
print("\n⚠️ IMPORTANT: Use very small amounts for testing!")
|
||||
input("\nPress Enter when you're ready to start monitoring...")
|
||||
|
||||
self._monitor_for_duration(120) # Monitor for 2 minutes
|
||||
|
||||
def _monitor_for_duration(self, seconds: int):
|
||||
"""Monitor requests for a specific duration"""
|
||||
print(f"\n🔍 Monitoring requests for {seconds} seconds...")
|
||||
print("Perform your trading actions now!")
|
||||
|
||||
start_time = time.time()
|
||||
initial_count = len(self.captured_requests)
|
||||
|
||||
while time.time() - start_time < seconds:
|
||||
current_count = len(self.captured_requests)
|
||||
new_requests = current_count - initial_count
|
||||
|
||||
remaining = seconds - int(time.time() - start_time)
|
||||
print(f"\r⏱️ Time remaining: {remaining}s | New requests: {new_requests}", end="", flush=True)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
final_count = len(self.captured_requests)
|
||||
new_total = final_count - initial_count
|
||||
print(f"\n✅ Monitoring complete! Captured {new_total} new requests")
|
||||
|
||||
def stop_monitoring(self):
|
||||
"""Stop monitoring and close browser"""
|
||||
logger.info("Stopping request monitoring...")
|
||||
self.monitoring = False
|
||||
|
||||
if self.driver:
|
||||
self.driver.quit()
|
||||
logger.info("Browser closed")
|
||||
|
||||
# Final save
|
||||
if self.save_to_file and (self.captured_requests or self.captured_responses):
|
||||
self._save_all_data()
|
||||
logger.info("Final data save complete")
|
||||
|
||||
def extract_captcha_tokens(self):
|
||||
"""Extract CAPTCHA tokens from captured requests"""
|
||||
captcha_tokens = []
|
||||
for request in self.captured_requests:
|
||||
if 'captcha-token' in request.get('headers', {}):
|
||||
token = request['headers']['captcha-token']
|
||||
captcha_tokens.append({
|
||||
'token': token,
|
||||
'url': request.get('url', ''),
|
||||
'timestamp': request.get('timestamp', '')
|
||||
})
|
||||
elif 'captcha' in request.get('url', '').lower():
|
||||
response = request.get('response', {})
|
||||
if response and 'captcha-token' in response.get('headers', {}):
|
||||
token = response['headers']['captcha-token']
|
||||
captcha_tokens.append({
|
||||
'token': token,
|
||||
'url': request.get('url', ''),
|
||||
'timestamp': request.get('timestamp', '')
|
||||
})
|
||||
return captcha_tokens
|
||||
|
||||
def main():
|
||||
"""Main function to run the interceptor"""
|
||||
print("🚀 MEXC Request Interceptor with ChromeDriver")
|
||||
print("=" * 50)
|
||||
print("This will automatically:")
|
||||
print("✅ Download/setup ChromeDriver")
|
||||
print("✅ Open MEXC futures page")
|
||||
print("✅ Capture all API requests/responses")
|
||||
print("✅ Extract session cookies")
|
||||
print("✅ Save data to JSON files")
|
||||
print("\nPress Ctrl+C to stop at any time")
|
||||
|
||||
# Ask for preferences
|
||||
headless = input("\nRun in headless mode? (y/n): ").lower().strip() == 'y'
|
||||
|
||||
interceptor = MEXCRequestInterceptor(headless=headless, save_to_file=True)
|
||||
|
||||
try:
|
||||
success = interceptor.start_monitoring()
|
||||
if not success:
|
||||
print("❌ Failed to start monitoring")
|
||||
return
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⏹️ Stopping interceptor...")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
finally:
|
||||
interceptor.stop_monitoring()
|
||||
print("\n👋 Goodbye!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,358 +0,0 @@
|
||||
"""
|
||||
MEXC Browser Automation for Cookie Extraction and Request Monitoring
|
||||
|
||||
This module uses Selenium to automate browser interactions and extract
|
||||
session cookies and request data for MEXC futures trading.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.common.exceptions import TimeoutException, WebDriverException
|
||||
from selenium.webdriver.chrome.service import Service
|
||||
from webdriver_manager.chrome import ChromeDriverManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MEXCBrowserAutomation:
|
||||
"""
|
||||
Browser automation for MEXC futures trading session management
|
||||
"""
|
||||
|
||||
def __init__(self, headless: bool = False, proxy: Optional[str] = None):
|
||||
"""
|
||||
Initialize browser automation
|
||||
|
||||
Args:
|
||||
headless: Run browser in headless mode
|
||||
proxy: HTTP proxy to use (format: host:port)
|
||||
"""
|
||||
self.driver = None
|
||||
self.headless = headless
|
||||
self.proxy = proxy
|
||||
self.logged_in = False
|
||||
|
||||
def setup_chrome_driver(self) -> webdriver.Chrome:
|
||||
"""Setup Chrome driver with appropriate options"""
|
||||
chrome_options = Options()
|
||||
|
||||
if self.headless:
|
||||
chrome_options.add_argument("--headless")
|
||||
|
||||
# Basic Chrome options for automation
|
||||
chrome_options.add_argument("--no-sandbox")
|
||||
chrome_options.add_argument("--disable-dev-shm-usage")
|
||||
chrome_options.add_argument("--disable-blink-features=AutomationControlled")
|
||||
chrome_options.add_experimental_option("excludeSwitches", ["enable-automation"])
|
||||
chrome_options.add_experimental_option('useAutomationExtension', False)
|
||||
|
||||
# Set user agent to avoid detection
|
||||
chrome_options.add_argument("--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36")
|
||||
|
||||
# Proxy setup if provided
|
||||
if self.proxy:
|
||||
chrome_options.add_argument(f"--proxy-server=http://{self.proxy}")
|
||||
|
||||
# Enable network logging
|
||||
chrome_options.add_argument("--enable-logging")
|
||||
chrome_options.add_argument("--log-level=0")
|
||||
chrome_options.set_capability("goog:loggingPrefs", {"performance": "ALL"})
|
||||
|
||||
# Automatically download and setup ChromeDriver
|
||||
service = Service(ChromeDriverManager().install())
|
||||
|
||||
try:
|
||||
driver = webdriver.Chrome(service=service, options=chrome_options)
|
||||
|
||||
# Execute script to avoid detection
|
||||
driver.execute_script("Object.defineProperty(navigator, 'webdriver', {get: () => undefined})")
|
||||
|
||||
return driver
|
||||
except WebDriverException as e:
|
||||
logger.error(f"Failed to setup Chrome driver: {e}")
|
||||
raise
|
||||
|
||||
def start_browser(self):
|
||||
"""Start the browser session"""
|
||||
if self.driver is None:
|
||||
logger.info("Starting Chrome browser for MEXC automation")
|
||||
self.driver = self.setup_chrome_driver()
|
||||
logger.info("Browser started successfully")
|
||||
|
||||
def stop_browser(self):
|
||||
"""Stop the browser session"""
|
||||
if self.driver:
|
||||
logger.info("Stopping browser")
|
||||
self.driver.quit()
|
||||
self.driver = None
|
||||
|
||||
def navigate_to_mexc_futures(self, symbol: str = "ETH_USDT"):
|
||||
"""
|
||||
Navigate to MEXC futures trading page
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol to navigate to
|
||||
"""
|
||||
if not self.driver:
|
||||
self.start_browser()
|
||||
|
||||
url = f"https://www.mexc.com/en-GB/futures/{symbol}?type=linear_swap"
|
||||
logger.info(f"Navigating to MEXC futures: {url}")
|
||||
|
||||
self.driver.get(url)
|
||||
|
||||
# Wait for page to load
|
||||
try:
|
||||
WebDriverWait(self.driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
logger.info("MEXC futures page loaded")
|
||||
except TimeoutException:
|
||||
logger.error("Timeout waiting for MEXC page to load")
|
||||
|
||||
def wait_for_login(self, timeout: int = 300) -> bool:
|
||||
"""
|
||||
Wait for user to manually log in to MEXC
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for login (seconds)
|
||||
|
||||
Returns:
|
||||
bool: True if login detected, False if timeout
|
||||
"""
|
||||
logger.info("Please log in to MEXC manually in the browser window")
|
||||
logger.info("Waiting for login completion...")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# Check if we can find elements that indicate logged in state
|
||||
try:
|
||||
# Look for user-specific elements that appear after login
|
||||
cookies = self.driver.get_cookies()
|
||||
|
||||
# Check for authentication cookies
|
||||
auth_cookies = ['uc_token', 'u_id']
|
||||
logged_in_indicators = 0
|
||||
|
||||
for cookie in cookies:
|
||||
if cookie['name'] in auth_cookies and cookie['value']:
|
||||
logged_in_indicators += 1
|
||||
|
||||
if logged_in_indicators >= 2:
|
||||
logger.info("Login detected!")
|
||||
self.logged_in = True
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking login status: {e}")
|
||||
|
||||
time.sleep(2) # Check every 2 seconds
|
||||
|
||||
logger.error(f"Login timeout after {timeout} seconds")
|
||||
return False
|
||||
|
||||
def extract_session_cookies(self) -> Dict[str, str]:
|
||||
"""
|
||||
Extract all cookies from current browser session
|
||||
|
||||
Returns:
|
||||
Dictionary of cookie name-value pairs
|
||||
"""
|
||||
if not self.driver:
|
||||
logger.error("Browser not started")
|
||||
return {}
|
||||
|
||||
cookies = {}
|
||||
|
||||
try:
|
||||
browser_cookies = self.driver.get_cookies()
|
||||
|
||||
for cookie in browser_cookies:
|
||||
cookies[cookie['name']] = cookie['value']
|
||||
|
||||
logger.info(f"Extracted {len(cookies)} cookies from browser session")
|
||||
|
||||
# Log important cookies (without values for security)
|
||||
important_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId']
|
||||
for cookie_name in important_cookies:
|
||||
if cookie_name in cookies:
|
||||
logger.info(f"Found important cookie: {cookie_name}")
|
||||
else:
|
||||
logger.warning(f"Missing important cookie: {cookie_name}")
|
||||
|
||||
return cookies
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract cookies: {e}")
|
||||
return {}
|
||||
|
||||
def monitor_network_requests(self, duration: int = 60) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Monitor network requests for the specified duration
|
||||
|
||||
Args:
|
||||
duration: How long to monitor requests (seconds)
|
||||
|
||||
Returns:
|
||||
List of captured network requests
|
||||
"""
|
||||
if not self.driver:
|
||||
logger.error("Browser not started")
|
||||
return []
|
||||
|
||||
logger.info(f"Starting network monitoring for {duration} seconds")
|
||||
logger.info("Please perform trading actions in the browser (open/close positions)")
|
||||
|
||||
start_time = time.time()
|
||||
captured_requests = []
|
||||
|
||||
while time.time() - start_time < duration:
|
||||
try:
|
||||
# Get performance logs (network requests)
|
||||
logs = self.driver.get_log('performance')
|
||||
|
||||
for log in logs:
|
||||
message = json.loads(log['message'])
|
||||
|
||||
# Filter for relevant MEXC API requests
|
||||
if (message.get('message', {}).get('method') == 'Network.responseReceived'):
|
||||
response = message['message']['params']['response']
|
||||
url = response.get('url', '')
|
||||
|
||||
# Look for futures API calls
|
||||
if ('futures.mexc.com' in url or
|
||||
'ucgateway/captcha_api' in url or
|
||||
'api/v1/private' in url):
|
||||
|
||||
request_data = {
|
||||
'url': url,
|
||||
'method': response.get('mimeType', ''),
|
||||
'status': response.get('status'),
|
||||
'headers': response.get('headers', {}),
|
||||
'timestamp': log['timestamp']
|
||||
}
|
||||
|
||||
captured_requests.append(request_data)
|
||||
logger.info(f"Captured request: {url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in network monitoring: {e}")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
logger.info(f"Network monitoring complete. Captured {len(captured_requests)} requests")
|
||||
return captured_requests
|
||||
|
||||
def perform_test_trade(self, symbol: str = "ETH_USDT", volume: float = 1.0, leverage: int = 200):
|
||||
"""
|
||||
Attempt to perform a test trade to capture the complete request flow
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
volume: Position size
|
||||
leverage: Leverage multiplier
|
||||
"""
|
||||
if not self.logged_in:
|
||||
logger.error("Not logged in - cannot perform test trade")
|
||||
return
|
||||
|
||||
logger.info(f"Attempting test trade: {symbol}, Volume: {volume}, Leverage: {leverage}x")
|
||||
logger.info("This will attempt to click trading interface elements")
|
||||
|
||||
try:
|
||||
# This would need to be implemented based on MEXC's specific UI elements
|
||||
# For now, just wait and let user perform manual actions
|
||||
logger.info("Please manually place a small test trade while monitoring is active")
|
||||
time.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during test trade: {e}")
|
||||
|
||||
def full_session_capture(self, symbol: str = "ETH_USDT") -> Dict[str, Any]:
|
||||
"""
|
||||
Complete session capture workflow
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing cookies and captured requests
|
||||
"""
|
||||
logger.info("Starting full MEXC session capture")
|
||||
|
||||
try:
|
||||
# Start browser and navigate to MEXC
|
||||
self.navigate_to_mexc_futures(symbol)
|
||||
|
||||
# Wait for manual login
|
||||
if not self.wait_for_login():
|
||||
return {'success': False, 'error': 'Login timeout'}
|
||||
|
||||
# Extract session cookies
|
||||
cookies = self.extract_session_cookies()
|
||||
|
||||
if not cookies:
|
||||
return {'success': False, 'error': 'Failed to extract cookies'}
|
||||
|
||||
# Monitor network requests while user performs actions
|
||||
logger.info("Starting network monitoring - please perform trading actions now")
|
||||
requests = self.monitor_network_requests(duration=120) # 2 minutes
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'cookies': cookies,
|
||||
'network_requests': requests,
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in session capture: {e}")
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
finally:
|
||||
self.stop_browser()
|
||||
|
||||
def main():
|
||||
"""Main function for standalone execution"""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
print("MEXC Browser Automation - Session Capture")
|
||||
print("This will open a browser window for you to log into MEXC")
|
||||
print("Make sure you have Chrome browser installed")
|
||||
|
||||
automation = MEXCBrowserAutomation(headless=False)
|
||||
|
||||
try:
|
||||
result = automation.full_session_capture()
|
||||
|
||||
if result['success']:
|
||||
print(f"\nSession capture successful!")
|
||||
print(f"Extracted {len(result['cookies'])} cookies")
|
||||
print(f"Captured {len(result['network_requests'])} network requests")
|
||||
|
||||
# Save results to file
|
||||
output_file = f"mexc_session_capture_{int(time.time())}.json"
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(result, f, indent=2)
|
||||
|
||||
print(f"Results saved to: {output_file}")
|
||||
|
||||
else:
|
||||
print(f"Session capture failed: {result['error']}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nSession capture interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
finally:
|
||||
automation.stop_browser()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,525 +0,0 @@
|
||||
"""
|
||||
MEXC Futures Web Client
|
||||
|
||||
This module implements a web-based client for MEXC futures trading
|
||||
since their official API doesn't support futures (leverage) trading.
|
||||
|
||||
It mimics browser behavior by replicating the exact HTTP requests
|
||||
that the web interface makes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import time
|
||||
import json
|
||||
import hmac
|
||||
import hashlib
|
||||
import base64
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
import glob
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MEXCSessionManager:
|
||||
def __init__(self):
|
||||
self.captcha_token = None
|
||||
|
||||
def get_captcha_token(self) -> str:
|
||||
return self.captcha_token if self.captcha_token else ""
|
||||
|
||||
def save_captcha_token(self, token: str):
|
||||
self.captcha_token = token
|
||||
logger.info("MEXC: Captcha token saved in session manager")
|
||||
|
||||
class MEXCFuturesWebClient:
|
||||
"""
|
||||
MEXC Futures Web Client that mimics browser behavior for futures trading.
|
||||
|
||||
Since MEXC's official API doesn't support futures, this client replicates
|
||||
the exact HTTP requests made by their web interface.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, api_secret: str, user_id: str, base_url: str = 'https://www.mexc.com', headless: bool = True):
|
||||
"""
|
||||
Initialize the MEXC Futures Web Client
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
api_secret: API secret for authentication
|
||||
user_id: User ID for authentication
|
||||
base_url: Base URL for the MEXC website
|
||||
headless: Whether to run the browser in headless mode
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.user_id = user_id
|
||||
self.base_url = base_url
|
||||
self.is_authenticated = False
|
||||
self.headless = headless
|
||||
self.session = requests.Session()
|
||||
self.session_manager = MEXCSessionManager() # Adding session_manager attribute
|
||||
self.captcha_url = f'{base_url}/ucgateway/captcha_api'
|
||||
self.futures_api_url = "https://futures.mexc.com/api/v1"
|
||||
|
||||
# Setup default headers that mimic a real browser
|
||||
self.setup_browser_headers()
|
||||
|
||||
def setup_browser_headers(self):
|
||||
"""Setup default headers that mimic Chrome browser"""
|
||||
self.session.headers.update({
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36',
|
||||
'Accept': '*/*',
|
||||
'Accept-Language': 'en-GB,en-US;q=0.9,en;q=0.8',
|
||||
'Accept-Encoding': 'gzip, deflate, br',
|
||||
'sec-ch-ua': '"Chromium";v="136", "Google Chrome";v="136", "Not.A/Brand";v="99"',
|
||||
'sec-ch-ua-mobile': '?0',
|
||||
'sec-ch-ua-platform': '"Windows"',
|
||||
'sec-fetch-dest': 'empty',
|
||||
'sec-fetch-mode': 'cors',
|
||||
'sec-fetch-site': 'same-origin',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Pragma': 'no-cache',
|
||||
'Referer': f'{self.base_url}/en-GB/futures/ETH_USDT?type=linear_swap',
|
||||
'Language': 'English',
|
||||
'X-Language': 'en-GB',
|
||||
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
|
||||
'trochilus-uid': str(self.user_id) if self.user_id is not None else ''
|
||||
})
|
||||
|
||||
def load_session_cookies(self, cookies: Dict[str, str]):
|
||||
"""
|
||||
Load session cookies from browser
|
||||
|
||||
Args:
|
||||
cookies: Dictionary of cookie name-value pairs
|
||||
"""
|
||||
for name, value in cookies.items():
|
||||
self.session.cookies.set(name, value)
|
||||
|
||||
# Extract important session info from cookies
|
||||
self.auth_token = cookies.get('uc_token')
|
||||
self.user_id = cookies.get('u_id')
|
||||
self.fingerprint = cookies.get('x-mxc-fingerprint')
|
||||
self.visitor_id = cookies.get('mexc_fingerprint_visitorId')
|
||||
|
||||
if self.auth_token and self.user_id:
|
||||
self.is_authenticated = True
|
||||
logger.info("MEXC: Loaded authenticated session")
|
||||
else:
|
||||
logger.warning("MEXC: Session cookies incomplete - authentication may fail")
|
||||
|
||||
def extract_cookies_from_browser(self, cookie_string: str) -> Dict[str, str]:
|
||||
"""
|
||||
Extract cookies from a browser cookie string
|
||||
|
||||
Args:
|
||||
cookie_string: Raw cookie string from browser (copy from Network tab)
|
||||
|
||||
Returns:
|
||||
Dictionary of parsed cookies
|
||||
"""
|
||||
cookies = {}
|
||||
cookie_pairs = cookie_string.split(';')
|
||||
|
||||
for pair in cookie_pairs:
|
||||
if '=' in pair:
|
||||
name, value = pair.strip().split('=', 1)
|
||||
cookies[name] = value
|
||||
|
||||
return cookies
|
||||
|
||||
def verify_captcha(self, symbol: str, side: str, leverage: str) -> bool:
|
||||
"""
|
||||
Verify captcha for robot trading protection
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH_USDT')
|
||||
side: 'openlong', 'closelong', 'openshort', 'closeshort'
|
||||
leverage: Leverage string (e.g., '200X')
|
||||
|
||||
Returns:
|
||||
bool: True if captcha verification successful
|
||||
"""
|
||||
if not self.is_authenticated:
|
||||
logger.error("MEXC: Cannot verify captcha - not authenticated")
|
||||
return False
|
||||
|
||||
# Build captcha endpoint URL
|
||||
endpoint = f"robot.future.{side}.{symbol}.{leverage}"
|
||||
url = f"{self.captcha_url}/{endpoint}"
|
||||
|
||||
# Attempt to get captcha token from session manager
|
||||
captcha_token = self.session_manager.get_captcha_token()
|
||||
if not captcha_token:
|
||||
logger.warning("MEXC: No captcha token available, attempting to fetch from browser")
|
||||
captcha_token = self._extract_captcha_token_from_browser()
|
||||
if captcha_token:
|
||||
self.session_manager.save_captcha_token(captcha_token)
|
||||
else:
|
||||
logger.error("MEXC: Failed to extract captcha token from browser")
|
||||
return False
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Language': 'en-GB',
|
||||
'Referer': f'{self.base_url}/en-GB/futures/{symbol}?type=linear_swap',
|
||||
'trochilus-uid': self.user_id if self.user_id else '',
|
||||
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
|
||||
'captcha-token': captcha_token
|
||||
}
|
||||
|
||||
logger.info(f"MEXC: Verifying captcha for {endpoint}")
|
||||
try:
|
||||
response = self.session.get(url, headers=headers, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('success'):
|
||||
logger.info(f"MEXC: Captcha verified successfully for {endpoint}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"MEXC: Captcha verification failed for {endpoint}: {data}")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"MEXC: Captcha verification request failed with status {response.status_code}: {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: Captcha verification error for {endpoint}: {str(e)}")
|
||||
return False
|
||||
|
||||
def _extract_captcha_token_from_browser(self) -> str:
|
||||
"""
|
||||
Extract captcha token from browser session using stored cookies or requests.
|
||||
This method looks for the most recent mexc_captcha_tokens JSON file to retrieve a token.
|
||||
"""
|
||||
try:
|
||||
# Look for the most recent mexc_captcha_tokens file
|
||||
captcha_files = glob.glob("mexc_captcha_tokens_*.json")
|
||||
if not captcha_files:
|
||||
logger.error("MEXC: No CAPTCHA token files found")
|
||||
return ""
|
||||
|
||||
# Sort files by timestamp (most recent first)
|
||||
latest_file = max(captcha_files, key=os.path.getctime)
|
||||
logger.info(f"MEXC: Using CAPTCHA token file {latest_file}")
|
||||
|
||||
with open(latest_file, 'r') as f:
|
||||
captcha_data = json.load(f)
|
||||
|
||||
if captcha_data and isinstance(captcha_data, list) and len(captcha_data) > 0:
|
||||
# Return the most recent token
|
||||
return captcha_data[0].get('token', '')
|
||||
else:
|
||||
logger.error("MEXC: No valid CAPTCHA tokens found in file")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: Error extracting captcha token from browser data: {str(e)}")
|
||||
return ""
|
||||
|
||||
def generate_signature(self, method: str, path: str, params: Dict[str, Any],
|
||||
timestamp: int, nonce: int) -> str:
|
||||
"""
|
||||
Generate signature for MEXC futures API requests
|
||||
|
||||
This is reverse-engineered from the browser requests
|
||||
"""
|
||||
# This is a placeholder - the actual signature generation would need
|
||||
# to be reverse-engineered from the browser's JavaScript
|
||||
# For now, return empty string and rely on cookie authentication
|
||||
return ""
|
||||
|
||||
def open_long_position(self, symbol: str, volume: float, leverage: int = 200,
|
||||
price: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Open a long futures position
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH_USDT')
|
||||
volume: Position size (contracts)
|
||||
leverage: Leverage multiplier (default 200)
|
||||
price: Limit price (None for market order)
|
||||
|
||||
Returns:
|
||||
dict: Order response with order ID
|
||||
"""
|
||||
if not self.is_authenticated:
|
||||
logger.error("MEXC: Cannot open position - not authenticated")
|
||||
return {'success': False, 'error': 'Not authenticated'}
|
||||
|
||||
# First verify captcha
|
||||
if not self.verify_captcha(symbol, 'openlong', f'{leverage}X'):
|
||||
logger.error("MEXC: Captcha verification failed for opening long position")
|
||||
return {'success': False, 'error': 'Captcha verification failed'}
|
||||
|
||||
# Prepare order parameters based on the request dump
|
||||
timestamp = int(time.time() * 1000)
|
||||
nonce = timestamp
|
||||
|
||||
order_data = {
|
||||
'symbol': symbol,
|
||||
'side': 1, # 1 = long, 2 = short
|
||||
'openType': 2, # Open position
|
||||
'type': '5', # Market order (might be '1' for limit)
|
||||
'vol': volume,
|
||||
'leverage': leverage,
|
||||
'marketCeiling': False,
|
||||
'priceProtect': '0',
|
||||
'ts': timestamp,
|
||||
'mhash': self._generate_mhash(), # This needs to be implemented
|
||||
'mtoken': self.visitor_id
|
||||
}
|
||||
|
||||
# Add price for limit orders
|
||||
if price is not None:
|
||||
order_data['price'] = price
|
||||
order_data['type'] = '1' # Limit order
|
||||
|
||||
# Add encrypted parameters (these would need proper implementation)
|
||||
order_data['p0'] = self._encrypt_p0(order_data) # Placeholder
|
||||
order_data['k0'] = self._encrypt_k0(order_data) # Placeholder
|
||||
order_data['chash'] = self._generate_chash(order_data) # Placeholder
|
||||
|
||||
# Setup headers for the order request
|
||||
headers = {
|
||||
'Authorization': self.auth_token,
|
||||
'Content-Type': 'application/json',
|
||||
'Language': 'English',
|
||||
'x-language': 'en-GB',
|
||||
'x-mxc-nonce': str(nonce),
|
||||
'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce),
|
||||
'trochilus-uid': self.user_id,
|
||||
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
|
||||
'Referer': 'https://www.mexc.com/'
|
||||
}
|
||||
|
||||
# Make the order request
|
||||
url = f"{self.futures_api_url}/private/order/create"
|
||||
|
||||
try:
|
||||
# First make OPTIONS request (preflight)
|
||||
options_response = self.session.options(url, headers=headers, timeout=10)
|
||||
|
||||
if options_response.status_code == 200:
|
||||
# Now make the actual POST request
|
||||
response = self.session.post(url, json=order_data, headers=headers, timeout=15)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('success') and data.get('code') == 0:
|
||||
order_id = data.get('data', {}).get('orderId')
|
||||
logger.info(f"MEXC: Long position opened successfully - Order ID: {order_id}")
|
||||
return {
|
||||
'success': True,
|
||||
'order_id': order_id,
|
||||
'timestamp': data.get('data', {}).get('ts'),
|
||||
'symbol': symbol,
|
||||
'side': 'long',
|
||||
'volume': volume,
|
||||
'leverage': leverage
|
||||
}
|
||||
else:
|
||||
logger.error(f"MEXC: Order failed: {data}")
|
||||
return {'success': False, 'error': data.get('msg', 'Unknown error')}
|
||||
else:
|
||||
logger.error(f"MEXC: Order request failed with status {response.status_code}")
|
||||
return {'success': False, 'error': f'HTTP {response.status_code}'}
|
||||
else:
|
||||
logger.error(f"MEXC: OPTIONS preflight failed with status {options_response.status_code}")
|
||||
return {'success': False, 'error': f'Preflight failed: HTTP {options_response.status_code}'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: Order execution error: {e}")
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
def close_long_position(self, symbol: str, volume: float, leverage: int = 200,
|
||||
price: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Close a long futures position
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH_USDT')
|
||||
volume: Position size to close (contracts)
|
||||
leverage: Leverage multiplier
|
||||
price: Limit price (None for market order)
|
||||
|
||||
Returns:
|
||||
dict: Order response
|
||||
"""
|
||||
if not self.is_authenticated:
|
||||
logger.error("MEXC: Cannot close position - not authenticated")
|
||||
return {'success': False, 'error': 'Not authenticated'}
|
||||
|
||||
# First verify captcha
|
||||
if not self.verify_captcha(symbol, 'closelong', f'{leverage}X'):
|
||||
logger.error("MEXC: Captcha verification failed for closing long position")
|
||||
return {'success': False, 'error': 'Captcha verification failed'}
|
||||
|
||||
# Similar to open_long_position but with closeType instead of openType
|
||||
timestamp = int(time.time() * 1000)
|
||||
nonce = timestamp
|
||||
|
||||
order_data = {
|
||||
'symbol': symbol,
|
||||
'side': 2, # Close side is opposite
|
||||
'closeType': 1, # Close position
|
||||
'type': '5', # Market order
|
||||
'vol': volume,
|
||||
'leverage': leverage,
|
||||
'marketCeiling': False,
|
||||
'priceProtect': '0',
|
||||
'ts': timestamp,
|
||||
'mhash': self._generate_mhash(),
|
||||
'mtoken': self.visitor_id
|
||||
}
|
||||
|
||||
if price is not None:
|
||||
order_data['price'] = price
|
||||
order_data['type'] = '1'
|
||||
|
||||
order_data['p0'] = self._encrypt_p0(order_data)
|
||||
order_data['k0'] = self._encrypt_k0(order_data)
|
||||
order_data['chash'] = self._generate_chash(order_data)
|
||||
|
||||
return self._execute_order(order_data, 'close_long')
|
||||
|
||||
def open_short_position(self, symbol: str, volume: float, leverage: int = 200,
|
||||
price: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""Open a short futures position"""
|
||||
if not self.verify_captcha(symbol, 'openshort', f'{leverage}X'):
|
||||
return {'success': False, 'error': 'Captcha verification failed'}
|
||||
|
||||
order_data = {
|
||||
'symbol': symbol,
|
||||
'side': 2, # 2 = short
|
||||
'openType': 2,
|
||||
'type': '5',
|
||||
'vol': volume,
|
||||
'leverage': leverage,
|
||||
'marketCeiling': False,
|
||||
'priceProtect': '0',
|
||||
'ts': int(time.time() * 1000),
|
||||
'mhash': self._generate_mhash(),
|
||||
'mtoken': self.visitor_id
|
||||
}
|
||||
|
||||
if price is not None:
|
||||
order_data['price'] = price
|
||||
order_data['type'] = '1'
|
||||
|
||||
order_data['p0'] = self._encrypt_p0(order_data)
|
||||
order_data['k0'] = self._encrypt_k0(order_data)
|
||||
order_data['chash'] = self._generate_chash(order_data)
|
||||
|
||||
return self._execute_order(order_data, 'open_short')
|
||||
|
||||
def close_short_position(self, symbol: str, volume: float, leverage: int = 200,
|
||||
price: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""Close a short futures position"""
|
||||
if not self.verify_captcha(symbol, 'closeshort', f'{leverage}X'):
|
||||
return {'success': False, 'error': 'Captcha verification failed'}
|
||||
|
||||
order_data = {
|
||||
'symbol': symbol,
|
||||
'side': 1, # Close side is opposite
|
||||
'closeType': 1,
|
||||
'type': '5',
|
||||
'vol': volume,
|
||||
'leverage': leverage,
|
||||
'marketCeiling': False,
|
||||
'priceProtect': '0',
|
||||
'ts': int(time.time() * 1000),
|
||||
'mhash': self._generate_mhash(),
|
||||
'mtoken': self.visitor_id
|
||||
}
|
||||
|
||||
if price is not None:
|
||||
order_data['price'] = price
|
||||
order_data['type'] = '1'
|
||||
|
||||
order_data['p0'] = self._encrypt_p0(order_data)
|
||||
order_data['k0'] = self._encrypt_k0(order_data)
|
||||
order_data['chash'] = self._generate_chash(order_data)
|
||||
|
||||
return self._execute_order(order_data, 'close_short')
|
||||
|
||||
def _execute_order(self, order_data: Dict[str, Any], action: str) -> Dict[str, Any]:
|
||||
"""Common order execution logic"""
|
||||
timestamp = order_data['ts']
|
||||
nonce = timestamp
|
||||
|
||||
headers = {
|
||||
'Authorization': self.auth_token,
|
||||
'Content-Type': 'application/json',
|
||||
'Language': 'English',
|
||||
'x-language': 'en-GB',
|
||||
'x-mxc-nonce': str(nonce),
|
||||
'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce),
|
||||
'trochilus-uid': self.user_id,
|
||||
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
|
||||
'Referer': 'https://www.mexc.com/'
|
||||
}
|
||||
|
||||
url = f"{self.futures_api_url}/private/order/create"
|
||||
|
||||
try:
|
||||
response = self.session.post(url, json=order_data, headers=headers, timeout=15)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('success') and data.get('code') == 0:
|
||||
order_id = data.get('data', {}).get('orderId')
|
||||
logger.info(f"MEXC: {action} executed successfully - Order ID: {order_id}")
|
||||
return {
|
||||
'success': True,
|
||||
'order_id': order_id,
|
||||
'timestamp': data.get('data', {}).get('ts'),
|
||||
'action': action
|
||||
}
|
||||
else:
|
||||
logger.error(f"MEXC: {action} failed: {data}")
|
||||
return {'success': False, 'error': data.get('msg', 'Unknown error')}
|
||||
else:
|
||||
logger.error(f"MEXC: {action} request failed with status {response.status_code}")
|
||||
return {'success': False, 'error': f'HTTP {response.status_code}'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: {action} execution error: {e}")
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
# Placeholder methods for encryption/hashing - these need proper implementation
|
||||
def _generate_mhash(self) -> str:
|
||||
"""Generate mhash parameter (needs reverse engineering)"""
|
||||
return "a0015441fd4c3b6ba427b894b76cb7dd" # Placeholder from request dump
|
||||
|
||||
def _encrypt_p0(self, order_data: Dict[str, Any]) -> str:
|
||||
"""Encrypt p0 parameter (needs reverse engineering)"""
|
||||
return "placeholder_p0_encryption" # This needs proper implementation
|
||||
|
||||
def _encrypt_k0(self, order_data: Dict[str, Any]) -> str:
|
||||
"""Encrypt k0 parameter (needs reverse engineering)"""
|
||||
return "placeholder_k0_encryption" # This needs proper implementation
|
||||
|
||||
def _generate_chash(self, order_data: Dict[str, Any]) -> str:
|
||||
"""Generate chash parameter (needs reverse engineering)"""
|
||||
return "d6c64d28e362f314071b3f9d78ff7494d9cd7177ae0465e772d1840e9f7905d8" # Placeholder
|
||||
|
||||
def get_account_info(self) -> Dict[str, Any]:
|
||||
"""Get account information including positions and balances"""
|
||||
if not self.is_authenticated:
|
||||
return {'success': False, 'error': 'Not authenticated'}
|
||||
|
||||
# This would need to be implemented by reverse engineering the account info endpoints
|
||||
logger.info("MEXC: Account info endpoint not yet implemented")
|
||||
return {'success': False, 'error': 'Not implemented'}
|
||||
|
||||
def get_open_positions(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of open futures positions"""
|
||||
if not self.is_authenticated:
|
||||
return []
|
||||
|
||||
# This would need to be implemented by reverse engineering the positions endpoint
|
||||
logger.info("MEXC: Open positions endpoint not yet implemented")
|
||||
return []
|
||||
@@ -1,259 +0,0 @@
|
||||
"""
|
||||
MEXC Session Manager
|
||||
|
||||
Helper utilities for managing MEXC web sessions and extracting cookies from browser.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Optional, Any
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MEXCSessionManager:
|
||||
"""
|
||||
Helper class for managing MEXC web sessions and extracting browser cookies
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.session_file = Path("mexc_session.json")
|
||||
|
||||
def extract_cookies_from_network_tab(self, cookie_header: str) -> Dict[str, str]:
|
||||
"""
|
||||
Extract cookies from browser Network tab cookie header
|
||||
|
||||
Args:
|
||||
cookie_header: Raw cookie string from browser (copy from Request Headers)
|
||||
|
||||
Returns:
|
||||
Dictionary of parsed cookies
|
||||
"""
|
||||
cookies = {}
|
||||
|
||||
# Remove 'Cookie: ' prefix if present
|
||||
if cookie_header.startswith('Cookie: '):
|
||||
cookie_header = cookie_header[8:]
|
||||
elif cookie_header.startswith('cookie: '):
|
||||
cookie_header = cookie_header[8:]
|
||||
|
||||
# Split by semicolon and parse each cookie
|
||||
cookie_pairs = cookie_header.split(';')
|
||||
|
||||
for pair in cookie_pairs:
|
||||
pair = pair.strip()
|
||||
if '=' in pair:
|
||||
name, value = pair.split('=', 1)
|
||||
cookies[name.strip()] = value.strip()
|
||||
|
||||
logger.info(f"Extracted {len(cookies)} cookies from browser")
|
||||
return cookies
|
||||
|
||||
def validate_session_cookies(self, cookies: Dict[str, str]) -> bool:
|
||||
"""
|
||||
Validate that essential cookies are present for authentication
|
||||
|
||||
Args:
|
||||
cookies: Dictionary of cookie name-value pairs
|
||||
|
||||
Returns:
|
||||
bool: True if cookies appear valid for authentication
|
||||
"""
|
||||
required_cookies = [
|
||||
'uc_token', # User authentication token
|
||||
'u_id', # User ID
|
||||
'x-mxc-fingerprint', # Browser fingerprint
|
||||
'mexc_fingerprint_visitorId' # Visitor ID
|
||||
]
|
||||
|
||||
missing_cookies = []
|
||||
for cookie_name in required_cookies:
|
||||
if cookie_name not in cookies or not cookies[cookie_name]:
|
||||
missing_cookies.append(cookie_name)
|
||||
|
||||
if missing_cookies:
|
||||
logger.warning(f"Missing required cookies: {missing_cookies}")
|
||||
return False
|
||||
|
||||
logger.info("All required cookies are present")
|
||||
return True
|
||||
|
||||
def save_session(self, cookies: Dict[str, str], metadata: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Save session cookies to file for reuse
|
||||
|
||||
Args:
|
||||
cookies: Dictionary of cookies to save
|
||||
metadata: Optional metadata about the session
|
||||
"""
|
||||
session_data = {
|
||||
'cookies': cookies,
|
||||
'metadata': metadata or {},
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
|
||||
try:
|
||||
with open(self.session_file, 'w') as f:
|
||||
json.dump(session_data, f, indent=2)
|
||||
logger.info(f"Session saved to {self.session_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save session: {e}")
|
||||
|
||||
def load_session(self) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Load session cookies from file
|
||||
|
||||
Returns:
|
||||
Dictionary of cookies if successful, None otherwise
|
||||
"""
|
||||
if not self.session_file.exists():
|
||||
logger.info("No saved session found")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(self.session_file, 'r') as f:
|
||||
session_data = json.load(f)
|
||||
|
||||
cookies = session_data.get('cookies', {})
|
||||
timestamp = session_data.get('timestamp', 0)
|
||||
|
||||
# Check if session is too old (24 hours)
|
||||
import time
|
||||
if time.time() - timestamp > 24 * 3600:
|
||||
logger.warning("Saved session is too old (>24h), may be expired")
|
||||
|
||||
if self.validate_session_cookies(cookies):
|
||||
logger.info("Loaded valid session from file")
|
||||
return cookies
|
||||
else:
|
||||
logger.warning("Loaded session has invalid cookies")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load session: {e}")
|
||||
return None
|
||||
|
||||
def extract_from_curl_command(self, curl_command: str) -> Dict[str, str]:
|
||||
"""
|
||||
Extract cookies from a curl command copied from browser
|
||||
|
||||
Args:
|
||||
curl_command: Complete curl command from browser "Copy as cURL"
|
||||
|
||||
Returns:
|
||||
Dictionary of extracted cookies
|
||||
"""
|
||||
cookies = {}
|
||||
|
||||
# Find cookie header in curl command
|
||||
cookie_match = re.search(r'-H [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE)
|
||||
if not cookie_match:
|
||||
cookie_match = re.search(r'--header [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE)
|
||||
|
||||
if cookie_match:
|
||||
cookie_header = cookie_match.group(1)
|
||||
cookies = self.extract_cookies_from_network_tab(cookie_header)
|
||||
logger.info(f"Extracted {len(cookies)} cookies from curl command")
|
||||
else:
|
||||
logger.warning("No cookie header found in curl command")
|
||||
|
||||
return cookies
|
||||
|
||||
def print_cookie_extraction_guide(self):
|
||||
"""Print instructions for extracting cookies from browser"""
|
||||
print("\n" + "="*80)
|
||||
print("MEXC COOKIE EXTRACTION GUIDE")
|
||||
print("="*80)
|
||||
print("""
|
||||
To extract cookies from your browser for MEXC futures trading:
|
||||
|
||||
METHOD 1: Browser Network Tab
|
||||
1. Open MEXC futures page and log in: https://www.mexc.com/en-GB/futures/ETH_USDT
|
||||
2. Open browser Developer Tools (F12)
|
||||
3. Go to Network tab
|
||||
4. Try to place a small futures trade (it will fail, but we need the request)
|
||||
5. Find the request to 'futures.mexc.com' in the Network tab
|
||||
6. Right-click on the request -> Copy -> Copy request headers
|
||||
7. Find the 'Cookie:' line and copy everything after 'Cookie: '
|
||||
|
||||
METHOD 2: Copy as cURL
|
||||
1. Follow steps 1-5 above
|
||||
2. Right-click on the futures API request -> Copy -> Copy as cURL
|
||||
3. Paste the entire cURL command
|
||||
|
||||
METHOD 3: Manual Cookie Extraction
|
||||
1. While logged into MEXC, press F12 -> Application/Storage tab
|
||||
2. On the left, expand 'Cookies' -> click on 'https://www.mexc.com'
|
||||
3. Copy the values for these important cookies:
|
||||
- uc_token
|
||||
- u_id
|
||||
- x-mxc-fingerprint
|
||||
- mexc_fingerprint_visitorId
|
||||
|
||||
IMPORTANT NOTES:
|
||||
- Cookies expire after some time (usually 24 hours)
|
||||
- You must be logged into MEXC futures (not just spot trading)
|
||||
- Keep your cookies secure - they provide access to your account
|
||||
- Test with small amounts first
|
||||
|
||||
Example usage:
|
||||
session_manager = MEXCSessionManager()
|
||||
|
||||
# Method 1: From cookie header
|
||||
cookie_header = "uc_token=ABC123; u_id=DEF456; ..."
|
||||
cookies = session_manager.extract_cookies_from_network_tab(cookie_header)
|
||||
|
||||
# Method 2: From cURL command
|
||||
curl_cmd = "curl 'https://futures.mexc.com/...' -H 'cookie: uc_token=ABC123...'"
|
||||
cookies = session_manager.extract_from_curl_command(curl_cmd)
|
||||
|
||||
# Save session for reuse
|
||||
session_manager.save_session(cookies)
|
||||
""")
|
||||
print("="*80)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# When run directly, show the extraction guide
|
||||
import time
|
||||
|
||||
manager = MEXCSessionManager()
|
||||
manager.print_cookie_extraction_guide()
|
||||
|
||||
print("\nWould you like to:")
|
||||
print("1. Load saved session")
|
||||
print("2. Extract cookies from clipboard")
|
||||
print("3. Exit")
|
||||
|
||||
choice = input("\nEnter choice (1-3): ").strip()
|
||||
|
||||
if choice == "1":
|
||||
cookies = manager.load_session()
|
||||
if cookies:
|
||||
print(f"\nLoaded {len(cookies)} cookies from saved session")
|
||||
if manager.validate_session_cookies(cookies):
|
||||
print("Session appears valid for trading")
|
||||
else:
|
||||
print("Warning: Session may be incomplete or expired")
|
||||
else:
|
||||
print("No valid saved session found")
|
||||
|
||||
elif choice == "2":
|
||||
print("\nPaste your cookie header or cURL command:")
|
||||
user_input = input().strip()
|
||||
|
||||
if user_input.startswith('curl'):
|
||||
cookies = manager.extract_from_curl_command(user_input)
|
||||
else:
|
||||
cookies = manager.extract_cookies_from_network_tab(user_input)
|
||||
|
||||
if cookies and manager.validate_session_cookies(cookies):
|
||||
print(f"\nSuccessfully extracted {len(cookies)} valid cookies")
|
||||
save = input("Save session for reuse? (y/n): ").strip().lower()
|
||||
if save == 'y':
|
||||
manager.save_session(cookies)
|
||||
else:
|
||||
print("Failed to extract valid cookies")
|
||||
|
||||
else:
|
||||
print("Goodbye!")
|
||||
@@ -1,346 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test MEXC Futures Web Client
|
||||
|
||||
This script demonstrates how to use the MEXC Futures Web Client
|
||||
for futures trading that isn't supported by their official API.
|
||||
|
||||
IMPORTANT: This requires extracting cookies from your browser session.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
|
||||
# Add the project root to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from mexc_futures_client import MEXCFuturesWebClient
|
||||
from session_manager import MEXCSessionManager
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
SYMBOL = "ETH_USDT"
|
||||
LEVERAGE = 300
|
||||
CREDENTIALS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json')
|
||||
|
||||
# Read credentials from mexc_credentials.json in JSON format
|
||||
def load_credentials():
|
||||
credentials_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json')
|
||||
cookies = {}
|
||||
captcha_token_open = ''
|
||||
captcha_token_close = ''
|
||||
try:
|
||||
with open(credentials_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
cookies = data.get('credentials', {}).get('cookies', {})
|
||||
captcha_token_open = data.get('credentials', {}).get('captcha_token_open', '')
|
||||
captcha_token_close = data.get('credentials', {}).get('captcha_token_close', '')
|
||||
logger.info(f"Loaded credentials from {credentials_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading credentials: {e}")
|
||||
return cookies, captcha_token_open, captcha_token_close
|
||||
|
||||
def test_basic_connection():
|
||||
"""Test basic connection and authentication"""
|
||||
logger.info("Testing MEXC Futures Web Client")
|
||||
|
||||
# Initialize session manager
|
||||
session_manager = MEXCSessionManager()
|
||||
|
||||
# Try to load saved session first
|
||||
cookies = session_manager.load_session()
|
||||
|
||||
if not cookies:
|
||||
# Explicitly load the cookies from the file we have
|
||||
cookies_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'mexc_cookies_20250703_003625.json')
|
||||
if os.path.exists(cookies_file):
|
||||
try:
|
||||
with open(cookies_file, 'r') as f:
|
||||
cookies = json.load(f)
|
||||
logger.info(f"Loaded cookies from {cookies_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load cookies from {cookies_file}: {e}")
|
||||
cookies = None
|
||||
else:
|
||||
logger.error(f"Cookies file not found at {cookies_file}")
|
||||
cookies = None
|
||||
|
||||
if not cookies:
|
||||
print("\nNo saved session found. You need to extract cookies from your browser.")
|
||||
session_manager.print_cookie_extraction_guide()
|
||||
|
||||
print("\nPaste your cookie header or cURL command (or press Enter to exit):")
|
||||
user_input = input().strip()
|
||||
|
||||
if not user_input:
|
||||
print("No input provided. Exiting.")
|
||||
return False
|
||||
|
||||
# Extract cookies from user input
|
||||
if user_input.startswith('curl'):
|
||||
cookies = session_manager.extract_from_curl_command(user_input)
|
||||
else:
|
||||
cookies = session_manager.extract_cookies_from_network_tab(user_input)
|
||||
|
||||
if not cookies:
|
||||
logger.error("Failed to extract cookies from input")
|
||||
return False
|
||||
|
||||
# Validate and save session
|
||||
if session_manager.validate_session_cookies(cookies):
|
||||
session_manager.save_session(cookies)
|
||||
logger.info("Session saved for future use")
|
||||
else:
|
||||
logger.warning("Extracted cookies may be incomplete")
|
||||
|
||||
# Initialize the web client
|
||||
client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='', base_url='https://www.mexc.com', headless=True)
|
||||
# Load cookies into the client's session
|
||||
for name, value in cookies.items():
|
||||
client.session.cookies.set(name, value)
|
||||
|
||||
# Update headers to include additional parameters from captured requests
|
||||
client.session.headers.update({
|
||||
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
|
||||
'trochilus-uid': cookies.get('u_id', ''),
|
||||
'Referer': 'https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap',
|
||||
'Language': 'English',
|
||||
'X-Language': 'en-GB'
|
||||
})
|
||||
|
||||
if not client.is_authenticated:
|
||||
logger.error("Failed to authenticate with extracted cookies")
|
||||
return False
|
||||
|
||||
logger.info("Successfully authenticated with MEXC")
|
||||
logger.info(f"User ID: {client.user_id}")
|
||||
logger.info(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "No auth token")
|
||||
|
||||
return True
|
||||
|
||||
def test_captcha_verification(client: MEXCFuturesWebClient):
|
||||
"""Test captcha verification system"""
|
||||
logger.info("Testing captcha verification...")
|
||||
|
||||
# Test captcha for ETH_USDT long position with 200x leverage
|
||||
success = client.verify_captcha('ETH_USDT', 'openlong', '200X')
|
||||
|
||||
if success:
|
||||
logger.info("Captcha verification successful")
|
||||
else:
|
||||
logger.warning("Captcha verification failed - this may be normal if no position is being opened")
|
||||
|
||||
return success
|
||||
|
||||
def test_position_opening(client: MEXCFuturesWebClient, dry_run: bool = True):
|
||||
"""Test opening a position (dry run by default)"""
|
||||
if dry_run:
|
||||
logger.info("DRY RUN: Testing position opening (no actual trade)")
|
||||
else:
|
||||
logger.warning("LIVE TRADING: Opening actual position!")
|
||||
|
||||
symbol = 'ETH_USDT'
|
||||
volume = 1 # Small test position
|
||||
leverage = 200
|
||||
|
||||
logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x")
|
||||
|
||||
if not dry_run:
|
||||
result = client.open_long_position(symbol, volume, leverage)
|
||||
|
||||
if result['success']:
|
||||
logger.info(f"Position opened successfully!")
|
||||
logger.info(f"Order ID: {result['order_id']}")
|
||||
logger.info(f"Timestamp: {result['timestamp']}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to open position: {result['error']}")
|
||||
return False
|
||||
else:
|
||||
logger.info("DRY RUN: Would attempt to open position here")
|
||||
# Test just the captcha verification part
|
||||
return client.verify_captcha(symbol, 'openlong', f'{leverage}X')
|
||||
|
||||
def test_position_opening_live(client):
|
||||
symbol = "ETH_USDT"
|
||||
volume = 1 # Small volume for testing
|
||||
leverage = 200
|
||||
|
||||
logger.info(f"LIVE TRADING: Opening actual position!")
|
||||
logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x")
|
||||
|
||||
result = client.open_long_position(symbol, volume, leverage)
|
||||
if result.get('success'):
|
||||
logger.info(f"Successfully opened position: {result}")
|
||||
else:
|
||||
logger.error(f"Failed to open position: {result.get('error', 'Unknown error')}")
|
||||
|
||||
def interactive_menu(client: MEXCFuturesWebClient):
|
||||
"""Interactive menu for testing different functions"""
|
||||
while True:
|
||||
print("\n" + "="*50)
|
||||
print("MEXC Futures Web Client Test Menu")
|
||||
print("="*50)
|
||||
print("1. Test captcha verification")
|
||||
print("2. Test position opening (DRY RUN)")
|
||||
print("3. Test position opening (LIVE - BE CAREFUL!)")
|
||||
print("4. Test position closing (DRY RUN)")
|
||||
print("5. Show session info")
|
||||
print("6. Refresh session")
|
||||
print("0. Exit")
|
||||
|
||||
choice = input("\nEnter choice (0-6): ").strip()
|
||||
|
||||
if choice == "1":
|
||||
test_captcha_verification(client)
|
||||
|
||||
elif choice == "2":
|
||||
test_position_opening(client, dry_run=True)
|
||||
|
||||
elif choice == "3":
|
||||
test_position_opening_live(client)
|
||||
|
||||
elif choice == "4":
|
||||
logger.info("DRY RUN: Position closing test")
|
||||
success = client.verify_captcha('ETH_USDT', 'closelong', '200X')
|
||||
if success:
|
||||
logger.info("DRY RUN: Would close position here")
|
||||
else:
|
||||
logger.warning("Captcha verification failed for position closing")
|
||||
|
||||
elif choice == "5":
|
||||
print(f"\nSession Information:")
|
||||
print(f"Authenticated: {client.is_authenticated}")
|
||||
print(f"User ID: {client.user_id}")
|
||||
print(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "None")
|
||||
print(f"Fingerprint: {client.fingerprint}")
|
||||
print(f"Visitor ID: {client.visitor_id}")
|
||||
|
||||
elif choice == "6":
|
||||
session_manager = MEXCSessionManager()
|
||||
session_manager.print_cookie_extraction_guide()
|
||||
|
||||
elif choice == "0":
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
else:
|
||||
print("Invalid choice. Please try again.")
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
print("MEXC Futures Web Client Test")
|
||||
print("WARNING: This is experimental software for futures trading")
|
||||
print("Use at your own risk and test with small amounts first!")
|
||||
|
||||
# Load cookies and tokens
|
||||
cookies, captcha_token_open, captcha_token_close = load_credentials()
|
||||
if not cookies:
|
||||
logger.error("Failed to load cookies from credentials file")
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize client with loaded cookies and tokens
|
||||
client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='')
|
||||
# Load cookies into the client's session
|
||||
for name, value in cookies.items():
|
||||
client.session.cookies.set(name, value)
|
||||
# Set captcha tokens
|
||||
client.captcha_token_open = captcha_token_open
|
||||
client.captcha_token_close = captcha_token_close
|
||||
|
||||
# Try to load credentials from the new JSON file
|
||||
try:
|
||||
with open(CREDENTIALS_FILE, 'r') as f:
|
||||
credentials_data = json.load(f)
|
||||
cookies = credentials_data['credentials']['cookies']
|
||||
captcha_token_open = credentials_data['credentials']['captcha_token_open']
|
||||
captcha_token_close = credentials_data['credentials']['captcha_token_close']
|
||||
client.load_session_cookies(cookies)
|
||||
client.session_manager.save_captcha_token(captcha_token_open) # Assuming this is for opening
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Credentials file not found at {CREDENTIALS_FILE}")
|
||||
return False
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error loading credentials: {e}")
|
||||
return False
|
||||
except KeyError as e:
|
||||
logger.error(f"Missing key in credentials file: {e}")
|
||||
return False
|
||||
|
||||
if not client.is_authenticated:
|
||||
logger.error("Client not authenticated. Please ensure valid cookies and tokens are in mexc_credentials.json")
|
||||
return False
|
||||
|
||||
# Test connection and authentication
|
||||
logger.info("Successfully authenticated with MEXC")
|
||||
|
||||
# Set leverage
|
||||
leverage_response = client.update_leverage(symbol=SYMBOL, leverage=LEVERAGE)
|
||||
if leverage_response and leverage_response.get('code') == 200:
|
||||
logger.info(f"Leverage set to {LEVERAGE}x for {SYMBOL}")
|
||||
else:
|
||||
logger.error(f"Failed to set leverage: {leverage_response}")
|
||||
sys.exit(1)
|
||||
|
||||
# Get current price
|
||||
ticker = client.get_ticker_data(symbol=SYMBOL)
|
||||
if ticker and ticker.get('code') == 200:
|
||||
current_price = float(ticker['data']['last'])
|
||||
logger.info(f"Current {SYMBOL} price: {current_price}")
|
||||
else:
|
||||
logger.error(f"Failed to get ticker data: {ticker}")
|
||||
sys.exit(1)
|
||||
|
||||
# Calculate order size for a small test trade (e.g., $10 worth)
|
||||
trade_usdt = 10.0
|
||||
order_qty = round((trade_usdt / current_price) * LEVERAGE, 3)
|
||||
logger.info(f"Calculated order quantity: {order_qty} {SYMBOL} for ~${trade_usdt} at {LEVERAGE}x")
|
||||
|
||||
# Test 1: Open LONG position
|
||||
logger.info(f"Opening LONG position for {SYMBOL} at {current_price} with qty {order_qty}")
|
||||
open_long_order = client.create_order(
|
||||
symbol=SYMBOL,
|
||||
side=1, # 1 for BUY
|
||||
position_side=1, # 1 for LONG
|
||||
order_type=1, # 1 for LIMIT
|
||||
price=current_price,
|
||||
vol=order_qty
|
||||
)
|
||||
if open_long_order and open_long_order.get('code') == 200:
|
||||
logger.info(f"✅ Successfully opened LONG position: {open_long_order['data']}")
|
||||
else:
|
||||
logger.error(f"❌ Failed to open LONG position: {open_long_order}")
|
||||
sys.exit(1)
|
||||
|
||||
# Test 2: Close LONG position
|
||||
logger.info(f"Closing LONG position for {SYMBOL}")
|
||||
close_long_order = client.create_order(
|
||||
symbol=SYMBOL,
|
||||
side=2, # 2 for SELL
|
||||
position_side=1, # 1 for LONG
|
||||
order_type=1, # 1 for LIMIT
|
||||
price=current_price,
|
||||
vol=order_qty,
|
||||
reduce_only=True
|
||||
)
|
||||
if close_long_order and close_long_order.get('code') == 200:
|
||||
logger.info(f"✅ Successfully closed LONG position: {close_long_order['data']}")
|
||||
else:
|
||||
logger.error(f"❌ Failed to close LONG position: {close_long_order}")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("All tests completed successfully!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,595 +0,0 @@
|
||||
"""
|
||||
Negative Case Trainer - Intensive Training on Losing Trades
|
||||
|
||||
This module focuses on learning from losses to prevent future mistakes.
|
||||
Stores negative cases in testcases/negative folder for reuse and retraining.
|
||||
Supports simultaneous inference and training.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Import checkpoint management
|
||||
import torch
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class NegativeCase:
|
||||
"""Represents a losing trade case for intensive training"""
|
||||
case_id: str
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
action: str # 'BUY' or 'SELL'
|
||||
entry_price: float
|
||||
exit_price: float
|
||||
loss_amount: float
|
||||
loss_percentage: float
|
||||
confidence_used: float
|
||||
market_state_before: Dict[str, Any]
|
||||
market_state_after: Dict[str, Any]
|
||||
tick_data: List[Dict[str, Any]] # 15 minutes of tick data around the trade
|
||||
technical_indicators: Dict[str, float]
|
||||
what_should_have_been_done: str # 'HOLD', 'OPPOSITE', 'WAIT'
|
||||
lesson_learned: str
|
||||
training_priority: int # 1-5, 5 being highest priority
|
||||
retraining_count: int = 0
|
||||
last_retrained: Optional[datetime] = None
|
||||
|
||||
@dataclass
|
||||
class TrainingSession:
|
||||
"""Represents an intensive training session on negative cases"""
|
||||
session_id: str
|
||||
start_time: datetime
|
||||
cases_trained: List[str] # case_ids
|
||||
epochs_completed: int
|
||||
loss_improvement: float
|
||||
accuracy_improvement: float
|
||||
inference_paused: bool = False
|
||||
training_active: bool = True
|
||||
|
||||
class NegativeCaseTrainer:
|
||||
"""
|
||||
Intensive trainer focused on learning from losing trades with checkpoint management
|
||||
|
||||
Features:
|
||||
- Stores all losing trades as negative cases
|
||||
- Intensive retraining on losses
|
||||
- Simultaneous inference and training
|
||||
- Persistent storage in testcases/negative
|
||||
- Priority-based training (bigger losses = higher priority)
|
||||
- Checkpoint management for training progress
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: str = "testcases/negative",
|
||||
model_name: str = "negative_case_trainer", enable_checkpoints: bool = True):
|
||||
self.storage_dir = storage_dir
|
||||
self.stored_cases: List[NegativeCase] = []
|
||||
self.training_queue = deque(maxlen=1000)
|
||||
self.training_lock = threading.Lock()
|
||||
self.inference_lock = threading.Lock()
|
||||
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
self.enable_checkpoints = enable_checkpoints
|
||||
self.training_integration = get_training_integration() if enable_checkpoints else None
|
||||
self.training_session_count = 0
|
||||
self.best_loss_reduction = 0.0
|
||||
self.checkpoint_frequency = 25 # Save checkpoint every 25 training sessions
|
||||
|
||||
# Training configuration
|
||||
self.max_concurrent_training = 3 # Max parallel training sessions
|
||||
self.intensive_training_epochs = 50 # Epochs per negative case
|
||||
self.priority_multiplier = 2.0 # Training time multiplier for high priority cases
|
||||
|
||||
# Simultaneous inference/training control
|
||||
self.inference_active = True
|
||||
self.training_active = False
|
||||
self.current_training_sessions: List[TrainingSession] = []
|
||||
|
||||
# Performance tracking
|
||||
self.total_cases_processed = 0
|
||||
self.total_training_time = 0.0
|
||||
self.accuracy_improvements = []
|
||||
|
||||
# Initialize storage
|
||||
self._initialize_storage()
|
||||
self._load_existing_cases()
|
||||
|
||||
# Load best checkpoint if available
|
||||
if self.enable_checkpoints:
|
||||
self.load_best_checkpoint()
|
||||
|
||||
# Start background training thread
|
||||
self.training_thread = threading.Thread(target=self._background_training_loop, daemon=True)
|
||||
self.training_thread.start()
|
||||
|
||||
logger.info(f"NegativeCaseTrainer initialized with {len(self.stored_cases)} existing cases")
|
||||
logger.info(f"Storage directory: {self.storage_dir}")
|
||||
logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
|
||||
logger.info("Background training thread started")
|
||||
|
||||
def _initialize_storage(self):
|
||||
"""Initialize storage directories"""
|
||||
try:
|
||||
os.makedirs(self.storage_dir, exist_ok=True)
|
||||
os.makedirs(f"{self.storage_dir}/cases", exist_ok=True)
|
||||
os.makedirs(f"{self.storage_dir}/sessions", exist_ok=True)
|
||||
os.makedirs(f"{self.storage_dir}/models", exist_ok=True)
|
||||
|
||||
# Create index file if it doesn't exist
|
||||
index_file = f"{self.storage_dir}/case_index.json"
|
||||
if not os.path.exists(index_file):
|
||||
with open(index_file, 'w') as f:
|
||||
json.dump({"cases": [], "last_updated": datetime.now().isoformat()}, f)
|
||||
|
||||
logger.info(f"Storage initialized at {self.storage_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing storage: {e}")
|
||||
|
||||
def _load_existing_cases(self):
|
||||
"""Load existing negative cases from storage"""
|
||||
try:
|
||||
index_file = f"{self.storage_dir}/case_index.json"
|
||||
if os.path.exists(index_file):
|
||||
with open(index_file, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
for case_info in index_data.get("cases", []):
|
||||
case_file = f"{self.storage_dir}/cases/{case_info['case_id']}.pkl"
|
||||
if os.path.exists(case_file):
|
||||
try:
|
||||
with open(case_file, 'rb') as f:
|
||||
case = pickle.load(f)
|
||||
self.stored_cases.append(case)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading case {case_info['case_id']}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(self.stored_cases)} existing negative cases")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading existing cases: {e}")
|
||||
|
||||
def add_losing_trade(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Add a losing trade as a negative case for intensive training
|
||||
|
||||
Args:
|
||||
trade_info: Trade information including P&L
|
||||
market_data: Market state and tick data around the trade
|
||||
|
||||
Returns:
|
||||
case_id: Unique identifier for the negative case
|
||||
"""
|
||||
try:
|
||||
# Generate unique case ID
|
||||
case_id = f"loss_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{trade_info['symbol'].replace('/', '')}"
|
||||
|
||||
# Calculate loss metrics
|
||||
loss_amount = abs(trade_info.get('pnl', 0))
|
||||
loss_percentage = (loss_amount / trade_info.get('value', 1)) * 100
|
||||
|
||||
# Determine training priority based on loss size
|
||||
if loss_percentage > 10:
|
||||
priority = 5 # Critical loss
|
||||
elif loss_percentage > 5:
|
||||
priority = 4 # High loss
|
||||
elif loss_percentage > 2:
|
||||
priority = 3 # Medium loss
|
||||
elif loss_percentage > 1:
|
||||
priority = 2 # Small loss
|
||||
else:
|
||||
priority = 1 # Minimal loss
|
||||
|
||||
# Analyze what should have been done
|
||||
what_should_have_been_done = self._analyze_optimal_action(trade_info, market_data)
|
||||
lesson_learned = self._generate_lesson(trade_info, market_data, what_should_have_been_done)
|
||||
|
||||
# Create negative case
|
||||
negative_case = NegativeCase(
|
||||
case_id=case_id,
|
||||
timestamp=trade_info['timestamp'],
|
||||
symbol=trade_info['symbol'],
|
||||
action=trade_info['action'],
|
||||
entry_price=trade_info['price'],
|
||||
exit_price=market_data.get('exit_price', trade_info['price']),
|
||||
loss_amount=loss_amount,
|
||||
loss_percentage=loss_percentage,
|
||||
confidence_used=trade_info.get('confidence', 0.5),
|
||||
market_state_before=market_data.get('state_before', {}),
|
||||
market_state_after=market_data.get('state_after', {}),
|
||||
tick_data=market_data.get('tick_data', []),
|
||||
technical_indicators=market_data.get('technical_indicators', {}),
|
||||
what_should_have_been_done=what_should_have_been_done,
|
||||
lesson_learned=lesson_learned,
|
||||
training_priority=priority
|
||||
)
|
||||
|
||||
# Store the case
|
||||
self._store_case(negative_case)
|
||||
|
||||
# Add to training queue with priority
|
||||
with self.training_lock:
|
||||
self.training_queue.append(negative_case)
|
||||
self.stored_cases.append(negative_case)
|
||||
|
||||
logger.error(f"NEGATIVE CASE ADDED: {case_id} | Loss: ${loss_amount:.2f} ({loss_percentage:.1f}%) | Priority: {priority}")
|
||||
logger.error(f"Lesson: {lesson_learned}")
|
||||
|
||||
return case_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding losing trade: {e}")
|
||||
return ""
|
||||
|
||||
def _analyze_optimal_action(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
|
||||
"""Analyze what the optimal action should have been"""
|
||||
try:
|
||||
# Simple analysis based on price movement
|
||||
entry_price = trade_info['price']
|
||||
exit_price = market_data.get('exit_price', entry_price)
|
||||
action = trade_info['action']
|
||||
|
||||
price_change = (exit_price - entry_price) / entry_price
|
||||
|
||||
if action == 'BUY' and price_change < 0:
|
||||
# Bought but price went down
|
||||
if abs(price_change) > 0.005: # >0.5% move
|
||||
return 'SELL' # Should have sold instead
|
||||
else:
|
||||
return 'HOLD' # Should have waited
|
||||
elif action == 'SELL' and price_change > 0:
|
||||
# Sold but price went up
|
||||
if price_change > 0.005: # >0.5% move
|
||||
return 'BUY' # Should have bought instead
|
||||
else:
|
||||
return 'HOLD' # Should have waited
|
||||
else:
|
||||
return 'HOLD' # Should have done nothing
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing optimal action: {e}")
|
||||
return 'HOLD'
|
||||
|
||||
def _generate_lesson(self, trade_info: Dict[str, Any], market_data: Dict[str, Any], optimal_action: str) -> str:
|
||||
"""Generate a lesson learned from the losing trade"""
|
||||
try:
|
||||
action = trade_info['action']
|
||||
symbol = trade_info['symbol']
|
||||
loss_pct = (abs(trade_info.get('pnl', 0)) / trade_info.get('value', 1)) * 100
|
||||
confidence = trade_info.get('confidence', 0.5)
|
||||
|
||||
if optimal_action == 'HOLD':
|
||||
return f"Should have HELD {symbol} instead of {action}. Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss."
|
||||
elif optimal_action == 'BUY' and action == 'SELL':
|
||||
return f"Should have BOUGHT {symbol} instead of SELLING. Market moved opposite to prediction."
|
||||
elif optimal_action == 'SELL' and action == 'BUY':
|
||||
return f"Should have SOLD {symbol} instead of BUYING. Market moved opposite to prediction."
|
||||
else:
|
||||
return f"Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss on {action} {symbol}."
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating lesson: {e}")
|
||||
return "Learn from this loss to improve future decisions."
|
||||
|
||||
def _store_case(self, case: NegativeCase):
|
||||
"""Store negative case to persistent storage"""
|
||||
try:
|
||||
# Store case file
|
||||
case_file = f"{self.storage_dir}/cases/{case.case_id}.pkl"
|
||||
with open(case_file, 'wb') as f:
|
||||
pickle.dump(case, f)
|
||||
|
||||
# Update index
|
||||
index_file = f"{self.storage_dir}/case_index.json"
|
||||
with open(index_file, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
# Add case to index
|
||||
case_info = {
|
||||
'case_id': case.case_id,
|
||||
'timestamp': case.timestamp.isoformat(),
|
||||
'symbol': case.symbol,
|
||||
'loss_amount': case.loss_amount,
|
||||
'loss_percentage': case.loss_percentage,
|
||||
'training_priority': case.training_priority,
|
||||
'retraining_count': case.retraining_count
|
||||
}
|
||||
|
||||
index_data['cases'].append(case_info)
|
||||
index_data['last_updated'] = datetime.now().isoformat()
|
||||
|
||||
with open(index_file, 'w') as f:
|
||||
json.dump(index_data, f, indent=2)
|
||||
|
||||
logger.info(f"Stored negative case: {case.case_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing case: {e}")
|
||||
|
||||
def _background_training_loop(self):
|
||||
"""Background loop for intensive training on negative cases"""
|
||||
logger.info("Background training loop started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Check if we have cases to train on
|
||||
with self.training_lock:
|
||||
if not self.training_queue:
|
||||
time.sleep(5) # Wait for new cases
|
||||
continue
|
||||
|
||||
# Get highest priority case
|
||||
cases_by_priority = sorted(self.training_queue, key=lambda x: x.training_priority, reverse=True)
|
||||
case_to_train = cases_by_priority[0]
|
||||
self.training_queue.remove(case_to_train)
|
||||
|
||||
# Start intensive training session
|
||||
self._start_intensive_training_session(case_to_train)
|
||||
|
||||
# Brief pause between training sessions
|
||||
time.sleep(2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background training loop: {e}")
|
||||
time.sleep(10) # Wait longer on error
|
||||
|
||||
def _start_intensive_training_session(self, case: NegativeCase):
|
||||
"""Start an intensive training session for a negative case"""
|
||||
try:
|
||||
session_id = f"session_{case.case_id}_{int(time.time())}"
|
||||
|
||||
# Create training session
|
||||
session = TrainingSession(
|
||||
session_id=session_id,
|
||||
start_time=datetime.now(),
|
||||
cases_trained=[case.case_id],
|
||||
epochs_completed=0,
|
||||
loss_improvement=0.0,
|
||||
accuracy_improvement=0.0
|
||||
)
|
||||
|
||||
self.current_training_sessions.append(session)
|
||||
self.training_active = True
|
||||
|
||||
logger.warning(f"INTENSIVE TRAINING STARTED: {session_id}")
|
||||
logger.warning(f"Training on loss case: {case.case_id} (Priority: {case.training_priority})")
|
||||
|
||||
# Calculate training epochs based on priority
|
||||
epochs = int(self.intensive_training_epochs * case.training_priority * self.priority_multiplier)
|
||||
|
||||
# Simulate intensive training (replace with actual model training)
|
||||
for epoch in range(epochs):
|
||||
# Pause inference during critical training phases
|
||||
if case.training_priority >= 4 and epoch % 10 == 0:
|
||||
with self.inference_lock:
|
||||
session.inference_paused = True
|
||||
time.sleep(0.1) # Brief pause for critical training
|
||||
session.inference_paused = False
|
||||
|
||||
# Simulate training step
|
||||
session.epochs_completed = epoch + 1
|
||||
|
||||
# Log progress for high priority cases
|
||||
if case.training_priority >= 4 and epoch % 10 == 0:
|
||||
logger.warning(f"Intensive training progress: {epoch}/{epochs} epochs ({case.case_id})")
|
||||
|
||||
time.sleep(0.05) # Simulate training time
|
||||
|
||||
# Update case retraining info
|
||||
case.retraining_count += 1
|
||||
case.last_retrained = datetime.now()
|
||||
|
||||
# Calculate improvements (simulated)
|
||||
session.loss_improvement = np.random.uniform(0.1, 0.5) # 10-50% improvement
|
||||
session.accuracy_improvement = np.random.uniform(0.05, 0.2) # 5-20% improvement
|
||||
|
||||
# Store training session results
|
||||
self._store_training_session(session)
|
||||
|
||||
# Update statistics
|
||||
self.total_cases_processed += 1
|
||||
self.total_training_time += (datetime.now() - session.start_time).total_seconds()
|
||||
self.accuracy_improvements.append(session.accuracy_improvement)
|
||||
|
||||
# Remove from active sessions
|
||||
self.current_training_sessions.remove(session)
|
||||
if not self.current_training_sessions:
|
||||
self.training_active = False
|
||||
|
||||
logger.warning(f"INTENSIVE TRAINING COMPLETED: {session_id}")
|
||||
logger.warning(f"Epochs: {session.epochs_completed} | Loss improvement: {session.loss_improvement:.1%} | Accuracy improvement: {session.accuracy_improvement:.1%}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in intensive training session: {e}")
|
||||
|
||||
def _store_training_session(self, session: TrainingSession):
|
||||
"""Store training session results"""
|
||||
try:
|
||||
session_file = f"{self.storage_dir}/sessions/{session.session_id}.json"
|
||||
session_data = {
|
||||
'session_id': session.session_id,
|
||||
'start_time': session.start_time.isoformat(),
|
||||
'end_time': datetime.now().isoformat(),
|
||||
'cases_trained': session.cases_trained,
|
||||
'epochs_completed': session.epochs_completed,
|
||||
'loss_improvement': session.loss_improvement,
|
||||
'accuracy_improvement': session.accuracy_improvement
|
||||
}
|
||||
|
||||
with open(session_file, 'w') as f:
|
||||
json.dump(session_data, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing training session: {e}")
|
||||
|
||||
def can_inference_proceed(self) -> bool:
|
||||
"""Check if inference can proceed (not blocked by critical training)"""
|
||||
with self.inference_lock:
|
||||
# Check if any critical training is pausing inference
|
||||
for session in self.current_training_sessions:
|
||||
if session.inference_paused:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get training statistics"""
|
||||
try:
|
||||
avg_accuracy_improvement = np.mean(self.accuracy_improvements) if self.accuracy_improvements else 0.0
|
||||
|
||||
return {
|
||||
'total_negative_cases': len(self.stored_cases),
|
||||
'cases_in_queue': len(self.training_queue),
|
||||
'total_cases_processed': self.total_cases_processed,
|
||||
'total_training_time': self.total_training_time,
|
||||
'avg_accuracy_improvement': avg_accuracy_improvement,
|
||||
'active_training_sessions': len(self.current_training_sessions),
|
||||
'training_active': self.training_active,
|
||||
'high_priority_cases': len([c for c in self.stored_cases if c.training_priority >= 4]),
|
||||
'storage_directory': self.storage_dir
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training stats: {e}")
|
||||
return {}
|
||||
|
||||
def get_recent_lessons(self, count: int = 5) -> List[str]:
|
||||
"""Get recent lessons learned from negative cases"""
|
||||
try:
|
||||
recent_cases = sorted(self.stored_cases, key=lambda x: x.timestamp, reverse=True)[:count]
|
||||
return [case.lesson_learned for case in recent_cases]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting recent lessons: {e}")
|
||||
return []
|
||||
|
||||
def retrain_all_cases(self):
|
||||
"""Retrain all stored negative cases (for periodic retraining)"""
|
||||
try:
|
||||
logger.warning("RETRAINING ALL NEGATIVE CASES - This may take a while...")
|
||||
|
||||
with self.training_lock:
|
||||
# Add all stored cases back to training queue
|
||||
for case in self.stored_cases:
|
||||
if case not in self.training_queue:
|
||||
self.training_queue.append(case)
|
||||
|
||||
logger.warning(f"Added {len(self.stored_cases)} cases to retraining queue")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retraining all cases: {e}")
|
||||
|
||||
def load_best_checkpoint(self):
|
||||
"""Load the best checkpoint for this negative case trainer"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return
|
||||
|
||||
result = load_best_checkpoint(self.model_name)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
checkpoint = torch.load(file_path, map_location='cpu')
|
||||
|
||||
# Load training state
|
||||
if 'training_session_count' in checkpoint:
|
||||
self.training_session_count = checkpoint['training_session_count']
|
||||
if 'best_loss_reduction' in checkpoint:
|
||||
self.best_loss_reduction = checkpoint['best_loss_reduction']
|
||||
if 'total_cases_processed' in checkpoint:
|
||||
self.total_cases_processed = checkpoint['total_cases_processed']
|
||||
if 'total_training_time' in checkpoint:
|
||||
self.total_training_time = checkpoint['total_training_time']
|
||||
if 'accuracy_improvements' in checkpoint:
|
||||
self.accuracy_improvements = checkpoint['accuracy_improvements']
|
||||
|
||||
logger.info(f"Loaded NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
|
||||
logger.info(f"Session: {self.training_session_count}, Best loss reduction: {self.best_loss_reduction:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
|
||||
|
||||
def save_checkpoint(self, loss_improvement: float = 0.0, force_save: bool = False):
|
||||
"""Save checkpoint if performance improved or forced"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return False
|
||||
|
||||
self.training_session_count += 1
|
||||
|
||||
# Update best loss reduction
|
||||
improved = False
|
||||
if loss_improvement > self.best_loss_reduction:
|
||||
self.best_loss_reduction = loss_improvement
|
||||
improved = True
|
||||
|
||||
# Save checkpoint if improved, forced, or at regular intervals
|
||||
should_save = (
|
||||
force_save or
|
||||
improved or
|
||||
self.training_session_count % self.checkpoint_frequency == 0
|
||||
)
|
||||
|
||||
if should_save:
|
||||
# Prepare checkpoint data
|
||||
checkpoint_data = {
|
||||
'training_session_count': self.training_session_count,
|
||||
'best_loss_reduction': self.best_loss_reduction,
|
||||
'total_cases_processed': self.total_cases_processed,
|
||||
'total_training_time': self.total_training_time,
|
||||
'accuracy_improvements': self.accuracy_improvements,
|
||||
'storage_dir': self.storage_dir,
|
||||
'max_concurrent_training': self.max_concurrent_training,
|
||||
'intensive_training_epochs': self.intensive_training_epochs
|
||||
}
|
||||
|
||||
# Create performance metrics for checkpoint manager
|
||||
avg_accuracy_improvement = (
|
||||
sum(self.accuracy_improvements) / len(self.accuracy_improvements)
|
||||
if self.accuracy_improvements else 0.0
|
||||
)
|
||||
|
||||
performance_metrics = {
|
||||
'loss_reduction': self.best_loss_reduction,
|
||||
'avg_accuracy_improvement': avg_accuracy_improvement,
|
||||
'total_cases_processed': self.total_cases_processed,
|
||||
'training_efficiency': (
|
||||
self.total_cases_processed / self.total_training_time
|
||||
if self.total_training_time > 0 else 0.0
|
||||
)
|
||||
}
|
||||
|
||||
# Save using checkpoint manager
|
||||
metadata = save_checkpoint(
|
||||
model=checkpoint_data, # We're saving data dict instead of model
|
||||
model_name=self.model_name,
|
||||
model_type="negative_case_trainer",
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata={
|
||||
'session': self.training_session_count,
|
||||
'cases_processed': self.total_cases_processed,
|
||||
'training_time_hours': self.total_training_time / 3600
|
||||
},
|
||||
force_save=force_save
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"Saved NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving NegativeCaseTrainer checkpoint: {e}")
|
||||
return False
|
||||
@@ -1,277 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Neural Network Decision Fusion System
|
||||
Central NN that merges all model outputs + market data for final trading decisions
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ModelPrediction:
|
||||
"""Standardized prediction from any model"""
|
||||
model_name: str
|
||||
prediction_type: str # 'price', 'direction', 'action'
|
||||
value: float # -1 to 1 for direction, actual price for price predictions
|
||||
confidence: float # 0 to 1
|
||||
timestamp: datetime
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@dataclass
|
||||
class MarketContext:
|
||||
"""Current market context for decision fusion"""
|
||||
symbol: str
|
||||
current_price: float
|
||||
price_change_1m: float
|
||||
price_change_5m: float
|
||||
volume_ratio: float
|
||||
volatility: float
|
||||
timestamp: datetime
|
||||
|
||||
@dataclass
|
||||
class FusionDecision:
|
||||
"""Final trading decision from fusion NN"""
|
||||
action: str # 'BUY', 'SELL', 'HOLD'
|
||||
confidence: float # 0 to 1
|
||||
expected_return: float # Expected return percentage
|
||||
risk_score: float # 0 to 1, higher = riskier
|
||||
position_size: float # Recommended position size
|
||||
reasoning: str # Human-readable explanation
|
||||
model_contributions: Dict[str, float] # How much each model contributed
|
||||
timestamp: datetime
|
||||
|
||||
class DecisionFusionNetwork(nn.Module):
|
||||
"""Small NN that fuses model predictions with market context"""
|
||||
|
||||
def __init__(self, input_dim: int = 32, hidden_dim: int = 64):
|
||||
super().__init__()
|
||||
|
||||
self.fusion_layers = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim // 2, 16)
|
||||
)
|
||||
|
||||
# Output heads
|
||||
self.action_head = nn.Linear(16, 3) # BUY, SELL, HOLD
|
||||
self.confidence_head = nn.Linear(16, 1)
|
||||
self.return_head = nn.Linear(16, 1)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
"""Forward pass through fusion network"""
|
||||
fusion_output = self.fusion_layers(features)
|
||||
|
||||
action_logits = self.action_head(fusion_output)
|
||||
action_probs = F.softmax(action_logits, dim=1)
|
||||
|
||||
confidence = torch.sigmoid(self.confidence_head(fusion_output))
|
||||
expected_return = torch.tanh(self.return_head(fusion_output))
|
||||
|
||||
return {
|
||||
'action_probs': action_probs,
|
||||
'confidence': confidence.squeeze(),
|
||||
'expected_return': expected_return.squeeze()
|
||||
}
|
||||
|
||||
class NeuralDecisionFusion:
|
||||
"""Main NN-based decision fusion system"""
|
||||
|
||||
def __init__(self, training_mode: bool = True):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.network = DecisionFusionNetwork().to(self.device)
|
||||
self.training_mode = training_mode
|
||||
self.registered_models = {}
|
||||
self.last_predictions = {}
|
||||
|
||||
logger.info(f"Neural Decision Fusion initialized on {self.device}")
|
||||
|
||||
def register_model(self, model_name: str, model_type: str, prediction_format: str):
|
||||
"""Register a model that will provide predictions"""
|
||||
self.registered_models[model_name] = {
|
||||
'type': model_type,
|
||||
'format': prediction_format,
|
||||
'prediction_count': 0
|
||||
}
|
||||
logger.info(f"Registered NN model: {model_name} ({model_type})")
|
||||
|
||||
def add_prediction(self, prediction: ModelPrediction):
|
||||
"""Add a prediction from a registered model"""
|
||||
self.last_predictions[prediction.model_name] = prediction
|
||||
if prediction.model_name in self.registered_models:
|
||||
self.registered_models[prediction.model_name]['prediction_count'] += 1
|
||||
|
||||
logger.debug(f"🔮 {prediction.model_name}: {prediction.value:.3f} "
|
||||
f"(confidence: {prediction.confidence:.3f})")
|
||||
|
||||
def make_decision(self, symbol: str, market_context: MarketContext,
|
||||
min_confidence: float = 0.25) -> Optional[FusionDecision]:
|
||||
"""Make NN-driven trading decision"""
|
||||
try:
|
||||
if len(self.last_predictions) < 1:
|
||||
logger.debug("No NN predictions available")
|
||||
return None
|
||||
|
||||
# Prepare features
|
||||
features = self._prepare_features(market_context)
|
||||
if features is None:
|
||||
return None
|
||||
|
||||
# Run NN inference
|
||||
with torch.no_grad():
|
||||
self.network.eval()
|
||||
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.device)
|
||||
outputs = self.network(features_tensor)
|
||||
|
||||
action_probs = outputs['action_probs'][0].cpu().numpy()
|
||||
confidence = outputs['confidence'].cpu().item()
|
||||
expected_return = outputs['expected_return'].cpu().item()
|
||||
|
||||
# Determine action
|
||||
action_idx = np.argmax(action_probs)
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
action = actions[action_idx]
|
||||
|
||||
# Check confidence threshold
|
||||
if confidence < min_confidence:
|
||||
action = 'HOLD'
|
||||
logger.debug(f"Low NN confidence ({confidence:.3f}), defaulting to HOLD")
|
||||
|
||||
# Calculate position size
|
||||
position_size = self._calculate_position_size(confidence, expected_return)
|
||||
|
||||
# Generate reasoning
|
||||
reasoning = self._generate_reasoning(action, confidence, expected_return, action_probs)
|
||||
|
||||
# Calculate risk score and model contributions
|
||||
risk_score = min(1.0, abs(expected_return) * 5 + (1 - confidence) * 0.5)
|
||||
model_contributions = self._calculate_model_contributions()
|
||||
|
||||
decision = FusionDecision(
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
expected_return=expected_return,
|
||||
risk_score=risk_score,
|
||||
position_size=position_size,
|
||||
reasoning=reasoning,
|
||||
model_contributions=model_contributions,
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
logger.info(f"🧠 NN DECISION: {action} (conf: {confidence:.3f}, "
|
||||
f"return: {expected_return:.3f}, size: {position_size:.4f})")
|
||||
|
||||
return decision
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in NN decision making: {e}")
|
||||
return None
|
||||
|
||||
def _prepare_features(self, context: MarketContext) -> Optional[np.ndarray]:
|
||||
"""Prepare feature vector for NN"""
|
||||
try:
|
||||
features = np.zeros(32)
|
||||
|
||||
# Model predictions (slots 0-15)
|
||||
idx = 0
|
||||
for model_name, prediction in self.last_predictions.items():
|
||||
if idx < 14: # Leave room for other features
|
||||
features[idx] = prediction.value
|
||||
features[idx + 1] = prediction.confidence
|
||||
idx += 2
|
||||
|
||||
# Market context (slots 16-31)
|
||||
features[16] = np.tanh(context.price_change_1m * 100) # 1m change
|
||||
features[17] = np.tanh(context.price_change_5m * 100) # 5m change
|
||||
features[18] = np.tanh(context.volume_ratio - 1) # Volume ratio
|
||||
features[19] = np.tanh(context.volatility * 100) # Volatility
|
||||
features[20] = context.current_price / 10000.0 # Normalized price
|
||||
|
||||
# Time features
|
||||
now = context.timestamp
|
||||
features[21] = now.hour / 24.0
|
||||
features[22] = now.weekday() / 7.0
|
||||
|
||||
# Model agreement features
|
||||
if len(self.last_predictions) >= 2:
|
||||
values = [p.value for p in self.last_predictions.values()]
|
||||
features[23] = np.mean(values) # Average prediction
|
||||
features[24] = np.std(values) # Prediction variance
|
||||
features[25] = len(self.last_predictions) # Model count
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing NN features: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_position_size(self, confidence: float, expected_return: float) -> float:
|
||||
"""Calculate position size based on NN outputs"""
|
||||
base_size = 0.01 # 0.01 ETH base
|
||||
|
||||
# Scale by confidence
|
||||
confidence_multiplier = max(0.1, min(2.0, confidence * 1.5))
|
||||
|
||||
# Scale by expected return
|
||||
return_multiplier = 1.0 + abs(expected_return) * 0.5
|
||||
|
||||
final_size = base_size * confidence_multiplier * return_multiplier
|
||||
return max(0.001, min(0.05, final_size))
|
||||
|
||||
def _generate_reasoning(self, action: str, confidence: float,
|
||||
expected_return: float, action_probs: np.ndarray) -> str:
|
||||
"""Generate human-readable reasoning"""
|
||||
reasons = []
|
||||
|
||||
if action == 'BUY':
|
||||
reasons.append(f"NN suggests BUY ({action_probs[0]:.1%})")
|
||||
elif action == 'SELL':
|
||||
reasons.append(f"NN suggests SELL ({action_probs[1]:.1%})")
|
||||
else:
|
||||
reasons.append(f"NN suggests HOLD")
|
||||
|
||||
if confidence > 0.7:
|
||||
reasons.append("High confidence")
|
||||
elif confidence > 0.5:
|
||||
reasons.append("Moderate confidence")
|
||||
else:
|
||||
reasons.append("Low confidence")
|
||||
|
||||
if abs(expected_return) > 0.01:
|
||||
direction = "positive" if expected_return > 0 else "negative"
|
||||
reasons.append(f"Expected {direction} return: {expected_return:.2%}")
|
||||
|
||||
reasons.append(f"Based on {len(self.last_predictions)} NN models")
|
||||
|
||||
return " | ".join(reasons)
|
||||
|
||||
def _calculate_model_contributions(self) -> Dict[str, float]:
|
||||
"""Calculate how much each model contributed to the decision"""
|
||||
contributions = {}
|
||||
total_confidence = sum(p.confidence for p in self.last_predictions.values()) if self.last_predictions else 1.0
|
||||
|
||||
if total_confidence > 0:
|
||||
for model_name, prediction in self.last_predictions.items():
|
||||
contributions[model_name] = prediction.confidence / total_confidence
|
||||
|
||||
return contributions
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get NN fusion system status"""
|
||||
return {
|
||||
'device': str(self.device),
|
||||
'training_mode': self.training_mode,
|
||||
'registered_models': len(self.registered_models),
|
||||
'recent_predictions': len(self.last_predictions),
|
||||
'model_parameters': sum(p.numel() for p in self.network.parameters())
|
||||
}
|
||||
Binary file not shown.
@@ -1,649 +0,0 @@
|
||||
"""
|
||||
Real-Time Tick Processing Neural Network Module
|
||||
|
||||
This module acts as a Neural Network DPS (Data Processing System) alternative,
|
||||
processing raw tick data with ultra-low latency and feeding processed features
|
||||
to trading models in real-time.
|
||||
|
||||
Features:
|
||||
- Real-time tick ingestion with volume processing
|
||||
- Neural network feature extraction from tick streams
|
||||
- Ultra-low latency processing (sub-millisecond)
|
||||
- Volume-weighted price analysis
|
||||
- Microstructure pattern detection
|
||||
- Real-time feature streaming to models
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Deque
|
||||
from collections import deque
|
||||
from threading import Thread, Lock
|
||||
import websockets
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TickData:
|
||||
"""Raw tick data structure"""
|
||||
timestamp: datetime
|
||||
price: float
|
||||
volume: float
|
||||
side: str # 'buy' or 'sell'
|
||||
trade_id: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class ProcessedTickFeatures:
|
||||
"""Processed tick features for model consumption"""
|
||||
timestamp: datetime
|
||||
price_features: np.ndarray # Price-based features
|
||||
volume_features: np.ndarray # Volume-based features
|
||||
microstructure_features: np.ndarray # Market microstructure features
|
||||
neural_features: np.ndarray # Neural network extracted features
|
||||
confidence: float # Feature quality confidence
|
||||
|
||||
class TickProcessingNN(nn.Module):
|
||||
"""
|
||||
Neural Network for real-time tick processing
|
||||
Extracts high-level features from raw tick data
|
||||
"""
|
||||
|
||||
def __init__(self, input_size: int = 9, hidden_size: int = 128, output_size: int = 64):
|
||||
super(TickProcessingNN, self).__init__()
|
||||
|
||||
# Tick sequence processing layers
|
||||
self.tick_encoder = nn.Sequential(
|
||||
nn.Linear(input_size, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1)
|
||||
)
|
||||
|
||||
# LSTM for temporal patterns
|
||||
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, num_layers=2)
|
||||
|
||||
# Attention mechanism for important tick selection
|
||||
self.attention = nn.MultiheadAttention(hidden_size, num_heads=8, batch_first=True)
|
||||
|
||||
# Feature extraction heads
|
||||
self.price_head = nn.Linear(hidden_size, 16) # Price pattern features
|
||||
self.volume_head = nn.Linear(hidden_size, 16) # Volume pattern features
|
||||
self.microstructure_head = nn.Linear(hidden_size, 16) # Microstructure features
|
||||
|
||||
# Final feature fusion
|
||||
self.feature_fusion = nn.Sequential(
|
||||
nn.Linear(48, output_size), # 16+16+16 = 48
|
||||
nn.ReLU(),
|
||||
nn.Linear(output_size, output_size)
|
||||
)
|
||||
|
||||
# Confidence estimation
|
||||
self.confidence_head = nn.Sequential(
|
||||
nn.Linear(output_size, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, tick_sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Process tick sequence and extract features
|
||||
|
||||
Args:
|
||||
tick_sequence: [batch, sequence_length, features]
|
||||
|
||||
Returns:
|
||||
features: [batch, output_size] - extracted features
|
||||
confidence: [batch, 1] - feature confidence
|
||||
"""
|
||||
batch_size, seq_len, _ = tick_sequence.shape
|
||||
|
||||
# Encode each tick
|
||||
encoded = self.tick_encoder(tick_sequence) # [batch, seq_len, hidden_size]
|
||||
|
||||
# LSTM processing for temporal patterns
|
||||
lstm_out, _ = self.lstm(encoded) # [batch, seq_len, hidden_size]
|
||||
|
||||
# Attention to focus on important ticks
|
||||
attended, _ = self.attention(lstm_out, lstm_out, lstm_out) # [batch, seq_len, hidden_size]
|
||||
|
||||
# Use the last attended output
|
||||
final_features = attended[:, -1, :] # [batch, hidden_size]
|
||||
|
||||
# Extract specialized features
|
||||
price_features = self.price_head(final_features)
|
||||
volume_features = self.volume_head(final_features)
|
||||
microstructure_features = self.microstructure_head(final_features)
|
||||
|
||||
# Fuse all features
|
||||
combined_features = torch.cat([price_features, volume_features, microstructure_features], dim=1)
|
||||
final_features = self.feature_fusion(combined_features)
|
||||
|
||||
# Estimate confidence
|
||||
confidence = self.confidence_head(final_features)
|
||||
|
||||
return final_features, confidence
|
||||
|
||||
class RealTimeTickProcessor:
|
||||
"""
|
||||
Real-time tick processing system with neural network feature extraction
|
||||
Acts as a DPS alternative for ultra-low latency tick processing
|
||||
"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None, tick_buffer_size: int = 1000):
|
||||
"""Initialize the real-time tick processor"""
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
self.tick_buffer_size = tick_buffer_size
|
||||
|
||||
# Tick storage buffers
|
||||
self.tick_buffers: Dict[str, Deque[TickData]] = {}
|
||||
self.processed_features: Dict[str, Deque[ProcessedTickFeatures]] = {}
|
||||
|
||||
# Initialize buffers for each symbol
|
||||
for symbol in self.symbols:
|
||||
self.tick_buffers[symbol] = deque(maxlen=tick_buffer_size)
|
||||
self.processed_features[symbol] = deque(maxlen=100) # Keep last 100 processed features
|
||||
|
||||
# Neural network for feature extraction
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.tick_nn = TickProcessingNN(input_size=9).to(self.device)
|
||||
self.tick_nn.eval() # Start in evaluation mode
|
||||
|
||||
# Processing parameters
|
||||
self.processing_window = 50 # Number of ticks to process at once
|
||||
self.min_ticks_for_processing = 10 # Minimum ticks before processing
|
||||
|
||||
# Real-time streaming
|
||||
self.streaming = False
|
||||
self.websocket_tasks = {}
|
||||
self.processing_threads = {}
|
||||
|
||||
# Performance tracking
|
||||
self.processing_times = deque(maxlen=1000)
|
||||
self.tick_counts = {symbol: 0 for symbol in self.symbols}
|
||||
|
||||
# Thread safety
|
||||
self.data_lock = Lock()
|
||||
|
||||
# Feature subscribers (models that want real-time features)
|
||||
self.feature_subscribers = []
|
||||
|
||||
logger.info(f"RealTimeTickProcessor initialized for symbols: {self.symbols}")
|
||||
logger.info(f"Neural network device: {self.device}")
|
||||
logger.info(f"Tick buffer size: {tick_buffer_size}")
|
||||
|
||||
def add_feature_subscriber(self, callback):
|
||||
"""Add a callback function to receive processed features"""
|
||||
self.feature_subscribers.append(callback)
|
||||
logger.info(f"Added feature subscriber: {callback.__name__}")
|
||||
|
||||
def remove_feature_subscriber(self, callback):
|
||||
"""Remove a feature subscriber"""
|
||||
if callback in self.feature_subscribers:
|
||||
self.feature_subscribers.remove(callback)
|
||||
logger.info(f"Removed feature subscriber: {callback.__name__}")
|
||||
|
||||
async def start_processing(self):
|
||||
"""Start real-time tick processing"""
|
||||
logger.info("Starting real-time tick processing...")
|
||||
self.streaming = True
|
||||
|
||||
# Start WebSocket streams for each symbol
|
||||
for symbol in self.symbols:
|
||||
task = asyncio.create_task(self._websocket_stream(symbol))
|
||||
self.websocket_tasks[symbol] = task
|
||||
|
||||
# Start processing thread for each symbol
|
||||
thread = Thread(target=self._processing_loop, args=(symbol,), daemon=True)
|
||||
thread.start()
|
||||
self.processing_threads[symbol] = thread
|
||||
|
||||
logger.info("Real-time tick processing started")
|
||||
|
||||
async def stop_processing(self):
|
||||
"""Stop real-time tick processing"""
|
||||
logger.info("Stopping real-time tick processing...")
|
||||
self.streaming = False
|
||||
|
||||
# Cancel WebSocket tasks
|
||||
for symbol, task in self.websocket_tasks.items():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.websocket_tasks.clear()
|
||||
logger.info("Real-time tick processing stopped")
|
||||
|
||||
async def _websocket_stream(self, symbol: str):
|
||||
"""WebSocket stream for real-time tick data"""
|
||||
binance_symbol = symbol.replace('/', '').lower()
|
||||
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade"
|
||||
|
||||
while self.streaming:
|
||||
try:
|
||||
async with websockets.connect(url) as websocket:
|
||||
logger.info(f"Tick WebSocket connected for {symbol}")
|
||||
|
||||
async for message in websocket:
|
||||
if not self.streaming:
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_raw_tick(symbol, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing tick for {symbol}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for {symbol}: {e}")
|
||||
if self.streaming:
|
||||
logger.info(f"Reconnecting tick WebSocket for {symbol} in 2 seconds...")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def _process_raw_tick(self, symbol: str, raw_data: Dict):
|
||||
"""Process raw tick data from WebSocket"""
|
||||
try:
|
||||
# Extract tick information
|
||||
tick = TickData(
|
||||
timestamp=datetime.fromtimestamp(int(raw_data['T']) / 1000),
|
||||
price=float(raw_data['p']),
|
||||
volume=float(raw_data['q']),
|
||||
side='buy' if raw_data['m'] == False else 'sell', # m=true means buyer is market maker (sell)
|
||||
trade_id=raw_data.get('t')
|
||||
)
|
||||
|
||||
# Add to buffer
|
||||
with self.data_lock:
|
||||
self.tick_buffers[symbol].append(tick)
|
||||
self.tick_counts[symbol] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing raw tick for {symbol}: {e}")
|
||||
|
||||
def _processing_loop(self, symbol: str):
|
||||
"""Main processing loop for a symbol"""
|
||||
logger.info(f"Starting processing loop for {symbol}")
|
||||
|
||||
while self.streaming:
|
||||
try:
|
||||
# Check if we have enough ticks to process
|
||||
with self.data_lock:
|
||||
tick_count = len(self.tick_buffers[symbol])
|
||||
|
||||
if tick_count >= self.min_ticks_for_processing:
|
||||
start_time = time.time()
|
||||
|
||||
# Process ticks
|
||||
features = self._extract_neural_features(symbol)
|
||||
|
||||
if features is not None:
|
||||
# Store processed features
|
||||
with self.data_lock:
|
||||
self.processed_features[symbol].append(features)
|
||||
|
||||
# Notify subscribers
|
||||
self._notify_feature_subscribers(symbol, features)
|
||||
|
||||
# Track processing time
|
||||
processing_time = (time.time() - start_time) * 1000 # Convert to ms
|
||||
self.processing_times.append(processing_time)
|
||||
|
||||
if len(self.processing_times) % 100 == 0:
|
||||
avg_time = np.mean(list(self.processing_times))
|
||||
logger.debug(f"RTP: Average processing time: {avg_time:.2f}ms")
|
||||
|
||||
# Small sleep to prevent CPU overload
|
||||
time.sleep(0.001) # 1ms sleep for ultra-low latency
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in processing loop for {symbol}: {e}")
|
||||
time.sleep(0.01) # Longer sleep on error
|
||||
|
||||
def _extract_neural_features(self, symbol: str) -> Optional[ProcessedTickFeatures]:
|
||||
"""Extract neural network features from recent ticks"""
|
||||
try:
|
||||
with self.data_lock:
|
||||
# Get recent ticks
|
||||
recent_ticks = list(self.tick_buffers[symbol])[-self.processing_window:]
|
||||
|
||||
if len(recent_ticks) < self.min_ticks_for_processing:
|
||||
return None
|
||||
|
||||
# Convert ticks to neural network input
|
||||
tick_features = self._ticks_to_features(recent_ticks)
|
||||
|
||||
# Process with neural network
|
||||
with torch.no_grad():
|
||||
tick_tensor = torch.FloatTensor(tick_features).unsqueeze(0).to(self.device)
|
||||
neural_features, confidence = self.tick_nn(tick_tensor)
|
||||
|
||||
neural_features = neural_features.cpu().numpy().flatten()
|
||||
confidence = confidence.cpu().numpy().item()
|
||||
|
||||
# Extract traditional features
|
||||
price_features = self._extract_price_features(recent_ticks)
|
||||
volume_features = self._extract_volume_features(recent_ticks)
|
||||
microstructure_features = self._extract_microstructure_features(recent_ticks)
|
||||
|
||||
# Create processed features object
|
||||
processed = ProcessedTickFeatures(
|
||||
timestamp=recent_ticks[-1].timestamp,
|
||||
price_features=price_features,
|
||||
volume_features=volume_features,
|
||||
microstructure_features=microstructure_features,
|
||||
neural_features=neural_features,
|
||||
confidence=confidence
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting neural features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _ticks_to_features(self, ticks: List[TickData]) -> np.ndarray:
|
||||
"""Convert tick data to neural network input features"""
|
||||
features = []
|
||||
|
||||
for i, tick in enumerate(ticks):
|
||||
tick_features = [
|
||||
tick.price,
|
||||
tick.volume,
|
||||
1.0 if tick.side == 'buy' else 0.0, # Buy/sell indicator
|
||||
tick.timestamp.timestamp(), # Timestamp
|
||||
]
|
||||
|
||||
# Add relative features if we have previous ticks
|
||||
if i > 0:
|
||||
prev_tick = ticks[i-1]
|
||||
price_change = (tick.price - prev_tick.price) / prev_tick.price
|
||||
volume_ratio = tick.volume / (prev_tick.volume + 1e-8)
|
||||
time_delta = (tick.timestamp - prev_tick.timestamp).total_seconds()
|
||||
|
||||
tick_features.extend([
|
||||
price_change,
|
||||
volume_ratio,
|
||||
time_delta
|
||||
])
|
||||
else:
|
||||
tick_features.extend([0.0, 1.0, 0.0]) # Default values for first tick
|
||||
|
||||
# Add moving averages if we have enough data
|
||||
if i >= 5:
|
||||
recent_prices = [t.price for t in ticks[max(0, i-4):i+1]]
|
||||
recent_volumes = [t.volume for t in ticks[max(0, i-4):i+1]]
|
||||
|
||||
price_ma = np.mean(recent_prices)
|
||||
volume_ma = np.mean(recent_volumes)
|
||||
|
||||
tick_features.extend([
|
||||
(tick.price - price_ma) / price_ma, # Price deviation from MA
|
||||
(tick.volume - volume_ma) / (volume_ma + 1e-8) # Volume deviation from MA
|
||||
])
|
||||
else:
|
||||
tick_features.extend([0.0, 0.0])
|
||||
|
||||
features.append(tick_features)
|
||||
|
||||
# Pad or truncate to fixed size
|
||||
target_length = self.processing_window
|
||||
if len(features) < target_length:
|
||||
# Pad with zeros
|
||||
padding = [[0.0] * len(features[0])] * (target_length - len(features))
|
||||
features = padding + features
|
||||
elif len(features) > target_length:
|
||||
# Take the most recent ticks
|
||||
features = features[-target_length:]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _extract_price_features(self, ticks: List[TickData]) -> np.ndarray:
|
||||
"""Extract price-based features"""
|
||||
prices = np.array([tick.price for tick in ticks])
|
||||
|
||||
features = [
|
||||
prices[-1], # Current price
|
||||
np.mean(prices), # Average price
|
||||
np.std(prices), # Price volatility
|
||||
np.max(prices), # High
|
||||
np.min(prices), # Low
|
||||
(prices[-1] - prices[0]) / prices[0] if prices[0] != 0 else 0, # Total return
|
||||
]
|
||||
|
||||
# Price momentum features
|
||||
if len(prices) >= 10:
|
||||
short_ma = np.mean(prices[-5:])
|
||||
long_ma = np.mean(prices[-10:])
|
||||
momentum = (short_ma - long_ma) / long_ma if long_ma != 0 else 0
|
||||
features.append(momentum)
|
||||
else:
|
||||
features.append(0.0)
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _extract_volume_features(self, ticks: List[TickData]) -> np.ndarray:
|
||||
"""Extract volume-based features"""
|
||||
volumes = np.array([tick.volume for tick in ticks])
|
||||
buy_volumes = np.array([tick.volume for tick in ticks if tick.side == 'buy'])
|
||||
sell_volumes = np.array([tick.volume for tick in ticks if tick.side == 'sell'])
|
||||
|
||||
features = [
|
||||
np.sum(volumes), # Total volume
|
||||
np.mean(volumes), # Average volume
|
||||
np.std(volumes), # Volume volatility
|
||||
np.sum(buy_volumes) if len(buy_volumes) > 0 else 0, # Buy volume
|
||||
np.sum(sell_volumes) if len(sell_volumes) > 0 else 0, # Sell volume
|
||||
]
|
||||
|
||||
# Volume imbalance
|
||||
total_buy = np.sum(buy_volumes) if len(buy_volumes) > 0 else 0
|
||||
total_sell = np.sum(sell_volumes) if len(sell_volumes) > 0 else 0
|
||||
total_volume = total_buy + total_sell
|
||||
|
||||
if total_volume > 0:
|
||||
buy_ratio = total_buy / total_volume
|
||||
volume_imbalance = buy_ratio - 0.5 # -0.5 to 0.5 range
|
||||
else:
|
||||
volume_imbalance = 0.0
|
||||
|
||||
features.append(volume_imbalance)
|
||||
|
||||
# VWAP (Volume Weighted Average Price)
|
||||
if np.sum(volumes) > 0:
|
||||
prices = np.array([tick.price for tick in ticks])
|
||||
vwap = np.sum(prices * volumes) / np.sum(volumes)
|
||||
current_price = ticks[-1].price
|
||||
vwap_deviation = (current_price - vwap) / vwap if vwap != 0 else 0
|
||||
else:
|
||||
vwap_deviation = 0.0
|
||||
|
||||
features.append(vwap_deviation)
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _extract_microstructure_features(self, ticks: List[TickData]) -> np.ndarray:
|
||||
"""Extract market microstructure features"""
|
||||
features = []
|
||||
|
||||
# Trade frequency
|
||||
if len(ticks) >= 2:
|
||||
time_deltas = [(ticks[i].timestamp - ticks[i-1].timestamp).total_seconds()
|
||||
for i in range(1, len(ticks))]
|
||||
avg_time_delta = np.mean(time_deltas)
|
||||
trade_frequency = 1.0 / avg_time_delta if avg_time_delta > 0 else 0
|
||||
else:
|
||||
trade_frequency = 0.0
|
||||
|
||||
features.append(trade_frequency)
|
||||
|
||||
# Price impact features
|
||||
prices = [tick.price for tick in ticks]
|
||||
volumes = [tick.volume for tick in ticks]
|
||||
|
||||
if len(prices) >= 3:
|
||||
# Calculate price changes and corresponding volumes
|
||||
price_changes = [(prices[i] - prices[i-1]) / prices[i-1]
|
||||
for i in range(1, len(prices)) if prices[i-1] != 0]
|
||||
corresponding_volumes = volumes[1:len(price_changes)+1]
|
||||
|
||||
if len(price_changes) > 0 and len(corresponding_volumes) > 0:
|
||||
# Simple price impact measure
|
||||
price_impact = np.corrcoef(np.abs(price_changes), corresponding_volumes)[0, 1]
|
||||
if np.isnan(price_impact):
|
||||
price_impact = 0.0
|
||||
else:
|
||||
price_impact = 0.0
|
||||
else:
|
||||
price_impact = 0.0
|
||||
|
||||
features.append(price_impact)
|
||||
|
||||
# Bid-ask spread proxy (using price volatility)
|
||||
if len(prices) >= 5:
|
||||
recent_prices = prices[-5:]
|
||||
spread_proxy = (np.max(recent_prices) - np.min(recent_prices)) / np.mean(recent_prices)
|
||||
else:
|
||||
spread_proxy = 0.0
|
||||
|
||||
features.append(spread_proxy)
|
||||
|
||||
# Order flow imbalance (already calculated in volume features, but different perspective)
|
||||
buy_count = sum(1 for tick in ticks if tick.side == 'buy')
|
||||
sell_count = len(ticks) - buy_count
|
||||
total_trades = len(ticks)
|
||||
|
||||
if total_trades > 0:
|
||||
order_flow_imbalance = (buy_count - sell_count) / total_trades
|
||||
else:
|
||||
order_flow_imbalance = 0.0
|
||||
|
||||
features.append(order_flow_imbalance)
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _notify_feature_subscribers(self, symbol: str, features: ProcessedTickFeatures):
|
||||
"""Notify all feature subscribers of new processed features"""
|
||||
for callback in self.feature_subscribers:
|
||||
try:
|
||||
callback(symbol, features)
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying feature subscriber {callback.__name__}: {e}")
|
||||
|
||||
def get_latest_features(self, symbol: str) -> Optional[ProcessedTickFeatures]:
|
||||
"""Get the latest processed features for a symbol"""
|
||||
with self.data_lock:
|
||||
if symbol in self.processed_features and self.processed_features[symbol]:
|
||||
return self.processed_features[symbol][-1]
|
||||
return None
|
||||
|
||||
def get_processing_stats(self) -> Dict[str, Any]:
|
||||
"""Get processing performance statistics"""
|
||||
stats = {
|
||||
'symbols': self.symbols,
|
||||
'streaming': self.streaming,
|
||||
'tick_counts': dict(self.tick_counts),
|
||||
'buffer_sizes': {symbol: len(self.tick_buffers[symbol]) for symbol in self.symbols},
|
||||
'feature_counts': {symbol: len(self.processed_features[symbol]) for symbol in self.symbols},
|
||||
'subscribers': len(self.feature_subscribers)
|
||||
}
|
||||
|
||||
if self.processing_times:
|
||||
stats['processing_performance'] = {
|
||||
'avg_time_ms': np.mean(list(self.processing_times)),
|
||||
'min_time_ms': np.min(list(self.processing_times)),
|
||||
'max_time_ms': np.max(list(self.processing_times)),
|
||||
'std_time_ms': np.std(list(self.processing_times))
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def train_neural_network(self, training_data: List[Tuple[np.ndarray, np.ndarray]], epochs: int = 100):
|
||||
"""Train the tick processing neural network"""
|
||||
logger.info("Training tick processing neural network...")
|
||||
|
||||
self.tick_nn.train()
|
||||
optimizer = torch.optim.Adam(self.tick_nn.parameters(), lr=0.001)
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
for epoch in range(epochs):
|
||||
total_loss = 0.0
|
||||
|
||||
for batch_features, batch_targets in training_data:
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Convert to tensors
|
||||
features_tensor = torch.FloatTensor(batch_features).to(self.device)
|
||||
targets_tensor = torch.FloatTensor(batch_targets).to(self.device)
|
||||
|
||||
# Forward pass
|
||||
outputs, confidence = self.tick_nn(features_tensor)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets_tensor)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if epoch % 10 == 0:
|
||||
avg_loss = total_loss / len(training_data)
|
||||
logger.info(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.6f}")
|
||||
|
||||
self.tick_nn.eval()
|
||||
logger.info("Neural network training completed")
|
||||
|
||||
# Integration with existing orchestrator
|
||||
def integrate_with_orchestrator(orchestrator, tick_processor: RealTimeTickProcessor):
|
||||
"""Integrate tick processor with enhanced orchestrator"""
|
||||
|
||||
def feature_callback(symbol: str, features: ProcessedTickFeatures):
|
||||
"""Callback to feed processed features to orchestrator"""
|
||||
try:
|
||||
# Convert processed features to format expected by orchestrator
|
||||
feature_dict = {
|
||||
'symbol': symbol,
|
||||
'timestamp': features.timestamp,
|
||||
'neural_features': features.neural_features,
|
||||
'price_features': features.price_features,
|
||||
'volume_features': features.volume_features,
|
||||
'microstructure_features': features.microstructure_features,
|
||||
'confidence': features.confidence
|
||||
}
|
||||
|
||||
# Feed to orchestrator's real-time feature processing
|
||||
if hasattr(orchestrator, 'process_realtime_features'):
|
||||
orchestrator.process_realtime_features(feature_dict)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error integrating features with orchestrator: {e}")
|
||||
|
||||
# Add the callback to tick processor
|
||||
tick_processor.add_feature_subscriber(feature_callback)
|
||||
logger.info("Tick processor integrated with orchestrator")
|
||||
|
||||
# Factory function for easy creation
|
||||
def create_realtime_tick_processor(symbols: List[str] = None) -> RealTimeTickProcessor:
|
||||
"""Create and configure a real-time tick processor"""
|
||||
if symbols is None:
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
|
||||
processor = RealTimeTickProcessor(symbols=symbols)
|
||||
logger.info(f"Created RealTimeTickProcessor for symbols: {symbols}")
|
||||
|
||||
return processor
|
||||
@@ -1,453 +0,0 @@
|
||||
"""
|
||||
Retrospective Training System
|
||||
|
||||
This module implements a retrospective training system that:
|
||||
1. Triggers training when trades close with known P&L outcomes
|
||||
2. Uses captured model inputs from trade entry to train models
|
||||
3. Optimizes for profit by learning from profitable vs unprofitable patterns
|
||||
4. Supports simultaneous inference and training without weight reloading
|
||||
5. Implements reinforcement learning with immediate reward feedback
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import queue
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TrainingCase:
|
||||
"""Represents a completed trade case for retrospective training"""
|
||||
case_id: str
|
||||
symbol: str
|
||||
action: str # 'BUY' or 'SELL'
|
||||
entry_price: float
|
||||
exit_price: float
|
||||
entry_time: datetime
|
||||
exit_time: datetime
|
||||
pnl: float
|
||||
fees: float
|
||||
confidence: float
|
||||
model_inputs: Dict[str, Any]
|
||||
market_state: Dict[str, Any]
|
||||
outcome_label: int # 1 for profit, 0 for loss, 2 for breakeven
|
||||
reward_signal: float # Scaled reward for RL training
|
||||
leverage: float = 1.0
|
||||
|
||||
class RetrospectiveTrainer:
|
||||
"""Retrospective training system for real-time model optimization"""
|
||||
|
||||
def __init__(self, orchestrator=None, config: Optional[Dict[str, Any]] = None):
|
||||
"""Initialize the retrospective trainer"""
|
||||
self.orchestrator = orchestrator
|
||||
self.config = config or {}
|
||||
|
||||
# Training configuration
|
||||
self.batch_size = self.config.get('batch_size', 32)
|
||||
self.min_cases_for_training = self.config.get('min_cases_for_training', 5)
|
||||
self.profit_threshold = self.config.get('profit_threshold', 0.0)
|
||||
self.training_frequency = self.config.get('training_frequency_seconds', 120) # 2 minutes
|
||||
self.max_training_cases = self.config.get('max_training_cases', 1000)
|
||||
|
||||
# Training state
|
||||
self.training_queue = queue.Queue()
|
||||
self.completed_cases = deque(maxlen=self.max_training_cases)
|
||||
self.training_stats = {
|
||||
'total_cases': 0,
|
||||
'profitable_cases': 0,
|
||||
'loss_cases': 0,
|
||||
'breakeven_cases': 0,
|
||||
'avg_profit': 0.0,
|
||||
'last_training_time': datetime.now(),
|
||||
'training_sessions': 0,
|
||||
'model_updates': 0
|
||||
}
|
||||
|
||||
# Threading
|
||||
self.training_thread = None
|
||||
self.is_training_active = False
|
||||
self.training_lock = threading.Lock()
|
||||
|
||||
logger.info("RetrospectiveTrainer initialized")
|
||||
logger.info(f"Configuration: batch_size={self.batch_size}, "
|
||||
f"min_cases={self.min_cases_for_training}, "
|
||||
f"training_freq={self.training_frequency}s")
|
||||
|
||||
def add_completed_trade(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> bool:
|
||||
"""Add a completed trade for retrospective training"""
|
||||
try:
|
||||
# Create training case from trade record
|
||||
case = self._create_training_case(trade_record, model_inputs)
|
||||
if case is None:
|
||||
return False
|
||||
|
||||
# Add to completed cases
|
||||
self.completed_cases.append(case)
|
||||
self.training_queue.put(case)
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['total_cases'] += 1
|
||||
if case.outcome_label == 1: # Profit
|
||||
self.training_stats['profitable_cases'] += 1
|
||||
elif case.outcome_label == 0: # Loss
|
||||
self.training_stats['loss_cases'] += 1
|
||||
else: # Breakeven
|
||||
self.training_stats['breakeven_cases'] += 1
|
||||
|
||||
# Calculate running average profit
|
||||
total_pnl = sum(c.pnl for c in self.completed_cases)
|
||||
self.training_stats['avg_profit'] = total_pnl / len(self.completed_cases)
|
||||
|
||||
logger.info(f"RETROSPECTIVE: Added training case {case.case_id} "
|
||||
f"(P&L: ${case.pnl:.3f}, Label: {case.outcome_label})")
|
||||
|
||||
# Trigger training if we have enough cases
|
||||
self._maybe_trigger_training()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding completed trade for retrospective training: {e}")
|
||||
return False
|
||||
|
||||
def _create_training_case(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> Optional[TrainingCase]:
|
||||
"""Create a training case from trade record and model inputs"""
|
||||
try:
|
||||
# Extract trade information
|
||||
symbol = trade_record.get('symbol', 'UNKNOWN')
|
||||
side = trade_record.get('side', 'UNKNOWN')
|
||||
pnl = trade_record.get('pnl', 0.0)
|
||||
fees = trade_record.get('fees', 0.0)
|
||||
confidence = trade_record.get('confidence', 0.0)
|
||||
|
||||
# Calculate net P&L after fees
|
||||
net_pnl = pnl - fees
|
||||
|
||||
# Determine outcome label and reward signal
|
||||
if net_pnl > self.profit_threshold:
|
||||
outcome_label = 1 # Profitable
|
||||
# Scale reward by profit magnitude and confidence
|
||||
reward_signal = min(10.0, net_pnl * confidence * 10) # Amplify for training
|
||||
elif net_pnl < -self.profit_threshold:
|
||||
outcome_label = 0 # Loss
|
||||
# Negative reward scaled by loss magnitude
|
||||
reward_signal = max(-10.0, net_pnl * confidence * 10) # Negative reward
|
||||
else:
|
||||
outcome_label = 2 # Breakeven
|
||||
reward_signal = 0.0
|
||||
|
||||
# Create case ID
|
||||
timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
case_id = f"retro_{timestamp_str}_{symbol.replace('/', '')}_{side}_pnl_{abs(net_pnl):.3f}".replace('.', 'p')
|
||||
|
||||
# Create training case
|
||||
case = TrainingCase(
|
||||
case_id=case_id,
|
||||
symbol=symbol,
|
||||
action=side,
|
||||
entry_price=trade_record.get('entry_price', 0.0),
|
||||
exit_price=trade_record.get('exit_price', 0.0),
|
||||
entry_time=trade_record.get('entry_time', datetime.now()),
|
||||
exit_time=trade_record.get('exit_time', datetime.now()),
|
||||
pnl=net_pnl,
|
||||
fees=fees,
|
||||
confidence=confidence,
|
||||
model_inputs=model_inputs,
|
||||
market_state=model_inputs.get('market_state', {}),
|
||||
outcome_label=outcome_label,
|
||||
reward_signal=reward_signal,
|
||||
leverage=trade_record.get('leverage', 1.0)
|
||||
)
|
||||
|
||||
return case
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating training case: {e}")
|
||||
return None
|
||||
|
||||
def _maybe_trigger_training(self):
|
||||
"""Check if we should trigger a training session"""
|
||||
try:
|
||||
# Check if we have enough cases
|
||||
if len(self.completed_cases) < self.min_cases_for_training:
|
||||
return
|
||||
|
||||
# Check if enough time has passed since last training
|
||||
time_since_last = (datetime.now() - self.training_stats['last_training_time']).total_seconds()
|
||||
if time_since_last < self.training_frequency:
|
||||
return
|
||||
|
||||
# Check if training thread is not already running
|
||||
if self.is_training_active:
|
||||
logger.debug("Training already in progress, skipping trigger")
|
||||
return
|
||||
|
||||
# Start training in background thread
|
||||
self._start_training_session()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking training trigger: {e}")
|
||||
|
||||
def _start_training_session(self):
|
||||
"""Start a training session in background thread"""
|
||||
try:
|
||||
if self.training_thread and self.training_thread.is_alive():
|
||||
logger.debug("Training thread already running")
|
||||
return
|
||||
|
||||
self.training_thread = threading.Thread(
|
||||
target=self._run_training_session,
|
||||
daemon=True,
|
||||
name="RetrospectiveTrainer"
|
||||
)
|
||||
self.training_thread.start()
|
||||
logger.info("RETROSPECTIVE: Started training session")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting training session: {e}")
|
||||
|
||||
def _run_training_session(self):
|
||||
"""Run a complete training session"""
|
||||
try:
|
||||
with self.training_lock:
|
||||
self.is_training_active = True
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"RETROSPECTIVE: Training with {len(self.completed_cases)} cases")
|
||||
|
||||
# Train models if orchestrator available
|
||||
training_results = {}
|
||||
if self.orchestrator:
|
||||
training_results = self._train_models()
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['last_training_time'] = datetime.now()
|
||||
self.training_stats['training_sessions'] += 1
|
||||
self.training_stats['model_updates'] += len(training_results)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"RETROSPECTIVE: Training completed in {elapsed_time:.2f}s - {training_results}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in retrospective training session: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
self.is_training_active = False
|
||||
|
||||
def _train_models(self) -> Dict[str, Any]:
|
||||
"""Train available models using retrospective data"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
# Prepare training data
|
||||
profitable_cases = [c for c in self.completed_cases if c.outcome_label == 1]
|
||||
loss_cases = [c for c in self.completed_cases if c.outcome_label == 0]
|
||||
|
||||
if len(profitable_cases) == 0 and len(loss_cases) == 0:
|
||||
return {'error': 'No labeled cases for training'}
|
||||
|
||||
logger.info(f"RETROSPECTIVE: Training data - Profitable: {len(profitable_cases)}, Loss: {len(loss_cases)}")
|
||||
|
||||
# Train DQN agent if available
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
try:
|
||||
dqn_result = self._train_dqn_retrospective()
|
||||
results['dqn'] = dqn_result
|
||||
logger.info(f"RETROSPECTIVE: DQN training result: {dqn_result}")
|
||||
except Exception as e:
|
||||
logger.warning(f"DQN retrospective training failed: {e}")
|
||||
results['dqn'] = {'error': str(e)}
|
||||
|
||||
# Train other models
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
|
||||
try:
|
||||
# Update extrema trainer with retrospective feedback
|
||||
extrema_feedback = self._create_extrema_feedback()
|
||||
if extrema_feedback:
|
||||
results['extrema'] = {'feedback_cases': len(extrema_feedback)}
|
||||
logger.info(f"RETROSPECTIVE: Extrema feedback provided for {len(extrema_feedback)} cases")
|
||||
except Exception as e:
|
||||
logger.warning(f"Extrema retrospective training failed: {e}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training models retrospectively: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _train_dqn_retrospective(self) -> Dict[str, Any]:
|
||||
"""Train DQN agent using retrospective experience replay"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
return {'error': 'DQN agent not available'}
|
||||
|
||||
dqn_agent = self.orchestrator.rl_agent
|
||||
experiences_added = 0
|
||||
|
||||
# Add retrospective experiences to DQN replay buffer
|
||||
for case in self.completed_cases:
|
||||
try:
|
||||
# Extract state from model inputs
|
||||
state = self._extract_state_vector(case.model_inputs)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
# Action mapping: BUY=0, SELL=1
|
||||
action = 0 if case.action == 'BUY' else 1
|
||||
|
||||
# Use reward signal as immediate reward
|
||||
reward = case.reward_signal
|
||||
|
||||
# For retrospective training, next_state is None (terminal)
|
||||
next_state = np.zeros_like(state) # Terminal state
|
||||
done = True
|
||||
|
||||
# Add experience to DQN replay buffer
|
||||
if hasattr(dqn_agent, 'add_experience'):
|
||||
dqn_agent.add_experience(state, action, reward, next_state, done)
|
||||
experiences_added += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding DQN experience: {e}")
|
||||
continue
|
||||
|
||||
# Train DQN if we have enough experiences
|
||||
if experiences_added > 0 and hasattr(dqn_agent, 'train'):
|
||||
try:
|
||||
# Perform multiple training steps on retrospective data
|
||||
training_steps = min(10, experiences_added // 4) # Conservative training
|
||||
for _ in range(training_steps):
|
||||
loss = dqn_agent.train()
|
||||
if loss is None:
|
||||
break
|
||||
|
||||
return {
|
||||
'experiences_added': experiences_added,
|
||||
'training_steps': training_steps,
|
||||
'method': 'retrospective_experience_replay'
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"DQN training step failed: {e}")
|
||||
return {'experiences_added': experiences_added, 'training_error': str(e)}
|
||||
|
||||
return {'experiences_added': experiences_added, 'training_steps': 0}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in DQN retrospective training: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _extract_state_vector(self, model_inputs: Dict[str, Any]) -> Optional[np.ndarray]:
|
||||
"""Extract state vector for DQN training from model inputs"""
|
||||
try:
|
||||
# Try to get pre-built RL state
|
||||
if 'dqn_state' in model_inputs:
|
||||
state = model_inputs['dqn_state']
|
||||
if isinstance(state, dict) and 'state_vector' in state:
|
||||
return np.array(state['state_vector'])
|
||||
|
||||
# Build state from market features
|
||||
market_state = model_inputs.get('market_state', {})
|
||||
features = []
|
||||
|
||||
# Price features
|
||||
for key in ['current_price', 'price_sma_5', 'price_sma_20', 'price_std_20', 'price_rsi']:
|
||||
features.append(market_state.get(key, 0.0))
|
||||
|
||||
# Volume features
|
||||
for key in ['volume_current', 'volume_sma_20', 'volume_ratio']:
|
||||
features.append(market_state.get(key, 0.0))
|
||||
|
||||
# Technical indicators
|
||||
indicators = model_inputs.get('technical_indicators', {})
|
||||
for key in ['sma_10', 'sma_20', 'bb_upper', 'bb_lower', 'bb_position', 'macd', 'volatility']:
|
||||
features.append(indicators.get(key, 0.0))
|
||||
|
||||
if len(features) < 5: # Minimum required features
|
||||
return None
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting state vector: {e}")
|
||||
return None
|
||||
|
||||
def _create_extrema_feedback(self) -> List[Dict[str, Any]]:
|
||||
"""Create feedback data for extrema trainer"""
|
||||
feedback = []
|
||||
|
||||
try:
|
||||
for case in self.completed_cases:
|
||||
if case.outcome_label in [0, 1]: # Only profit/loss cases
|
||||
feedback_item = {
|
||||
'symbol': case.symbol,
|
||||
'action': case.action,
|
||||
'entry_price': case.entry_price,
|
||||
'exit_price': case.exit_price,
|
||||
'was_profitable': case.outcome_label == 1,
|
||||
'reward_signal': case.reward_signal,
|
||||
'market_state': case.market_state
|
||||
}
|
||||
feedback.append(feedback_item)
|
||||
|
||||
return feedback
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating extrema feedback: {e}")
|
||||
return []
|
||||
|
||||
def get_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get current training statistics"""
|
||||
stats = self.training_stats.copy()
|
||||
stats['total_cases_in_memory'] = len(self.completed_cases)
|
||||
stats['training_queue_size'] = self.training_queue.qsize()
|
||||
stats['is_training_active'] = self.is_training_active
|
||||
|
||||
# Calculate profit metrics
|
||||
if len(self.completed_cases) > 0:
|
||||
profitable_count = sum(1 for c in self.completed_cases if c.pnl > 0)
|
||||
stats['profit_rate'] = profitable_count / len(self.completed_cases)
|
||||
stats['total_pnl'] = sum(c.pnl for c in self.completed_cases)
|
||||
stats['avg_reward'] = sum(c.reward_signal for c in self.completed_cases) / len(self.completed_cases)
|
||||
|
||||
return stats
|
||||
|
||||
def force_training_session(self) -> bool:
|
||||
"""Force a training session regardless of timing constraints"""
|
||||
try:
|
||||
if self.is_training_active:
|
||||
logger.warning("Training already in progress")
|
||||
return False
|
||||
|
||||
if len(self.completed_cases) < 1:
|
||||
logger.warning("No completed cases available for training")
|
||||
return False
|
||||
|
||||
logger.info("RETROSPECTIVE: Forcing training session")
|
||||
self._start_training_session()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcing training session: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the retrospective trainer"""
|
||||
try:
|
||||
self.is_training_active = False
|
||||
if self.training_thread and self.training_thread.is_alive():
|
||||
self.training_thread.join(timeout=10)
|
||||
logger.info("RetrospectiveTrainer stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping RetrospectiveTrainer: {e}")
|
||||
|
||||
|
||||
def create_retrospective_trainer(orchestrator=None, config: Optional[Dict[str, Any]] = None) -> RetrospectiveTrainer:
|
||||
"""Factory function to create a RetrospectiveTrainer instance"""
|
||||
return RetrospectiveTrainer(orchestrator=orchestrator, config=config)
|
||||
@@ -1,529 +0,0 @@
|
||||
"""
|
||||
RL Training Pipeline with Comprehensive Experience Storage and Replay
|
||||
|
||||
This module implements a robust RL training pipeline that:
|
||||
1. Stores all training experiences with profitability metrics
|
||||
2. Implements profit-weighted experience replay
|
||||
3. Tracks gradient information for each training step
|
||||
4. Enables retraining on most profitable trading sequences
|
||||
5. Maintains comprehensive trading episode analysis
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import pickle
|
||||
from collections import deque
|
||||
import threading
|
||||
import random
|
||||
|
||||
from .training_data_collector import get_training_data_collector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class RLExperience:
|
||||
"""Single RL experience with complete state-action-reward information"""
|
||||
experience_id: str
|
||||
timestamp: datetime
|
||||
episode_id: str
|
||||
|
||||
# Core RL components
|
||||
state: np.ndarray
|
||||
action: int # 0=SELL, 1=HOLD, 2=BUY
|
||||
reward: float
|
||||
next_state: np.ndarray
|
||||
done: bool
|
||||
|
||||
# Extended state information
|
||||
market_context: Dict[str, Any]
|
||||
cnn_predictions: Optional[Dict[str, Any]] = None
|
||||
confidence_score: float = 0.0
|
||||
|
||||
# Actual trading outcome
|
||||
actual_profit: Optional[float] = None
|
||||
actual_holding_time: Optional[timedelta] = None
|
||||
optimal_action: Optional[int] = None
|
||||
|
||||
# Experience value for replay
|
||||
experience_value: float = 0.0
|
||||
profitability_score: float = 0.0
|
||||
learning_priority: float = 0.0
|
||||
|
||||
# Training metadata
|
||||
times_trained: int = 0
|
||||
last_trained: Optional[datetime] = None
|
||||
|
||||
class ProfitWeightedExperienceBuffer:
|
||||
"""Experience buffer with profit-weighted sampling for replay"""
|
||||
|
||||
def __init__(self, max_size: int = 100000):
|
||||
self.max_size = max_size
|
||||
self.experiences: Dict[str, RLExperience] = {}
|
||||
self.experience_order: deque = deque(maxlen=max_size)
|
||||
self.profitable_experiences: List[str] = []
|
||||
self.total_experiences = 0
|
||||
self.total_profitable = 0
|
||||
|
||||
def add_experience(self, experience: RLExperience):
|
||||
"""Add experience to buffer"""
|
||||
try:
|
||||
self.experiences[experience.experience_id] = experience
|
||||
self.experience_order.append(experience.experience_id)
|
||||
|
||||
if experience.actual_profit is not None and experience.actual_profit > 0:
|
||||
self.profitable_experiences.append(experience.experience_id)
|
||||
self.total_profitable += 1
|
||||
|
||||
# Remove oldest if buffer is full
|
||||
if len(self.experiences) > self.max_size:
|
||||
oldest_id = self.experience_order[0]
|
||||
if oldest_id in self.experiences:
|
||||
del self.experiences[oldest_id]
|
||||
if oldest_id in self.profitable_experiences:
|
||||
self.profitable_experiences.remove(oldest_id)
|
||||
|
||||
self.total_experiences += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding experience to buffer: {e}")
|
||||
|
||||
def sample_batch(self, batch_size: int, prioritize_profitable: bool = True) -> List[RLExperience]:
|
||||
"""Sample batch with profit-weighted prioritization"""
|
||||
try:
|
||||
if len(self.experiences) < batch_size:
|
||||
return list(self.experiences.values())
|
||||
|
||||
if prioritize_profitable and len(self.profitable_experiences) > batch_size // 2:
|
||||
# Sample mix of profitable and all experiences
|
||||
profitable_sample_size = min(batch_size // 2, len(self.profitable_experiences))
|
||||
remaining_sample_size = batch_size - profitable_sample_size
|
||||
|
||||
profitable_ids = random.sample(self.profitable_experiences, profitable_sample_size)
|
||||
all_ids = list(self.experiences.keys())
|
||||
remaining_ids = random.sample(all_ids, remaining_sample_size)
|
||||
|
||||
sampled_ids = profitable_ids + remaining_ids
|
||||
else:
|
||||
# Random sampling from all experiences
|
||||
all_ids = list(self.experiences.keys())
|
||||
sampled_ids = random.sample(all_ids, batch_size)
|
||||
|
||||
sampled_experiences = [self.experiences[exp_id] for exp_id in sampled_ids]
|
||||
|
||||
# Update training counts
|
||||
for experience in sampled_experiences:
|
||||
experience.times_trained += 1
|
||||
experience.last_trained = datetime.now()
|
||||
|
||||
return sampled_experiences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sampling batch: {e}")
|
||||
return list(self.experiences.values())[:batch_size]
|
||||
|
||||
def get_most_profitable_experiences(self, limit: int = 100) -> List[RLExperience]:
|
||||
"""Get most profitable experiences for targeted training"""
|
||||
try:
|
||||
profitable_experiences = [
|
||||
self.experiences[exp_id] for exp_id in self.profitable_experiences
|
||||
if exp_id in self.experiences
|
||||
]
|
||||
|
||||
profitable_experiences.sort(
|
||||
key=lambda x: x.actual_profit if x.actual_profit else 0,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return profitable_experiences[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting profitable experiences: {e}")
|
||||
return []
|
||||
|
||||
class RLTradingAgent(nn.Module):
|
||||
"""RL Trading Agent with comprehensive state processing"""
|
||||
|
||||
def __init__(self, state_dim: int = 2000, action_dim: int = 3, hidden_dim: int = 512):
|
||||
super(RLTradingAgent, self).__init__()
|
||||
|
||||
self.state_dim = state_dim
|
||||
self.action_dim = action_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
# State processing network
|
||||
self.state_processor = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LayerNorm(hidden_dim // 2),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# Q-value network
|
||||
self.q_network = nn.Sequential(
|
||||
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim // 4, action_dim)
|
||||
)
|
||||
|
||||
# Policy network
|
||||
self.policy_network = nn.Sequential(
|
||||
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim // 4, action_dim),
|
||||
nn.Softmax(dim=-1)
|
||||
)
|
||||
|
||||
# Value network
|
||||
self.value_network = nn.Sequential(
|
||||
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim // 4, 1)
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
"""Forward pass through the agent"""
|
||||
processed_state = self.state_processor(state)
|
||||
|
||||
q_values = self.q_network(processed_state)
|
||||
policy_probs = self.policy_network(processed_state)
|
||||
state_value = self.value_network(processed_state)
|
||||
|
||||
return {
|
||||
'q_values': q_values,
|
||||
'policy_probs': policy_probs,
|
||||
'state_value': state_value,
|
||||
'processed_state': processed_state
|
||||
}
|
||||
|
||||
def select_action(self, state, epsilon: float = 0.1) -> Tuple[int, float]:
|
||||
"""Select action using epsilon-greedy policy"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
if isinstance(state, np.ndarray):
|
||||
state = torch.from_numpy(state).float().unsqueeze(0)
|
||||
|
||||
outputs = self.forward(state)
|
||||
|
||||
if random.random() < epsilon:
|
||||
action = random.randint(0, self.action_dim - 1)
|
||||
confidence = 0.33
|
||||
else:
|
||||
q_values = outputs['q_values']
|
||||
action = torch.argmax(q_values, dim=1).item()
|
||||
q_softmax = F.softmax(q_values, dim=1)
|
||||
confidence = torch.max(q_softmax).item()
|
||||
|
||||
return action, confidence
|
||||
|
||||
@dataclass
|
||||
class RLTrainingStep:
|
||||
"""Single RL training step with backpropagation data"""
|
||||
step_id: str
|
||||
timestamp: datetime
|
||||
batch_experiences: List[str]
|
||||
|
||||
# Training data
|
||||
total_loss: float
|
||||
q_loss: float
|
||||
policy_loss: float
|
||||
|
||||
# Gradients
|
||||
gradients: Dict[str, torch.Tensor]
|
||||
gradient_norms: Dict[str, float]
|
||||
|
||||
# Metadata
|
||||
learning_rate: float = 0.001
|
||||
batch_size: int = 32
|
||||
|
||||
# Performance
|
||||
batch_profitability: float = 0.0
|
||||
correct_actions: int = 0
|
||||
total_actions: int = 0
|
||||
step_value: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class RLTrainingSession:
|
||||
"""Complete RL training session"""
|
||||
session_id: str
|
||||
start_timestamp: datetime
|
||||
end_timestamp: Optional[datetime] = None
|
||||
|
||||
training_mode: str = 'experience_replay'
|
||||
symbol: str = ''
|
||||
|
||||
training_steps: List[RLTrainingStep] = field(default_factory=list)
|
||||
|
||||
total_steps: int = 0
|
||||
average_loss: float = 0.0
|
||||
best_loss: float = float('inf')
|
||||
|
||||
profitable_actions: int = 0
|
||||
total_actions: int = 0
|
||||
profitability_rate: float = 0.0
|
||||
session_value: float = 0.0
|
||||
|
||||
class RLTrainer:
|
||||
"""RL trainer with comprehensive experience storage and replay"""
|
||||
|
||||
def __init__(self, agent: RLTradingAgent, device: str = 'cuda', storage_dir: str = "rl_training_storage"):
|
||||
self.agent = agent.to(device)
|
||||
self.device = device
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001)
|
||||
self.experience_buffer = ProfitWeightedExperienceBuffer()
|
||||
self.data_collector = get_training_data_collector()
|
||||
|
||||
self.training_sessions: List[RLTrainingSession] = []
|
||||
self.current_session: Optional[RLTrainingSession] = None
|
||||
|
||||
self.gamma = 0.99
|
||||
|
||||
self.training_stats = {
|
||||
'total_sessions': 0,
|
||||
'total_steps': 0,
|
||||
'total_experiences': 0,
|
||||
'profitable_actions': 0,
|
||||
'total_actions': 0,
|
||||
'average_reward': 0.0
|
||||
}
|
||||
|
||||
logger.info(f"RL Trainer initialized with {sum(p.numel() for p in agent.parameters()):,} parameters")
|
||||
|
||||
def add_experience(self, state: np.ndarray, action: int, reward: float,
|
||||
next_state: np.ndarray, done: bool, market_context: Dict[str, Any],
|
||||
cnn_predictions: Dict[str, Any] = None, confidence_score: float = 0.0) -> str:
|
||||
"""Add experience to the buffer"""
|
||||
try:
|
||||
experience_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
|
||||
|
||||
experience = RLExperience(
|
||||
experience_id=experience_id,
|
||||
timestamp=datetime.now(),
|
||||
episode_id=market_context.get('episode_id', 'unknown'),
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
market_context=market_context,
|
||||
cnn_predictions=cnn_predictions,
|
||||
confidence_score=confidence_score
|
||||
)
|
||||
|
||||
self.experience_buffer.add_experience(experience)
|
||||
self.training_stats['total_experiences'] += 1
|
||||
|
||||
return experience_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding experience: {e}")
|
||||
return None
|
||||
|
||||
def train_on_experiences(self, batch_size: int = 32, num_batches: int = 10) -> Dict[str, Any]:
|
||||
"""Train on experiences with comprehensive data storage"""
|
||||
try:
|
||||
session = RLTrainingSession(
|
||||
session_id=f"rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
start_timestamp=datetime.now(),
|
||||
training_mode='experience_replay'
|
||||
)
|
||||
self.current_session = session
|
||||
|
||||
self.agent.train()
|
||||
total_loss = 0.0
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
experiences = self.experience_buffer.sample_batch(batch_size, True)
|
||||
|
||||
if len(experiences) < batch_size:
|
||||
continue
|
||||
|
||||
# Prepare batch tensors
|
||||
states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
|
||||
actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
|
||||
rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
|
||||
next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
|
||||
dones = torch.BoolTensor([exp.done for exp in experiences]).to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
current_outputs = self.agent(states)
|
||||
current_q_values = current_outputs['q_values']
|
||||
|
||||
# Calculate target Q-values
|
||||
with torch.no_grad():
|
||||
next_outputs = self.agent(next_states)
|
||||
next_q_values = next_outputs['q_values']
|
||||
max_next_q_values = torch.max(next_q_values, dim=1)[0]
|
||||
target_q_values = rewards + (self.gamma * max_next_q_values * ~dones)
|
||||
|
||||
# Calculate loss
|
||||
current_q_values_for_actions = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
q_loss = F.mse_loss(current_q_values_for_actions, target_q_values)
|
||||
|
||||
# Backward pass
|
||||
q_loss.backward()
|
||||
|
||||
# Store gradients
|
||||
gradients = {}
|
||||
gradient_norms = {}
|
||||
for name, param in self.agent.named_parameters():
|
||||
if param.grad is not None:
|
||||
gradients[name] = param.grad.clone().detach()
|
||||
gradient_norms[name] = param.grad.norm().item()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Create training step record
|
||||
step = RLTrainingStep(
|
||||
step_id=f"{session.session_id}_step_{batch_idx}",
|
||||
timestamp=datetime.now(),
|
||||
batch_experiences=[exp.experience_id for exp in experiences],
|
||||
total_loss=q_loss.item(),
|
||||
q_loss=q_loss.item(),
|
||||
policy_loss=0.0,
|
||||
gradients=gradients,
|
||||
gradient_norms=gradient_norms,
|
||||
batch_size=len(experiences)
|
||||
)
|
||||
|
||||
session.training_steps.append(step)
|
||||
total_loss += q_loss.item()
|
||||
|
||||
# Finalize session
|
||||
session.end_timestamp = datetime.now()
|
||||
session.total_steps = num_batches
|
||||
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
|
||||
self._save_training_session(session)
|
||||
|
||||
self.training_stats['total_sessions'] += 1
|
||||
self.training_stats['total_steps'] += session.total_steps
|
||||
|
||||
logger.info(f"RL training session completed: {session.session_id}")
|
||||
logger.info(f"Average loss: {session.average_loss:.4f}")
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'session_id': session.session_id,
|
||||
'average_loss': session.average_loss,
|
||||
'total_steps': session.total_steps
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RL training session: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
finally:
|
||||
self.current_session = None
|
||||
|
||||
def train_on_profitable_experiences(self, min_profitability: float = 0.1,
|
||||
max_experiences: int = 1000, batch_size: int = 32) -> Dict[str, Any]:
|
||||
"""Train specifically on most profitable experiences"""
|
||||
try:
|
||||
profitable_experiences = self.experience_buffer.get_most_profitable_experiences(max_experiences)
|
||||
|
||||
filtered_experiences = [
|
||||
exp for exp in profitable_experiences
|
||||
if exp.actual_profit is not None and exp.actual_profit >= min_profitability
|
||||
]
|
||||
|
||||
if len(filtered_experiences) < batch_size:
|
||||
return {'status': 'insufficient_data', 'experiences_found': len(filtered_experiences)}
|
||||
|
||||
logger.info(f"Training on {len(filtered_experiences)} profitable experiences")
|
||||
|
||||
num_batches = len(filtered_experiences) // batch_size
|
||||
|
||||
# Temporarily replace buffer sampling
|
||||
original_sample_method = self.experience_buffer.sample_batch
|
||||
|
||||
def profitable_sample_batch(batch_size, prioritize_profitable=True):
|
||||
return random.sample(filtered_experiences, min(batch_size, len(filtered_experiences)))
|
||||
|
||||
self.experience_buffer.sample_batch = profitable_sample_batch
|
||||
|
||||
try:
|
||||
results = self.train_on_experiences(batch_size=batch_size, num_batches=num_batches)
|
||||
results['training_mode'] = 'profitable_replay'
|
||||
results['experiences_used'] = len(filtered_experiences)
|
||||
return results
|
||||
finally:
|
||||
self.experience_buffer.sample_batch = original_sample_method
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on profitable experiences: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
def _save_training_session(self, session: RLTrainingSession):
|
||||
"""Save training session to disk"""
|
||||
try:
|
||||
session_dir = self.storage_dir / 'sessions'
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
session_file = session_dir / f"{session.session_id}.pkl"
|
||||
with open(session_file, 'wb') as f:
|
||||
pickle.dump(session, f)
|
||||
|
||||
metadata = {
|
||||
'session_id': session.session_id,
|
||||
'start_timestamp': session.start_timestamp.isoformat(),
|
||||
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
|
||||
'training_mode': session.training_mode,
|
||||
'total_steps': session.total_steps,
|
||||
'average_loss': session.average_loss
|
||||
}
|
||||
|
||||
metadata_file = session_dir / f"{session.session_id}_metadata.json"
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving training session: {e}")
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive training statistics"""
|
||||
stats = self.training_stats.copy()
|
||||
|
||||
if self.training_sessions:
|
||||
recent_sessions = sorted(self.training_sessions, key=lambda x: x.start_timestamp, reverse=True)[:10]
|
||||
stats['recent_sessions'] = [
|
||||
{
|
||||
'session_id': s.session_id,
|
||||
'timestamp': s.start_timestamp.isoformat(),
|
||||
'mode': s.training_mode,
|
||||
'average_loss': s.average_loss
|
||||
}
|
||||
for s in recent_sessions
|
||||
]
|
||||
|
||||
return stats
|
||||
|
||||
# Global instance
|
||||
rl_trainer = None
|
||||
|
||||
def get_rl_trainer(agent: RLTradingAgent = None) -> RLTrainer:
|
||||
"""Get global RL trainer instance"""
|
||||
global rl_trainer
|
||||
if rl_trainer is None:
|
||||
if agent is None:
|
||||
agent = RLTradingAgent()
|
||||
rl_trainer = RLTrainer(agent)
|
||||
return rl_trainer
|
||||
@@ -1,460 +0,0 @@
|
||||
"""
|
||||
Robust COB (Consolidated Order Book) Provider
|
||||
|
||||
This module provides a robust COB data provider that handles:
|
||||
- HTTP 418 errors from Binance (rate limiting)
|
||||
- Thread safety issues
|
||||
- API rate limiting and backoff
|
||||
- Fallback data sources
|
||||
- Error recovery strategies
|
||||
|
||||
Features:
|
||||
- Automatic rate limiting and backoff
|
||||
- Multiple exchange support with fallbacks
|
||||
- Thread-safe operations
|
||||
- Comprehensive error handling
|
||||
- Data validation and integrity checking
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
import json
|
||||
import numpy as np
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import requests
|
||||
|
||||
from .api_rate_limiter import get_rate_limiter, RateLimitConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class COBData:
|
||||
"""Consolidated Order Book data structure"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
bids: List[Tuple[float, float]] # [(price, quantity), ...]
|
||||
asks: List[Tuple[float, float]] # [(price, quantity), ...]
|
||||
|
||||
# Derived metrics
|
||||
spread: float = 0.0
|
||||
mid_price: float = 0.0
|
||||
total_bid_volume: float = 0.0
|
||||
total_ask_volume: float = 0.0
|
||||
|
||||
# Data quality
|
||||
data_source: str = 'unknown'
|
||||
quality_score: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Calculate derived metrics"""
|
||||
if self.bids and self.asks:
|
||||
self.spread = self.asks[0][0] - self.bids[0][0]
|
||||
self.mid_price = (self.asks[0][0] + self.bids[0][0]) / 2
|
||||
self.total_bid_volume = sum(qty for _, qty in self.bids)
|
||||
self.total_ask_volume = sum(qty for _, qty in self.asks)
|
||||
|
||||
# Calculate quality score based on data completeness
|
||||
self.quality_score = min(
|
||||
len(self.bids) / 20, # Expect at least 20 bid levels
|
||||
len(self.asks) / 20, # Expect at least 20 ask levels
|
||||
1.0
|
||||
)
|
||||
|
||||
class RobustCOBProvider:
|
||||
"""Robust COB provider with error handling and rate limiting"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None):
|
||||
self.symbols = symbols or ['ETHUSDT', 'BTCUSDT']
|
||||
|
||||
# Rate limiter
|
||||
self.rate_limiter = get_rate_limiter()
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Data cache
|
||||
self.cob_cache: Dict[str, COBData] = {}
|
||||
self.cache_timestamps: Dict[str, datetime] = {}
|
||||
self.cache_ttl = timedelta(seconds=5) # 5 second cache TTL
|
||||
|
||||
# Error tracking
|
||||
self.error_counts: Dict[str, int] = {}
|
||||
self.last_successful_fetch: Dict[str, datetime] = {}
|
||||
|
||||
# Background fetching
|
||||
self.is_running = False
|
||||
self.fetch_threads: Dict[str, threading.Thread] = {}
|
||||
self.executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="COB-Fetcher")
|
||||
|
||||
# Fallback data
|
||||
self.fallback_data: Dict[str, COBData] = {}
|
||||
|
||||
# Performance tracking
|
||||
self.fetch_stats = {
|
||||
'total_requests': 0,
|
||||
'successful_requests': 0,
|
||||
'failed_requests': 0,
|
||||
'rate_limited_requests': 0,
|
||||
'cache_hits': 0,
|
||||
'fallback_uses': 0
|
||||
}
|
||||
|
||||
logger.info(f"Robust COB Provider initialized for symbols: {self.symbols}")
|
||||
|
||||
def start_background_fetching(self):
|
||||
"""Start background COB data fetching"""
|
||||
if self.is_running:
|
||||
logger.warning("Background fetching already running")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start fetching thread for each symbol
|
||||
for symbol in self.symbols:
|
||||
thread = threading.Thread(
|
||||
target=self._background_fetch_worker,
|
||||
args=(symbol,),
|
||||
name=f"COB-{symbol}",
|
||||
daemon=True
|
||||
)
|
||||
self.fetch_threads[symbol] = thread
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Started background COB fetching for {len(self.symbols)} symbols")
|
||||
|
||||
def stop_background_fetching(self):
|
||||
"""Stop background COB data fetching"""
|
||||
self.is_running = False
|
||||
|
||||
# Wait for threads to finish
|
||||
for symbol, thread in self.fetch_threads.items():
|
||||
thread.join(timeout=5)
|
||||
logger.debug(f"Stopped COB fetching for {symbol}")
|
||||
|
||||
# Shutdown executor
|
||||
self.executor.shutdown(wait=True, timeout=10)
|
||||
|
||||
logger.info("Stopped background COB fetching")
|
||||
|
||||
def _background_fetch_worker(self, symbol: str):
|
||||
"""Background worker for fetching COB data"""
|
||||
logger.info(f"Started COB fetching worker for {symbol}")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# Fetch COB data
|
||||
cob_data = self._fetch_cob_data_safe(symbol)
|
||||
|
||||
if cob_data:
|
||||
with self.lock:
|
||||
self.cob_cache[symbol] = cob_data
|
||||
self.cache_timestamps[symbol] = datetime.now()
|
||||
self.last_successful_fetch[symbol] = datetime.now()
|
||||
self.error_counts[symbol] = 0 # Reset error count on success
|
||||
|
||||
logger.debug(f"Updated COB cache for {symbol}")
|
||||
else:
|
||||
with self.lock:
|
||||
self.error_counts[symbol] = self.error_counts.get(symbol, 0) + 1
|
||||
|
||||
logger.debug(f"Failed to fetch COB for {symbol}, error count: {self.error_counts.get(symbol, 0)}")
|
||||
|
||||
# Wait before next fetch (adaptive based on errors)
|
||||
error_count = self.error_counts.get(symbol, 0)
|
||||
base_interval = 2.0 # Base 2 second interval
|
||||
backoff_interval = min(base_interval * (2 ** min(error_count, 5)), 60.0) # Max 60s
|
||||
|
||||
time.sleep(backoff_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in COB fetching worker for {symbol}: {e}")
|
||||
time.sleep(10) # Wait 10s on unexpected errors
|
||||
|
||||
logger.info(f"Stopped COB fetching worker for {symbol}")
|
||||
|
||||
def _fetch_cob_data_safe(self, symbol: str) -> Optional[COBData]:
|
||||
"""Safely fetch COB data with error handling"""
|
||||
try:
|
||||
self.fetch_stats['total_requests'] += 1
|
||||
|
||||
# Try Binance first
|
||||
cob_data = self._fetch_binance_cob(symbol)
|
||||
if cob_data:
|
||||
self.fetch_stats['successful_requests'] += 1
|
||||
return cob_data
|
||||
|
||||
# Try MEXC as fallback
|
||||
cob_data = self._fetch_mexc_cob(symbol)
|
||||
if cob_data:
|
||||
self.fetch_stats['successful_requests'] += 1
|
||||
cob_data.data_source = 'mexc_fallback'
|
||||
return cob_data
|
||||
|
||||
# Use cached fallback data if available
|
||||
if symbol in self.fallback_data:
|
||||
self.fetch_stats['fallback_uses'] += 1
|
||||
fallback = self.fallback_data[symbol]
|
||||
fallback.timestamp = datetime.now()
|
||||
fallback.data_source = 'fallback_cache'
|
||||
fallback.quality_score *= 0.5 # Reduce quality score for old data
|
||||
return fallback
|
||||
|
||||
self.fetch_stats['failed_requests'] += 1
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching COB data for {symbol}: {e}")
|
||||
self.fetch_stats['failed_requests'] += 1
|
||||
return None
|
||||
|
||||
def _fetch_binance_cob(self, symbol: str) -> Optional[COBData]:
|
||||
"""Fetch COB data from Binance with rate limiting"""
|
||||
try:
|
||||
url = f"https://api.binance.com/api/v3/depth"
|
||||
params = {
|
||||
'symbol': symbol,
|
||||
'limit': 100 # Get 100 levels
|
||||
}
|
||||
|
||||
# Use rate limiter
|
||||
response = self.rate_limiter.make_request(
|
||||
'binance_api',
|
||||
url,
|
||||
method='GET',
|
||||
params=params
|
||||
)
|
||||
|
||||
if not response:
|
||||
self.fetch_stats['rate_limited_requests'] += 1
|
||||
return None
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Binance COB API returned {response.status_code} for {symbol}")
|
||||
return None
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Parse order book data
|
||||
bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])]
|
||||
asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])]
|
||||
|
||||
if not bids or not asks:
|
||||
logger.warning(f"Empty order book data from Binance for {symbol}")
|
||||
return None
|
||||
|
||||
cob_data = COBData(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
bids=bids,
|
||||
asks=asks,
|
||||
data_source='binance'
|
||||
)
|
||||
|
||||
# Store as fallback for future use
|
||||
self.fallback_data[symbol] = cob_data
|
||||
|
||||
return cob_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Binance COB for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _fetch_mexc_cob(self, symbol: str) -> Optional[COBData]:
|
||||
"""Fetch COB data from MEXC as fallback"""
|
||||
try:
|
||||
url = f"https://api.mexc.com/api/v3/depth"
|
||||
params = {
|
||||
'symbol': symbol,
|
||||
'limit': 100
|
||||
}
|
||||
|
||||
response = self.rate_limiter.make_request(
|
||||
'mexc_api',
|
||||
url,
|
||||
method='GET',
|
||||
params=params
|
||||
)
|
||||
|
||||
if not response or response.status_code != 200:
|
||||
return None
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Parse order book data
|
||||
bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])]
|
||||
asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])]
|
||||
|
||||
if not bids or not asks:
|
||||
return None
|
||||
|
||||
return COBData(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
bids=bids,
|
||||
asks=asks,
|
||||
data_source='mexc'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error fetching MEXC COB for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_cob_data(self, symbol: str) -> Optional[COBData]:
|
||||
"""Get COB data for a symbol (from cache or fresh fetch)"""
|
||||
with self.lock:
|
||||
# Check cache first
|
||||
if symbol in self.cob_cache:
|
||||
cached_data = self.cob_cache[symbol]
|
||||
cache_time = self.cache_timestamps.get(symbol, datetime.min)
|
||||
|
||||
# Return cached data if still fresh
|
||||
if datetime.now() - cache_time < self.cache_ttl:
|
||||
self.fetch_stats['cache_hits'] += 1
|
||||
return cached_data
|
||||
|
||||
# If background fetching is running, return cached data even if stale
|
||||
if self.is_running and symbol in self.cob_cache:
|
||||
return self.cob_cache[symbol]
|
||||
|
||||
# Fetch fresh data if not running background fetching
|
||||
if not self.is_running:
|
||||
return self._fetch_cob_data_safe(symbol)
|
||||
|
||||
return None
|
||||
|
||||
def get_cob_features(self, symbol: str, feature_count: int = 120) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get COB features for ML models
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
feature_count: Number of features to return
|
||||
|
||||
Returns:
|
||||
Numpy array of COB features or None if no data
|
||||
"""
|
||||
cob_data = self.get_cob_data(symbol)
|
||||
if not cob_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
features = []
|
||||
|
||||
# Basic market metrics
|
||||
features.extend([
|
||||
cob_data.mid_price,
|
||||
cob_data.spread,
|
||||
cob_data.total_bid_volume,
|
||||
cob_data.total_ask_volume,
|
||||
cob_data.quality_score
|
||||
])
|
||||
|
||||
# Bid levels (price and volume)
|
||||
max_levels = min(len(cob_data.bids), 20)
|
||||
for i in range(max_levels):
|
||||
price, volume = cob_data.bids[i]
|
||||
features.extend([price, volume])
|
||||
|
||||
# Pad bid levels if needed
|
||||
for i in range(max_levels, 20):
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
# Ask levels (price and volume)
|
||||
max_levels = min(len(cob_data.asks), 20)
|
||||
for i in range(max_levels):
|
||||
price, volume = cob_data.asks[i]
|
||||
features.extend([price, volume])
|
||||
|
||||
# Pad ask levels if needed
|
||||
for i in range(max_levels, 20):
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
# Calculate additional features
|
||||
if len(cob_data.bids) > 0 and len(cob_data.asks) > 0:
|
||||
# Volume imbalance
|
||||
bid_volume_5 = sum(vol for _, vol in cob_data.bids[:5])
|
||||
ask_volume_5 = sum(vol for _, vol in cob_data.asks[:5])
|
||||
volume_imbalance = (bid_volume_5 - ask_volume_5) / (bid_volume_5 + ask_volume_5) if (bid_volume_5 + ask_volume_5) > 0 else 0
|
||||
features.append(volume_imbalance)
|
||||
|
||||
# Price levels
|
||||
bid_price_levels = [price for price, _ in cob_data.bids[:10]]
|
||||
ask_price_levels = [price for price, _ in cob_data.asks[:10]]
|
||||
features.extend(bid_price_levels + ask_price_levels)
|
||||
|
||||
# Pad or truncate to desired feature count
|
||||
if len(features) < feature_count:
|
||||
features.extend([0.0] * (feature_count - len(features)))
|
||||
else:
|
||||
features = features[:feature_count]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating COB features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_provider_status(self) -> Dict[str, Any]:
|
||||
"""Get provider status and statistics"""
|
||||
with self.lock:
|
||||
status = {
|
||||
'is_running': self.is_running,
|
||||
'symbols': self.symbols,
|
||||
'cache_status': {},
|
||||
'error_counts': self.error_counts.copy(),
|
||||
'last_successful_fetch': {
|
||||
symbol: timestamp.isoformat()
|
||||
for symbol, timestamp in self.last_successful_fetch.items()
|
||||
},
|
||||
'fetch_stats': self.fetch_stats.copy(),
|
||||
'rate_limiter_status': self.rate_limiter.get_all_endpoint_status()
|
||||
}
|
||||
|
||||
# Cache status for each symbol
|
||||
for symbol in self.symbols:
|
||||
cache_time = self.cache_timestamps.get(symbol)
|
||||
status['cache_status'][symbol] = {
|
||||
'has_data': symbol in self.cob_cache,
|
||||
'cache_time': cache_time.isoformat() if cache_time else None,
|
||||
'cache_age_seconds': (datetime.now() - cache_time).total_seconds() if cache_time else None,
|
||||
'data_quality': self.cob_cache[symbol].quality_score if symbol in self.cob_cache else 0.0
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
def reset_errors(self):
|
||||
"""Reset error counts and rate limiter"""
|
||||
with self.lock:
|
||||
self.error_counts.clear()
|
||||
self.rate_limiter.reset_all_endpoints()
|
||||
logger.info("Reset all error counts and rate limiter")
|
||||
|
||||
def force_refresh(self, symbol: str = None):
|
||||
"""Force refresh COB data for symbol(s)"""
|
||||
symbols_to_refresh = [symbol] if symbol else self.symbols
|
||||
|
||||
for sym in symbols_to_refresh:
|
||||
# Clear cache to force refresh
|
||||
with self.lock:
|
||||
if sym in self.cob_cache:
|
||||
del self.cob_cache[sym]
|
||||
if sym in self.cache_timestamps:
|
||||
del self.cache_timestamps[sym]
|
||||
|
||||
logger.info(f"Forced refresh for {sym}")
|
||||
|
||||
# Global COB provider instance
|
||||
_global_cob_provider = None
|
||||
|
||||
def get_cob_provider(symbols: List[str] = None) -> RobustCOBProvider:
|
||||
"""Get global COB provider instance"""
|
||||
global _global_cob_provider
|
||||
if _global_cob_provider is None:
|
||||
_global_cob_provider = RobustCOBProvider(symbols)
|
||||
return _global_cob_provider
|
||||
@@ -1,350 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Shared COB Service - Eliminates Redundant COB Implementations
|
||||
|
||||
This service provides a singleton COB integration that can be shared across:
|
||||
- Dashboard components
|
||||
- RL trading systems
|
||||
- Enhanced orchestrators
|
||||
- Training pipelines
|
||||
|
||||
Instead of each component creating its own COBIntegration instance,
|
||||
they all share this single service, eliminating redundant connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import weakref
|
||||
from typing import Dict, List, Optional, Any, Callable, Set
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .cob_integration import COBIntegration
|
||||
from .multi_exchange_cob_provider import COBSnapshot
|
||||
from .data_provider import DataProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class COBSubscription:
|
||||
"""Represents a subscription to COB updates"""
|
||||
subscriber_id: str
|
||||
callback: Callable
|
||||
symbol_filter: Optional[List[str]] = None
|
||||
callback_type: str = "general" # general, cnn, dqn, dashboard
|
||||
|
||||
class SharedCOBService:
|
||||
"""
|
||||
Shared COB Service - Singleton pattern for unified COB data access
|
||||
|
||||
This service eliminates redundant COB integrations by providing a single
|
||||
shared instance that all components can subscribe to.
|
||||
"""
|
||||
|
||||
_instance: Optional['SharedCOBService'] = None
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Singleton pattern implementation"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SharedCOBService, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, symbols: Optional[List[str]] = None, data_provider: Optional[DataProvider] = None):
|
||||
"""Initialize shared COB service (only called once due to singleton)"""
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
|
||||
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
|
||||
self.data_provider = data_provider
|
||||
|
||||
# Single COB integration instance
|
||||
self.cob_integration: Optional[COBIntegration] = None
|
||||
self.is_running = False
|
||||
|
||||
# Subscriber management
|
||||
self.subscribers: Dict[str, COBSubscription] = {}
|
||||
self.subscriber_counter = 0
|
||||
self.subscription_lock = Lock()
|
||||
|
||||
# Cached data for immediate access
|
||||
self.latest_snapshots: Dict[str, COBSnapshot] = {}
|
||||
self.latest_cnn_features: Dict[str, Any] = {}
|
||||
self.latest_dqn_states: Dict[str, Any] = {}
|
||||
|
||||
# Performance tracking
|
||||
self.total_subscribers = 0
|
||||
self.update_count = 0
|
||||
self.start_time = None
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"SharedCOBService initialized for symbols: {self.symbols}")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the shared COB service"""
|
||||
if self.is_running:
|
||||
logger.warning("SharedCOBService already running")
|
||||
return
|
||||
|
||||
logger.info("Starting SharedCOBService...")
|
||||
|
||||
try:
|
||||
# Initialize COB integration if not already done
|
||||
if self.cob_integration is None:
|
||||
self.cob_integration = COBIntegration(
|
||||
data_provider=self.data_provider,
|
||||
symbols=self.symbols
|
||||
)
|
||||
|
||||
# Register internal callbacks
|
||||
self.cob_integration.add_cnn_callback(self._on_cob_cnn_update)
|
||||
self.cob_integration.add_dqn_callback(self._on_cob_dqn_update)
|
||||
self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_update)
|
||||
|
||||
# Start COB integration
|
||||
await self.cob_integration.start()
|
||||
|
||||
self.is_running = True
|
||||
self.start_time = datetime.now()
|
||||
|
||||
logger.info("SharedCOBService started successfully")
|
||||
logger.info(f"Active subscribers: {len(self.subscribers)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting SharedCOBService: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the shared COB service"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping SharedCOBService...")
|
||||
|
||||
try:
|
||||
if self.cob_integration:
|
||||
await self.cob_integration.stop()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
# Notify all subscribers of shutdown
|
||||
for subscription in self.subscribers.values():
|
||||
try:
|
||||
if hasattr(subscription.callback, '__call__'):
|
||||
subscription.callback("SHUTDOWN", None)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}")
|
||||
|
||||
logger.info("SharedCOBService stopped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping SharedCOBService: {e}")
|
||||
|
||||
def subscribe(self,
|
||||
callback: Callable,
|
||||
callback_type: str = "general",
|
||||
symbol_filter: Optional[List[str]] = None,
|
||||
subscriber_name: str = None) -> str:
|
||||
"""
|
||||
Subscribe to COB updates
|
||||
|
||||
Args:
|
||||
callback: Function to call on updates
|
||||
callback_type: Type of callback ('general', 'cnn', 'dqn', 'dashboard')
|
||||
symbol_filter: Only receive updates for these symbols (None = all)
|
||||
subscriber_name: Optional name for the subscriber
|
||||
|
||||
Returns:
|
||||
Subscription ID for unsubscribing
|
||||
"""
|
||||
with self.subscription_lock:
|
||||
self.subscriber_counter += 1
|
||||
subscriber_id = f"{callback_type}_{self.subscriber_counter}"
|
||||
if subscriber_name:
|
||||
subscriber_id = f"{subscriber_name}_{subscriber_id}"
|
||||
|
||||
subscription = COBSubscription(
|
||||
subscriber_id=subscriber_id,
|
||||
callback=callback,
|
||||
symbol_filter=symbol_filter,
|
||||
callback_type=callback_type
|
||||
)
|
||||
|
||||
self.subscribers[subscriber_id] = subscription
|
||||
self.total_subscribers += 1
|
||||
|
||||
logger.info(f"New subscriber: {subscriber_id} ({callback_type})")
|
||||
logger.info(f"Total active subscribers: {len(self.subscribers)}")
|
||||
|
||||
return subscriber_id
|
||||
|
||||
def unsubscribe(self, subscriber_id: str) -> bool:
|
||||
"""
|
||||
Unsubscribe from COB updates
|
||||
|
||||
Args:
|
||||
subscriber_id: ID returned from subscribe()
|
||||
|
||||
Returns:
|
||||
True if successfully unsubscribed
|
||||
"""
|
||||
with self.subscription_lock:
|
||||
if subscriber_id in self.subscribers:
|
||||
del self.subscribers[subscriber_id]
|
||||
logger.info(f"Unsubscribed: {subscriber_id}")
|
||||
logger.info(f"Remaining subscribers: {len(self.subscribers)}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Subscriber not found: {subscriber_id}")
|
||||
return False
|
||||
|
||||
# Internal callback handlers
|
||||
|
||||
async def _on_cob_cnn_update(self, symbol: str, data: Dict):
|
||||
"""Handle CNN feature updates from COB integration"""
|
||||
try:
|
||||
self.latest_cnn_features[symbol] = data
|
||||
await self._notify_subscribers("cnn", symbol, data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN update handler: {e}")
|
||||
|
||||
async def _on_cob_dqn_update(self, symbol: str, data: Dict):
|
||||
"""Handle DQN state updates from COB integration"""
|
||||
try:
|
||||
self.latest_dqn_states[symbol] = data
|
||||
await self._notify_subscribers("dqn", symbol, data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in DQN update handler: {e}")
|
||||
|
||||
async def _on_cob_dashboard_update(self, symbol: str, data: Dict):
|
||||
"""Handle dashboard updates from COB integration"""
|
||||
try:
|
||||
# Store snapshot if it's a COBSnapshot
|
||||
if hasattr(data, 'volume_weighted_mid'): # Duck typing for COBSnapshot
|
||||
self.latest_snapshots[symbol] = data
|
||||
|
||||
await self._notify_subscribers("dashboard", symbol, data)
|
||||
await self._notify_subscribers("general", symbol, data)
|
||||
|
||||
self.update_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in dashboard update handler: {e}")
|
||||
|
||||
async def _notify_subscribers(self, callback_type: str, symbol: str, data: Any):
|
||||
"""Notify all relevant subscribers of an update"""
|
||||
try:
|
||||
relevant_subscribers = [
|
||||
sub for sub in self.subscribers.values()
|
||||
if (sub.callback_type == callback_type or sub.callback_type == "general") and
|
||||
(sub.symbol_filter is None or symbol in sub.symbol_filter)
|
||||
]
|
||||
|
||||
for subscription in relevant_subscribers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(subscription.callback):
|
||||
asyncio.create_task(subscription.callback(symbol, data))
|
||||
else:
|
||||
subscription.callback(symbol, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying subscribers: {e}")
|
||||
|
||||
# Public data access methods
|
||||
|
||||
def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]:
|
||||
"""Get latest COB snapshot for a symbol"""
|
||||
if self.cob_integration:
|
||||
return self.cob_integration.get_cob_snapshot(symbol)
|
||||
return self.latest_snapshots.get(symbol)
|
||||
|
||||
def get_cnn_features(self, symbol: str) -> Optional[Any]:
|
||||
"""Get latest CNN features for a symbol"""
|
||||
return self.latest_cnn_features.get(symbol)
|
||||
|
||||
def get_dqn_state(self, symbol: str) -> Optional[Any]:
|
||||
"""Get latest DQN state for a symbol"""
|
||||
return self.latest_dqn_states.get(symbol)
|
||||
|
||||
def get_market_depth_analysis(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get detailed market depth analysis"""
|
||||
if self.cob_integration:
|
||||
return self.cob_integration.get_market_depth_analysis(symbol)
|
||||
return None
|
||||
|
||||
def get_exchange_breakdown(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get liquidity breakdown by exchange"""
|
||||
if self.cob_integration:
|
||||
return self.cob_integration.get_exchange_breakdown(symbol)
|
||||
return None
|
||||
|
||||
def get_price_buckets(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get fine-grain price buckets"""
|
||||
if self.cob_integration:
|
||||
return self.cob_integration.get_price_buckets(symbol)
|
||||
return None
|
||||
|
||||
def get_session_volume_profile(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get session volume profile"""
|
||||
if self.cob_integration and hasattr(self.cob_integration.cob_provider, 'get_session_volume_profile'):
|
||||
return self.cob_integration.cob_provider.get_session_volume_profile(symbol)
|
||||
return None
|
||||
|
||||
def get_realtime_stats_for_nn(self, symbol: str) -> Dict:
|
||||
"""Get real-time statistics formatted for NN models"""
|
||||
if self.cob_integration:
|
||||
return self.cob_integration.get_realtime_stats_for_nn(symbol)
|
||||
return {}
|
||||
|
||||
def get_service_statistics(self) -> Dict[str, Any]:
|
||||
"""Get service statistics"""
|
||||
uptime = None
|
||||
if self.start_time:
|
||||
uptime = (datetime.now() - self.start_time).total_seconds()
|
||||
|
||||
base_stats = {
|
||||
'service_name': 'SharedCOBService',
|
||||
'is_running': self.is_running,
|
||||
'symbols': self.symbols,
|
||||
'total_subscribers': len(self.subscribers),
|
||||
'lifetime_subscribers': self.total_subscribers,
|
||||
'update_count': self.update_count,
|
||||
'uptime_seconds': uptime,
|
||||
'subscribers_by_type': {}
|
||||
}
|
||||
|
||||
# Count subscribers by type
|
||||
for subscription in self.subscribers.values():
|
||||
callback_type = subscription.callback_type
|
||||
if callback_type not in base_stats['subscribers_by_type']:
|
||||
base_stats['subscribers_by_type'][callback_type] = 0
|
||||
base_stats['subscribers_by_type'][callback_type] += 1
|
||||
|
||||
# Get COB integration stats if available
|
||||
if self.cob_integration:
|
||||
cob_stats = self.cob_integration.get_statistics()
|
||||
base_stats.update(cob_stats)
|
||||
|
||||
return base_stats
|
||||
|
||||
# Global service instance access functions
|
||||
|
||||
def get_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService:
|
||||
"""Get the shared COB service instance"""
|
||||
return SharedCOBService(symbols=symbols, data_provider=data_provider)
|
||||
|
||||
async def start_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService:
|
||||
"""Start the shared COB service"""
|
||||
service = get_shared_cob_service(symbols=symbols, data_provider=data_provider)
|
||||
await service.start()
|
||||
return service
|
||||
|
||||
async def stop_shared_cob_service():
|
||||
"""Stop the shared COB service"""
|
||||
service = get_shared_cob_service()
|
||||
await service.stop()
|
||||
@@ -1,425 +0,0 @@
|
||||
"""
|
||||
Shared Data Manager for UI Stability Fix
|
||||
|
||||
Manages data sharing between processes through files with proper locking
|
||||
and atomic operations to prevent corruption and conflicts.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import tempfile
|
||||
import platform
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
# Windows-compatible file locking
|
||||
if platform.system() == "Windows":
|
||||
import msvcrt
|
||||
else:
|
||||
import fcntl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ProcessStatus:
|
||||
"""Model for process status information"""
|
||||
name: str
|
||||
pid: int
|
||||
status: str # 'running', 'stopped', 'error'
|
||||
start_time: datetime
|
||||
last_heartbeat: datetime
|
||||
memory_usage: float
|
||||
cpu_usage: float
|
||||
error_message: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary with datetime serialization"""
|
||||
data = asdict(self)
|
||||
data['start_time'] = self.start_time.isoformat()
|
||||
data['last_heartbeat'] = self.last_heartbeat.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ProcessStatus':
|
||||
"""Create from dictionary with datetime deserialization"""
|
||||
data['start_time'] = datetime.fromisoformat(data['start_time'])
|
||||
data['last_heartbeat'] = datetime.fromisoformat(data['last_heartbeat'])
|
||||
return cls(**data)
|
||||
|
||||
@dataclass
|
||||
class TrainingStatus:
|
||||
"""Model for training status information"""
|
||||
is_running: bool
|
||||
current_epoch: int
|
||||
total_epochs: int
|
||||
loss: float
|
||||
accuracy: float
|
||||
last_update: datetime
|
||||
model_path: str
|
||||
error_message: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary with datetime serialization"""
|
||||
data = asdict(self)
|
||||
data['last_update'] = self.last_update.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'TrainingStatus':
|
||||
"""Create from dictionary with datetime deserialization"""
|
||||
data['last_update'] = datetime.fromisoformat(data['last_update'])
|
||||
return cls(**data)
|
||||
|
||||
@dataclass
|
||||
class DashboardState:
|
||||
"""Model for dashboard state information"""
|
||||
is_connected: bool
|
||||
last_data_update: datetime
|
||||
active_connections: int
|
||||
error_count: int
|
||||
performance_metrics: Dict[str, float]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary with datetime serialization"""
|
||||
data = asdict(self)
|
||||
data['last_data_update'] = self.last_data_update.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DashboardState':
|
||||
"""Create from dictionary with datetime deserialization"""
|
||||
data['last_data_update'] = datetime.fromisoformat(data['last_data_update'])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class SharedDataManager:
|
||||
"""
|
||||
Manages data sharing between processes through files with proper locking
|
||||
and atomic operations to prevent corruption and conflicts.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str = "shared_data"):
|
||||
"""
|
||||
Initialize the shared data manager
|
||||
|
||||
Args:
|
||||
data_dir: Directory to store shared data files
|
||||
"""
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Define file paths for different data types
|
||||
self.training_status_file = self.data_dir / "training_status.json"
|
||||
self.dashboard_state_file = self.data_dir / "dashboard_state.json"
|
||||
self.process_status_file = self.data_dir / "process_status.json"
|
||||
self.market_data_file = self.data_dir / "market_data.json"
|
||||
self.model_metrics_file = self.data_dir / "model_metrics.json"
|
||||
|
||||
logger.info(f"SharedDataManager initialized with data directory: {self.data_dir}")
|
||||
|
||||
def _lock_file(self, file_handle, exclusive=True):
|
||||
"""Cross-platform file locking"""
|
||||
if platform.system() == "Windows":
|
||||
# Windows file locking
|
||||
try:
|
||||
if exclusive:
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
|
||||
else:
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
|
||||
except IOError:
|
||||
pass # File locking may not be available in all scenarios
|
||||
else:
|
||||
# Unix file locking
|
||||
lock_type = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH
|
||||
fcntl.flock(file_handle.fileno(), lock_type)
|
||||
|
||||
def _unlock_file(self, file_handle):
|
||||
"""Cross-platform file unlocking"""
|
||||
if platform.system() == "Windows":
|
||||
try:
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
except IOError:
|
||||
pass
|
||||
else:
|
||||
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
def _write_json_atomic(self, file_path: Path, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Write JSON data atomically with file locking
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to write
|
||||
data: Data to write as JSON
|
||||
"""
|
||||
temp_path = None
|
||||
try:
|
||||
# Create temporary file in the same directory
|
||||
temp_fd, temp_path = tempfile.mkstemp(
|
||||
dir=file_path.parent,
|
||||
prefix=f".{file_path.name}.",
|
||||
suffix=".tmp"
|
||||
)
|
||||
|
||||
with os.fdopen(temp_fd, 'w') as temp_file:
|
||||
# Lock the temporary file
|
||||
self._lock_file(temp_file, exclusive=True)
|
||||
|
||||
# Write data with proper formatting
|
||||
json.dump(data, temp_file, indent=2, default=str)
|
||||
temp_file.flush()
|
||||
os.fsync(temp_file.fileno())
|
||||
|
||||
# Unlock before closing
|
||||
self._unlock_file(temp_file)
|
||||
|
||||
# Atomically replace the original file
|
||||
os.replace(temp_path, file_path)
|
||||
logger.debug(f"Successfully wrote data to {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
# Clean up temporary file if it exists
|
||||
if temp_path:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except:
|
||||
pass
|
||||
logger.error(f"Failed to write data to {file_path}: {e}")
|
||||
raise
|
||||
|
||||
def _read_json_safe(self, file_path: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Read JSON data safely with file locking
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to read
|
||||
|
||||
Returns:
|
||||
Dictionary containing the JSON data
|
||||
"""
|
||||
if not file_path.exists():
|
||||
logger.debug(f"File {file_path} does not exist, returning empty dict")
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(file_path, 'r') as file:
|
||||
# Lock the file for reading
|
||||
self._lock_file(file, exclusive=False)
|
||||
data = json.load(file)
|
||||
self._unlock_file(file)
|
||||
logger.debug(f"Successfully read data from {file_path}")
|
||||
return data
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in {file_path}: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read data from {file_path}: {e}")
|
||||
return {}
|
||||
|
||||
def write_training_status(self, status: TrainingStatus) -> None:
|
||||
"""
|
||||
Write training status to shared file
|
||||
|
||||
Args:
|
||||
status: TrainingStatus object to write
|
||||
"""
|
||||
try:
|
||||
data = status.to_dict()
|
||||
self._write_json_atomic(self.training_status_file, data)
|
||||
logger.debug("Training status written successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write training status: {e}")
|
||||
raise
|
||||
|
||||
def read_training_status(self) -> Optional[TrainingStatus]:
|
||||
"""
|
||||
Read training status from shared file
|
||||
|
||||
Returns:
|
||||
TrainingStatus object or None if not available
|
||||
"""
|
||||
try:
|
||||
data = self._read_json_safe(self.training_status_file)
|
||||
if not data:
|
||||
return None
|
||||
return TrainingStatus.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read training status: {e}")
|
||||
return None
|
||||
|
||||
def write_dashboard_state(self, state: DashboardState) -> None:
|
||||
"""
|
||||
Write dashboard state to shared file
|
||||
|
||||
Args:
|
||||
state: DashboardState object to write
|
||||
"""
|
||||
try:
|
||||
data = state.to_dict()
|
||||
self._write_json_atomic(self.dashboard_state_file, data)
|
||||
logger.debug("Dashboard state written successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write dashboard state: {e}")
|
||||
raise
|
||||
|
||||
def read_dashboard_state(self) -> Optional[DashboardState]:
|
||||
"""
|
||||
Read dashboard state from shared file
|
||||
|
||||
Returns:
|
||||
DashboardState object or None if not available
|
||||
"""
|
||||
try:
|
||||
data = self._read_json_safe(self.dashboard_state_file)
|
||||
if not data:
|
||||
return None
|
||||
return DashboardState.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read dashboard state: {e}")
|
||||
return None
|
||||
|
||||
def write_process_status(self, status: ProcessStatus) -> None:
|
||||
"""
|
||||
Write process status to shared file
|
||||
|
||||
Args:
|
||||
status: ProcessStatus object to write
|
||||
"""
|
||||
try:
|
||||
data = status.to_dict()
|
||||
self._write_json_atomic(self.process_status_file, data)
|
||||
logger.debug("Process status written successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write process status: {e}")
|
||||
raise
|
||||
|
||||
def read_process_status(self) -> Optional[ProcessStatus]:
|
||||
"""
|
||||
Read process status from shared file
|
||||
|
||||
Returns:
|
||||
ProcessStatus object or None if not available
|
||||
"""
|
||||
try:
|
||||
data = self._read_json_safe(self.process_status_file)
|
||||
if not data:
|
||||
return None
|
||||
return ProcessStatus.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read process status: {e}")
|
||||
return None
|
||||
|
||||
def write_market_data(self, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Write market data to shared file
|
||||
|
||||
Args:
|
||||
data: Market data dictionary to write
|
||||
"""
|
||||
try:
|
||||
# Add timestamp to market data
|
||||
data['timestamp'] = datetime.now().isoformat()
|
||||
self._write_json_atomic(self.market_data_file, data)
|
||||
logger.debug("Market data written successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write market data: {e}")
|
||||
raise
|
||||
|
||||
def read_market_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Read market data from shared file
|
||||
|
||||
Returns:
|
||||
Dictionary containing market data
|
||||
"""
|
||||
try:
|
||||
return self._read_json_safe(self.market_data_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read market data: {e}")
|
||||
return {}
|
||||
|
||||
def write_model_metrics(self, metrics: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Write model metrics to shared file
|
||||
|
||||
Args:
|
||||
metrics: Model metrics dictionary to write
|
||||
"""
|
||||
try:
|
||||
# Add timestamp to metrics
|
||||
metrics['timestamp'] = datetime.now().isoformat()
|
||||
self._write_json_atomic(self.model_metrics_file, metrics)
|
||||
logger.debug("Model metrics written successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write model metrics: {e}")
|
||||
raise
|
||||
|
||||
def read_model_metrics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Read model metrics from shared file
|
||||
|
||||
Returns:
|
||||
Dictionary containing model metrics
|
||||
"""
|
||||
try:
|
||||
return self._read_json_safe(self.model_metrics_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read model metrics: {e}")
|
||||
return {}
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""
|
||||
Clean up shared data files
|
||||
"""
|
||||
try:
|
||||
for file_path in [
|
||||
self.training_status_file,
|
||||
self.dashboard_state_file,
|
||||
self.process_status_file,
|
||||
self.market_data_file,
|
||||
self.model_metrics_file
|
||||
]:
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
logger.debug(f"Removed {file_path}")
|
||||
|
||||
# Remove directory if empty
|
||||
if self.data_dir.exists() and not any(self.data_dir.iterdir()):
|
||||
self.data_dir.rmdir()
|
||||
logger.debug(f"Removed empty directory {self.data_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup shared data: {e}")
|
||||
|
||||
def get_data_age(self, data_type: str) -> Optional[float]:
|
||||
"""
|
||||
Get the age of data in seconds
|
||||
|
||||
Args:
|
||||
data_type: Type of data ('training', 'dashboard', 'process', 'market', 'metrics')
|
||||
|
||||
Returns:
|
||||
Age in seconds or None if file doesn't exist
|
||||
"""
|
||||
file_map = {
|
||||
'training': self.training_status_file,
|
||||
'dashboard': self.dashboard_state_file,
|
||||
'process': self.process_status_file,
|
||||
'market': self.market_data_file,
|
||||
'metrics': self.model_metrics_file
|
||||
}
|
||||
|
||||
file_path = file_map.get(data_type)
|
||||
if not file_path or not file_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
mtime = file_path.stat().st_mtime
|
||||
return time.time() - mtime
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get data age for {data_type}: {e}")
|
||||
return None
|
||||
@@ -1,59 +0,0 @@
|
||||
"""
|
||||
Trading Action Module
|
||||
|
||||
Defines the TradingAction class used throughout the trading system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
|
||||
@dataclass
|
||||
class TradingAction:
|
||||
"""Represents a trading action with full context"""
|
||||
symbol: str
|
||||
action: str # 'BUY', 'SELL', 'HOLD'
|
||||
quantity: float
|
||||
confidence: float
|
||||
price: float
|
||||
timestamp: datetime
|
||||
reasoning: Dict[str, Any]
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate the trading action after initialization"""
|
||||
if self.action not in ['BUY', 'SELL', 'HOLD']:
|
||||
raise ValueError(f"Invalid action: {self.action}. Must be 'BUY', 'SELL', or 'HOLD'")
|
||||
|
||||
if self.confidence < 0.0 or self.confidence > 1.0:
|
||||
raise ValueError(f"Invalid confidence: {self.confidence}. Must be between 0.0 and 1.0")
|
||||
|
||||
if self.quantity < 0:
|
||||
raise ValueError(f"Invalid quantity: {self.quantity}. Must be non-negative")
|
||||
|
||||
if self.price <= 0:
|
||||
raise ValueError(f"Invalid price: {self.price}. Must be positive")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert trading action to dictionary"""
|
||||
return {
|
||||
'symbol': self.symbol,
|
||||
'action': self.action,
|
||||
'quantity': self.quantity,
|
||||
'confidence': self.confidence,
|
||||
'price': self.price,
|
||||
'timestamp': self.timestamp.isoformat(),
|
||||
'reasoning': self.reasoning
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'TradingAction':
|
||||
"""Create trading action from dictionary"""
|
||||
return cls(
|
||||
symbol=data['symbol'],
|
||||
action=data['action'],
|
||||
quantity=data['quantity'],
|
||||
confidence=data['confidence'],
|
||||
price=data['price'],
|
||||
timestamp=datetime.fromisoformat(data['timestamp']),
|
||||
reasoning=data['reasoning']
|
||||
)
|
||||
@@ -1,401 +0,0 @@
|
||||
"""
|
||||
Trading Executor Fix - Addresses issues with entry/exit prices and P&L calculations
|
||||
|
||||
This module provides fixes for:
|
||||
1. Identical entry prices issue
|
||||
2. Price caching problems
|
||||
3. Position tracking reset logic
|
||||
4. Trade cooldown implementation
|
||||
5. P&L calculation verification
|
||||
|
||||
Apply these fixes to the TradingExecutor class to improve trade execution reliability.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingExecutorFix:
|
||||
"""
|
||||
Fixes for the TradingExecutor class to address entry/exit price issues
|
||||
and improve P&L calculation accuracy.
|
||||
"""
|
||||
|
||||
def __init__(self, trading_executor):
|
||||
"""
|
||||
Initialize the fix with a reference to the trading executor
|
||||
|
||||
Args:
|
||||
trading_executor: The TradingExecutor instance to fix
|
||||
"""
|
||||
self.trading_executor = trading_executor
|
||||
|
||||
# Add cooldown tracking
|
||||
self.last_trade_time = {} # {symbol: timestamp}
|
||||
self.min_trade_cooldown = 30 # 30 seconds minimum between trades
|
||||
|
||||
# Add price history for validation
|
||||
self.recent_entry_prices = {} # {symbol: [recent_prices]}
|
||||
self.max_price_history = 10 # Keep last 10 entry prices
|
||||
|
||||
# Add position reset tracking
|
||||
self.position_reset_flags = {} # {symbol: bool}
|
||||
|
||||
# Add price update tracking
|
||||
self.last_price_update = {} # {symbol: timestamp}
|
||||
self.price_update_threshold = 5 # 5 seconds max since last price update
|
||||
|
||||
# Add P&L verification
|
||||
self.trade_history = {} # {symbol: [trade_records]}
|
||||
|
||||
logger.info("TradingExecutorFix initialized - addressing entry/exit price issues")
|
||||
|
||||
def apply_fixes(self):
|
||||
"""Apply all fixes to the trading executor"""
|
||||
self._patch_execute_action()
|
||||
self._patch_close_position()
|
||||
self._patch_calculate_pnl()
|
||||
self._patch_update_prices()
|
||||
|
||||
logger.info("All trading executor fixes applied successfully")
|
||||
|
||||
def _patch_execute_action(self):
|
||||
"""Patch the execute_action method to add price validation and cooldown"""
|
||||
original_execute_action = self.trading_executor.execute_action
|
||||
|
||||
def execute_action_with_fixes(decision):
|
||||
"""Enhanced execute_action with price validation and cooldown"""
|
||||
try:
|
||||
symbol = decision.symbol
|
||||
action = decision.action
|
||||
current_time = datetime.now()
|
||||
|
||||
# 1. Check cooldown period
|
||||
if symbol in self.last_trade_time:
|
||||
time_since_last_trade = (current_time - self.last_trade_time[symbol]).total_seconds()
|
||||
if time_since_last_trade < self.min_trade_cooldown:
|
||||
logger.warning(f"Trade rejected: Cooldown period ({time_since_last_trade:.1f}s < {self.min_trade_cooldown}s) for {symbol}")
|
||||
return False
|
||||
|
||||
# 2. Validate price freshness
|
||||
if symbol in self.last_price_update:
|
||||
time_since_update = (current_time - self.last_price_update[symbol]).total_seconds()
|
||||
if time_since_update > self.price_update_threshold:
|
||||
logger.warning(f"Trade rejected: Price data stale ({time_since_update:.1f}s > {self.price_update_threshold}s) for {symbol}")
|
||||
# Force price refresh
|
||||
self._refresh_price(symbol)
|
||||
return False
|
||||
|
||||
# 3. Validate entry price against recent history
|
||||
current_price = self._get_current_price(symbol)
|
||||
if symbol in self.recent_entry_prices and len(self.recent_entry_prices[symbol]) > 0:
|
||||
# Check if price is identical to any recent entry
|
||||
if current_price in self.recent_entry_prices[symbol]:
|
||||
logger.warning(f"Trade rejected: Duplicate entry price ${current_price} for {symbol}")
|
||||
return False
|
||||
|
||||
# 4. Ensure position is properly reset before new entry
|
||||
if not self._ensure_position_reset(symbol):
|
||||
logger.warning(f"Trade rejected: Position not properly reset for {symbol}")
|
||||
return False
|
||||
|
||||
# Execute the original action
|
||||
result = original_execute_action(decision)
|
||||
|
||||
# If successful, update tracking
|
||||
if result:
|
||||
# Update cooldown timestamp
|
||||
self.last_trade_time[symbol] = current_time
|
||||
|
||||
# Update price history
|
||||
if symbol not in self.recent_entry_prices:
|
||||
self.recent_entry_prices[symbol] = []
|
||||
|
||||
self.recent_entry_prices[symbol].append(current_price)
|
||||
# Keep only the most recent prices
|
||||
if len(self.recent_entry_prices[symbol]) > self.max_price_history:
|
||||
self.recent_entry_prices[symbol] = self.recent_entry_prices[symbol][-self.max_price_history:]
|
||||
|
||||
# Mark position as active
|
||||
self.position_reset_flags[symbol] = False
|
||||
|
||||
logger.info(f"Trade executed: {action} {symbol} at ${current_price} with validation")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in execute_action_with_fixes: {e}")
|
||||
return original_execute_action(decision)
|
||||
|
||||
# Replace the original method
|
||||
self.trading_executor.execute_action = execute_action_with_fixes
|
||||
logger.info("Patched execute_action with price validation and cooldown")
|
||||
|
||||
def _patch_close_position(self):
|
||||
"""Patch the close_position method to ensure proper position reset"""
|
||||
original_close_position = self.trading_executor.close_position
|
||||
|
||||
def close_position_with_fixes(symbol, **kwargs):
|
||||
"""Enhanced close_position with proper reset logic"""
|
||||
try:
|
||||
# Get current price for P&L verification
|
||||
exit_price = self._get_current_price(symbol)
|
||||
|
||||
# Call original close position
|
||||
result = original_close_position(symbol, **kwargs)
|
||||
|
||||
if result:
|
||||
# Mark position as reset
|
||||
self.position_reset_flags[symbol] = True
|
||||
|
||||
# Record trade for verification
|
||||
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
|
||||
position = self.trading_executor.positions[symbol]
|
||||
|
||||
# Create trade record
|
||||
trade_record = {
|
||||
'symbol': symbol,
|
||||
'entry_time': getattr(position, 'entry_time', datetime.now()),
|
||||
'exit_time': datetime.now(),
|
||||
'entry_price': getattr(position, 'entry_price', 0),
|
||||
'exit_price': exit_price,
|
||||
'size': getattr(position, 'size', 0),
|
||||
'side': getattr(position, 'side', 'UNKNOWN'),
|
||||
'pnl': self._calculate_verified_pnl(position, exit_price),
|
||||
'fees': getattr(position, 'fees', 0),
|
||||
'hold_time_seconds': (datetime.now() - getattr(position, 'entry_time', datetime.now())).total_seconds()
|
||||
}
|
||||
|
||||
# Store trade record
|
||||
if symbol not in self.trade_history:
|
||||
self.trade_history[symbol] = []
|
||||
self.trade_history[symbol].append(trade_record)
|
||||
|
||||
logger.info(f"Position closed: {symbol} at ${exit_price} with verified P&L: ${trade_record['pnl']:.2f}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in close_position_with_fixes: {e}")
|
||||
return original_close_position(symbol, **kwargs)
|
||||
|
||||
# Replace the original method
|
||||
self.trading_executor.close_position = close_position_with_fixes
|
||||
logger.info("Patched close_position with proper reset logic")
|
||||
|
||||
def _patch_calculate_pnl(self):
|
||||
"""Patch the calculate_pnl method to ensure accurate P&L calculation"""
|
||||
original_calculate_pnl = getattr(self.trading_executor, 'calculate_pnl', None)
|
||||
|
||||
def calculate_pnl_with_fixes(position, current_price=None):
|
||||
"""Enhanced calculate_pnl with verification"""
|
||||
try:
|
||||
# If no original method, implement our own
|
||||
if original_calculate_pnl is None:
|
||||
return self._calculate_verified_pnl(position, current_price)
|
||||
|
||||
# Call original method
|
||||
original_pnl = original_calculate_pnl(position, current_price)
|
||||
|
||||
# Calculate our verified P&L
|
||||
verified_pnl = self._calculate_verified_pnl(position, current_price)
|
||||
|
||||
# If there's a significant difference, log it
|
||||
if abs(original_pnl - verified_pnl) > 0.01:
|
||||
logger.warning(f"P&L calculation discrepancy: original=${original_pnl:.2f}, verified=${verified_pnl:.2f}")
|
||||
# Use the verified P&L
|
||||
return verified_pnl
|
||||
|
||||
return original_pnl
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in calculate_pnl_with_fixes: {e}")
|
||||
if original_calculate_pnl:
|
||||
return original_calculate_pnl(position, current_price)
|
||||
return 0.0
|
||||
|
||||
# Replace the original method if it exists
|
||||
if original_calculate_pnl:
|
||||
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
|
||||
logger.info("Patched calculate_pnl with verification")
|
||||
else:
|
||||
# Add the method if it doesn't exist
|
||||
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
|
||||
logger.info("Added calculate_pnl method with verification")
|
||||
|
||||
def _patch_update_prices(self):
|
||||
"""Patch the update_prices method to track price updates"""
|
||||
original_update_prices = getattr(self.trading_executor, 'update_prices', None)
|
||||
|
||||
def update_prices_with_tracking(prices):
|
||||
"""Enhanced update_prices with timestamp tracking"""
|
||||
try:
|
||||
# Call original method if it exists
|
||||
if original_update_prices:
|
||||
result = original_update_prices(prices)
|
||||
else:
|
||||
# If no original method, update prices directly
|
||||
if hasattr(self.trading_executor, 'current_prices'):
|
||||
self.trading_executor.current_prices.update(prices)
|
||||
result = True
|
||||
|
||||
# Track update timestamps
|
||||
current_time = datetime.now()
|
||||
for symbol in prices:
|
||||
self.last_price_update[symbol] = current_time
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in update_prices_with_tracking: {e}")
|
||||
if original_update_prices:
|
||||
return original_update_prices(prices)
|
||||
return False
|
||||
|
||||
# Replace the original method if it exists
|
||||
if original_update_prices:
|
||||
self.trading_executor.update_prices = update_prices_with_tracking
|
||||
logger.info("Patched update_prices with timestamp tracking")
|
||||
else:
|
||||
# Add the method if it doesn't exist
|
||||
self.trading_executor.update_prices = update_prices_with_tracking
|
||||
logger.info("Added update_prices method with timestamp tracking")
|
||||
|
||||
def _calculate_verified_pnl(self, position, current_price=None):
|
||||
"""Calculate verified P&L for a position"""
|
||||
try:
|
||||
# Get position details
|
||||
entry_price = getattr(position, 'entry_price', 0)
|
||||
size = getattr(position, 'size', 0)
|
||||
side = getattr(position, 'side', 'UNKNOWN')
|
||||
leverage = getattr(position, 'leverage', 1.0)
|
||||
fees = getattr(position, 'fees', 0.0)
|
||||
|
||||
# If current_price is not provided, try to get it
|
||||
if current_price is None:
|
||||
symbol = getattr(position, 'symbol', None)
|
||||
if symbol:
|
||||
current_price = self._get_current_price(symbol)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
# Calculate P&L based on position side
|
||||
if side == 'LONG':
|
||||
pnl = (current_price - entry_price) * size * leverage
|
||||
elif side == 'SHORT':
|
||||
pnl = (entry_price - current_price) * size * leverage
|
||||
else:
|
||||
pnl = 0.0
|
||||
|
||||
# Subtract fees for net P&L
|
||||
net_pnl = pnl - fees
|
||||
|
||||
return net_pnl
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating verified P&L: {e}")
|
||||
return 0.0
|
||||
|
||||
def _get_current_price(self, symbol):
|
||||
"""Get current price for a symbol with fallbacks"""
|
||||
try:
|
||||
# Try to get from trading executor
|
||||
if hasattr(self.trading_executor, 'current_prices') and symbol in self.trading_executor.current_prices:
|
||||
return self.trading_executor.current_prices[symbol]
|
||||
|
||||
# Try to get from data provider
|
||||
if hasattr(self.trading_executor, 'data_provider'):
|
||||
data_provider = self.trading_executor.data_provider
|
||||
if hasattr(data_provider, 'get_current_price'):
|
||||
price = data_provider.get_current_price(symbol)
|
||||
if price and price > 0:
|
||||
return price
|
||||
|
||||
# Try to get from COB data
|
||||
if hasattr(self.trading_executor, 'latest_cob_data') and symbol in self.trading_executor.latest_cob_data:
|
||||
cob_data = self.trading_executor.latest_cob_data[symbol]
|
||||
if hasattr(cob_data, 'stats') and 'mid_price' in cob_data.stats:
|
||||
return cob_data.stats['mid_price']
|
||||
|
||||
# Default fallback
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||
return 0.0
|
||||
|
||||
def _refresh_price(self, symbol):
|
||||
"""Force a price refresh for a symbol"""
|
||||
try:
|
||||
# Try to refresh from data provider
|
||||
if hasattr(self.trading_executor, 'data_provider'):
|
||||
data_provider = self.trading_executor.data_provider
|
||||
if hasattr(data_provider, 'fetch_current_price'):
|
||||
price = data_provider.fetch_current_price(symbol)
|
||||
if price and price > 0:
|
||||
# Update trading executor price
|
||||
if hasattr(self.trading_executor, 'current_prices'):
|
||||
self.trading_executor.current_prices[symbol] = price
|
||||
|
||||
# Update timestamp
|
||||
self.last_price_update[symbol] = datetime.now()
|
||||
|
||||
logger.info(f"Refreshed price for {symbol}: ${price:.2f}")
|
||||
return True
|
||||
|
||||
logger.warning(f"Failed to refresh price for {symbol}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing price for {symbol}: {e}")
|
||||
return False
|
||||
|
||||
def _ensure_position_reset(self, symbol):
|
||||
"""Ensure position is properly reset before new entry"""
|
||||
try:
|
||||
# Check if we have an active position
|
||||
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
|
||||
# Position exists, check if it's valid
|
||||
position = self.trading_executor.positions[symbol]
|
||||
if position and getattr(position, 'active', False):
|
||||
logger.warning(f"Position already active for {symbol}, cannot enter new position")
|
||||
return False
|
||||
|
||||
# Check reset flag
|
||||
if symbol in self.position_reset_flags and not self.position_reset_flags[symbol]:
|
||||
# Force position cleanup
|
||||
if hasattr(self.trading_executor, 'positions'):
|
||||
self.trading_executor.positions.pop(symbol, None)
|
||||
|
||||
logger.info(f"Forced position reset for {symbol}")
|
||||
self.position_reset_flags[symbol] = True
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring position reset for {symbol}: {e}")
|
||||
return False
|
||||
|
||||
def get_trade_history(self, symbol=None):
|
||||
"""Get verified trade history"""
|
||||
if symbol:
|
||||
return self.trade_history.get(symbol, [])
|
||||
return self.trade_history
|
||||
|
||||
def get_price_update_status(self):
|
||||
"""Get price update status for all symbols"""
|
||||
status = {}
|
||||
current_time = datetime.now()
|
||||
|
||||
for symbol, timestamp in self.last_price_update.items():
|
||||
time_since_update = (current_time - timestamp).total_seconds()
|
||||
status[symbol] = {
|
||||
'last_update': timestamp,
|
||||
'seconds_ago': time_since_update,
|
||||
'is_fresh': time_since_update <= self.price_update_threshold
|
||||
}
|
||||
|
||||
return status
|
||||
@@ -1,795 +0,0 @@
|
||||
"""
|
||||
Comprehensive Training Data Collection System
|
||||
|
||||
This module implements a robust training data collection system that:
|
||||
1. Captures all model inputs with validation and completeness checks
|
||||
2. Stores training data packages with future outcome validation
|
||||
3. Detects rapid price changes for high-value training examples
|
||||
4. Enables replay and retraining on most profitable setups
|
||||
5. Maintains data integrity and traceability
|
||||
|
||||
Key Features:
|
||||
- Real-time data package creation with all model inputs
|
||||
- Future outcome validation (profitable vs unprofitable predictions)
|
||||
- Rapid price change detection for premium training examples
|
||||
- Comprehensive data validation and completeness verification
|
||||
- Backpropagation data storage for gradient replay
|
||||
- Training episode profitability tracking and ranking
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pickle
|
||||
import torch
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from collections import deque
|
||||
import hashlib
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ModelInputPackage:
|
||||
"""Complete package of all model inputs at a specific timestamp"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
|
||||
# Market data inputs
|
||||
ohlcv_data: Dict[str, pd.DataFrame] # {timeframe: DataFrame}
|
||||
tick_data: List[Dict[str, Any]] # Raw tick data
|
||||
cob_data: Dict[str, Any] # Consolidated Order Book data
|
||||
technical_indicators: Dict[str, float] # All technical indicators
|
||||
pivot_points: List[Dict[str, Any]] # Detected pivot points
|
||||
|
||||
# Model-specific inputs
|
||||
cnn_features: np.ndarray # CNN input features
|
||||
rl_state: np.ndarray # RL state representation
|
||||
orchestrator_context: Dict[str, Any] # Orchestrator context
|
||||
|
||||
# Cross-model inputs (outputs from other models)
|
||||
cnn_predictions: Optional[Dict[str, Any]] = None
|
||||
rl_predictions: Optional[Dict[str, Any]] = None
|
||||
orchestrator_decision: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Data validation
|
||||
data_hash: str = ""
|
||||
completeness_score: float = 0.0
|
||||
validation_flags: Dict[str, bool] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Calculate data hash and completeness after initialization"""
|
||||
self.data_hash = self._calculate_hash()
|
||||
self.completeness_score = self._calculate_completeness()
|
||||
self.validation_flags = self._validate_data()
|
||||
|
||||
def _calculate_hash(self) -> str:
|
||||
"""Calculate hash for data integrity verification"""
|
||||
try:
|
||||
# Create a string representation of all data
|
||||
data_str = f"{self.timestamp}_{self.symbol}"
|
||||
data_str += f"_{len(self.ohlcv_data)}_{len(self.tick_data)}"
|
||||
data_str += f"_{self.cnn_features.shape if self.cnn_features is not None else 'None'}"
|
||||
data_str += f"_{self.rl_state.shape if self.rl_state is not None else 'None'}"
|
||||
|
||||
return hashlib.md5(data_str.encode()).hexdigest()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating data hash: {e}")
|
||||
return "invalid_hash"
|
||||
|
||||
def _calculate_completeness(self) -> float:
|
||||
"""Calculate completeness score (0.0 to 1.0)"""
|
||||
try:
|
||||
total_fields = 10 # Total expected data fields
|
||||
complete_fields = 0
|
||||
|
||||
# Check each required field
|
||||
if self.ohlcv_data and len(self.ohlcv_data) > 0:
|
||||
complete_fields += 1
|
||||
if self.tick_data and len(self.tick_data) > 0:
|
||||
complete_fields += 1
|
||||
if self.cob_data and len(self.cob_data) > 0:
|
||||
complete_fields += 1
|
||||
if self.technical_indicators and len(self.technical_indicators) > 0:
|
||||
complete_fields += 1
|
||||
if self.pivot_points and len(self.pivot_points) > 0:
|
||||
complete_fields += 1
|
||||
if self.cnn_features is not None and self.cnn_features.size > 0:
|
||||
complete_fields += 1
|
||||
if self.rl_state is not None and self.rl_state.size > 0:
|
||||
complete_fields += 1
|
||||
if self.orchestrator_context and len(self.orchestrator_context) > 0:
|
||||
complete_fields += 1
|
||||
if self.cnn_predictions is not None:
|
||||
complete_fields += 1
|
||||
if self.rl_predictions is not None:
|
||||
complete_fields += 1
|
||||
|
||||
return complete_fields / total_fields
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating completeness: {e}")
|
||||
return 0.0
|
||||
|
||||
def _validate_data(self) -> Dict[str, bool]:
|
||||
"""Validate data integrity and consistency"""
|
||||
flags = {}
|
||||
|
||||
try:
|
||||
# Validate timestamp
|
||||
flags['valid_timestamp'] = isinstance(self.timestamp, datetime)
|
||||
|
||||
# Validate OHLCV data
|
||||
flags['valid_ohlcv'] = (
|
||||
self.ohlcv_data is not None and
|
||||
len(self.ohlcv_data) > 0 and
|
||||
all(isinstance(df, pd.DataFrame) for df in self.ohlcv_data.values())
|
||||
)
|
||||
|
||||
# Validate feature arrays
|
||||
flags['valid_cnn_features'] = (
|
||||
self.cnn_features is not None and
|
||||
isinstance(self.cnn_features, np.ndarray) and
|
||||
self.cnn_features.size > 0
|
||||
)
|
||||
|
||||
flags['valid_rl_state'] = (
|
||||
self.rl_state is not None and
|
||||
isinstance(self.rl_state, np.ndarray) and
|
||||
self.rl_state.size > 0
|
||||
)
|
||||
|
||||
# Validate data consistency
|
||||
flags['data_consistent'] = self.completeness_score > 0.7
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error validating data: {e}")
|
||||
flags['validation_error'] = True
|
||||
|
||||
return flags
|
||||
|
||||
@dataclass
|
||||
class TrainingOutcome:
|
||||
"""Future outcome validation for training data"""
|
||||
input_package_hash: str
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
|
||||
# Price movement outcomes
|
||||
price_change_1m: float
|
||||
price_change_5m: float
|
||||
price_change_15m: float
|
||||
price_change_1h: float
|
||||
|
||||
# Profitability metrics
|
||||
max_profit_potential: float
|
||||
max_loss_potential: float
|
||||
optimal_entry_price: float
|
||||
optimal_exit_price: float
|
||||
optimal_holding_time: timedelta
|
||||
|
||||
# Classification labels
|
||||
is_profitable: bool
|
||||
profitability_score: float # 0.0 to 1.0
|
||||
risk_reward_ratio: float
|
||||
|
||||
# Rapid price change detection
|
||||
is_rapid_change: bool
|
||||
change_velocity: float # Price change per minute
|
||||
volatility_spike: bool
|
||||
|
||||
# Validation
|
||||
outcome_validated: bool = False
|
||||
validation_timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
@dataclass
|
||||
class TrainingEpisode:
|
||||
"""Complete training episode with inputs, predictions, and outcomes"""
|
||||
episode_id: str
|
||||
input_package: ModelInputPackage
|
||||
model_predictions: Dict[str, Any] # Predictions from all models
|
||||
actual_outcome: TrainingOutcome
|
||||
|
||||
# Training metadata
|
||||
episode_type: str # 'normal', 'rapid_change', 'high_profit'
|
||||
profitability_rank: float # Ranking among all episodes
|
||||
training_priority: float # Priority for replay training
|
||||
|
||||
# Backpropagation data storage
|
||||
gradient_data: Optional[Dict[str, torch.Tensor]] = None
|
||||
loss_components: Optional[Dict[str, float]] = None
|
||||
model_states: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Episode statistics
|
||||
created_timestamp: datetime = field(default_factory=datetime.now)
|
||||
last_trained_timestamp: Optional[datetime] = None
|
||||
training_count: int = 0
|
||||
|
||||
def calculate_training_priority(self) -> float:
|
||||
"""Calculate training priority based on profitability and characteristics"""
|
||||
try:
|
||||
priority = 0.0
|
||||
|
||||
# Base priority from profitability
|
||||
if self.actual_outcome.is_profitable:
|
||||
priority += self.actual_outcome.profitability_score * 0.4
|
||||
|
||||
# Bonus for rapid changes (high learning value)
|
||||
if self.actual_outcome.is_rapid_change:
|
||||
priority += 0.3
|
||||
|
||||
# Bonus for high risk-reward ratio
|
||||
if self.actual_outcome.risk_reward_ratio > 2.0:
|
||||
priority += 0.2
|
||||
|
||||
# Bonus for data completeness
|
||||
priority += self.input_package.completeness_score * 0.1
|
||||
|
||||
# Penalty for frequent training (avoid overfitting)
|
||||
if self.training_count > 5:
|
||||
priority *= 0.8
|
||||
|
||||
return min(priority, 1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating training priority: {e}")
|
||||
return 0.0
|
||||
|
||||
class RapidChangeDetector:
|
||||
"""Detects rapid price changes for high-value training examples"""
|
||||
|
||||
def __init__(self,
|
||||
velocity_threshold: float = 0.5, # % per minute
|
||||
volatility_multiplier: float = 3.0,
|
||||
lookback_minutes: int = 5):
|
||||
self.velocity_threshold = velocity_threshold
|
||||
self.volatility_multiplier = volatility_multiplier
|
||||
self.lookback_minutes = lookback_minutes
|
||||
|
||||
# Price history for change detection
|
||||
self.price_history: Dict[str, deque] = {}
|
||||
self.volatility_baseline: Dict[str, float] = {}
|
||||
|
||||
def add_price_point(self, symbol: str, timestamp: datetime, price: float):
|
||||
"""Add new price point for change detection"""
|
||||
if symbol not in self.price_history:
|
||||
self.price_history[symbol] = deque(maxlen=self.lookback_minutes * 60) # 1 second resolution
|
||||
self.volatility_baseline[symbol] = 0.0
|
||||
|
||||
self.price_history[symbol].append((timestamp, price))
|
||||
self._update_volatility_baseline(symbol)
|
||||
|
||||
def detect_rapid_change(self, symbol: str) -> Tuple[bool, float, bool]:
|
||||
"""
|
||||
Detect rapid price changes
|
||||
|
||||
Returns:
|
||||
(is_rapid_change, change_velocity, volatility_spike)
|
||||
"""
|
||||
if symbol not in self.price_history or len(self.price_history[symbol]) < 60:
|
||||
return False, 0.0, False
|
||||
|
||||
try:
|
||||
prices = list(self.price_history[symbol])
|
||||
|
||||
# Calculate recent velocity (last minute)
|
||||
recent_prices = prices[-60:] # Last 60 seconds
|
||||
if len(recent_prices) < 2:
|
||||
return False, 0.0, False
|
||||
|
||||
start_price = recent_prices[0][1]
|
||||
end_price = recent_prices[-1][1]
|
||||
time_diff = (recent_prices[-1][0] - recent_prices[0][0]).total_seconds() / 60.0 # minutes
|
||||
|
||||
if time_diff <= 0:
|
||||
return False, 0.0, False
|
||||
|
||||
# Calculate velocity (% change per minute)
|
||||
velocity = abs((end_price - start_price) / start_price * 100) / time_diff
|
||||
|
||||
# Check for rapid change
|
||||
is_rapid = velocity > self.velocity_threshold
|
||||
|
||||
# Check for volatility spike
|
||||
current_volatility = self._calculate_current_volatility(symbol)
|
||||
baseline_volatility = self.volatility_baseline.get(symbol, 0.0)
|
||||
volatility_spike = (
|
||||
baseline_volatility > 0 and
|
||||
current_volatility > baseline_volatility * self.volatility_multiplier
|
||||
)
|
||||
|
||||
return is_rapid, velocity, volatility_spike
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error detecting rapid change for {symbol}: {e}")
|
||||
return False, 0.0, False
|
||||
|
||||
def _update_volatility_baseline(self, symbol: str):
|
||||
"""Update volatility baseline for the symbol"""
|
||||
try:
|
||||
if len(self.price_history[symbol]) < 120: # Need at least 2 minutes of data
|
||||
return
|
||||
|
||||
# Calculate rolling volatility over longer period
|
||||
prices = [p[1] for p in list(self.price_history[symbol])[-300:]] # Last 5 minutes
|
||||
if len(prices) < 2:
|
||||
return
|
||||
|
||||
# Calculate standard deviation of price changes
|
||||
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
|
||||
volatility = np.std(price_changes) * 100 # Convert to percentage
|
||||
|
||||
# Update baseline with exponential moving average
|
||||
alpha = 0.1
|
||||
if self.volatility_baseline[symbol] == 0:
|
||||
self.volatility_baseline[symbol] = volatility
|
||||
else:
|
||||
self.volatility_baseline[symbol] = (
|
||||
alpha * volatility + (1 - alpha) * self.volatility_baseline[symbol]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error updating volatility baseline for {symbol}: {e}")
|
||||
|
||||
def _calculate_current_volatility(self, symbol: str) -> float:
|
||||
"""Calculate current volatility for the symbol"""
|
||||
try:
|
||||
if len(self.price_history[symbol]) < 60:
|
||||
return 0.0
|
||||
|
||||
# Use last minute of data
|
||||
recent_prices = [p[1] for p in list(self.price_history[symbol])[-60:]]
|
||||
if len(recent_prices) < 2:
|
||||
return 0.0
|
||||
|
||||
price_changes = [abs(recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
|
||||
for i in range(1, len(recent_prices))]
|
||||
return np.std(price_changes) * 100
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating current volatility for {symbol}: {e}")
|
||||
return 0.0
|
||||
|
||||
class TrainingDataCollector:
|
||||
"""Main training data collection system"""
|
||||
|
||||
def __init__(self,
|
||||
storage_dir: str = "training_data",
|
||||
max_episodes_per_symbol: int = 10000,
|
||||
outcome_validation_delay: timedelta = timedelta(hours=1)):
|
||||
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.max_episodes_per_symbol = max_episodes_per_symbol
|
||||
self.outcome_validation_delay = outcome_validation_delay
|
||||
|
||||
# Data storage
|
||||
self.training_episodes: Dict[str, List[TrainingEpisode]] = {} # {symbol: episodes}
|
||||
self.pending_outcomes: Dict[str, List[ModelInputPackage]] = {} # Awaiting outcome validation
|
||||
|
||||
# Rapid change detection
|
||||
self.rapid_change_detector = RapidChangeDetector()
|
||||
|
||||
# Data validation and statistics
|
||||
self.collection_stats = {
|
||||
'total_episodes': 0,
|
||||
'profitable_episodes': 0,
|
||||
'rapid_change_episodes': 0,
|
||||
'validation_errors': 0,
|
||||
'data_completeness_avg': 0.0
|
||||
}
|
||||
|
||||
# Background processing
|
||||
self.is_collecting = False
|
||||
self.collection_thread = None
|
||||
self.outcome_validation_thread = None
|
||||
|
||||
# Thread safety
|
||||
self.data_lock = threading.Lock()
|
||||
|
||||
logger.info(f"Training Data Collector initialized")
|
||||
logger.info(f"Storage directory: {self.storage_dir}")
|
||||
logger.info(f"Max episodes per symbol: {self.max_episodes_per_symbol}")
|
||||
|
||||
def start_collection(self):
|
||||
"""Start the training data collection system"""
|
||||
if self.is_collecting:
|
||||
logger.warning("Training data collection already running")
|
||||
return
|
||||
|
||||
self.is_collecting = True
|
||||
|
||||
# Start outcome validation thread
|
||||
self.outcome_validation_thread = threading.Thread(
|
||||
target=self._outcome_validation_worker,
|
||||
daemon=True
|
||||
)
|
||||
self.outcome_validation_thread.start()
|
||||
|
||||
logger.info("Training data collection started")
|
||||
|
||||
def stop_collection(self):
|
||||
"""Stop the training data collection system"""
|
||||
self.is_collecting = False
|
||||
|
||||
if self.outcome_validation_thread:
|
||||
self.outcome_validation_thread.join(timeout=5)
|
||||
|
||||
logger.info("Training data collection stopped")
|
||||
|
||||
def collect_training_data(self,
|
||||
symbol: str,
|
||||
ohlcv_data: Dict[str, pd.DataFrame],
|
||||
tick_data: List[Dict[str, Any]],
|
||||
cob_data: Dict[str, Any],
|
||||
technical_indicators: Dict[str, float],
|
||||
pivot_points: List[Dict[str, Any]],
|
||||
cnn_features: np.ndarray,
|
||||
rl_state: np.ndarray,
|
||||
orchestrator_context: Dict[str, Any],
|
||||
model_predictions: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
Collect comprehensive training data package
|
||||
|
||||
Returns:
|
||||
episode_id for tracking
|
||||
"""
|
||||
try:
|
||||
# Create input package
|
||||
input_package = ModelInputPackage(
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
ohlcv_data=ohlcv_data,
|
||||
tick_data=tick_data,
|
||||
cob_data=cob_data,
|
||||
technical_indicators=technical_indicators,
|
||||
pivot_points=pivot_points,
|
||||
cnn_features=cnn_features,
|
||||
rl_state=rl_state,
|
||||
orchestrator_context=orchestrator_context
|
||||
)
|
||||
|
||||
# Validate data completeness
|
||||
if input_package.completeness_score < 0.5:
|
||||
logger.warning(f"Low data completeness for {symbol}: {input_package.completeness_score:.2f}")
|
||||
self.collection_stats['validation_errors'] += 1
|
||||
return None
|
||||
|
||||
# Check for rapid price changes
|
||||
current_price = self._extract_current_price(ohlcv_data)
|
||||
if current_price:
|
||||
self.rapid_change_detector.add_price_point(symbol, input_package.timestamp, current_price)
|
||||
|
||||
# Add to pending outcomes for future validation
|
||||
with self.data_lock:
|
||||
if symbol not in self.pending_outcomes:
|
||||
self.pending_outcomes[symbol] = []
|
||||
|
||||
self.pending_outcomes[symbol].append(input_package)
|
||||
|
||||
# Limit pending outcomes to prevent memory issues
|
||||
if len(self.pending_outcomes[symbol]) > 1000:
|
||||
self.pending_outcomes[symbol] = self.pending_outcomes[symbol][-500:]
|
||||
|
||||
# Generate episode ID
|
||||
episode_id = f"{symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
|
||||
|
||||
# Update statistics
|
||||
self.collection_stats['total_episodes'] += 1
|
||||
self.collection_stats['data_completeness_avg'] = (
|
||||
(self.collection_stats['data_completeness_avg'] * (self.collection_stats['total_episodes'] - 1) +
|
||||
input_package.completeness_score) / self.collection_stats['total_episodes']
|
||||
)
|
||||
|
||||
logger.debug(f"Collected training data for {symbol}: {episode_id}")
|
||||
logger.debug(f"Data completeness: {input_package.completeness_score:.2f}")
|
||||
|
||||
return episode_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting training data for {symbol}: {e}")
|
||||
self.collection_stats['validation_errors'] += 1
|
||||
return None
|
||||
|
||||
def _extract_current_price(self, ohlcv_data: Dict[str, pd.DataFrame]) -> Optional[float]:
|
||||
"""Extract current price from OHLCV data"""
|
||||
try:
|
||||
# Try to get price from shortest timeframe first
|
||||
for timeframe in ['1s', '1m', '5m', '15m', '1h']:
|
||||
if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty:
|
||||
return float(ohlcv_data[timeframe]['close'].iloc[-1])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting current price: {e}")
|
||||
return None
|
||||
|
||||
def _outcome_validation_worker(self):
|
||||
"""Background worker for validating training outcomes"""
|
||||
logger.info("Outcome validation worker started")
|
||||
|
||||
while self.is_collecting:
|
||||
try:
|
||||
self._validate_pending_outcomes()
|
||||
threading.Event().wait(60) # Check every minute
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in outcome validation worker: {e}")
|
||||
threading.Event().wait(30) # Wait before retrying
|
||||
|
||||
logger.info("Outcome validation worker stopped")
|
||||
|
||||
def _validate_pending_outcomes(self):
|
||||
"""Validate outcomes for pending training data"""
|
||||
current_time = datetime.now()
|
||||
|
||||
with self.data_lock:
|
||||
for symbol in list(self.pending_outcomes.keys()):
|
||||
if symbol not in self.pending_outcomes:
|
||||
continue
|
||||
|
||||
validated_packages = []
|
||||
remaining_packages = []
|
||||
|
||||
for package in self.pending_outcomes[symbol]:
|
||||
# Check if enough time has passed for outcome validation
|
||||
if current_time - package.timestamp >= self.outcome_validation_delay:
|
||||
outcome = self._calculate_training_outcome(package)
|
||||
if outcome:
|
||||
self._create_training_episode(package, outcome)
|
||||
validated_packages.append(package)
|
||||
else:
|
||||
remaining_packages.append(package)
|
||||
else:
|
||||
remaining_packages.append(package)
|
||||
|
||||
# Update pending outcomes
|
||||
self.pending_outcomes[symbol] = remaining_packages
|
||||
|
||||
if validated_packages:
|
||||
logger.info(f"Validated {len(validated_packages)} outcomes for {symbol}")
|
||||
|
||||
def _calculate_training_outcome(self, input_package: ModelInputPackage) -> Optional[TrainingOutcome]:
|
||||
"""Calculate training outcome based on future price movements"""
|
||||
try:
|
||||
# This would typically fetch recent price data to calculate outcomes
|
||||
# For now, we'll create a placeholder implementation
|
||||
|
||||
# Extract base price from input package
|
||||
base_price = self._extract_current_price(input_package.ohlcv_data)
|
||||
if not base_price:
|
||||
return None
|
||||
|
||||
# Simulate outcome calculation (in real implementation, fetch actual future prices)
|
||||
# This is where you would integrate with your data provider to get actual outcomes
|
||||
|
||||
# Check for rapid change
|
||||
is_rapid, velocity, volatility_spike = self.rapid_change_detector.detect_rapid_change(
|
||||
input_package.symbol
|
||||
)
|
||||
|
||||
# Create outcome (placeholder values - replace with actual calculation)
|
||||
outcome = TrainingOutcome(
|
||||
input_package_hash=input_package.data_hash,
|
||||
timestamp=input_package.timestamp,
|
||||
symbol=input_package.symbol,
|
||||
price_change_1m=0.0, # Calculate from actual future data
|
||||
price_change_5m=0.0,
|
||||
price_change_15m=0.0,
|
||||
price_change_1h=0.0,
|
||||
max_profit_potential=0.0,
|
||||
max_loss_potential=0.0,
|
||||
optimal_entry_price=base_price,
|
||||
optimal_exit_price=base_price,
|
||||
optimal_holding_time=timedelta(minutes=5),
|
||||
is_profitable=False, # Determine from actual outcomes
|
||||
profitability_score=0.0,
|
||||
risk_reward_ratio=1.0,
|
||||
is_rapid_change=is_rapid,
|
||||
change_velocity=velocity,
|
||||
volatility_spike=volatility_spike,
|
||||
outcome_validated=True
|
||||
)
|
||||
|
||||
return outcome
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating training outcome: {e}")
|
||||
return None
|
||||
|
||||
def _create_training_episode(self, input_package: ModelInputPackage, outcome: TrainingOutcome):
|
||||
"""Create complete training episode"""
|
||||
try:
|
||||
episode_id = f"{input_package.symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
|
||||
|
||||
# Determine episode type
|
||||
episode_type = 'normal'
|
||||
if outcome.is_rapid_change:
|
||||
episode_type = 'rapid_change'
|
||||
self.collection_stats['rapid_change_episodes'] += 1
|
||||
elif outcome.profitability_score > 0.8:
|
||||
episode_type = 'high_profit'
|
||||
|
||||
if outcome.is_profitable:
|
||||
self.collection_stats['profitable_episodes'] += 1
|
||||
|
||||
# Create training episode
|
||||
episode = TrainingEpisode(
|
||||
episode_id=episode_id,
|
||||
input_package=input_package,
|
||||
model_predictions={}, # Will be filled when models make predictions
|
||||
actual_outcome=outcome,
|
||||
episode_type=episode_type,
|
||||
profitability_rank=0.0, # Will be calculated later
|
||||
training_priority=0.0
|
||||
)
|
||||
|
||||
# Calculate training priority
|
||||
episode.training_priority = episode.calculate_training_priority()
|
||||
|
||||
# Store episode
|
||||
symbol = input_package.symbol
|
||||
if symbol not in self.training_episodes:
|
||||
self.training_episodes[symbol] = []
|
||||
|
||||
self.training_episodes[symbol].append(episode)
|
||||
|
||||
# Limit episodes per symbol
|
||||
if len(self.training_episodes[symbol]) > self.max_episodes_per_symbol:
|
||||
# Keep highest priority episodes
|
||||
self.training_episodes[symbol].sort(key=lambda x: x.training_priority, reverse=True)
|
||||
self.training_episodes[symbol] = self.training_episodes[symbol][:self.max_episodes_per_symbol]
|
||||
|
||||
# Save episode to disk
|
||||
self._save_episode_to_disk(episode)
|
||||
|
||||
logger.debug(f"Created training episode: {episode_id}")
|
||||
logger.debug(f"Episode type: {episode_type}, Priority: {episode.training_priority:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating training episode: {e}")
|
||||
|
||||
def _save_episode_to_disk(self, episode: TrainingEpisode):
|
||||
"""Save training episode to disk for persistence"""
|
||||
try:
|
||||
symbol_dir = self.storage_dir / episode.input_package.symbol
|
||||
symbol_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save episode data
|
||||
episode_file = symbol_dir / f"{episode.episode_id}.pkl"
|
||||
with open(episode_file, 'wb') as f:
|
||||
pickle.dump(episode, f)
|
||||
|
||||
# Save episode metadata for quick access
|
||||
metadata = {
|
||||
'episode_id': episode.episode_id,
|
||||
'timestamp': episode.input_package.timestamp.isoformat(),
|
||||
'episode_type': episode.episode_type,
|
||||
'training_priority': episode.training_priority,
|
||||
'profitability_score': episode.actual_outcome.profitability_score,
|
||||
'is_profitable': episode.actual_outcome.is_profitable,
|
||||
'is_rapid_change': episode.actual_outcome.is_rapid_change,
|
||||
'data_completeness': episode.input_package.completeness_score
|
||||
}
|
||||
|
||||
metadata_file = symbol_dir / f"{episode.episode_id}_metadata.json"
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving episode to disk: {e}")
|
||||
|
||||
def get_high_priority_episodes(self,
|
||||
symbol: str,
|
||||
limit: int = 100,
|
||||
min_priority: float = 0.5) -> List[TrainingEpisode]:
|
||||
"""Get high-priority training episodes for replay training"""
|
||||
try:
|
||||
if symbol not in self.training_episodes:
|
||||
return []
|
||||
|
||||
# Filter and sort by priority
|
||||
high_priority = [
|
||||
ep for ep in self.training_episodes[symbol]
|
||||
if ep.training_priority >= min_priority
|
||||
]
|
||||
|
||||
high_priority.sort(key=lambda x: x.training_priority, reverse=True)
|
||||
|
||||
return high_priority[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting high priority episodes for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def get_collection_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive collection statistics"""
|
||||
stats = self.collection_stats.copy()
|
||||
|
||||
# Add per-symbol statistics
|
||||
stats['episodes_per_symbol'] = {
|
||||
symbol: len(episodes)
|
||||
for symbol, episodes in self.training_episodes.items()
|
||||
}
|
||||
|
||||
# Add pending outcomes count
|
||||
stats['pending_outcomes'] = {
|
||||
symbol: len(packages)
|
||||
for symbol, packages in self.pending_outcomes.items()
|
||||
}
|
||||
|
||||
# Calculate profitability rate
|
||||
if stats['total_episodes'] > 0:
|
||||
stats['profitability_rate'] = stats['profitable_episodes'] / stats['total_episodes']
|
||||
stats['rapid_change_rate'] = stats['rapid_change_episodes'] / stats['total_episodes']
|
||||
else:
|
||||
stats['profitability_rate'] = 0.0
|
||||
stats['rapid_change_rate'] = 0.0
|
||||
|
||||
return stats
|
||||
|
||||
def validate_data_integrity(self) -> Dict[str, Any]:
|
||||
"""Comprehensive data integrity validation"""
|
||||
validation_results = {
|
||||
'total_episodes_checked': 0,
|
||||
'hash_mismatches': 0,
|
||||
'completeness_issues': 0,
|
||||
'validation_flag_failures': 0,
|
||||
'corrupted_episodes': [],
|
||||
'integrity_score': 1.0
|
||||
}
|
||||
|
||||
try:
|
||||
for symbol, episodes in self.training_episodes.items():
|
||||
for episode in episodes:
|
||||
validation_results['total_episodes_checked'] += 1
|
||||
|
||||
# Check data hash
|
||||
expected_hash = episode.input_package._calculate_hash()
|
||||
if expected_hash != episode.input_package.data_hash:
|
||||
validation_results['hash_mismatches'] += 1
|
||||
validation_results['corrupted_episodes'].append(episode.episode_id)
|
||||
|
||||
# Check completeness
|
||||
if episode.input_package.completeness_score < 0.7:
|
||||
validation_results['completeness_issues'] += 1
|
||||
|
||||
# Check validation flags
|
||||
if not episode.input_package.validation_flags.get('data_consistent', False):
|
||||
validation_results['validation_flag_failures'] += 1
|
||||
|
||||
# Calculate integrity score
|
||||
total_issues = (
|
||||
validation_results['hash_mismatches'] +
|
||||
validation_results['completeness_issues'] +
|
||||
validation_results['validation_flag_failures']
|
||||
)
|
||||
|
||||
if validation_results['total_episodes_checked'] > 0:
|
||||
validation_results['integrity_score'] = 1.0 - (
|
||||
total_issues / validation_results['total_episodes_checked']
|
||||
)
|
||||
|
||||
logger.info(f"Data integrity validation completed")
|
||||
logger.info(f"Integrity score: {validation_results['integrity_score']:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during data integrity validation: {e}")
|
||||
validation_results['validation_error'] = str(e)
|
||||
|
||||
return validation_results
|
||||
|
||||
# Global instance for easy access
|
||||
training_data_collector = None
|
||||
|
||||
def get_training_data_collector() -> TrainingDataCollector:
|
||||
"""Get global training data collector instance"""
|
||||
global training_data_collector
|
||||
if training_data_collector is None:
|
||||
training_data_collector = TrainingDataCollector()
|
||||
return training_data_collector
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,164 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify that both model prediction and trading statistics issues are fixed
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
from core.trading_executor import TradingExecutor
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def test_model_predictions():
|
||||
"""Test that model predictions are working correctly"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING MODEL PREDICTIONS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider)
|
||||
|
||||
# Check model registration
|
||||
logger.info("1. Checking model registration...")
|
||||
models = orchestrator.model_registry.get_all_models()
|
||||
logger.info(f" Registered models: {list(models.keys()) if models else 'None'}")
|
||||
|
||||
# Test making a decision
|
||||
logger.info("2. Testing trading decision generation...")
|
||||
decision = await orchestrator.make_trading_decision('ETH/USDT')
|
||||
|
||||
if decision:
|
||||
logger.info(f" ✅ Decision generated: {decision.action} (confidence: {decision.confidence:.3f})")
|
||||
logger.info(f" ✅ Reasoning: {decision.reasoning}")
|
||||
return True
|
||||
else:
|
||||
logger.error(" ❌ No decision generated")
|
||||
return False
|
||||
|
||||
def test_trading_statistics():
|
||||
"""Test that trading statistics calculations are working correctly"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING TRADING STATISTICS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize trading executor
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Check if we have any trades
|
||||
trade_history = trading_executor.get_trade_history()
|
||||
logger.info(f"1. Current trade history: {len(trade_history)} trades")
|
||||
|
||||
# Get daily stats
|
||||
daily_stats = trading_executor.get_daily_stats()
|
||||
logger.info("2. Daily statistics from trading executor:")
|
||||
logger.info(f" Total trades: {daily_stats.get('total_trades', 0)}")
|
||||
logger.info(f" Winning trades: {daily_stats.get('winning_trades', 0)}")
|
||||
logger.info(f" Losing trades: {daily_stats.get('losing_trades', 0)}")
|
||||
logger.info(f" Win rate: {daily_stats.get('win_rate', 0.0) * 100:.1f}%")
|
||||
logger.info(f" Avg winning trade: ${daily_stats.get('avg_winning_trade', 0.0):.2f}")
|
||||
logger.info(f" Avg losing trade: ${daily_stats.get('avg_losing_trade', 0.0):.2f}")
|
||||
logger.info(f" Total P&L: ${daily_stats.get('total_pnl', 0.0):.2f}")
|
||||
|
||||
# Simulate some trades if we don't have any
|
||||
if daily_stats.get('total_trades', 0) == 0:
|
||||
logger.info("3. No trades found - simulating some test trades...")
|
||||
|
||||
# Add some mock trades to the trade history
|
||||
from core.trading_executor import TradeRecord
|
||||
from datetime import datetime
|
||||
|
||||
# Add a winning trade
|
||||
winning_trade = TradeRecord(
|
||||
symbol='ETH/USDT',
|
||||
side='LONG',
|
||||
quantity=0.01,
|
||||
entry_price=2500.0,
|
||||
exit_price=2550.0,
|
||||
entry_time=datetime.now(),
|
||||
exit_time=datetime.now(),
|
||||
pnl=0.50, # $0.50 profit
|
||||
fees=0.01,
|
||||
confidence=0.8
|
||||
)
|
||||
trading_executor.trade_history.append(winning_trade)
|
||||
|
||||
# Add a losing trade
|
||||
losing_trade = TradeRecord(
|
||||
symbol='ETH/USDT',
|
||||
side='LONG',
|
||||
quantity=0.01,
|
||||
entry_price=2500.0,
|
||||
exit_price=2480.0,
|
||||
entry_time=datetime.now(),
|
||||
exit_time=datetime.now(),
|
||||
pnl=-0.20, # $0.20 loss
|
||||
fees=0.01,
|
||||
confidence=0.7
|
||||
)
|
||||
trading_executor.trade_history.append(losing_trade)
|
||||
|
||||
# Get updated stats
|
||||
daily_stats = trading_executor.get_daily_stats()
|
||||
logger.info(" Updated statistics after adding test trades:")
|
||||
logger.info(f" Total trades: {daily_stats.get('total_trades', 0)}")
|
||||
logger.info(f" Winning trades: {daily_stats.get('winning_trades', 0)}")
|
||||
logger.info(f" Losing trades: {daily_stats.get('losing_trades', 0)}")
|
||||
logger.info(f" Win rate: {daily_stats.get('win_rate', 0.0) * 100:.1f}%")
|
||||
logger.info(f" Avg winning trade: ${daily_stats.get('avg_winning_trade', 0.0):.2f}")
|
||||
logger.info(f" Avg losing trade: ${daily_stats.get('avg_losing_trade', 0.0):.2f}")
|
||||
logger.info(f" Total P&L: ${daily_stats.get('total_pnl', 0.0):.2f}")
|
||||
|
||||
# Verify calculations
|
||||
expected_win_rate = 1/2 # 1 win out of 2 trades = 50%
|
||||
expected_avg_win = 0.50
|
||||
expected_avg_loss = -0.20
|
||||
|
||||
actual_win_rate = daily_stats.get('win_rate', 0.0)
|
||||
actual_avg_win = daily_stats.get('avg_winning_trade', 0.0)
|
||||
actual_avg_loss = daily_stats.get('avg_losing_trade', 0.0)
|
||||
|
||||
logger.info("4. Verifying calculations:")
|
||||
logger.info(f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {actual_win_rate*100:.1f}% ✅" if abs(actual_win_rate - expected_win_rate) < 0.01 else f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {actual_win_rate*100:.1f}% ❌")
|
||||
logger.info(f" Avg win: Expected ${expected_avg_win:.2f}, Got ${actual_avg_win:.2f} ✅" if abs(actual_avg_win - expected_avg_win) < 0.01 else f" Avg win: Expected ${expected_avg_win:.2f}, Got ${actual_avg_win:.2f} ❌")
|
||||
logger.info(f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${actual_avg_loss:.2f} ✅" if abs(actual_avg_loss - expected_avg_loss) < 0.01 else f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${actual_avg_loss:.2f} ❌")
|
||||
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
async def main():
|
||||
"""Run all tests"""
|
||||
|
||||
logger.info("🚀 STARTING COMPREHENSIVE FIXES TEST")
|
||||
logger.info("Testing both model prediction fixes and trading statistics fixes")
|
||||
|
||||
# Test model predictions
|
||||
prediction_success = await test_model_predictions()
|
||||
|
||||
# Test trading statistics
|
||||
stats_success = test_trading_statistics()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TEST SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Model Predictions: {'✅ FIXED' if prediction_success else '❌ STILL BROKEN'}")
|
||||
logger.info(f"Trading Statistics: {'✅ FIXED' if stats_success else '❌ STILL BROKEN'}")
|
||||
|
||||
if prediction_success and stats_success:
|
||||
logger.info("🎉 ALL ISSUES FIXED! The system should now work correctly.")
|
||||
else:
|
||||
logger.error("❌ Some issues remain. Check the logs above for details.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,250 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify trading fixes:
|
||||
1. Position sizes with leverage
|
||||
2. ETH-only trading
|
||||
3. Correct win rate calculations
|
||||
4. Meaningful P&L values
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from core.trading_executor import TradingExecutor
|
||||
from core.trading_executor import TradeRecord
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_position_sizing():
|
||||
"""Test that position sizing now includes leverage and meaningful amounts"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING POSITION SIZING WITH LEVERAGE")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize trading executor
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Test position calculation
|
||||
confidence = 0.8
|
||||
current_price = 2500.0 # ETH price
|
||||
|
||||
position_value = trading_executor._calculate_position_size(confidence, current_price)
|
||||
quantity = position_value / current_price
|
||||
|
||||
logger.info(f"1. Position calculation test:")
|
||||
logger.info(f" Confidence: {confidence}")
|
||||
logger.info(f" ETH Price: ${current_price}")
|
||||
logger.info(f" Position Value: ${position_value:.2f}")
|
||||
logger.info(f" Quantity: {quantity:.6f} ETH")
|
||||
|
||||
# Check if position is meaningful
|
||||
if position_value > 1000: # Should be >$1000 with 10x leverage
|
||||
logger.info(" ✅ Position size is meaningful (>$1000)")
|
||||
else:
|
||||
logger.error(f" ❌ Position size too small: ${position_value:.2f}")
|
||||
|
||||
# Test different confidence levels
|
||||
logger.info("2. Testing different confidence levels:")
|
||||
for conf in [0.2, 0.5, 0.8, 1.0]:
|
||||
pos_val = trading_executor._calculate_position_size(conf, current_price)
|
||||
qty = pos_val / current_price
|
||||
logger.info(f" Confidence {conf}: ${pos_val:.2f} ({qty:.6f} ETH)")
|
||||
|
||||
def test_eth_only_restriction():
|
||||
"""Test that only ETH trades are allowed"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING ETH-ONLY TRADING RESTRICTION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Test ETH trade (should be allowed)
|
||||
logger.info("1. Testing ETH/USDT trade (should be allowed):")
|
||||
eth_allowed = trading_executor._check_safety_conditions('ETH/USDT', 'BUY')
|
||||
logger.info(f" ETH/USDT allowed: {'✅ YES' if eth_allowed else '❌ NO'}")
|
||||
|
||||
# Test BTC trade (should be blocked)
|
||||
logger.info("2. Testing BTC/USDT trade (should be blocked):")
|
||||
btc_allowed = trading_executor._check_safety_conditions('BTC/USDT', 'BUY')
|
||||
logger.info(f" BTC/USDT allowed: {'❌ YES (ERROR!)' if btc_allowed else '✅ NO (CORRECT)'}")
|
||||
|
||||
def test_win_rate_calculation():
|
||||
"""Test that win rate calculations are correct"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING WIN RATE CALCULATIONS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Clear existing trades
|
||||
trading_executor.trade_history = []
|
||||
|
||||
# Add test trades with meaningful P&L
|
||||
logger.info("1. Adding test trades with meaningful P&L:")
|
||||
|
||||
# Add 3 winning trades
|
||||
for i in range(3):
|
||||
winning_trade = TradeRecord(
|
||||
symbol='ETH/USDT',
|
||||
side='LONG',
|
||||
quantity=1.0,
|
||||
entry_price=2500.0,
|
||||
exit_price=2550.0,
|
||||
entry_time=datetime.now(),
|
||||
exit_time=datetime.now(),
|
||||
pnl=50.0, # $50 profit with leverage
|
||||
fees=1.0,
|
||||
confidence=0.8,
|
||||
hold_time_seconds=30.0 # 30 second hold
|
||||
)
|
||||
trading_executor.trade_history.append(winning_trade)
|
||||
logger.info(f" Added winning trade #{i+1}: +$50.00 (30s hold)")
|
||||
|
||||
# Add 2 losing trades
|
||||
for i in range(2):
|
||||
losing_trade = TradeRecord(
|
||||
symbol='ETH/USDT',
|
||||
side='LONG',
|
||||
quantity=1.0,
|
||||
entry_price=2500.0,
|
||||
exit_price=2475.0,
|
||||
entry_time=datetime.now(),
|
||||
exit_time=datetime.now(),
|
||||
pnl=-25.0, # $25 loss with leverage
|
||||
fees=1.0,
|
||||
confidence=0.7,
|
||||
hold_time_seconds=15.0 # 15 second hold
|
||||
)
|
||||
trading_executor.trade_history.append(losing_trade)
|
||||
logger.info(f" Added losing trade #{i+1}: -$25.00 (15s hold)")
|
||||
|
||||
# Get statistics
|
||||
stats = trading_executor.get_daily_stats()
|
||||
|
||||
logger.info("2. Calculated statistics:")
|
||||
logger.info(f" Total trades: {stats['total_trades']}")
|
||||
logger.info(f" Winning trades: {stats['winning_trades']}")
|
||||
logger.info(f" Losing trades: {stats['losing_trades']}")
|
||||
logger.info(f" Win rate: {stats['win_rate']*100:.1f}%")
|
||||
logger.info(f" Avg winning trade: ${stats['avg_winning_trade']:.2f}")
|
||||
logger.info(f" Avg losing trade: ${stats['avg_losing_trade']:.2f}")
|
||||
logger.info(f" Total P&L: ${stats['total_pnl']:.2f}")
|
||||
|
||||
# Verify calculations
|
||||
expected_win_rate = 3/5 # 3 wins out of 5 trades = 60%
|
||||
expected_avg_win = 50.0
|
||||
expected_avg_loss = -25.0
|
||||
|
||||
logger.info("3. Verification:")
|
||||
win_rate_ok = abs(stats['win_rate'] - expected_win_rate) < 0.01
|
||||
avg_win_ok = abs(stats['avg_winning_trade'] - expected_avg_win) < 0.01
|
||||
avg_loss_ok = abs(stats['avg_losing_trade'] - expected_avg_loss) < 0.01
|
||||
|
||||
logger.info(f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {stats['win_rate']*100:.1f}% {'✅' if win_rate_ok else '❌'}")
|
||||
logger.info(f" Avg win: Expected ${expected_avg_win:.2f}, Got ${stats['avg_winning_trade']:.2f} {'✅' if avg_win_ok else '❌'}")
|
||||
logger.info(f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${stats['avg_losing_trade']:.2f} {'✅' if avg_loss_ok else '❌'}")
|
||||
|
||||
return win_rate_ok and avg_win_ok and avg_loss_ok
|
||||
|
||||
def test_new_features():
|
||||
"""Test new features: hold time, leverage, percentage-based sizing"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING NEW FEATURES")
|
||||
logger.info("=" * 60)
|
||||
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Test account info
|
||||
account_info = trading_executor.get_account_info()
|
||||
logger.info(f"1. Account Information:")
|
||||
logger.info(f" Account Balance: ${account_info['account_balance']:.2f}")
|
||||
logger.info(f" Leverage: {account_info['leverage']:.0f}x")
|
||||
logger.info(f" Trading Mode: {account_info['trading_mode']}")
|
||||
logger.info(f" Position Sizing: {account_info['position_sizing']['base_percent']:.1f}% base")
|
||||
|
||||
# Test leverage setting
|
||||
logger.info("2. Testing leverage control:")
|
||||
old_leverage = trading_executor.get_leverage()
|
||||
logger.info(f" Current leverage: {old_leverage:.0f}x")
|
||||
|
||||
success = trading_executor.set_leverage(100.0)
|
||||
new_leverage = trading_executor.get_leverage()
|
||||
logger.info(f" Set to 100x: {'✅ SUCCESS' if success and new_leverage == 100.0 else '❌ FAILED'}")
|
||||
|
||||
# Reset leverage
|
||||
trading_executor.set_leverage(old_leverage)
|
||||
|
||||
# Test percentage-based position sizing
|
||||
logger.info("3. Testing percentage-based position sizing:")
|
||||
confidence = 0.8
|
||||
eth_price = 2500.0
|
||||
|
||||
position_value = trading_executor._calculate_position_size(confidence, eth_price)
|
||||
account_balance = trading_executor._get_account_balance_for_sizing()
|
||||
base_percent = trading_executor.mexc_config.get('base_position_percent', 5.0)
|
||||
leverage = trading_executor.get_leverage()
|
||||
|
||||
expected_base = account_balance * (base_percent / 100.0) * confidence
|
||||
expected_leveraged = expected_base * leverage
|
||||
|
||||
logger.info(f" Account: ${account_balance:.2f}")
|
||||
logger.info(f" Base %: {base_percent:.1f}%")
|
||||
logger.info(f" Confidence: {confidence:.1f}")
|
||||
logger.info(f" Leverage: {leverage:.0f}x")
|
||||
logger.info(f" Expected base: ${expected_base:.2f}")
|
||||
logger.info(f" Expected leveraged: ${expected_leveraged:.2f}")
|
||||
logger.info(f" Actual: ${position_value:.2f}")
|
||||
|
||||
sizing_ok = abs(position_value - expected_leveraged) < 0.01
|
||||
logger.info(f" Percentage sizing: {'✅ CORRECT' if sizing_ok else '❌ INCORRECT'}")
|
||||
|
||||
return sizing_ok
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
|
||||
logger.info("🚀 TESTING TRADING FIXES AND NEW FEATURES")
|
||||
logger.info("Testing position sizing, ETH-only trading, win rate calculations, and new features")
|
||||
|
||||
# Test position sizing
|
||||
test_position_sizing()
|
||||
|
||||
# Test ETH-only restriction
|
||||
test_eth_only_restriction()
|
||||
|
||||
# Test win rate calculation
|
||||
calculation_success = test_win_rate_calculation()
|
||||
|
||||
# Test new features
|
||||
features_success = test_new_features()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TEST SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Position Sizing: ✅ Updated with percentage-based leverage")
|
||||
logger.info(f"ETH-Only Trading: ✅ Configured in config")
|
||||
logger.info(f"Win Rate Calculation: {'✅ FIXED' if calculation_success else '❌ STILL BROKEN'}")
|
||||
logger.info(f"New Features: {'✅ WORKING' if features_success else '❌ ISSUES FOUND'}")
|
||||
|
||||
if calculation_success and features_success:
|
||||
logger.info("🎉 ALL FEATURES WORKING! Now you should see:")
|
||||
logger.info(" - Percentage-based position sizing (2-20% of account)")
|
||||
logger.info(" - 50x leverage (adjustable in UI)")
|
||||
logger.info(" - Hold time in seconds for each trade")
|
||||
logger.info(" - Total fees in trading statistics")
|
||||
logger.info(" - Only ETH/USDT trades")
|
||||
logger.info(" - Correct win rate calculations")
|
||||
else:
|
||||
logger.error("❌ Some issues remain. Check the logs above for details.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,344 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Trade Audit Tool
|
||||
|
||||
This tool analyzes trade data to identify potential issues with:
|
||||
- Duplicate entry prices
|
||||
- Rapid consecutive trades
|
||||
- P&L calculation accuracy
|
||||
- Position tracking problems
|
||||
|
||||
Usage:
|
||||
python debug/trade_audit.py [--trades-file path/to/trades.json]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
def parse_trade_time(time_str):
|
||||
"""Parse trade time string to datetime object"""
|
||||
try:
|
||||
# Try HH:MM:SS format
|
||||
return datetime.strptime(time_str, "%H:%M:%S")
|
||||
except ValueError:
|
||||
try:
|
||||
# Try full datetime format
|
||||
return datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
# Return as is if parsing fails
|
||||
return time_str
|
||||
|
||||
def load_trades_from_file(file_path):
|
||||
"""Load trades from JSON file"""
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File {file_path} not found")
|
||||
return []
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: File {file_path} is not valid JSON")
|
||||
return []
|
||||
|
||||
def load_trades_from_dashboard_cache():
|
||||
"""Load trades from dashboard cache file if available"""
|
||||
cache_paths = [
|
||||
"cache/dashboard_trades.json",
|
||||
"cache/closed_trades.json",
|
||||
"data/trades_history.json"
|
||||
]
|
||||
|
||||
for path in cache_paths:
|
||||
if os.path.exists(path):
|
||||
print(f"Loading trades from cache: {path}")
|
||||
return load_trades_from_file(path)
|
||||
|
||||
print("No trade cache files found")
|
||||
return []
|
||||
|
||||
def parse_trade_data(trades_data):
|
||||
"""Parse trade data into a pandas DataFrame for analysis"""
|
||||
parsed_trades = []
|
||||
|
||||
for trade in trades_data:
|
||||
# Handle different trade data formats
|
||||
parsed_trade = {}
|
||||
|
||||
# Time field might be named entry_time or time
|
||||
if 'entry_time' in trade:
|
||||
parsed_trade['time'] = parse_trade_time(trade['entry_time'])
|
||||
elif 'time' in trade:
|
||||
parsed_trade['time'] = parse_trade_time(trade['time'])
|
||||
else:
|
||||
parsed_trade['time'] = None
|
||||
|
||||
# Side might be named side or action
|
||||
parsed_trade['side'] = trade.get('side', trade.get('action', 'UNKNOWN'))
|
||||
|
||||
# Size might be named size or quantity
|
||||
parsed_trade['size'] = float(trade.get('size', trade.get('quantity', 0)))
|
||||
|
||||
# Entry and exit prices
|
||||
parsed_trade['entry_price'] = float(trade.get('entry_price', trade.get('entry', 0)))
|
||||
parsed_trade['exit_price'] = float(trade.get('exit_price', trade.get('exit', 0)))
|
||||
|
||||
# Hold time in seconds
|
||||
parsed_trade['hold_time'] = float(trade.get('hold_time_seconds', trade.get('hold', 0)))
|
||||
|
||||
# P&L and fees
|
||||
parsed_trade['pnl'] = float(trade.get('pnl', 0))
|
||||
parsed_trade['fees'] = float(trade.get('fees', 0))
|
||||
|
||||
# Calculate expected P&L for verification
|
||||
if parsed_trade['side'] == 'LONG' or parsed_trade['side'] == 'BUY':
|
||||
expected_pnl = (parsed_trade['exit_price'] - parsed_trade['entry_price']) * parsed_trade['size']
|
||||
else: # SHORT or SELL
|
||||
expected_pnl = (parsed_trade['entry_price'] - parsed_trade['exit_price']) * parsed_trade['size']
|
||||
|
||||
parsed_trade['expected_pnl'] = expected_pnl
|
||||
parsed_trade['pnl_difference'] = parsed_trade['pnl'] - expected_pnl
|
||||
|
||||
parsed_trades.append(parsed_trade)
|
||||
|
||||
# Convert to DataFrame
|
||||
if parsed_trades:
|
||||
df = pd.DataFrame(parsed_trades)
|
||||
return df
|
||||
else:
|
||||
return pd.DataFrame()
|
||||
|
||||
def analyze_trades(df):
|
||||
"""Analyze trades for potential issues"""
|
||||
if df.empty:
|
||||
print("No trades to analyze")
|
||||
return
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print("TRADE AUDIT RESULTS")
|
||||
print(f"{'='*50}")
|
||||
print(f"Total trades analyzed: {len(df)}")
|
||||
|
||||
# Check for duplicate entry prices
|
||||
entry_price_counts = df['entry_price'].value_counts()
|
||||
duplicate_entries = entry_price_counts[entry_price_counts > 1]
|
||||
|
||||
print(f"\n{'='*20} DUPLICATE ENTRY PRICES {'='*20}")
|
||||
if not duplicate_entries.empty:
|
||||
print(f"Found {len(duplicate_entries)} prices with multiple entries:")
|
||||
for price, count in duplicate_entries.items():
|
||||
print(f" ${price:.2f}: {count} trades")
|
||||
|
||||
# Analyze the duplicate entry trades in more detail
|
||||
for price in duplicate_entries.index:
|
||||
duplicate_df = df[df['entry_price'] == price].copy()
|
||||
duplicate_df['time_diff'] = duplicate_df['time'].diff().dt.total_seconds()
|
||||
|
||||
print(f"\nDetailed analysis for entry price ${price:.2f}:")
|
||||
print(f" Time gaps between consecutive trades:")
|
||||
for i, (_, row) in enumerate(duplicate_df.iterrows()):
|
||||
if i > 0: # Skip first row as it has no previous trade
|
||||
time_diff = row['time_diff']
|
||||
if pd.notna(time_diff):
|
||||
print(f" {row['time'].strftime('%H:%M:%S')}: {time_diff:.0f} seconds after previous trade")
|
||||
else:
|
||||
print("No duplicate entry prices found")
|
||||
|
||||
# Check for rapid consecutive trades
|
||||
df = df.sort_values('time')
|
||||
df['time_since_last'] = df['time'].diff().dt.total_seconds()
|
||||
|
||||
rapid_trades = df[df['time_since_last'] < 30].copy()
|
||||
|
||||
print(f"\n{'='*20} RAPID CONSECUTIVE TRADES {'='*20}")
|
||||
if not rapid_trades.empty:
|
||||
print(f"Found {len(rapid_trades)} trades executed within 30 seconds of previous trade:")
|
||||
for _, row in rapid_trades.iterrows():
|
||||
if pd.notna(row['time_since_last']):
|
||||
print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} ${row['size']:.2f} @ ${row['entry_price']:.2f} - {row['time_since_last']:.0f}s after previous")
|
||||
else:
|
||||
print("No rapid consecutive trades found")
|
||||
|
||||
# Check for P&L calculation accuracy
|
||||
pnl_diff = df[abs(df['pnl_difference']) > 0.01].copy()
|
||||
|
||||
print(f"\n{'='*20} P&L CALCULATION ISSUES {'='*20}")
|
||||
if not pnl_diff.empty:
|
||||
print(f"Found {len(pnl_diff)} trades with P&L calculation discrepancies:")
|
||||
for _, row in pnl_diff.iterrows():
|
||||
print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} - Reported: ${row['pnl']:.2f}, Expected: ${row['expected_pnl']:.2f}, Diff: ${row['pnl_difference']:.2f}")
|
||||
else:
|
||||
print("No P&L calculation issues found")
|
||||
|
||||
# Check for side distribution
|
||||
side_counts = df['side'].value_counts()
|
||||
|
||||
print(f"\n{'='*20} TRADE SIDE DISTRIBUTION {'='*20}")
|
||||
for side, count in side_counts.items():
|
||||
print(f" {side}: {count} trades ({count/len(df)*100:.1f}%)")
|
||||
|
||||
# Check for hold time distribution
|
||||
print(f"\n{'='*20} HOLD TIME DISTRIBUTION {'='*20}")
|
||||
print(f" Min hold time: {df['hold_time'].min():.0f} seconds")
|
||||
print(f" Max hold time: {df['hold_time'].max():.0f} seconds")
|
||||
print(f" Avg hold time: {df['hold_time'].mean():.0f} seconds")
|
||||
print(f" Median hold time: {df['hold_time'].median():.0f} seconds")
|
||||
|
||||
# Hold time buckets
|
||||
hold_buckets = [0, 30, 60, 120, 300, 600, 1800, 3600, float('inf')]
|
||||
hold_labels = ['0-30s', '30-60s', '1-2m', '2-5m', '5-10m', '10-30m', '30-60m', '60m+']
|
||||
|
||||
df['hold_bucket'] = pd.cut(df['hold_time'], bins=hold_buckets, labels=hold_labels)
|
||||
hold_dist = df['hold_bucket'].value_counts().sort_index()
|
||||
|
||||
for bucket, count in hold_dist.items():
|
||||
print(f" {bucket}: {count} trades ({count/len(df)*100:.1f}%)")
|
||||
|
||||
# Generate summary statistics
|
||||
print(f"\n{'='*20} TRADE PERFORMANCE SUMMARY {'='*20}")
|
||||
winning_trades = df[df['pnl'] > 0]
|
||||
losing_trades = df[df['pnl'] < 0]
|
||||
|
||||
print(f" Win rate: {len(winning_trades)/len(df)*100:.1f}% ({len(winning_trades)}W/{len(losing_trades)}L)")
|
||||
print(f" Avg win: ${winning_trades['pnl'].mean():.2f}")
|
||||
print(f" Avg loss: ${abs(losing_trades['pnl'].mean()):.2f}")
|
||||
print(f" Total P&L: ${df['pnl'].sum():.2f}")
|
||||
print(f" Total fees: ${df['fees'].sum():.2f}")
|
||||
print(f" Net P&L: ${(df['pnl'].sum() - df['fees'].sum()):.2f}")
|
||||
|
||||
# Plot entry price distribution
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.hist(df['entry_price'], bins=20, alpha=0.7)
|
||||
plt.title('Entry Price Distribution')
|
||||
plt.xlabel('Entry Price ($)')
|
||||
plt.ylabel('Number of Trades')
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.savefig('debug/entry_price_distribution.png')
|
||||
|
||||
# Plot P&L distribution
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.hist(df['pnl'], bins=20, alpha=0.7)
|
||||
plt.title('P&L Distribution')
|
||||
plt.xlabel('P&L ($)')
|
||||
plt.ylabel('Number of Trades')
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.savefig('debug/pnl_distribution.png')
|
||||
|
||||
print(f"\n{'='*20} AUDIT COMPLETE {'='*20}")
|
||||
print("Plots saved to debug/entry_price_distribution.png and debug/pnl_distribution.png")
|
||||
|
||||
def analyze_manual_trades(trades_data):
|
||||
"""Analyze manually provided trade data"""
|
||||
# Parse the trade data into a structured format
|
||||
parsed_trades = []
|
||||
|
||||
for line in trades_data.strip().split('\n'):
|
||||
if not line or line.startswith('from last session') or line.startswith('Recent Closed Trades') or line.startswith('Trading Performance'):
|
||||
continue
|
||||
|
||||
if line.startswith('Win Rate:'):
|
||||
# This is the summary line, skip it
|
||||
continue
|
||||
|
||||
try:
|
||||
# Parse trade line format: Time Side Size Entry Exit Hold P&L Fees
|
||||
parts = line.split('$')
|
||||
|
||||
time_side = parts[0].strip().split()
|
||||
time = time_side[0]
|
||||
side = time_side[1]
|
||||
|
||||
size = float(parts[1].split()[0])
|
||||
entry = float(parts[2].split()[0])
|
||||
exit = float(parts[3].split()[0])
|
||||
|
||||
# The hold time and P&L are in the last parts
|
||||
remaining = parts[3].split()
|
||||
hold = int(remaining[1])
|
||||
pnl = float(parts[4].split()[0])
|
||||
|
||||
# Fees might be in a different format
|
||||
if len(parts) > 5:
|
||||
fees = float(parts[5].strip())
|
||||
else:
|
||||
fees = 0.0
|
||||
|
||||
parsed_trade = {
|
||||
'time': parse_trade_time(time),
|
||||
'side': side,
|
||||
'size': size,
|
||||
'entry_price': entry,
|
||||
'exit_price': exit,
|
||||
'hold_time': hold,
|
||||
'pnl': pnl,
|
||||
'fees': fees
|
||||
}
|
||||
|
||||
# Calculate expected P&L
|
||||
if side == 'LONG' or side == 'BUY':
|
||||
expected_pnl = (exit - entry) * size
|
||||
else: # SHORT or SELL
|
||||
expected_pnl = (entry - exit) * size
|
||||
|
||||
parsed_trade['expected_pnl'] = expected_pnl
|
||||
parsed_trade['pnl_difference'] = pnl - expected_pnl
|
||||
|
||||
parsed_trades.append(parsed_trade)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing trade line: {line}")
|
||||
print(f"Error details: {e}")
|
||||
|
||||
# Convert to DataFrame
|
||||
if parsed_trades:
|
||||
df = pd.DataFrame(parsed_trades)
|
||||
return df
|
||||
else:
|
||||
return pd.DataFrame()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Trade Audit Tool')
|
||||
parser.add_argument('--trades-file', type=str, help='Path to trades JSON file')
|
||||
parser.add_argument('--manual-trades', type=str, help='Path to text file with manually entered trades')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create debug directory if it doesn't exist
|
||||
os.makedirs('debug', exist_ok=True)
|
||||
|
||||
if args.trades_file:
|
||||
trades_data = load_trades_from_file(args.trades_file)
|
||||
df = parse_trade_data(trades_data)
|
||||
elif args.manual_trades:
|
||||
try:
|
||||
with open(args.manual_trades, 'r') as f:
|
||||
manual_trades = f.read()
|
||||
df = analyze_manual_trades(manual_trades)
|
||||
except Exception as e:
|
||||
print(f"Error reading manual trades file: {e}")
|
||||
df = pd.DataFrame()
|
||||
else:
|
||||
# Try to load from dashboard cache
|
||||
trades_data = load_trades_from_dashboard_cache()
|
||||
if trades_data:
|
||||
df = parse_trade_data(trades_data)
|
||||
else:
|
||||
print("No trade data provided. Use --trades-file or --manual-trades")
|
||||
return
|
||||
|
||||
if not df.empty:
|
||||
analyze_trades(df)
|
||||
else:
|
||||
print("No valid trade data to analyze")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,84 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Debug Training Methods
|
||||
|
||||
This script checks what training methods are available on each model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
async def debug_training_methods():
|
||||
"""Debug the available training methods on each model"""
|
||||
print("=== Debugging Training Methods ===")
|
||||
|
||||
# Initialize orchestrator
|
||||
print("1. Initializing orchestrator...")
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Wait for initialization
|
||||
await asyncio.sleep(2)
|
||||
|
||||
print("\n2. Checking available training methods on each model:")
|
||||
|
||||
for model_name, model_interface in orchestrator.model_registry.models.items():
|
||||
print(f"\n--- {model_name} ---")
|
||||
print(f"Interface type: {type(model_interface).__name__}")
|
||||
|
||||
# Get underlying model
|
||||
underlying_model = getattr(model_interface, 'model', None)
|
||||
if underlying_model:
|
||||
print(f"Underlying model type: {type(underlying_model).__name__}")
|
||||
else:
|
||||
print("No underlying model found")
|
||||
continue
|
||||
|
||||
# Check for training methods
|
||||
training_methods = []
|
||||
for method in ['train_on_outcome', 'add_experience', 'remember', 'replay', 'add_training_sample', 'train', 'train_with_reward', 'update_loss']:
|
||||
if hasattr(underlying_model, method):
|
||||
training_methods.append(method)
|
||||
|
||||
print(f"Available training methods: {training_methods}")
|
||||
|
||||
# Check for specific attributes
|
||||
attributes = []
|
||||
for attr in ['memory', 'batch_size', 'training_data']:
|
||||
if hasattr(underlying_model, attr):
|
||||
attr_value = getattr(underlying_model, attr)
|
||||
if attr == 'memory' and hasattr(attr_value, '__len__'):
|
||||
attributes.append(f"{attr}(len={len(attr_value)})")
|
||||
elif attr == 'training_data' and hasattr(attr_value, '__len__'):
|
||||
attributes.append(f"{attr}(len={len(attr_value)})")
|
||||
else:
|
||||
attributes.append(f"{attr}={attr_value}")
|
||||
|
||||
print(f"Relevant attributes: {attributes}")
|
||||
|
||||
# Check if it's an RL agent
|
||||
if hasattr(underlying_model, 'act') and hasattr(underlying_model, 'remember'):
|
||||
print("✅ Detected as RL Agent")
|
||||
elif hasattr(underlying_model, 'predict') and hasattr(underlying_model, 'add_training_sample'):
|
||||
print("✅ Detected as CNN Model")
|
||||
else:
|
||||
print("❓ Unknown model type")
|
||||
|
||||
print("\n3. Testing a simple training attempt:")
|
||||
|
||||
# Get a prediction first
|
||||
predictions = await orchestrator._get_all_predictions('ETH/USDT')
|
||||
print(f"Got {len(predictions)} predictions")
|
||||
|
||||
# Try to trigger training for each model
|
||||
for model_name in orchestrator.model_registry.models.keys():
|
||||
print(f"\nTesting training for {model_name}...")
|
||||
try:
|
||||
await orchestrator._trigger_immediate_training_for_model(model_name, 'ETH/USDT')
|
||||
print(f"✅ Training attempt completed for {model_name}")
|
||||
except Exception as e:
|
||||
print(f"❌ Training failed for {model_name}: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(debug_training_methods())
|
||||
@@ -1,233 +0,0 @@
|
||||
"""
|
||||
Bybit Integration Examples
|
||||
Based on official pybit library documentation and examples
|
||||
"""
|
||||
|
||||
import os
|
||||
from pybit.unified_trading import HTTP
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def create_bybit_session(testnet=True):
|
||||
"""Create a Bybit HTTP session.
|
||||
|
||||
Args:
|
||||
testnet (bool): Use testnet if True, live if False
|
||||
|
||||
Returns:
|
||||
HTTP: Bybit session object
|
||||
"""
|
||||
api_key = os.getenv('BYBIT_API_KEY')
|
||||
api_secret = os.getenv('BYBIT_API_SECRET')
|
||||
|
||||
if not api_key or not api_secret:
|
||||
raise ValueError("BYBIT_API_KEY and BYBIT_API_SECRET must be set in environment")
|
||||
|
||||
session = HTTP(
|
||||
testnet=testnet,
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
)
|
||||
|
||||
logger.info(f"Created Bybit session (testnet: {testnet})")
|
||||
return session
|
||||
|
||||
def get_account_info(session):
|
||||
"""Get account information and balances."""
|
||||
try:
|
||||
# Get account info
|
||||
account_info = session.get_wallet_balance(accountType="UNIFIED")
|
||||
logger.info(f"Account info: {account_info}")
|
||||
|
||||
return account_info
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting account info: {e}")
|
||||
return None
|
||||
|
||||
def get_ticker_info(session, symbol="BTCUSDT"):
|
||||
"""Get ticker information for a symbol.
|
||||
|
||||
Args:
|
||||
session: Bybit HTTP session
|
||||
symbol: Trading symbol (default: BTCUSDT)
|
||||
"""
|
||||
try:
|
||||
ticker = session.get_tickers(category="linear", symbol=symbol)
|
||||
logger.info(f"Ticker for {symbol}: {ticker}")
|
||||
return ticker
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting ticker for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_orderbook(session, symbol="BTCUSDT", limit=25):
|
||||
"""Get orderbook for a symbol.
|
||||
|
||||
Args:
|
||||
session: Bybit HTTP session
|
||||
symbol: Trading symbol
|
||||
limit: Number of price levels to return
|
||||
"""
|
||||
try:
|
||||
orderbook = session.get_orderbook(
|
||||
category="linear",
|
||||
symbol=symbol,
|
||||
limit=limit
|
||||
)
|
||||
logger.info(f"Orderbook for {symbol}: {orderbook}")
|
||||
return orderbook
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting orderbook for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def place_limit_order(session, symbol="BTCUSDT", side="Buy", qty="0.001", price="50000"):
|
||||
"""Place a limit order.
|
||||
|
||||
Args:
|
||||
session: Bybit HTTP session
|
||||
symbol: Trading symbol
|
||||
side: "Buy" or "Sell"
|
||||
qty: Order quantity as string
|
||||
price: Order price as string
|
||||
"""
|
||||
try:
|
||||
order = session.place_order(
|
||||
category="linear",
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
orderType="Limit",
|
||||
qty=qty,
|
||||
price=price,
|
||||
timeInForce="GTC" # Good Till Cancelled
|
||||
)
|
||||
logger.info(f"Placed order: {order}")
|
||||
return order
|
||||
except Exception as e:
|
||||
logger.error(f"Error placing order: {e}")
|
||||
return None
|
||||
|
||||
def place_market_order(session, symbol="BTCUSDT", side="Buy", qty="0.001"):
|
||||
"""Place a market order.
|
||||
|
||||
Args:
|
||||
session: Bybit HTTP session
|
||||
symbol: Trading symbol
|
||||
side: "Buy" or "Sell"
|
||||
qty: Order quantity as string
|
||||
"""
|
||||
try:
|
||||
order = session.place_order(
|
||||
category="linear",
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
orderType="Market",
|
||||
qty=qty
|
||||
)
|
||||
logger.info(f"Placed market order: {order}")
|
||||
return order
|
||||
except Exception as e:
|
||||
logger.error(f"Error placing market order: {e}")
|
||||
return None
|
||||
|
||||
def get_open_orders(session, symbol=None):
|
||||
"""Get open orders.
|
||||
|
||||
Args:
|
||||
session: Bybit HTTP session
|
||||
symbol: Trading symbol (optional, gets all if None)
|
||||
"""
|
||||
try:
|
||||
params = {"category": "linear", "openOnly": True}
|
||||
if symbol:
|
||||
params["symbol"] = symbol
|
||||
|
||||
orders = session.get_open_orders(**params)
|
||||
logger.info(f"Open orders: {orders}")
|
||||
return orders
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting open orders: {e}")
|
||||
return None
|
||||
|
||||
def cancel_order(session, symbol, order_id):
|
||||
"""Cancel an order.
|
||||
|
||||
Args:
|
||||
session: Bybit HTTP session
|
||||
symbol: Trading symbol
|
||||
order_id: Order ID to cancel
|
||||
"""
|
||||
try:
|
||||
result = session.cancel_order(
|
||||
category="linear",
|
||||
symbol=symbol,
|
||||
orderId=order_id
|
||||
)
|
||||
logger.info(f"Cancelled order {order_id}: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling order {order_id}: {e}")
|
||||
return None
|
||||
|
||||
def get_position(session, symbol="BTCUSDT"):
|
||||
"""Get position information.
|
||||
|
||||
Args:
|
||||
session: Bybit HTTP session
|
||||
symbol: Trading symbol
|
||||
"""
|
||||
try:
|
||||
positions = session.get_positions(
|
||||
category="linear",
|
||||
symbol=symbol
|
||||
)
|
||||
logger.info(f"Position for {symbol}: {positions}")
|
||||
return positions
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting position for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_trade_history(session, symbol="BTCUSDT", limit=50):
|
||||
"""Get trade history.
|
||||
|
||||
Args:
|
||||
session: Bybit HTTP session
|
||||
symbol: Trading symbol
|
||||
limit: Number of trades to return
|
||||
"""
|
||||
try:
|
||||
trades = session.get_executions(
|
||||
category="linear",
|
||||
symbol=symbol,
|
||||
limit=limit
|
||||
)
|
||||
logger.info(f"Trade history for {symbol}: {trades}")
|
||||
return trades
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting trade history for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create session (testnet by default)
|
||||
session = create_bybit_session(testnet=True)
|
||||
|
||||
# Get account info
|
||||
account_info = get_account_info(session)
|
||||
|
||||
# Get ticker
|
||||
ticker = get_ticker_info(session, "BTCUSDT")
|
||||
|
||||
# Get orderbook
|
||||
orderbook = get_orderbook(session, "BTCUSDT")
|
||||
|
||||
# Get open orders
|
||||
open_orders = get_open_orders(session)
|
||||
|
||||
# Get position
|
||||
position = get_position(session, "BTCUSDT")
|
||||
|
||||
# Note: Uncomment below to actually place orders (use with caution)
|
||||
# order = place_limit_order(session, "BTCUSDT", "Buy", "0.001", "30000")
|
||||
# market_order = place_market_order(session, "BTCUSDT", "Buy", "0.001")
|
||||
@@ -1,89 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example usage of the simplified data provider
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
"""Demonstrate the simplified data provider usage"""
|
||||
|
||||
# Initialize data provider (starts automatic maintenance)
|
||||
logger.info("Initializing DataProvider...")
|
||||
dp = DataProvider()
|
||||
|
||||
# Wait for initial data load (happens automatically in background)
|
||||
logger.info("Waiting for initial data load...")
|
||||
time.sleep(15) # Give it time to load data
|
||||
|
||||
# Example 1: Get cached historical data (no API calls)
|
||||
logger.info("\n=== Example 1: Getting Historical Data ===")
|
||||
eth_1m_data = dp.get_historical_data('ETH/USDT', '1m', limit=50)
|
||||
if eth_1m_data is not None:
|
||||
logger.info(f"ETH/USDT 1m data: {len(eth_1m_data)} candles")
|
||||
logger.info(f"Latest candle: {eth_1m_data.iloc[-1]['close']}")
|
||||
|
||||
# Example 2: Get current prices
|
||||
logger.info("\n=== Example 2: Current Prices ===")
|
||||
eth_price = dp.get_current_price('ETH/USDT')
|
||||
btc_price = dp.get_current_price('BTC/USDT')
|
||||
logger.info(f"ETH current price: ${eth_price}")
|
||||
logger.info(f"BTC current price: ${btc_price}")
|
||||
|
||||
# Example 3: Check cache status
|
||||
logger.info("\n=== Example 3: Cache Status ===")
|
||||
cache_summary = dp.get_cached_data_summary()
|
||||
for symbol in cache_summary['cached_data']:
|
||||
logger.info(f"\n{symbol}:")
|
||||
for timeframe, info in cache_summary['cached_data'][symbol].items():
|
||||
if 'candle_count' in info and info['candle_count'] > 0:
|
||||
logger.info(f" {timeframe}: {info['candle_count']} candles, latest: ${info['latest_price']}")
|
||||
else:
|
||||
logger.info(f" {timeframe}: {info.get('status', 'no data')}")
|
||||
|
||||
# Example 4: Multiple timeframe data
|
||||
logger.info("\n=== Example 4: Multiple Timeframes ===")
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
data = dp.get_historical_data('ETH/USDT', tf, limit=5)
|
||||
if data is not None and not data.empty:
|
||||
logger.info(f"ETH {tf}: {len(data)} candles, range: ${data['close'].min():.2f} - ${data['close'].max():.2f}")
|
||||
|
||||
# Example 5: Health check
|
||||
logger.info("\n=== Example 5: Health Check ===")
|
||||
health = dp.health_check()
|
||||
logger.info(f"Data maintenance active: {health['data_maintenance_active']}")
|
||||
logger.info(f"Symbols: {health['symbols']}")
|
||||
logger.info(f"Timeframes: {health['timeframes']}")
|
||||
|
||||
# Example 6: Wait and show automatic updates
|
||||
logger.info("\n=== Example 6: Automatic Updates ===")
|
||||
logger.info("Waiting 30 seconds to show automatic data updates...")
|
||||
|
||||
# Get initial timestamp
|
||||
initial_data = dp.get_historical_data('ETH/USDT', '1s', limit=1)
|
||||
initial_time = initial_data.index[-1] if initial_data is not None else None
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
# Check if data was updated
|
||||
updated_data = dp.get_historical_data('ETH/USDT', '1s', limit=1)
|
||||
updated_time = updated_data.index[-1] if updated_data is not None else None
|
||||
|
||||
if initial_time and updated_time and updated_time > initial_time:
|
||||
logger.info(f"✅ Data automatically updated! New timestamp: {updated_time}")
|
||||
else:
|
||||
logger.info("⏳ Data update in progress...")
|
||||
|
||||
# Clean shutdown
|
||||
logger.info("\n=== Shutting Down ===")
|
||||
dp.stop_automatic_data_maintenance()
|
||||
logger.info("DataProvider stopped successfully")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,331 +0,0 @@
|
||||
"""
|
||||
Kill Stale Processes
|
||||
|
||||
This script identifies and kills stale Python processes that might be causing
|
||||
the dashboard startup freeze. It looks for:
|
||||
1. Hanging dashboard processes
|
||||
2. Stale COB data collection threads
|
||||
3. Matplotlib GUI processes
|
||||
4. Blocked network connections
|
||||
|
||||
Usage:
|
||||
python kill_stale_processes.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import psutil
|
||||
import signal
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
def find_python_processes():
|
||||
"""Find all Python processes"""
|
||||
python_processes = []
|
||||
|
||||
try:
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline', 'create_time', 'status']):
|
||||
try:
|
||||
if proc.info['name'] and 'python' in proc.info['name'].lower():
|
||||
# Get command line to identify dashboard processes
|
||||
cmdline = ' '.join(proc.info['cmdline']) if proc.info['cmdline'] else ''
|
||||
|
||||
python_processes.append({
|
||||
'pid': proc.info['pid'],
|
||||
'name': proc.info['name'],
|
||||
'cmdline': cmdline,
|
||||
'create_time': proc.info['create_time'],
|
||||
'status': proc.info['status'],
|
||||
'process': proc
|
||||
})
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error finding Python processes: {e}")
|
||||
|
||||
return python_processes
|
||||
|
||||
def identify_dashboard_processes(python_processes):
|
||||
"""Identify processes related to the dashboard"""
|
||||
dashboard_processes = []
|
||||
|
||||
dashboard_keywords = [
|
||||
'clean_dashboard',
|
||||
'run_clean_dashboard',
|
||||
'dashboard',
|
||||
'trading',
|
||||
'cob_data',
|
||||
'orchestrator',
|
||||
'data_provider'
|
||||
]
|
||||
|
||||
for proc_info in python_processes:
|
||||
cmdline = proc_info['cmdline'].lower()
|
||||
|
||||
# Check if this is a dashboard-related process
|
||||
is_dashboard = any(keyword in cmdline for keyword in dashboard_keywords)
|
||||
|
||||
if is_dashboard:
|
||||
dashboard_processes.append(proc_info)
|
||||
|
||||
return dashboard_processes
|
||||
|
||||
def identify_stale_processes(python_processes):
|
||||
"""Identify potentially stale processes"""
|
||||
stale_processes = []
|
||||
current_time = time.time()
|
||||
|
||||
for proc_info in python_processes:
|
||||
try:
|
||||
proc = proc_info['process']
|
||||
|
||||
# Check if process is in a problematic state
|
||||
if proc_info['status'] in ['zombie', 'stopped']:
|
||||
stale_processes.append({
|
||||
**proc_info,
|
||||
'reason': f"Process status: {proc_info['status']}"
|
||||
})
|
||||
continue
|
||||
|
||||
# Check if process has been running for a very long time without activity
|
||||
age_hours = (current_time - proc_info['create_time']) / 3600
|
||||
if age_hours > 24: # Running for more than 24 hours
|
||||
try:
|
||||
# Check CPU usage
|
||||
cpu_percent = proc.cpu_percent(interval=1)
|
||||
if cpu_percent < 0.1: # Very low CPU usage
|
||||
stale_processes.append({
|
||||
**proc_info,
|
||||
'reason': f"Old process ({age_hours:.1f}h) with low CPU usage ({cpu_percent:.1f}%)"
|
||||
})
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check for processes with high memory usage but no activity
|
||||
try:
|
||||
memory_info = proc.memory_info()
|
||||
memory_mb = memory_info.rss / 1024 / 1024
|
||||
|
||||
if memory_mb > 500: # More than 500MB
|
||||
cpu_percent = proc.cpu_percent(interval=1)
|
||||
if cpu_percent < 0.1:
|
||||
stale_processes.append({
|
||||
**proc_info,
|
||||
'reason': f"High memory usage ({memory_mb:.1f}MB) with low CPU usage ({cpu_percent:.1f}%)"
|
||||
})
|
||||
except:
|
||||
pass
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
|
||||
return stale_processes
|
||||
|
||||
def kill_process_safely(proc_info, force=False):
|
||||
"""Kill a process safely"""
|
||||
try:
|
||||
proc = proc_info['process']
|
||||
pid = proc_info['pid']
|
||||
|
||||
print(f"Attempting to {'force kill' if force else 'terminate'} PID {pid}: {proc_info['name']}")
|
||||
|
||||
if force:
|
||||
# Force kill
|
||||
if os.name == 'nt': # Windows
|
||||
os.system(f"taskkill /F /PID {pid}")
|
||||
else: # Unix/Linux
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
else:
|
||||
# Graceful termination
|
||||
proc.terminate()
|
||||
|
||||
# Wait for termination
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
print(f"✅ Process {pid} terminated gracefully")
|
||||
return True
|
||||
except psutil.TimeoutExpired:
|
||||
print(f"⚠️ Process {pid} didn't terminate gracefully, will force kill")
|
||||
return False
|
||||
|
||||
print(f"✅ Process {pid} killed")
|
||||
return True
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
|
||||
print(f"⚠️ Could not kill process {proc_info['pid']}: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Error killing process {proc_info['pid']}: {e}")
|
||||
return False
|
||||
|
||||
def check_port_usage():
|
||||
"""Check if dashboard port is in use"""
|
||||
try:
|
||||
import socket
|
||||
|
||||
# Check if port 8050 is in use
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
result = sock.connect_ex(('localhost', 8050))
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
print("⚠️ Port 8050 is in use")
|
||||
|
||||
# Find process using the port
|
||||
for conn in psutil.net_connections():
|
||||
if conn.laddr.port == 8050:
|
||||
try:
|
||||
proc = psutil.Process(conn.pid)
|
||||
print(f" Port 8050 used by PID {conn.pid}: {proc.name()}")
|
||||
return conn.pid
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
print("✅ Port 8050 is available")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error checking port usage: {e}")
|
||||
return None
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
print("🔍 Stale Process Killer")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# Step 1: Find all Python processes
|
||||
print("🔍 Finding Python processes...")
|
||||
python_processes = find_python_processes()
|
||||
print(f"Found {len(python_processes)} Python processes")
|
||||
|
||||
# Step 2: Identify dashboard processes
|
||||
print("\n🎯 Identifying dashboard processes...")
|
||||
dashboard_processes = identify_dashboard_processes(python_processes)
|
||||
|
||||
if dashboard_processes:
|
||||
print(f"Found {len(dashboard_processes)} dashboard-related processes:")
|
||||
for proc in dashboard_processes:
|
||||
age_hours = (time.time() - proc['create_time']) / 3600
|
||||
print(f" PID {proc['pid']}: {proc['name']} (age: {age_hours:.1f}h, status: {proc['status']})")
|
||||
print(f" Command: {proc['cmdline'][:100]}...")
|
||||
else:
|
||||
print("No dashboard processes found")
|
||||
|
||||
# Step 3: Check port usage
|
||||
print("\n🌐 Checking port usage...")
|
||||
port_pid = check_port_usage()
|
||||
|
||||
# Step 4: Identify stale processes
|
||||
print("\n🕵️ Identifying stale processes...")
|
||||
stale_processes = identify_stale_processes(python_processes)
|
||||
|
||||
if stale_processes:
|
||||
print(f"Found {len(stale_processes)} potentially stale processes:")
|
||||
for proc in stale_processes:
|
||||
print(f" PID {proc['pid']}: {proc['name']} - {proc['reason']}")
|
||||
else:
|
||||
print("No stale processes identified")
|
||||
|
||||
# Step 5: Ask user what to do
|
||||
if dashboard_processes or stale_processes or port_pid:
|
||||
print("\n🤔 What would you like to do?")
|
||||
print("1. Kill all dashboard processes")
|
||||
print("2. Kill only stale processes")
|
||||
print("3. Kill process using port 8050")
|
||||
print("4. Kill all identified processes")
|
||||
print("5. Show process details and exit")
|
||||
print("6. Exit without killing anything")
|
||||
|
||||
try:
|
||||
choice = input("\nEnter your choice (1-6): ").strip()
|
||||
|
||||
if choice == '1':
|
||||
# Kill dashboard processes
|
||||
print("\n🔫 Killing dashboard processes...")
|
||||
for proc in dashboard_processes:
|
||||
if not kill_process_safely(proc):
|
||||
kill_process_safely(proc, force=True)
|
||||
|
||||
elif choice == '2':
|
||||
# Kill stale processes
|
||||
print("\n🔫 Killing stale processes...")
|
||||
for proc in stale_processes:
|
||||
if not kill_process_safely(proc):
|
||||
kill_process_safely(proc, force=True)
|
||||
|
||||
elif choice == '3':
|
||||
# Kill process using port 8050
|
||||
if port_pid:
|
||||
print(f"\n🔫 Killing process using port 8050 (PID {port_pid})...")
|
||||
try:
|
||||
proc = psutil.Process(port_pid)
|
||||
proc_info = {
|
||||
'pid': port_pid,
|
||||
'name': proc.name(),
|
||||
'process': proc
|
||||
}
|
||||
if not kill_process_safely(proc_info):
|
||||
kill_process_safely(proc_info, force=True)
|
||||
except:
|
||||
print(f"❌ Could not kill process {port_pid}")
|
||||
else:
|
||||
print("No process found using port 8050")
|
||||
|
||||
elif choice == '4':
|
||||
# Kill all identified processes
|
||||
print("\n🔫 Killing all identified processes...")
|
||||
all_processes = dashboard_processes + stale_processes
|
||||
if port_pid:
|
||||
try:
|
||||
proc = psutil.Process(port_pid)
|
||||
all_processes.append({
|
||||
'pid': port_pid,
|
||||
'name': proc.name(),
|
||||
'process': proc
|
||||
})
|
||||
except:
|
||||
pass
|
||||
|
||||
for proc in all_processes:
|
||||
if not kill_process_safely(proc):
|
||||
kill_process_safely(proc, force=True)
|
||||
|
||||
elif choice == '5':
|
||||
# Show details
|
||||
print("\n📋 Process Details:")
|
||||
all_processes = dashboard_processes + stale_processes
|
||||
for proc in all_processes:
|
||||
print(f"\nPID {proc['pid']}: {proc['name']}")
|
||||
print(f" Status: {proc['status']}")
|
||||
print(f" Command: {proc['cmdline']}")
|
||||
print(f" Created: {datetime.fromtimestamp(proc['create_time'])}")
|
||||
|
||||
elif choice == '6':
|
||||
print("👋 Exiting without killing processes")
|
||||
|
||||
else:
|
||||
print("❌ Invalid choice")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Cancelled by user")
|
||||
else:
|
||||
print("\n✅ No problematic processes found")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("💡 After killing processes, you can try:")
|
||||
print(" python run_lightweight_dashboard.py")
|
||||
print(" or")
|
||||
print(" python fix_startup_freeze.py")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error in main function: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
@@ -1,41 +0,0 @@
|
||||
"""
|
||||
Launch training with optimized short-term models only
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import load_config
|
||||
from core.training import TrainingManager
|
||||
from core.models import OptimizedShortTermModel
|
||||
|
||||
def main():
|
||||
"""Main training function using only optimized models"""
|
||||
config = load_config()
|
||||
|
||||
# Initialize model
|
||||
model = OptimizedShortTermModel()
|
||||
|
||||
# Load best model if exists
|
||||
best_model_path = config.model_paths.get('ticks_model')
|
||||
if os.path.exists(best_model_path):
|
||||
model.load_state_dict(torch.load(best_model_path))
|
||||
|
||||
# Initialize training
|
||||
trainer = TrainingManager(
|
||||
model=model,
|
||||
config=config,
|
||||
use_ticks=True,
|
||||
use_realtime=True
|
||||
)
|
||||
|
||||
# Start training
|
||||
trainer.train()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
458
main.py
458
main.py
@@ -1,458 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Streamlined Trading System - Web Dashboard + Training
|
||||
|
||||
Integrated system with both training loop and web dashboard:
|
||||
- Training Pipeline: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution
|
||||
- Web Dashboard: Real-time monitoring and control interface
|
||||
- 2-Action System: BUY/SELL with intelligent position management
|
||||
- Always invested approach with smart risk/reward setup detection
|
||||
|
||||
Usage:
|
||||
python main.py [--symbol ETH/USDT] [--port 8050]
|
||||
"""
|
||||
|
||||
import os
|
||||
# Fix OpenMP library conflicts before importing other modules
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
import time
|
||||
from safe_logging import setup_safe_logging
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import get_config, setup_logging, Config
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def run_web_dashboard():
|
||||
"""Run the streamlined web dashboard with 2-action system and always-invested approach"""
|
||||
try:
|
||||
logger.info("Starting Streamlined Trading Dashboard...")
|
||||
logger.info("2-Action System: BUY/SELL with intelligent position management")
|
||||
logger.info("Always Invested Approach: Smart risk/reward setup detection")
|
||||
logger.info("Integrated Training Pipeline: Live data -> Models -> Trading")
|
||||
|
||||
# Get configuration
|
||||
config = get_config()
|
||||
|
||||
# Initialize core components for streamlined pipeline
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Create standardized data provider (validated BaseDataInput, pivots, COB)
|
||||
data_provider = StandardizedDataProvider()
|
||||
|
||||
# Start real-time streaming for BOM caching
|
||||
try:
|
||||
await data_provider.start_real_time_streaming()
|
||||
logger.info("[SUCCESS] Real-time data streaming started for BOM caching")
|
||||
except Exception as e:
|
||||
logger.warning(f"[WARNING] Real-time streaming failed: {e}")
|
||||
|
||||
# Verify data connection with retry mechanism
|
||||
logger.info("[DATA] Verifying live data connection...")
|
||||
symbol = config.get('symbols', ['ETH/USDT'])[0]
|
||||
|
||||
# Wait for data provider to initialize and fetch initial data
|
||||
max_retries = 10
|
||||
retry_delay = 2
|
||||
|
||||
for attempt in range(max_retries):
|
||||
test_df = data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if test_df is not None and len(test_df) > 0:
|
||||
logger.info("[SUCCESS] Data connection verified")
|
||||
logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation")
|
||||
break
|
||||
else:
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"[DATA] Waiting for data provider to initialize... (attempt {attempt + 1}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.warning("[WARNING] Data connection verification failed, but continuing with system startup")
|
||||
logger.warning("The system will attempt to fetch data as needed during operation")
|
||||
|
||||
# Load model registry for integrated pipeline
|
||||
try:
|
||||
from models import get_model_registry
|
||||
model_registry = {} # Use simple dict for now
|
||||
logger.info("[MODELS] Model registry initialized for training")
|
||||
except ImportError:
|
||||
model_registry = {}
|
||||
logger.warning("Model registry not available, using empty registry")
|
||||
|
||||
# Initialize checkpoint management
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
training_integration = get_training_integration()
|
||||
logger.info("Checkpoint management initialized for training pipeline")
|
||||
|
||||
# Create unified orchestrator with full ML pipeline
|
||||
orchestrator = TradingOrchestrator(
|
||||
data_provider=data_provider,
|
||||
enhanced_rl_training=True,
|
||||
model_registry={}
|
||||
)
|
||||
logger.info("Unified Trading Orchestrator initialized with full ML pipeline")
|
||||
logger.info("Data Bus -> Models (DQN + CNN + COB) -> Decision Model -> Trading Signals")
|
||||
|
||||
# Checkpoint management will be handled in the training loop
|
||||
logger.info("Checkpoint management will be initialized in training loop")
|
||||
|
||||
# Unified orchestrator includes COB integration as part of data bus
|
||||
logger.info("COB Integration available - feeds into unified data bus")
|
||||
|
||||
# Create trading executor for live execution
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Start the training and monitoring loop
|
||||
logger.info(f"Starting Enhanced Training Pipeline")
|
||||
logger.info("Live Data Processing: ENABLED")
|
||||
logger.info("COB Integration: ENABLED (Real-time market microstructure)")
|
||||
logger.info("Integrated CNN Training: ENABLED")
|
||||
logger.info("Integrated RL Training: ENABLED")
|
||||
logger.info("Real-time Indicators & Pivots: ENABLED")
|
||||
logger.info("Live Trading Execution: ENABLED")
|
||||
logger.info("2-Action System: BUY/SELL with position intelligence")
|
||||
logger.info("Always Invested: Different thresholds for entry/exit")
|
||||
logger.info("Pipeline: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
|
||||
logger.info("Starting training loop...")
|
||||
|
||||
# Start the training loop
|
||||
logger.info("About to start training loop...")
|
||||
await start_training_loop(orchestrator, trading_executor)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streamlined dashboard: {e}")
|
||||
logger.error("Training stopped")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def start_web_ui(port=8051):
|
||||
"""Start the main TradingDashboard UI in a separate thread"""
|
||||
try:
|
||||
logger.info("=" * 50)
|
||||
logger.info("Starting Main Trading Dashboard UI...")
|
||||
logger.info(f"Trading Dashboard: http://127.0.0.1:{port}")
|
||||
logger.info("COB Integration: ENABLED (Real-time order book visualization)")
|
||||
logger.info("=" * 50)
|
||||
|
||||
# Import and create the Clean Trading Dashboard
|
||||
from web.clean_dashboard import CleanTradingDashboard
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Initialize components for the dashboard
|
||||
config = get_config()
|
||||
data_provider = StandardizedDataProvider()
|
||||
|
||||
# Start real-time streaming for BOM caching (non-blocking)
|
||||
try:
|
||||
import threading
|
||||
def start_streaming():
|
||||
import asyncio
|
||||
asyncio.run(data_provider.start_real_time_streaming())
|
||||
|
||||
streaming_thread = threading.Thread(target=start_streaming, daemon=True)
|
||||
streaming_thread.start()
|
||||
logger.info("[SUCCESS] Real-time streaming thread started for dashboard")
|
||||
except Exception as e:
|
||||
logger.warning(f"[WARNING] Dashboard streaming setup failed: {e}")
|
||||
|
||||
# Load model registry for enhanced features
|
||||
try:
|
||||
from models import get_model_registry
|
||||
model_registry = {} # Use simple dict for now
|
||||
except ImportError:
|
||||
model_registry = {}
|
||||
|
||||
# Initialize checkpoint management for dashboard
|
||||
dashboard_checkpoint_manager = get_checkpoint_manager()
|
||||
dashboard_training_integration = get_training_integration()
|
||||
|
||||
# Create unified orchestrator for the dashboard
|
||||
dashboard_orchestrator = TradingOrchestrator(
|
||||
data_provider=data_provider,
|
||||
enhanced_rl_training=True,
|
||||
model_registry={}
|
||||
)
|
||||
|
||||
trading_executor = TradingExecutor("config.yaml")
|
||||
|
||||
# Create the clean trading dashboard with enhanced features
|
||||
dashboard = CleanTradingDashboard(
|
||||
data_provider=data_provider,
|
||||
orchestrator=dashboard_orchestrator,
|
||||
trading_executor=trading_executor
|
||||
)
|
||||
|
||||
logger.info("Clean Trading Dashboard created successfully")
|
||||
logger.info("Features: Live trading, COB visualization, ML pipeline monitoring, Position management")
|
||||
logger.info("✅ Unified orchestrator with decision-making model and checkpoint management")
|
||||
|
||||
# Run the dashboard server (COB integration will start automatically)
|
||||
dashboard.run_server(host='127.0.0.1', port=port, debug=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting main trading dashboard UI: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def start_training_loop(orchestrator, trading_executor):
|
||||
"""Start the main training and monitoring loop with checkpoint management"""
|
||||
logger.info("=" * 70)
|
||||
logger.info("STARTING ENHANCED TRAINING LOOP WITH COB INTEGRATION")
|
||||
logger.info("=" * 70)
|
||||
|
||||
logger.info("Training loop function entered successfully")
|
||||
|
||||
# Initialize checkpoint management for training loop
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
training_integration = get_training_integration()
|
||||
|
||||
# Training statistics for checkpoint management
|
||||
training_stats = {
|
||||
'iteration_count': 0,
|
||||
'total_decisions': 0,
|
||||
'successful_trades': 0,
|
||||
'best_performance': 0.0,
|
||||
'last_checkpoint_iteration': 0
|
||||
}
|
||||
|
||||
try:
|
||||
# Start real-time processing (Basic orchestrator doesn't have this method)
|
||||
logger.info("Checking for real-time processing capabilities...")
|
||||
try:
|
||||
if hasattr(orchestrator, 'start_realtime_processing'):
|
||||
logger.info("Starting real-time processing...")
|
||||
await orchestrator.start_realtime_processing()
|
||||
logger.info("Real-time processing started")
|
||||
else:
|
||||
logger.info("Basic orchestrator - no real-time processing method available")
|
||||
except Exception as e:
|
||||
logger.warning(f"Real-time processing not available: {e}")
|
||||
|
||||
logger.info("About to enter main training loop...")
|
||||
|
||||
# Main training loop
|
||||
iteration = 0
|
||||
while True:
|
||||
iteration += 1
|
||||
training_stats['iteration_count'] = iteration
|
||||
|
||||
logger.info(f"Training iteration {iteration}")
|
||||
|
||||
# Make trading decisions using Basic orchestrator (single symbol method)
|
||||
decisions = {}
|
||||
symbols = ['ETH/USDT'] # Focus on ETH only for training
|
||||
|
||||
for symbol in symbols:
|
||||
try:
|
||||
decision = await orchestrator.make_trading_decision(symbol)
|
||||
decisions[symbol] = decision
|
||||
except Exception as e:
|
||||
logger.warning(f"Error making decision for {symbol}: {e}")
|
||||
decisions[symbol] = None
|
||||
|
||||
# Process decisions and collect training metrics
|
||||
iteration_decisions = 0
|
||||
iteration_performance = 0.0
|
||||
|
||||
# Log decisions and performance
|
||||
for symbol, decision in decisions.items():
|
||||
if decision:
|
||||
iteration_decisions += 1
|
||||
logger.info(f"{symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
|
||||
|
||||
# Track performance for checkpoint management
|
||||
iteration_performance += decision.confidence
|
||||
|
||||
# Execute if confidence is high enough
|
||||
if decision.confidence > 0.7:
|
||||
logger.info(f"Executing {symbol}: {decision.action}")
|
||||
training_stats['successful_trades'] += 1
|
||||
# trading_executor.execute_action(decision)
|
||||
|
||||
# Update training statistics
|
||||
training_stats['total_decisions'] += iteration_decisions
|
||||
if iteration_performance > training_stats['best_performance']:
|
||||
training_stats['best_performance'] = iteration_performance
|
||||
|
||||
# Save checkpoint every 50 iterations or when performance improves significantly
|
||||
should_save_checkpoint = (
|
||||
iteration % 50 == 0 or # Regular interval
|
||||
iteration_performance > training_stats['best_performance'] * 1.1 or # 10% improvement
|
||||
iteration - training_stats['last_checkpoint_iteration'] >= 100 # Force save every 100 iterations
|
||||
)
|
||||
|
||||
if should_save_checkpoint:
|
||||
try:
|
||||
# Create performance metrics for checkpoint
|
||||
performance_metrics = {
|
||||
'avg_confidence': iteration_performance / max(iteration_decisions, 1),
|
||||
'success_rate': training_stats['successful_trades'] / max(training_stats['total_decisions'], 1),
|
||||
'total_decisions': training_stats['total_decisions'],
|
||||
'iteration': iteration
|
||||
}
|
||||
|
||||
# Save orchestrator state (if it has models)
|
||||
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
|
||||
saved = orchestrator.rl_agent.save_checkpoint(iteration_performance)
|
||||
if saved:
|
||||
logger.info(f"✅ RL Agent checkpoint saved at iteration {iteration}")
|
||||
|
||||
if hasattr(orchestrator, 'cnn_model') and orchestrator.cnn_model:
|
||||
# Simulate CNN checkpoint save
|
||||
logger.info(f"✅ CNN Model training state saved at iteration {iteration}")
|
||||
|
||||
if hasattr(orchestrator, 'extrema_trainer') and orchestrator.extrema_trainer:
|
||||
saved = orchestrator.extrema_trainer.save_checkpoint()
|
||||
if saved:
|
||||
logger.info(f"✅ ExtremaTrainer checkpoint saved at iteration {iteration}")
|
||||
|
||||
training_stats['last_checkpoint_iteration'] = iteration
|
||||
logger.info(f"📊 Checkpoint management completed for iteration {iteration}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Checkpoint saving failed at iteration {iteration}: {e}")
|
||||
|
||||
# Log performance metrics every 10 iterations
|
||||
if iteration % 10 == 0:
|
||||
metrics = orchestrator.get_performance_metrics()
|
||||
logger.info(f"Performance metrics: {metrics}")
|
||||
|
||||
# Log training statistics
|
||||
logger.info(f"Training stats: {training_stats}")
|
||||
|
||||
# Log checkpoint statistics
|
||||
checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
|
||||
logger.info(f"Checkpoints: {checkpoint_stats['total_checkpoints']} total, "
|
||||
f"{checkpoint_stats['total_size_mb']:.2f} MB")
|
||||
|
||||
# Log COB integration status (Basic orchestrator doesn't have COB features)
|
||||
symbols = getattr(orchestrator, 'symbols', ['ETH/USDT'])
|
||||
if hasattr(orchestrator, 'latest_cob_features'):
|
||||
for symbol in symbols:
|
||||
cob_features = orchestrator.latest_cob_features.get(symbol)
|
||||
cob_state = orchestrator.latest_cob_state.get(symbol)
|
||||
if cob_features is not None:
|
||||
logger.info(f"{symbol} COB: CNN features {cob_features.shape}, DQN state {cob_state.shape if cob_state is not None else 'None'}")
|
||||
else:
|
||||
logger.debug("Basic orchestrator - no COB integration features available")
|
||||
|
||||
# Sleep between iterations
|
||||
await asyncio.sleep(5) # 5 second intervals
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training loop: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
# Save final checkpoints before shutdown
|
||||
try:
|
||||
logger.info("Saving final checkpoints before shutdown...")
|
||||
|
||||
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
|
||||
orchestrator.rl_agent.save_checkpoint(0.0, force_save=True)
|
||||
logger.info("✅ Final RL Agent checkpoint saved")
|
||||
|
||||
if hasattr(orchestrator, 'extrema_trainer') and orchestrator.extrema_trainer:
|
||||
orchestrator.extrema_trainer.save_checkpoint(force_save=True)
|
||||
logger.info("✅ Final ExtremaTrainer checkpoint saved")
|
||||
|
||||
# Log final checkpoint statistics
|
||||
final_stats = checkpoint_manager.get_checkpoint_stats()
|
||||
logger.info(f"📊 Final checkpoint stats: {final_stats['total_checkpoints']} checkpoints, "
|
||||
f"{final_stats['total_size_mb']:.2f} MB total")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error saving final checkpoints: {e}")
|
||||
|
||||
# Stop real-time processing (Basic orchestrator doesn't have these methods)
|
||||
try:
|
||||
if hasattr(orchestrator, 'stop_realtime_processing'):
|
||||
await orchestrator.stop_realtime_processing()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping real-time processing: {e}")
|
||||
|
||||
try:
|
||||
if hasattr(orchestrator, 'stop_cob_integration'):
|
||||
await orchestrator.stop_cob_integration()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping COB integration: {e}")
|
||||
logger.info("Training loop stopped with checkpoint management")
|
||||
|
||||
async def main():
|
||||
"""Main entry point with both training loop and web dashboard"""
|
||||
parser = argparse.ArgumentParser(description='Streamlined Trading System - Training + Web Dashboard')
|
||||
parser.add_argument('--symbol', type=str, default='ETH/USDT',
|
||||
help='Primary trading symbol (default: ETH/USDT)')
|
||||
parser.add_argument('--port', type=int, default=8050,
|
||||
help='Web dashboard port (default: 8050)')
|
||||
parser.add_argument('--debug', action='store_true',
|
||||
help='Enable debug mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging and ensure directories exist
|
||||
Path("logs").mkdir(exist_ok=True)
|
||||
Path("NN/models/saved").mkdir(parents=True, exist_ok=True)
|
||||
setup_safe_logging()
|
||||
|
||||
try:
|
||||
logger.info("=" * 70)
|
||||
logger.info("STREAMLINED TRADING SYSTEM - TRAINING + MAIN DASHBOARD")
|
||||
logger.info(f"Primary Symbol: {args.symbol}")
|
||||
logger.info(f"Training Port: {args.port}")
|
||||
logger.info(f"Main Trading Dashboard: http://127.0.0.1:{args.port}")
|
||||
logger.info("2-Action System: BUY/SELL with intelligent position management")
|
||||
logger.info("Always Invested: Learning to spot high risk/reward setups")
|
||||
logger.info("Flow: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
|
||||
logger.info("Main Dashboard: Live trading, RL monitoring, Position management")
|
||||
logger.info("🔄 Checkpoint Management: Automatic training state persistence")
|
||||
# logger.info("📊 W&B Integration: Optional experiment tracking")
|
||||
logger.info("💾 Model Rotation: Keep best 5 checkpoints per model")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Start main trading dashboard UI in a separate thread
|
||||
web_thread = Thread(target=lambda: start_web_ui(args.port), daemon=True)
|
||||
web_thread.start()
|
||||
logger.info("Main trading dashboard UI thread started")
|
||||
|
||||
# Give web UI time to start
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Run the training loop (this will run indefinitely)
|
||||
await run_web_dashboard()
|
||||
|
||||
logger.info("[SUCCESS] Operation completed successfully!")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("System shutdown requested by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(asyncio.run(main()))
|
||||
133
main_clean.py
133
main_clean.py
@@ -1,133 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Clean Main Entry Point for Enhanced Trading Dashboard
|
||||
|
||||
This is the main entry point that safely launches the clean dashboard
|
||||
with proper error handling and optimized settings.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
from typing import Optional
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# Import core components
|
||||
try:
|
||||
from core.config import setup_logging
|
||||
from core.data_provider import DataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
from web.clean_dashboard import create_clean_dashboard
|
||||
except ImportError as e:
|
||||
print(f"Error importing core modules: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def create_safe_orchestrator() -> Optional[TradingOrchestrator]:
|
||||
"""Create orchestrator with safe CNN model handling"""
|
||||
try:
|
||||
# Create orchestrator with basic configuration (uses correct constructor parameters)
|
||||
orchestrator = TradingOrchestrator(
|
||||
enhanced_rl_training=False # Disable problematic training initially
|
||||
)
|
||||
|
||||
logger.info("Trading orchestrator created successfully")
|
||||
return orchestrator
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating orchestrator: {e}")
|
||||
logger.info("Continuing without orchestrator - dashboard will run in view-only mode")
|
||||
return None
|
||||
|
||||
def create_safe_trading_executor() -> Optional[TradingExecutor]:
|
||||
"""Create trading executor with safe configuration"""
|
||||
try:
|
||||
# TradingExecutor only accepts config_path parameter
|
||||
trading_executor = TradingExecutor(config_path="config.yaml")
|
||||
|
||||
logger.info("Trading executor created successfully")
|
||||
return trading_executor
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating trading executor: {e}")
|
||||
logger.info("Continuing without trading executor - dashboard will be view-only")
|
||||
return None
|
||||
|
||||
def main():
|
||||
"""Main entry point for clean dashboard"""
|
||||
parser = argparse.ArgumentParser(description='Enhanced Trading Dashboard')
|
||||
parser.add_argument('--port', type=int, default=8050, help='Dashboard port (default: 8050)')
|
||||
parser.add_argument('--host', type=str, default='127.0.0.1', help='Dashboard host (default: 127.0.0.1)')
|
||||
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
|
||||
parser.add_argument('--no-training', action='store_true', help='Disable ML training for stability')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
try:
|
||||
setup_logging()
|
||||
logger.info("================================================================================")
|
||||
logger.info("CLEAN ENHANCED TRADING DASHBOARD")
|
||||
logger.info("================================================================================")
|
||||
logger.info(f"Starting on http://{args.host}:{args.port}")
|
||||
logger.info("Features: Real-time Charts, Trading Interface, Model Monitoring")
|
||||
logger.info("================================================================================")
|
||||
except Exception as e:
|
||||
print(f"Error setting up logging: {e}")
|
||||
# Continue without logging setup
|
||||
|
||||
# Set environment variables for optimization
|
||||
os.environ['ENABLE_REALTIME_CHARTS'] = '1'
|
||||
if not args.no_training:
|
||||
os.environ['ENABLE_NN_MODELS'] = '1'
|
||||
|
||||
try:
|
||||
# Create data provider
|
||||
logger.info("Initializing data provider...")
|
||||
data_provider = DataProvider(symbols=['ETH/USDT', 'BTC/USDT'])
|
||||
|
||||
# Create orchestrator (with safe CNN handling)
|
||||
logger.info("Initializing trading orchestrator...")
|
||||
orchestrator = create_safe_orchestrator()
|
||||
|
||||
# Create trading executor
|
||||
logger.info("Initializing trading executor...")
|
||||
trading_executor = create_safe_trading_executor()
|
||||
|
||||
# Create and run dashboard
|
||||
logger.info("Creating clean dashboard...")
|
||||
dashboard = create_clean_dashboard(
|
||||
data_provider=data_provider,
|
||||
orchestrator=orchestrator,
|
||||
trading_executor=trading_executor
|
||||
)
|
||||
|
||||
# Start the dashboard server
|
||||
logger.info(f"Starting dashboard server on http://{args.host}:{args.port}")
|
||||
dashboard.run_server(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
debug=args.debug
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Dashboard stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error running dashboard: {e}")
|
||||
|
||||
# Try to provide helpful error message
|
||||
if "model.fit" in str(e) or "CNN" in str(e):
|
||||
logger.error("CNN model training error detected. Try running with --no-training flag")
|
||||
logger.error("Command: python main_clean.py --no-training")
|
||||
|
||||
sys.exit(1)
|
||||
finally:
|
||||
logger.info("Clean dashboard shutdown complete")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,204 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Migrate Existing Models to Checkpoint System
|
||||
|
||||
This script migrates existing model files to the new checkpoint system
|
||||
and creates proper database metadata entries.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from utils.database_manager import get_database_manager, CheckpointMetadata
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
from utils.text_logger import get_text_logger
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def migrate_existing_models():
|
||||
"""Migrate existing models to checkpoint system"""
|
||||
print("=== Migrating Existing Models to Checkpoint System ===")
|
||||
|
||||
db_manager = get_database_manager()
|
||||
text_logger = get_text_logger()
|
||||
|
||||
# Define model migrations
|
||||
migrations = [
|
||||
{
|
||||
'model_name': 'enhanced_cnn',
|
||||
'model_type': 'cnn',
|
||||
'source_file': 'models/enhanced_cnn/ETH_USDT_cnn.pth',
|
||||
'performance_metrics': {'loss': 0.0187, 'accuracy': 0.92},
|
||||
'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True}
|
||||
},
|
||||
{
|
||||
'model_name': 'dqn_agent',
|
||||
'model_type': 'rl',
|
||||
'source_file': 'models/enhanced_rl/ETH_USDT_dqn_policy.pth',
|
||||
'performance_metrics': {'loss': 0.0234, 'reward': 145.2},
|
||||
'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True, 'type': 'policy'}
|
||||
},
|
||||
{
|
||||
'model_name': 'dqn_agent_target',
|
||||
'model_type': 'rl',
|
||||
'source_file': 'models/enhanced_rl/ETH_USDT_dqn_target.pth',
|
||||
'performance_metrics': {'loss': 0.0234, 'reward': 145.2},
|
||||
'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True, 'type': 'target'}
|
||||
}
|
||||
]
|
||||
|
||||
migrated_count = 0
|
||||
|
||||
for migration in migrations:
|
||||
source_path = Path(migration['source_file'])
|
||||
|
||||
if not source_path.exists():
|
||||
logger.warning(f"Source file not found: {source_path}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = Path("models/checkpoints") / migration['model_name']
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create checkpoint filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_id = f"{migration['model_name']}_{timestamp}"
|
||||
checkpoint_file = checkpoint_dir / f"{checkpoint_id}.pt"
|
||||
|
||||
# Copy model file to checkpoint location
|
||||
shutil.copy2(source_path, checkpoint_file)
|
||||
logger.info(f"Copied {source_path} -> {checkpoint_file}")
|
||||
|
||||
# Calculate file size
|
||||
file_size_mb = checkpoint_file.stat().st_size / (1024 * 1024)
|
||||
|
||||
# Create checkpoint metadata
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
model_name=migration['model_name'],
|
||||
model_type=migration['model_type'],
|
||||
timestamp=datetime.now(),
|
||||
performance_metrics=migration['performance_metrics'],
|
||||
training_metadata=migration['training_metadata'],
|
||||
file_path=str(checkpoint_file),
|
||||
file_size_mb=file_size_mb,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
# Save to database
|
||||
if db_manager.save_checkpoint_metadata(metadata):
|
||||
logger.info(f"Saved checkpoint metadata: {checkpoint_id}")
|
||||
|
||||
# Log to text file
|
||||
text_logger.log_checkpoint_event(
|
||||
model_name=migration['model_name'],
|
||||
event_type="MIGRATED",
|
||||
checkpoint_id=checkpoint_id,
|
||||
details=f"from {source_path}, size={file_size_mb:.1f}MB"
|
||||
)
|
||||
|
||||
migrated_count += 1
|
||||
else:
|
||||
logger.error(f"Failed to save checkpoint metadata: {checkpoint_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate {migration['model_name']}: {e}")
|
||||
|
||||
print(f"\nMigration completed: {migrated_count} models migrated")
|
||||
|
||||
# Show current checkpoint status
|
||||
print("\n=== Current Checkpoint Status ===")
|
||||
for model_name in ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target']:
|
||||
checkpoints = db_manager.list_checkpoints(model_name)
|
||||
if checkpoints:
|
||||
print(f"{model_name}: {len(checkpoints)} checkpoints")
|
||||
for checkpoint in checkpoints[:2]: # Show first 2
|
||||
print(f" - {checkpoint.checkpoint_id} ({checkpoint.file_size_mb:.1f}MB)")
|
||||
else:
|
||||
print(f"{model_name}: No checkpoints")
|
||||
|
||||
def verify_checkpoint_system():
|
||||
"""Verify the checkpoint system is working"""
|
||||
print("\n=== Verifying Checkpoint System ===")
|
||||
|
||||
db_manager = get_database_manager()
|
||||
|
||||
# Test loading checkpoints
|
||||
for model_name in ['dqn_agent', 'enhanced_cnn']:
|
||||
metadata = db_manager.get_best_checkpoint_metadata(model_name)
|
||||
if metadata:
|
||||
file_exists = Path(metadata.file_path).exists()
|
||||
print(f"{model_name}: ✅ Metadata found, File exists: {file_exists}")
|
||||
if file_exists:
|
||||
print(f" -> {metadata.checkpoint_id} ({metadata.file_size_mb:.1f}MB)")
|
||||
else:
|
||||
print(f" -> ERROR: File missing: {metadata.file_path}")
|
||||
else:
|
||||
print(f"{model_name}: ❌ No checkpoint metadata found")
|
||||
|
||||
def create_test_checkpoint():
|
||||
"""Create a test checkpoint to verify saving works"""
|
||||
print("\n=== Testing Checkpoint Saving ===")
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Create a simple test model
|
||||
class TestModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(10, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
test_model = TestModel()
|
||||
|
||||
# Save using the checkpoint system
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
|
||||
result = save_checkpoint(
|
||||
model=test_model,
|
||||
model_name="test_model",
|
||||
model_type="test",
|
||||
performance_metrics={"loss": 0.1, "accuracy": 0.95},
|
||||
training_metadata={"test": True, "created": datetime.now().isoformat()}
|
||||
)
|
||||
|
||||
if result:
|
||||
print(f"✅ Test checkpoint saved successfully: {result.checkpoint_id}")
|
||||
|
||||
# Verify it exists
|
||||
db_manager = get_database_manager()
|
||||
metadata = db_manager.get_best_checkpoint_metadata("test_model")
|
||||
if metadata and Path(metadata.file_path).exists():
|
||||
print(f"✅ Test checkpoint verified: {metadata.file_path}")
|
||||
|
||||
# Clean up test checkpoint
|
||||
Path(metadata.file_path).unlink()
|
||||
print("🧹 Test checkpoint cleaned up")
|
||||
else:
|
||||
print("❌ Test checkpoint verification failed")
|
||||
else:
|
||||
print("❌ Test checkpoint saving failed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test checkpoint creation failed: {e}")
|
||||
|
||||
def main():
|
||||
"""Main migration process"""
|
||||
migrate_existing_models()
|
||||
verify_checkpoint_system()
|
||||
create_test_checkpoint()
|
||||
|
||||
print("\n=== Migration Complete ===")
|
||||
print("The checkpoint system should now work properly!")
|
||||
print("Existing models have been migrated and the system is ready for use.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
558
model_manager.py
558
model_manager.py
@@ -1,558 +0,0 @@
|
||||
"""
|
||||
Enhanced Model Management System for Trading Dashboard
|
||||
|
||||
This system provides:
|
||||
- Automatic cleanup of old model checkpoints
|
||||
- Best model tracking with performance metrics
|
||||
- Configurable retention policies
|
||||
- Startup model loading
|
||||
- Performance-based model selection
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import logging
|
||||
import torch
|
||||
import glob
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ModelMetrics:
|
||||
"""Performance metrics for model evaluation"""
|
||||
accuracy: float = 0.0
|
||||
profit_factor: float = 0.0
|
||||
win_rate: float = 0.0
|
||||
sharpe_ratio: float = 0.0
|
||||
max_drawdown: float = 0.0
|
||||
total_trades: int = 0
|
||||
avg_trade_duration: float = 0.0
|
||||
confidence_score: float = 0.0
|
||||
|
||||
def get_composite_score(self) -> float:
|
||||
"""Calculate composite performance score"""
|
||||
# Weighted composite score
|
||||
weights = {
|
||||
'profit_factor': 0.3,
|
||||
'sharpe_ratio': 0.25,
|
||||
'win_rate': 0.2,
|
||||
'accuracy': 0.15,
|
||||
'confidence_score': 0.1
|
||||
}
|
||||
|
||||
# Normalize values to 0-1 range
|
||||
normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0
|
||||
normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1
|
||||
normalized_win_rate = self.win_rate
|
||||
normalized_accuracy = self.accuracy
|
||||
normalized_confidence = self.confidence_score
|
||||
|
||||
# Apply penalties for poor performance
|
||||
drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown
|
||||
|
||||
score = (
|
||||
weights['profit_factor'] * normalized_pf +
|
||||
weights['sharpe_ratio'] * normalized_sharpe +
|
||||
weights['win_rate'] * normalized_win_rate +
|
||||
weights['accuracy'] * normalized_accuracy +
|
||||
weights['confidence_score'] * normalized_confidence
|
||||
) * drawdown_penalty
|
||||
|
||||
return min(max(score, 0), 1)
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Complete model information and metadata"""
|
||||
model_type: str # 'cnn', 'rl', 'transformer'
|
||||
model_name: str
|
||||
file_path: str
|
||||
creation_time: datetime
|
||||
last_updated: datetime
|
||||
file_size_mb: float
|
||||
metrics: ModelMetrics
|
||||
training_episodes: int = 0
|
||||
model_version: str = "1.0"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
data = asdict(self)
|
||||
data['creation_time'] = self.creation_time.isoformat()
|
||||
data['last_updated'] = self.last_updated.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
|
||||
"""Create from dictionary"""
|
||||
data['creation_time'] = datetime.fromisoformat(data['creation_time'])
|
||||
data['last_updated'] = datetime.fromisoformat(data['last_updated'])
|
||||
data['metrics'] = ModelMetrics(**data['metrics'])
|
||||
return cls(**data)
|
||||
|
||||
class ModelManager:
|
||||
"""Enhanced model management system"""
|
||||
|
||||
def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None):
|
||||
self.base_dir = Path(base_dir)
|
||||
self.config = config or self._get_default_config()
|
||||
|
||||
# Model directories
|
||||
self.models_dir = self.base_dir / "models"
|
||||
self.nn_models_dir = self.base_dir / "NN" / "models"
|
||||
self.registry_file = self.models_dir / "model_registry.json"
|
||||
self.best_models_dir = self.models_dir / "best_models"
|
||||
|
||||
# Create directories
|
||||
self.best_models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Model registry
|
||||
self.model_registry: Dict[str, ModelInfo] = {}
|
||||
self._load_registry()
|
||||
|
||||
logger.info(f"Model Manager initialized - Base: {self.base_dir}")
|
||||
logger.info(f"Retention policy: Keep {self.config['max_models_per_type']} best models per type")
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration"""
|
||||
return {
|
||||
'max_models_per_type': 3, # Keep top 3 models per type
|
||||
'max_total_models': 10, # Maximum total models to keep
|
||||
'cleanup_frequency_hours': 24, # Cleanup every 24 hours
|
||||
'min_performance_threshold': 0.3, # Minimum composite score
|
||||
'max_checkpoint_age_days': 7, # Delete checkpoints older than 7 days
|
||||
'auto_cleanup_enabled': True,
|
||||
'backup_before_cleanup': True,
|
||||
'model_size_limit_mb': 100, # Individual model size limit
|
||||
'total_storage_limit_gb': 5.0 # Total storage limit
|
||||
}
|
||||
|
||||
def _load_registry(self):
|
||||
"""Load model registry from file"""
|
||||
try:
|
||||
if self.registry_file.exists():
|
||||
with open(self.registry_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
self.model_registry = {
|
||||
k: ModelInfo.from_dict(v) for k, v in data.items()
|
||||
}
|
||||
logger.info(f"Loaded {len(self.model_registry)} models from registry")
|
||||
else:
|
||||
logger.info("No existing model registry found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model registry: {e}")
|
||||
self.model_registry = {}
|
||||
|
||||
def _save_registry(self):
|
||||
"""Save model registry to file"""
|
||||
try:
|
||||
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.registry_file, 'w') as f:
|
||||
data = {k: v.to_dict() for k, v in self.model_registry.items()}
|
||||
json.dump(data, f, indent=2, default=str)
|
||||
logger.info(f"Saved registry with {len(self.model_registry)} models")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model registry: {e}")
|
||||
|
||||
def cleanup_all_existing_models(self, confirm: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Clean up all existing model files and prepare for 2-action system training
|
||||
|
||||
Args:
|
||||
confirm: If True, perform the cleanup. If False, return what would be cleaned
|
||||
|
||||
Returns:
|
||||
Dict with cleanup statistics
|
||||
"""
|
||||
cleanup_stats = {
|
||||
'files_found': 0,
|
||||
'files_deleted': 0,
|
||||
'directories_cleaned': 0,
|
||||
'space_freed_mb': 0.0,
|
||||
'errors': []
|
||||
}
|
||||
|
||||
# Model file patterns for both 2-action and legacy 3-action systems
|
||||
model_patterns = [
|
||||
"**/*.pt", "**/*.pth", "**/*.h5", "**/*.pkl", "**/*.joblib", "**/*.model",
|
||||
"**/checkpoint_*", "**/model_*", "**/cnn_*", "**/dqn_*", "**/rl_*"
|
||||
]
|
||||
|
||||
# Directories to clean
|
||||
model_directories = [
|
||||
"models/saved",
|
||||
"NN/models/saved",
|
||||
"NN/models/saved/checkpoints",
|
||||
"NN/models/saved/realtime_checkpoints",
|
||||
"NN/models/saved/realtime_ticks_checkpoints",
|
||||
"model_backups"
|
||||
]
|
||||
|
||||
try:
|
||||
# Scan for files to be cleaned
|
||||
for directory in model_directories:
|
||||
dir_path = Path(self.base_dir) / directory
|
||||
if dir_path.exists():
|
||||
for pattern in model_patterns:
|
||||
for file_path in dir_path.glob(pattern):
|
||||
if file_path.is_file():
|
||||
cleanup_stats['files_found'] += 1
|
||||
file_size = file_path.stat().st_size / (1024 * 1024) # MB
|
||||
cleanup_stats['space_freed_mb'] += file_size
|
||||
|
||||
if confirm:
|
||||
try:
|
||||
file_path.unlink()
|
||||
cleanup_stats['files_deleted'] += 1
|
||||
logger.info(f"Deleted model file: {file_path}")
|
||||
except Exception as e:
|
||||
cleanup_stats['errors'].append(f"Failed to delete {file_path}: {e}")
|
||||
|
||||
# Clean up empty checkpoint directories
|
||||
for directory in model_directories:
|
||||
dir_path = Path(self.base_dir) / directory
|
||||
if dir_path.exists():
|
||||
for subdir in dir_path.rglob("*"):
|
||||
if subdir.is_dir() and not any(subdir.iterdir()):
|
||||
if confirm:
|
||||
try:
|
||||
subdir.rmdir()
|
||||
cleanup_stats['directories_cleaned'] += 1
|
||||
logger.info(f"Removed empty directory: {subdir}")
|
||||
except Exception as e:
|
||||
cleanup_stats['errors'].append(f"Failed to remove directory {subdir}: {e}")
|
||||
|
||||
if confirm:
|
||||
# Clear the registry for fresh start with 2-action system
|
||||
self.model_registry = {
|
||||
'models': {},
|
||||
'metadata': {
|
||||
'last_updated': datetime.now().isoformat(),
|
||||
'total_models': 0,
|
||||
'system_type': '2_action', # Mark as 2-action system
|
||||
'action_space': ['SELL', 'BUY'],
|
||||
'version': '2.0'
|
||||
}
|
||||
}
|
||||
self._save_registry()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("MODEL CLEANUP COMPLETED - 2-ACTION SYSTEM READY")
|
||||
logger.info(f"Files deleted: {cleanup_stats['files_deleted']}")
|
||||
logger.info(f"Space freed: {cleanup_stats['space_freed_mb']:.2f} MB")
|
||||
logger.info(f"Directories cleaned: {cleanup_stats['directories_cleaned']}")
|
||||
logger.info("Registry reset for 2-action system (BUY/SELL)")
|
||||
logger.info("Ready for fresh training with intelligent position management")
|
||||
logger.info("=" * 60)
|
||||
else:
|
||||
logger.info("=" * 60)
|
||||
logger.info("MODEL CLEANUP PREVIEW - 2-ACTION SYSTEM MIGRATION")
|
||||
logger.info(f"Files to delete: {cleanup_stats['files_found']}")
|
||||
logger.info(f"Space to free: {cleanup_stats['space_freed_mb']:.2f} MB")
|
||||
logger.info("Run with confirm=True to perform cleanup")
|
||||
logger.info("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
cleanup_stats['errors'].append(f"Cleanup error: {e}")
|
||||
logger.error(f"Error during model cleanup: {e}")
|
||||
|
||||
return cleanup_stats
|
||||
|
||||
def register_model(self, model_path: str, model_type: str, metrics: Optional[ModelMetrics] = None) -> str:
|
||||
"""
|
||||
Register a new model in the 2-action system
|
||||
|
||||
Args:
|
||||
model_path: Path to the model file
|
||||
model_type: Type of model ('cnn', 'rl', 'transformer')
|
||||
metrics: Performance metrics
|
||||
|
||||
Returns:
|
||||
str: Unique model name/ID
|
||||
"""
|
||||
if not Path(model_path).exists():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Generate unique model name
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_name = f"{model_type}_2action_{timestamp}"
|
||||
|
||||
# Get file info
|
||||
file_path = Path(model_path)
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
# Default metrics for 2-action system
|
||||
if metrics is None:
|
||||
metrics = ModelMetrics(
|
||||
accuracy=0.0,
|
||||
profit_factor=1.0,
|
||||
win_rate=0.5,
|
||||
sharpe_ratio=0.0,
|
||||
max_drawdown=0.0,
|
||||
confidence_score=0.5
|
||||
)
|
||||
|
||||
# Create model info
|
||||
model_info = ModelInfo(
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
file_path=str(file_path.absolute()),
|
||||
creation_time=datetime.now(),
|
||||
last_updated=datetime.now(),
|
||||
file_size_mb=file_size_mb,
|
||||
metrics=metrics,
|
||||
model_version="2.0" # 2-action system version
|
||||
)
|
||||
|
||||
# Add to registry
|
||||
self.model_registry['models'][model_name] = model_info.to_dict()
|
||||
self.model_registry['metadata']['total_models'] = len(self.model_registry['models'])
|
||||
self.model_registry['metadata']['last_updated'] = datetime.now().isoformat()
|
||||
self.model_registry['metadata']['system_type'] = '2_action'
|
||||
self.model_registry['metadata']['action_space'] = ['SELL', 'BUY']
|
||||
|
||||
self._save_registry()
|
||||
|
||||
# Cleanup old models if necessary
|
||||
self._cleanup_models_by_type(model_type)
|
||||
|
||||
logger.info(f"Registered 2-action model: {model_name}")
|
||||
logger.info(f"Model type: {model_type}, Size: {file_size_mb:.2f} MB")
|
||||
logger.info(f"Performance score: {metrics.get_composite_score():.4f}")
|
||||
|
||||
return model_name
|
||||
|
||||
def _should_keep_model(self, model_info: ModelInfo) -> bool:
|
||||
"""Determine if model should be kept based on performance"""
|
||||
score = model_info.metrics.get_composite_score()
|
||||
|
||||
# Check minimum threshold
|
||||
if score < self.config['min_performance_threshold']:
|
||||
return False
|
||||
|
||||
# Check size limit
|
||||
if model_info.file_size_mb > self.config['model_size_limit_mb']:
|
||||
logger.warning(f"Model too large: {model_info.file_size_mb:.1f}MB > {self.config['model_size_limit_mb']}MB")
|
||||
return False
|
||||
|
||||
# Check if better than existing models of same type
|
||||
existing_models = self.get_models_by_type(model_info.model_type)
|
||||
if len(existing_models) >= self.config['max_models_per_type']:
|
||||
# Find worst performing model
|
||||
worst_model = min(existing_models.values(), key=lambda m: m.metrics.get_composite_score())
|
||||
if score <= worst_model.metrics.get_composite_score():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _cleanup_models_by_type(self, model_type: str):
|
||||
"""Cleanup old models of specific type, keeping only the best ones"""
|
||||
models_of_type = self.get_models_by_type(model_type)
|
||||
max_keep = self.config['max_models_per_type']
|
||||
|
||||
if len(models_of_type) <= max_keep:
|
||||
return
|
||||
|
||||
# Sort by performance score
|
||||
sorted_models = sorted(
|
||||
models_of_type.items(),
|
||||
key=lambda x: x[1].metrics.get_composite_score(),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Keep only the best models
|
||||
models_to_keep = sorted_models[:max_keep]
|
||||
models_to_remove = sorted_models[max_keep:]
|
||||
|
||||
for model_name, model_info in models_to_remove:
|
||||
try:
|
||||
# Remove file
|
||||
model_path = Path(model_info.file_path)
|
||||
if model_path.exists():
|
||||
model_path.unlink()
|
||||
|
||||
# Remove from registry
|
||||
del self.model_registry[model_name]
|
||||
|
||||
logger.info(f"Removed old model: {model_name} (Score: {model_info.metrics.get_composite_score():.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing model {model_name}: {e}")
|
||||
|
||||
def get_models_by_type(self, model_type: str) -> Dict[str, ModelInfo]:
|
||||
"""Get all models of a specific type"""
|
||||
return {
|
||||
name: info for name, info in self.model_registry.items()
|
||||
if info.model_type == model_type
|
||||
}
|
||||
|
||||
def get_best_model(self, model_type: str) -> Optional[ModelInfo]:
|
||||
"""Get the best performing model of a specific type"""
|
||||
models_of_type = self.get_models_by_type(model_type)
|
||||
|
||||
if not models_of_type:
|
||||
return None
|
||||
|
||||
return max(models_of_type.values(), key=lambda m: m.metrics.get_composite_score())
|
||||
|
||||
def load_best_models(self) -> Dict[str, Any]:
|
||||
"""Load the best models for each type"""
|
||||
loaded_models = {}
|
||||
|
||||
for model_type in ['cnn', 'rl', 'transformer']:
|
||||
best_model = self.get_best_model(model_type)
|
||||
|
||||
if best_model:
|
||||
try:
|
||||
model_path = Path(best_model.file_path)
|
||||
if model_path.exists():
|
||||
# Load the model
|
||||
model_data = torch.load(model_path, map_location='cpu')
|
||||
loaded_models[model_type] = {
|
||||
'model': model_data,
|
||||
'info': best_model,
|
||||
'path': str(model_path)
|
||||
}
|
||||
logger.info(f"Loaded best {model_type} model: {best_model.model_name} "
|
||||
f"(Score: {best_model.metrics.get_composite_score():.3f})")
|
||||
else:
|
||||
logger.warning(f"Best {model_type} model file not found: {model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {model_type} model: {e}")
|
||||
else:
|
||||
logger.info(f"No {model_type} model available")
|
||||
|
||||
return loaded_models
|
||||
|
||||
def update_model_performance(self, model_name: str, metrics: ModelMetrics):
|
||||
"""Update performance metrics for a model"""
|
||||
if model_name in self.model_registry:
|
||||
self.model_registry[model_name].metrics = metrics
|
||||
self.model_registry[model_name].last_updated = datetime.now()
|
||||
self._save_registry()
|
||||
|
||||
logger.info(f"Updated metrics for {model_name}: Score {metrics.get_composite_score():.3f}")
|
||||
else:
|
||||
logger.warning(f"Model {model_name} not found in registry")
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""Get storage usage statistics"""
|
||||
total_size_mb = 0
|
||||
model_count = 0
|
||||
|
||||
for model_info in self.model_registry.values():
|
||||
total_size_mb += model_info.file_size_mb
|
||||
model_count += 1
|
||||
|
||||
# Check actual storage usage
|
||||
actual_size_mb = 0
|
||||
if self.best_models_dir.exists():
|
||||
actual_size_mb = sum(
|
||||
f.stat().st_size for f in self.best_models_dir.rglob('*') if f.is_file()
|
||||
) / 1024 / 1024
|
||||
|
||||
return {
|
||||
'total_models': model_count,
|
||||
'registered_size_mb': total_size_mb,
|
||||
'actual_size_mb': actual_size_mb,
|
||||
'storage_limit_gb': self.config['total_storage_limit_gb'],
|
||||
'utilization_percent': (actual_size_mb / 1024) / self.config['total_storage_limit_gb'] * 100,
|
||||
'models_by_type': {
|
||||
model_type: len(self.get_models_by_type(model_type))
|
||||
for model_type in ['cnn', 'rl', 'transformer']
|
||||
}
|
||||
}
|
||||
|
||||
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
|
||||
"""Get model performance leaderboard"""
|
||||
leaderboard = []
|
||||
|
||||
for model_name, model_info in self.model_registry.items():
|
||||
leaderboard.append({
|
||||
'name': model_name,
|
||||
'type': model_info.model_type,
|
||||
'score': model_info.metrics.get_composite_score(),
|
||||
'profit_factor': model_info.metrics.profit_factor,
|
||||
'win_rate': model_info.metrics.win_rate,
|
||||
'sharpe_ratio': model_info.metrics.sharpe_ratio,
|
||||
'size_mb': model_info.file_size_mb,
|
||||
'age_days': (datetime.now() - model_info.creation_time).days,
|
||||
'last_updated': model_info.last_updated.strftime('%Y-%m-%d %H:%M')
|
||||
})
|
||||
|
||||
# Sort by score
|
||||
leaderboard.sort(key=lambda x: x['score'], reverse=True)
|
||||
|
||||
return leaderboard
|
||||
|
||||
def cleanup_checkpoints(self) -> Dict[str, Any]:
|
||||
"""Clean up old checkpoint files"""
|
||||
cleanup_summary = {
|
||||
'deleted_files': 0,
|
||||
'freed_space_mb': 0,
|
||||
'errors': []
|
||||
}
|
||||
|
||||
cutoff_date = datetime.now() - timedelta(days=self.config['max_checkpoint_age_days'])
|
||||
|
||||
# Search for checkpoint files
|
||||
checkpoint_patterns = [
|
||||
"**/checkpoint_*.pt",
|
||||
"**/model_*.pt",
|
||||
"**/*checkpoint*",
|
||||
"**/epoch_*.pt"
|
||||
]
|
||||
|
||||
for pattern in checkpoint_patterns:
|
||||
for file_path in self.base_dir.rglob(pattern):
|
||||
if "best_models" not in str(file_path) and file_path.is_file():
|
||||
try:
|
||||
file_time = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
if file_time < cutoff_date:
|
||||
size_mb = file_path.stat().st_size / 1024 / 1024
|
||||
file_path.unlink()
|
||||
cleanup_summary['deleted_files'] += 1
|
||||
cleanup_summary['freed_space_mb'] += size_mb
|
||||
except Exception as e:
|
||||
error_msg = f"Error deleting checkpoint {file_path}: {e}"
|
||||
logger.error(error_msg)
|
||||
cleanup_summary['errors'].append(error_msg)
|
||||
|
||||
if cleanup_summary['deleted_files'] > 0:
|
||||
logger.info(f"Checkpoint cleanup: Deleted {cleanup_summary['deleted_files']} files, "
|
||||
f"freed {cleanup_summary['freed_space_mb']:.1f}MB")
|
||||
|
||||
return cleanup_summary
|
||||
|
||||
def create_model_manager() -> ModelManager:
|
||||
"""Create and initialize the global model manager"""
|
||||
return ModelManager()
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Create model manager
|
||||
manager = ModelManager()
|
||||
|
||||
# Clean up all existing models (with confirmation)
|
||||
print("WARNING: This will delete ALL existing models!")
|
||||
print("Type 'CONFIRM' to proceed:")
|
||||
user_input = input().strip()
|
||||
|
||||
if user_input == "CONFIRM":
|
||||
cleanup_result = manager.cleanup_all_existing_models(confirm=True)
|
||||
print(f"\nCleanup complete:")
|
||||
print(f"- Deleted {cleanup_result['files_deleted']} files")
|
||||
print(f"- Freed {cleanup_result['space_freed_mb']:.1f}MB of space")
|
||||
print(f"- Cleaned {cleanup_result['directories_cleaned']} directories")
|
||||
|
||||
if cleanup_result['errors']:
|
||||
print(f"- {len(cleanup_result['errors'])} errors occurred")
|
||||
else:
|
||||
print("Cleanup cancelled")
|
||||
@@ -1,193 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Position Sync Enhancement - Fix P&L and Win Rate Calculation
|
||||
|
||||
This script enhances the position synchronization and P&L calculation
|
||||
to properly account for leverage in the trading system.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import get_config, setup_logging
|
||||
from core.trading_executor import TradingExecutor, TradeRecord
|
||||
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def analyze_trade_records():
|
||||
"""Analyze trade records for P&L calculation issues"""
|
||||
logger.info("Analyzing trade records for P&L calculation issues...")
|
||||
|
||||
# Initialize trading executor
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Get trade records
|
||||
trade_records = trading_executor.trade_records
|
||||
|
||||
if not trade_records:
|
||||
logger.warning("No trade records found.")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(trade_records)} trade records.")
|
||||
|
||||
# Analyze P&L calculation
|
||||
total_pnl = 0.0
|
||||
total_gross_pnl = 0.0
|
||||
total_fees = 0.0
|
||||
winning_trades = 0
|
||||
losing_trades = 0
|
||||
breakeven_trades = 0
|
||||
|
||||
for trade in trade_records:
|
||||
# Calculate correct P&L with leverage
|
||||
entry_value = trade.entry_price * trade.quantity
|
||||
exit_value = trade.exit_price * trade.quantity
|
||||
|
||||
if trade.side == 'LONG':
|
||||
gross_pnl = (exit_value - entry_value) * trade.leverage
|
||||
else: # SHORT
|
||||
gross_pnl = (entry_value - exit_value) * trade.leverage
|
||||
|
||||
# Calculate fees
|
||||
fees = (entry_value + exit_value) * 0.001 # 0.1% fee on both entry and exit
|
||||
|
||||
# Calculate net P&L
|
||||
net_pnl = gross_pnl - fees
|
||||
|
||||
# Compare with stored values
|
||||
pnl_diff = abs(net_pnl - trade.pnl)
|
||||
if pnl_diff > 0.01: # More than 1 cent difference
|
||||
logger.warning(f"P&L calculation issue detected for trade {trade.entry_time}:")
|
||||
logger.warning(f" Stored P&L: ${trade.pnl:.2f}")
|
||||
logger.warning(f" Calculated P&L: ${net_pnl:.2f}")
|
||||
logger.warning(f" Difference: ${pnl_diff:.2f}")
|
||||
logger.warning(f" Leverage used: {trade.leverage}x")
|
||||
|
||||
# Update statistics
|
||||
total_pnl += net_pnl
|
||||
total_gross_pnl += gross_pnl
|
||||
total_fees += fees
|
||||
|
||||
if net_pnl > 0.01: # More than 1 cent profit
|
||||
winning_trades += 1
|
||||
elif net_pnl < -0.01: # More than 1 cent loss
|
||||
losing_trades += 1
|
||||
else:
|
||||
breakeven_trades += 1
|
||||
|
||||
# Calculate win rate
|
||||
total_trades = winning_trades + losing_trades + breakeven_trades
|
||||
win_rate = (winning_trades / total_trades * 100) if total_trades > 0 else 0.0
|
||||
|
||||
logger.info("\nTrade Analysis Results:")
|
||||
logger.info(f" Total trades: {total_trades}")
|
||||
logger.info(f" Winning trades: {winning_trades}")
|
||||
logger.info(f" Losing trades: {losing_trades}")
|
||||
logger.info(f" Breakeven trades: {breakeven_trades}")
|
||||
logger.info(f" Win rate: {win_rate:.1f}%")
|
||||
logger.info(f" Total P&L: ${total_pnl:.2f}")
|
||||
logger.info(f" Total gross P&L: ${total_gross_pnl:.2f}")
|
||||
logger.info(f" Total fees: ${total_fees:.2f}")
|
||||
|
||||
# Check for leverage issues
|
||||
leverage_issues = False
|
||||
for trade in trade_records:
|
||||
if trade.leverage <= 1.0:
|
||||
leverage_issues = True
|
||||
logger.warning(f"Low leverage detected: {trade.leverage}x for trade at {trade.entry_time}")
|
||||
|
||||
if leverage_issues:
|
||||
logger.warning("\nLeverage issues detected. Consider fixing the leverage calculation.")
|
||||
logger.info("Recommended fix: Ensure leverage is properly set in the trading executor.")
|
||||
else:
|
||||
logger.info("\nNo leverage issues detected.")
|
||||
|
||||
def fix_leverage_calculation():
|
||||
"""Fix leverage calculation in the trading executor"""
|
||||
logger.info("Fixing leverage calculation in the trading executor...")
|
||||
|
||||
# Initialize trading executor
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Get current leverage
|
||||
current_leverage = trading_executor.current_leverage
|
||||
logger.info(f"Current leverage setting: {current_leverage}x")
|
||||
|
||||
# Check if leverage is properly set
|
||||
if current_leverage <= 1:
|
||||
logger.warning("Leverage is set too low. Updating to 20x...")
|
||||
trading_executor.current_leverage = 20
|
||||
logger.info(f"Updated leverage to {trading_executor.current_leverage}x")
|
||||
else:
|
||||
logger.info("Leverage is already set correctly.")
|
||||
|
||||
# Update trade records with correct leverage
|
||||
updated_count = 0
|
||||
for i, trade in enumerate(trading_executor.trade_records):
|
||||
if trade.leverage <= 1.0:
|
||||
# Create updated trade record
|
||||
updated_trade = TradeRecord(
|
||||
symbol=trade.symbol,
|
||||
side=trade.side,
|
||||
quantity=trade.quantity,
|
||||
entry_price=trade.entry_price,
|
||||
exit_price=trade.exit_price,
|
||||
entry_time=trade.entry_time,
|
||||
exit_time=trade.exit_time,
|
||||
pnl=trade.pnl,
|
||||
fees=trade.fees,
|
||||
confidence=trade.confidence,
|
||||
hold_time_seconds=trade.hold_time_seconds,
|
||||
leverage=trading_executor.current_leverage, # Use current leverage setting
|
||||
position_size_usd=trade.position_size_usd,
|
||||
gross_pnl=trade.gross_pnl,
|
||||
net_pnl=trade.net_pnl
|
||||
)
|
||||
|
||||
# Recalculate P&L with correct leverage
|
||||
entry_value = updated_trade.entry_price * updated_trade.quantity
|
||||
exit_value = updated_trade.exit_price * updated_trade.quantity
|
||||
|
||||
if updated_trade.side == 'LONG':
|
||||
updated_trade.gross_pnl = (exit_value - entry_value) * updated_trade.leverage
|
||||
else: # SHORT
|
||||
updated_trade.gross_pnl = (entry_value - exit_value) * updated_trade.leverage
|
||||
|
||||
# Recalculate fees
|
||||
updated_trade.fees = (entry_value + exit_value) * 0.001 # 0.1% fee on both entry and exit
|
||||
|
||||
# Recalculate net P&L
|
||||
updated_trade.net_pnl = updated_trade.gross_pnl - updated_trade.fees
|
||||
updated_trade.pnl = updated_trade.net_pnl
|
||||
|
||||
# Update trade record
|
||||
trading_executor.trade_records[i] = updated_trade
|
||||
updated_count += 1
|
||||
|
||||
logger.info(f"Updated {updated_count} trade records with correct leverage.")
|
||||
|
||||
# Save updated trade records
|
||||
# Note: This is a placeholder. In a real implementation, you would need to
|
||||
# persist the updated trade records to storage.
|
||||
logger.info("Changes will take effect on next dashboard restart.")
|
||||
|
||||
return updated_count > 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("=" * 70)
|
||||
logger.info("POSITION SYNC ENHANCEMENT")
|
||||
logger.info("=" * 70)
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == 'fix':
|
||||
fix_leverage_calculation()
|
||||
else:
|
||||
analyze_trade_records()
|
||||
124
read_logs.py
124
read_logs.py
@@ -1,124 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Log Reader Utility
|
||||
|
||||
This script provides a convenient way to read and filter log files during
|
||||
development.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description='Read and filter log files')
|
||||
parser.add_argument('--file', type=str, help='Log file to read (defaults to most recent .log file)')
|
||||
parser.add_argument('--tail', type=int, default=50, help='Number of lines to show from the end')
|
||||
parser.add_argument('--follow', '-f', action='store_true', help='Follow the file as it grows')
|
||||
parser.add_argument('--filter', type=str, help='Only show lines containing this string')
|
||||
parser.add_argument('--list', action='store_true', help='List all log files sorted by modification time')
|
||||
return parser.parse_args()
|
||||
|
||||
def get_most_recent_log():
|
||||
"""Find the most recently modified log file"""
|
||||
log_files = [f for f in os.listdir('.') if f.endswith('.log')]
|
||||
if not log_files:
|
||||
print("No log files found in current directory.")
|
||||
sys.exit(1)
|
||||
|
||||
# Sort by modification time (newest first)
|
||||
log_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
||||
return log_files[0]
|
||||
|
||||
def list_log_files():
|
||||
"""List all log files sorted by modification time"""
|
||||
log_files = [f for f in os.listdir('.') if f.endswith('.log')]
|
||||
if not log_files:
|
||||
print("No log files found in current directory.")
|
||||
sys.exit(1)
|
||||
|
||||
# Sort by modification time (newest first)
|
||||
log_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
||||
|
||||
print(f"{'LAST MODIFIED':<20} {'SIZE':<10} FILENAME")
|
||||
print("-" * 60)
|
||||
for log_file in log_files:
|
||||
mtime = datetime.fromtimestamp(os.path.getmtime(log_file))
|
||||
size = os.path.getsize(log_file)
|
||||
size_str = f"{size / 1024:.1f} KB" if size > 1024 else f"{size} B"
|
||||
print(f"{mtime.strftime('%Y-%m-%d %H:%M:%S'):<20} {size_str:<10} {log_file}")
|
||||
|
||||
def read_log_tail(file_path, num_lines, filter_text=None):
|
||||
"""Read the last N lines of a file"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
# Read all lines (inefficient but simple)
|
||||
lines = f.readlines()
|
||||
|
||||
# Filter if needed
|
||||
if filter_text:
|
||||
lines = [line for line in lines if filter_text in line]
|
||||
|
||||
# Get the last N lines
|
||||
last_lines = lines[-num_lines:] if len(lines) > num_lines else lines
|
||||
return last_lines
|
||||
except Exception as e:
|
||||
print(f"Error reading file: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
def follow_log(file_path, filter_text=None):
|
||||
"""Follow the log file as it grows (like tail -f)"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
# Go to the end of the file
|
||||
f.seek(0, 2)
|
||||
|
||||
while True:
|
||||
line = f.readline()
|
||||
if line:
|
||||
if not filter_text or filter_text in line:
|
||||
# Remove newlines at the end to avoid double spacing
|
||||
print(line.rstrip())
|
||||
else:
|
||||
time.sleep(0.1) # Sleep briefly to avoid consuming CPU
|
||||
except KeyboardInterrupt:
|
||||
print("\nLog reading stopped.")
|
||||
except Exception as e:
|
||||
print(f"Error following file: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
args = parse_args()
|
||||
|
||||
# List all log files if requested
|
||||
if args.list:
|
||||
list_log_files()
|
||||
return
|
||||
|
||||
# Determine which file to read
|
||||
file_path = args.file
|
||||
if not file_path:
|
||||
file_path = get_most_recent_log()
|
||||
print(f"Reading most recent log file: {file_path}")
|
||||
|
||||
# Follow mode (like tail -f)
|
||||
if args.follow:
|
||||
print(f"Following {file_path} (Press Ctrl+C to stop)...")
|
||||
# First print the tail
|
||||
for line in read_log_tail(file_path, args.tail, args.filter):
|
||||
print(line.rstrip())
|
||||
print("-" * 80)
|
||||
print("Waiting for new content...")
|
||||
# Then follow
|
||||
follow_log(file_path, args.filter)
|
||||
else:
|
||||
# Just print the tail
|
||||
for line in read_log_tail(file_path, args.tail, args.filter):
|
||||
print(line.rstrip())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,31 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to reset the database manager instance to trigger migration in running system
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from utils.database_manager import reset_database_manager
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
"""Reset the database manager to trigger migration"""
|
||||
try:
|
||||
logger.info("Resetting database manager to trigger migration...")
|
||||
reset_database_manager()
|
||||
logger.info("✅ Database manager reset successfully!")
|
||||
logger.info("The migration will run automatically on the next database access.")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to reset database manager: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -1,204 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Reset Models and Fix Action Mapping
|
||||
|
||||
This script:
|
||||
1. Deletes existing model files
|
||||
2. Creates new model files with consistent action mapping
|
||||
3. Updates action mapping in key files
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import logging
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def ensure_directory(directory):
|
||||
"""Ensure directory exists"""
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
logger.info(f"Created directory: {directory}")
|
||||
|
||||
def delete_directory_contents(directory):
|
||||
"""Delete all files in a directory"""
|
||||
if os.path.exists(directory):
|
||||
for filename in os.listdir(directory):
|
||||
file_path = os.path.join(directory, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
logger.info(f"Deleted: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete {file_path}. Reason: {e}")
|
||||
|
||||
def create_backup_directory():
|
||||
"""Create a backup directory with timestamp"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_dir = f"models/backup_{timestamp}"
|
||||
ensure_directory(backup_dir)
|
||||
return backup_dir
|
||||
|
||||
def backup_models():
|
||||
"""Backup existing models"""
|
||||
backup_dir = create_backup_directory()
|
||||
|
||||
# List of model directories to backup
|
||||
model_dirs = [
|
||||
"models/enhanced_rl",
|
||||
"models/enhanced_cnn",
|
||||
"models/realtime_rl_cob",
|
||||
"models/rl",
|
||||
"models/cnn"
|
||||
]
|
||||
|
||||
for model_dir in model_dirs:
|
||||
if os.path.exists(model_dir):
|
||||
dest_dir = os.path.join(backup_dir, os.path.basename(model_dir))
|
||||
ensure_directory(dest_dir)
|
||||
|
||||
# Copy files
|
||||
for filename in os.listdir(model_dir):
|
||||
file_path = os.path.join(model_dir, filename)
|
||||
if os.path.isfile(file_path):
|
||||
shutil.copy2(file_path, dest_dir)
|
||||
logger.info(f"Backed up: {file_path} to {dest_dir}")
|
||||
|
||||
return backup_dir
|
||||
|
||||
def initialize_dqn_model():
|
||||
"""Initialize a new DQN model with consistent action mapping"""
|
||||
try:
|
||||
# Import necessary modules
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
|
||||
# Define state shape for BTC and ETH
|
||||
state_shape = (100,) # Default feature dimension
|
||||
|
||||
# Create models directory
|
||||
ensure_directory("models/enhanced_rl")
|
||||
|
||||
# Initialize DQN with 3 actions (BUY=0, SELL=1, HOLD=2)
|
||||
dqn_btc = DQNAgent(
|
||||
state_shape=state_shape,
|
||||
n_actions=3, # BUY=0, SELL=1, HOLD=2
|
||||
learning_rate=0.001,
|
||||
epsilon=0.5, # Start with moderate exploration
|
||||
epsilon_min=0.01,
|
||||
epsilon_decay=0.995,
|
||||
model_name="BTC_USDT_dqn"
|
||||
)
|
||||
|
||||
dqn_eth = DQNAgent(
|
||||
state_shape=state_shape,
|
||||
n_actions=3, # BUY=0, SELL=1, HOLD=2
|
||||
learning_rate=0.001,
|
||||
epsilon=0.5, # Start with moderate exploration
|
||||
epsilon_min=0.01,
|
||||
epsilon_decay=0.995,
|
||||
model_name="ETH_USDT_dqn"
|
||||
)
|
||||
|
||||
# Save initial models
|
||||
torch.save(dqn_btc.policy_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_policy.pth")
|
||||
torch.save(dqn_btc.target_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_target.pth")
|
||||
torch.save(dqn_eth.policy_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_policy.pth")
|
||||
torch.save(dqn_eth.target_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_target.pth")
|
||||
|
||||
logger.info("Initialized new DQN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize DQN models: {e}")
|
||||
return False
|
||||
|
||||
def initialize_cnn_model():
|
||||
"""Initialize a new CNN model with consistent action mapping"""
|
||||
try:
|
||||
# Import necessary modules
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
|
||||
# Define input dimension and number of actions
|
||||
input_dim = 100 # Default feature dimension
|
||||
n_actions = 3 # BUY=0, SELL=1, HOLD=2
|
||||
|
||||
# Create models directory
|
||||
ensure_directory("models/enhanced_cnn")
|
||||
|
||||
# Initialize CNN models for BTC and ETH
|
||||
cnn_btc = EnhancedCNN(input_dim, n_actions)
|
||||
cnn_eth = EnhancedCNN(input_dim, n_actions)
|
||||
|
||||
# Save initial models
|
||||
torch.save(cnn_btc.state_dict(), "models/enhanced_cnn/BTC_USDT_cnn.pth")
|
||||
torch.save(cnn_eth.state_dict(), "models/enhanced_cnn/ETH_USDT_cnn.pth")
|
||||
|
||||
logger.info("Initialized new CNN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize CNN models: {e}")
|
||||
return False
|
||||
|
||||
def initialize_realtime_rl_model():
|
||||
"""Initialize a new realtime RL model with consistent action mapping"""
|
||||
try:
|
||||
# Create models directory
|
||||
ensure_directory("models/realtime_rl_cob")
|
||||
|
||||
# Create empty model files to ensure directory is not empty
|
||||
with open("models/realtime_rl_cob/README.txt", "w") as f:
|
||||
f.write("Realtime RL COB models will be saved here.\n")
|
||||
f.write("Action mapping: BUY=0, SELL=1, HOLD=2\n")
|
||||
|
||||
logger.info("Initialized realtime RL model directory")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize realtime RL models: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main function to reset models and fix action mapping"""
|
||||
logger.info("Starting model reset and action mapping fix")
|
||||
|
||||
# Backup existing models
|
||||
backup_dir = backup_models()
|
||||
logger.info(f"Backed up existing models to {backup_dir}")
|
||||
|
||||
# Delete existing model files
|
||||
model_dirs = [
|
||||
"models/enhanced_rl",
|
||||
"models/enhanced_cnn",
|
||||
"models/realtime_rl_cob"
|
||||
]
|
||||
|
||||
for model_dir in model_dirs:
|
||||
delete_directory_contents(model_dir)
|
||||
logger.info(f"Deleted contents of {model_dir}")
|
||||
|
||||
# Initialize new models with consistent action mapping
|
||||
dqn_success = initialize_dqn_model()
|
||||
cnn_success = initialize_cnn_model()
|
||||
rl_success = initialize_realtime_rl_model()
|
||||
|
||||
if dqn_success and cnn_success and rl_success:
|
||||
logger.info("Successfully reset models and fixed action mapping")
|
||||
logger.info("New action mapping: BUY=0, SELL=1, HOLD=2")
|
||||
else:
|
||||
logger.error("Failed to reset models and fix action mapping")
|
||||
|
||||
logger.info("Model reset complete")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,325 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run Clean Trading Dashboard with Full Training Pipeline
|
||||
Integrated system with both training loop and clean web dashboard
|
||||
"""
|
||||
|
||||
import os
|
||||
# Fix OpenMP library conflicts before importing other modules
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
|
||||
# Fix matplotlib backend issue - set non-interactive backend before any imports
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # Use non-interactive Agg backend
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import platform
|
||||
from safe_logging import setup_safe_logging
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Windows-specific async event loop configuration
|
||||
if platform.system() == "Windows":
|
||||
# Use ProactorEventLoop on Windows for better I/O handling
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import get_config, setup_logging
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
# Setup logging
|
||||
setup_safe_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def start_training_pipeline(orchestrator, trading_executor):
|
||||
"""Start the training pipeline in the background with comprehensive error handling"""
|
||||
logger.info("=" * 70)
|
||||
logger.info("STARTING TRAINING PIPELINE WITH CLEAN DASHBOARD")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Set up async exception handler
|
||||
def handle_async_exception(loop, context):
|
||||
"""Handle uncaught async exceptions"""
|
||||
exception = context.get('exception')
|
||||
if exception:
|
||||
logger.error(f"Uncaught async exception: {exception}")
|
||||
logger.error(f"Context: {context}")
|
||||
else:
|
||||
logger.error(f"Async error: {context.get('message', 'Unknown error')}")
|
||||
|
||||
# Get current event loop and set exception handler
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.set_exception_handler(handle_async_exception)
|
||||
|
||||
# Initialize checkpoint management
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
training_integration = get_training_integration()
|
||||
|
||||
# Training statistics
|
||||
training_stats = {
|
||||
'iteration_count': 0,
|
||||
'total_decisions': 0,
|
||||
'successful_trades': 0,
|
||||
'best_performance': 0.0,
|
||||
'last_checkpoint_iteration': 0
|
||||
}
|
||||
|
||||
try:
|
||||
# Start real-time processing with error handling
|
||||
try:
|
||||
if hasattr(orchestrator, 'start_realtime_processing'):
|
||||
await orchestrator.start_realtime_processing()
|
||||
logger.info("Real-time processing started")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting real-time processing: {e}")
|
||||
|
||||
# Start COB integration with error handling
|
||||
try:
|
||||
if hasattr(orchestrator, 'start_cob_integration'):
|
||||
await orchestrator.start_cob_integration()
|
||||
logger.info("COB integration started - 5-minute data matrix active")
|
||||
else:
|
||||
logger.info("COB integration not available")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting COB integration: {e}")
|
||||
|
||||
# Main training loop
|
||||
iteration = 0
|
||||
last_checkpoint_time = time.time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
iteration += 1
|
||||
training_stats['iteration_count'] = iteration
|
||||
|
||||
# Get symbols to process
|
||||
symbols = orchestrator.symbols if hasattr(orchestrator, 'symbols') else ['ETH/USDT']
|
||||
|
||||
# Process each symbol
|
||||
for symbol in symbols:
|
||||
try:
|
||||
# Make trading decision (this triggers model training)
|
||||
decision = await orchestrator.make_trading_decision(symbol)
|
||||
if decision:
|
||||
training_stats['total_decisions'] += 1
|
||||
logger.debug(f"[{symbol}] Decision: {decision.action} @ {decision.confidence:.1%}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing {symbol}: {e}")
|
||||
|
||||
# Status logging every 100 iterations
|
||||
if iteration % 100 == 0:
|
||||
current_time = time.time()
|
||||
elapsed = current_time - last_checkpoint_time
|
||||
|
||||
logger.info(f"[TRAINING] Iteration {iteration}, Decisions: {training_stats['total_decisions']}, Time: {elapsed:.1f}s")
|
||||
|
||||
# Models will save their own checkpoints when performance improves
|
||||
training_stats['last_checkpoint_iteration'] = iteration
|
||||
last_checkpoint_time = current_time
|
||||
|
||||
# Brief pause to prevent overwhelming the system
|
||||
await asyncio.sleep(0.1) # 100ms between iterations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training loop error: {e}")
|
||||
await asyncio.sleep(5) # Wait longer on error
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training pipeline error: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def start_clean_dashboard_with_training():
|
||||
"""Start clean dashboard with full training pipeline"""
|
||||
try:
|
||||
logger.info("=" * 80)
|
||||
logger.info("CLEAN TRADING DASHBOARD + FULL TRAINING PIPELINE")
|
||||
logger.info("=" * 80)
|
||||
logger.info("Features: Real-time Training, COB Integration, Clean UI")
|
||||
logger.info("Universal Data Stream: ENABLED")
|
||||
logger.info("Neural Decision Fusion: ENABLED")
|
||||
logger.info("COB Integration: ENABLED")
|
||||
logger.info("GPU Training: ENABLED")
|
||||
logger.info("TensorBoard Integration: ENABLED")
|
||||
logger.info("Multi-symbol: ETH/USDT, BTC/USDT")
|
||||
|
||||
# Get port from environment or use default
|
||||
dashboard_port = int(os.environ.get('DASHBOARD_PORT', '8051'))
|
||||
tensorboard_port = int(os.environ.get('TENSORBOARD_PORT', '6006'))
|
||||
logger.info(f"Dashboard: http://127.0.0.1:{dashboard_port}")
|
||||
logger.info(f"TensorBoard: http://127.0.0.1:{tensorboard_port}")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Check environment variables
|
||||
enable_universal_stream = os.environ.get('ENABLE_UNIVERSAL_DATA_STREAM', '1') == '1'
|
||||
enable_nn_fusion = os.environ.get('ENABLE_NN_DECISION_FUSION', '1') == '1'
|
||||
enable_cob = os.environ.get('ENABLE_COB_INTEGRATION', '1') == '1'
|
||||
|
||||
logger.info(f"Universal Data Stream: {'ENABLED' if enable_universal_stream else 'DISABLED'}")
|
||||
logger.info(f"Neural Decision Fusion: {'ENABLED' if enable_nn_fusion else 'DISABLED'}")
|
||||
logger.info(f"COB Integration: {'ENABLED' if enable_cob else 'DISABLED'}")
|
||||
|
||||
# Get configuration
|
||||
config = get_config()
|
||||
|
||||
# Initialize core components with standardized versions
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Create standardized data provider
|
||||
data_provider = StandardizedDataProvider()
|
||||
logger.info("StandardizedDataProvider created with BaseDataInput support")
|
||||
|
||||
# Create enhanced orchestrator with standardized data provider
|
||||
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
||||
logger.info("Enhanced Trading Orchestrator created with COB integration")
|
||||
|
||||
# Create trading executor
|
||||
trading_executor = TradingExecutor(config_path="config.yaml")
|
||||
logger.info(f"Creating trading executor with {trading_executor.primary_name} configuration...")
|
||||
|
||||
|
||||
# Connect trading executor to orchestrator
|
||||
orchestrator.trading_executor = trading_executor
|
||||
logger.info("Trading Executor connected to Orchestrator")
|
||||
|
||||
# Initialize system resource monitoring
|
||||
from utils.system_monitor import start_system_monitoring
|
||||
system_monitor = start_system_monitoring()
|
||||
|
||||
# Set up cleanup callback for memory management
|
||||
def cleanup_callback():
|
||||
"""Custom cleanup for memory management"""
|
||||
try:
|
||||
# Clear orchestrator caches
|
||||
if hasattr(orchestrator, 'recent_decisions'):
|
||||
for symbol in orchestrator.recent_decisions:
|
||||
if len(orchestrator.recent_decisions[symbol]) > 50:
|
||||
orchestrator.recent_decisions[symbol] = orchestrator.recent_decisions[symbol][-25:]
|
||||
|
||||
# Clear data provider caches
|
||||
if hasattr(data_provider, 'clear_old_data'):
|
||||
data_provider.clear_old_data()
|
||||
|
||||
logger.info("Custom memory cleanup completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in custom cleanup: {e}")
|
||||
|
||||
system_monitor.set_callbacks(cleanup=cleanup_callback)
|
||||
logger.info("System resource monitoring started with memory cleanup")
|
||||
|
||||
# Import clean dashboard
|
||||
from web.clean_dashboard import create_clean_dashboard
|
||||
|
||||
# Create clean dashboard
|
||||
logger.info("Creating clean dashboard...")
|
||||
dashboard = create_clean_dashboard(data_provider, orchestrator, trading_executor)
|
||||
logger.info("Clean Trading Dashboard created")
|
||||
|
||||
# Add memory cleanup method to dashboard
|
||||
def cleanup_dashboard_memory():
|
||||
"""Clean up dashboard memory caches"""
|
||||
try:
|
||||
if hasattr(dashboard, 'recent_decisions'):
|
||||
dashboard.recent_decisions = dashboard.recent_decisions[-50:] # Keep last 50
|
||||
if hasattr(dashboard, 'closed_trades'):
|
||||
dashboard.closed_trades = dashboard.closed_trades[-100:] # Keep last 100
|
||||
if hasattr(dashboard, 'tick_cache'):
|
||||
dashboard.tick_cache = dashboard.tick_cache[-1000:] # Keep last 1000
|
||||
logger.debug("Dashboard memory cleanup completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in dashboard memory cleanup: {e}")
|
||||
|
||||
# Set cleanup method on dashboard
|
||||
dashboard.cleanup_memory = cleanup_dashboard_memory
|
||||
|
||||
# Start training pipeline in background thread with enhanced error handling
|
||||
def training_worker():
|
||||
"""Run training pipeline in background with comprehensive error handling"""
|
||||
try:
|
||||
asyncio.run(start_training_pipeline(orchestrator, trading_executor))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training worker stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Training worker error: {e}")
|
||||
import traceback
|
||||
logger.error(f"Training worker traceback: {traceback.format_exc()}")
|
||||
# Don't exit - let main thread handle restart
|
||||
|
||||
training_thread = threading.Thread(target=training_worker, daemon=True)
|
||||
training_thread.start()
|
||||
logger.info("Training pipeline started in background with error handling")
|
||||
|
||||
# Wait a moment for training to initialize
|
||||
time.sleep(3)
|
||||
|
||||
# Start TensorBoard in background
|
||||
from web.tensorboard_integration import get_tensorboard_integration
|
||||
tensorboard_port = int(os.environ.get('TENSORBOARD_PORT', '6006'))
|
||||
tensorboard_integration = get_tensorboard_integration(log_dir="runs", port=tensorboard_port)
|
||||
|
||||
# Start TensorBoard server
|
||||
tensorboard_started = tensorboard_integration.start_tensorboard(open_browser=False)
|
||||
if tensorboard_started:
|
||||
logger.info(f"TensorBoard started at {tensorboard_integration.get_tensorboard_url()}")
|
||||
else:
|
||||
logger.warning("Failed to start TensorBoard - training metrics will not be visualized")
|
||||
|
||||
# Start dashboard server with error handling (this blocks)
|
||||
logger.info("Starting Clean Dashboard Server with error handling...")
|
||||
try:
|
||||
dashboard.run_server(host='127.0.0.1', port=dashboard_port, debug=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Dashboard server error: {e}")
|
||||
import traceback
|
||||
logger.error(f"Dashboard server traceback: {traceback.format_exc()}")
|
||||
raise # Re-raise to trigger main error handling
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("System stopped by user")
|
||||
# Stop TensorBoard
|
||||
try:
|
||||
tensorboard_integration = get_tensorboard_integration()
|
||||
tensorboard_integration.stop_tensorboard()
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error running clean dashboard with training: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
from core.config import setup_logging, get_config
|
||||
from core.trading_executor import TradingExecutor
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
from web.clean_dashboard import CleanTradingDashboard
|
||||
|
||||
def main():
|
||||
"""Main function with comprehensive error handling"""
|
||||
try:
|
||||
start_clean_dashboard_with_training()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Dashboard stopped by user (Ctrl+C)")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error in main: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
setup_logging()
|
||||
cfg = get_config()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Ensure logging is flushed on exit
|
||||
import atexit
|
||||
def flush_logs():
|
||||
logging.shutdown()
|
||||
atexit.register(flush_logs)
|
||||
data_provider = StandardizedDataProvider()
|
||||
trading_executor = TradingExecutor()
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
dashboard = CleanTradingDashboard(
|
||||
data_provider=data_provider,
|
||||
orchestrator=orchestrator,
|
||||
trading_executor=trading_executor
|
||||
)
|
||||
logging.getLogger(__name__).info("Starting Clean Trading Dashboard at http://127.0.0.1:8050")
|
||||
dashboard.run_server(host='127.0.0.1', port=8050, debug=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
@@ -1,501 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Continuous Full Training System (RL + CNN)
|
||||
|
||||
This system runs continuous training for both RL and CNN models using the enhanced
|
||||
DataProvider for consistent data streaming to both models and the dashboard.
|
||||
|
||||
Features:
|
||||
- Single DataProvider instance for all data needs
|
||||
- Continuous RL training with real-time market data
|
||||
- CNN training with perfect move detection
|
||||
- Real-time performance monitoring
|
||||
- Automatic model checkpointing
|
||||
- Integration with live trading dashboard
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Thread, Event
|
||||
from typing import Dict, Any
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('logs/continuous_training.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import our components
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider, MarketTick
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
class ContinuousTrainingSystem:
|
||||
"""Comprehensive continuous training system for RL + CNN models"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the continuous training system"""
|
||||
self.config = get_config()
|
||||
|
||||
# Single DataProvider instance for all data needs
|
||||
self.data_provider = DataProvider(
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
timeframes=['1s', '1m', '1h', '1d']
|
||||
)
|
||||
|
||||
# Enhanced orchestrator for AI trading
|
||||
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||
|
||||
# Dashboard for monitoring
|
||||
self.dashboard = None
|
||||
|
||||
# Training control
|
||||
self.running = False
|
||||
self.shutdown_event = Event()
|
||||
|
||||
# Checkpoint management
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
self.training_integration = get_training_integration()
|
||||
|
||||
# Performance tracking
|
||||
self.training_stats = {
|
||||
'start_time': None,
|
||||
'rl_training_cycles': 0,
|
||||
'cnn_training_cycles': 0,
|
||||
'perfect_moves_detected': 0,
|
||||
'total_ticks_processed': 0,
|
||||
'models_saved': 0,
|
||||
'last_checkpoint': None,
|
||||
'best_rl_reward': float('-inf'),
|
||||
'best_cnn_accuracy': 0.0
|
||||
}
|
||||
|
||||
# Training intervals
|
||||
self.rl_training_interval = 300 # 5 minutes
|
||||
self.cnn_training_interval = 600 # 10 minutes
|
||||
self.checkpoint_interval = 1800 # 30 minutes
|
||||
|
||||
logger.info("Continuous Training System initialized with checkpoint management")
|
||||
logger.info(f"RL training interval: {self.rl_training_interval}s")
|
||||
logger.info(f"CNN training interval: {self.cnn_training_interval}s")
|
||||
logger.info(f"Checkpoint interval: {self.checkpoint_interval}s")
|
||||
|
||||
async def start(self, run_dashboard: bool = True):
|
||||
"""Start the continuous training system"""
|
||||
logger.info("Starting Continuous Training System...")
|
||||
self.running = True
|
||||
self.training_stats['start_time'] = datetime.now()
|
||||
|
||||
try:
|
||||
# Start DataProvider streaming
|
||||
logger.info("Starting DataProvider real-time streaming...")
|
||||
await self.data_provider.start_real_time_streaming()
|
||||
|
||||
# Subscribe to tick data for training
|
||||
subscriber_id = self.data_provider.subscribe_to_ticks(
|
||||
callback=self._handle_training_tick,
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
subscriber_name="ContinuousTraining"
|
||||
)
|
||||
logger.info(f"Subscribed to training tick stream: {subscriber_id}")
|
||||
|
||||
# Start training threads
|
||||
training_tasks = [
|
||||
asyncio.create_task(self._rl_training_loop()),
|
||||
asyncio.create_task(self._cnn_training_loop()),
|
||||
asyncio.create_task(self._checkpoint_loop()),
|
||||
asyncio.create_task(self._monitoring_loop())
|
||||
]
|
||||
|
||||
# Start dashboard if requested
|
||||
if run_dashboard:
|
||||
dashboard_task = asyncio.create_task(self._run_dashboard())
|
||||
training_tasks.append(dashboard_task)
|
||||
|
||||
logger.info("All training components started successfully")
|
||||
|
||||
# Wait for shutdown signal
|
||||
await self._wait_for_shutdown()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous training system: {e}")
|
||||
raise
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
def _handle_training_tick(self, tick: MarketTick):
|
||||
"""Handle incoming tick data for training"""
|
||||
try:
|
||||
self.training_stats['total_ticks_processed'] += 1
|
||||
|
||||
# Process tick through orchestrator for RL training
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'process_tick'):
|
||||
self.orchestrator.process_tick(tick)
|
||||
|
||||
# Log every 1000 ticks
|
||||
if self.training_stats['total_ticks_processed'] % 1000 == 0:
|
||||
logger.info(f"Processed {self.training_stats['total_ticks_processed']} training ticks")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing training tick: {e}")
|
||||
|
||||
async def _rl_training_loop(self):
|
||||
"""Continuous RL training loop"""
|
||||
logger.info("Starting RL training loop...")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Perform RL training cycle
|
||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
logger.info("Starting RL training cycle...")
|
||||
|
||||
# Get recent market data for training
|
||||
training_data = self._prepare_rl_training_data()
|
||||
|
||||
if training_data is not None:
|
||||
# Train RL agent
|
||||
training_results = await self._train_rl_agent(training_data)
|
||||
|
||||
if training_results:
|
||||
self.training_stats['rl_training_cycles'] += 1
|
||||
logger.info(f"RL training cycle {self.training_stats['rl_training_cycles']} completed")
|
||||
logger.info(f"Training results: {training_results}")
|
||||
else:
|
||||
logger.warning("No training data available for RL agent")
|
||||
|
||||
# Wait for next training cycle
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.rl_training_interval - elapsed)
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RL training loop: {e}")
|
||||
await asyncio.sleep(60) # Wait before retrying
|
||||
|
||||
async def _cnn_training_loop(self):
|
||||
"""Continuous CNN training loop"""
|
||||
logger.info("Starting CNN training loop...")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Perform CNN training cycle
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
logger.info("Starting CNN training cycle...")
|
||||
|
||||
# Detect perfect moves for CNN training
|
||||
perfect_moves = self._detect_perfect_moves()
|
||||
|
||||
if perfect_moves:
|
||||
self.training_stats['perfect_moves_detected'] += len(perfect_moves)
|
||||
|
||||
# Train CNN with perfect moves
|
||||
training_results = await self._train_cnn_model(perfect_moves)
|
||||
|
||||
if training_results:
|
||||
self.training_stats['cnn_training_cycles'] += 1
|
||||
logger.info(f"CNN training cycle {self.training_stats['cnn_training_cycles']} completed")
|
||||
logger.info(f"Perfect moves processed: {len(perfect_moves)}")
|
||||
else:
|
||||
logger.info("No perfect moves detected for CNN training")
|
||||
|
||||
# Wait for next training cycle
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.cnn_training_interval - elapsed)
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN training loop: {e}")
|
||||
await asyncio.sleep(60) # Wait before retrying
|
||||
|
||||
async def _checkpoint_loop(self):
|
||||
"""Automatic model checkpointing loop"""
|
||||
logger.info("Starting checkpoint loop...")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
await asyncio.sleep(self.checkpoint_interval)
|
||||
|
||||
logger.info("Creating model checkpoints...")
|
||||
|
||||
# Save RL model
|
||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
rl_checkpoint = await self._save_rl_checkpoint()
|
||||
if rl_checkpoint:
|
||||
logger.info(f"RL checkpoint saved: {rl_checkpoint}")
|
||||
|
||||
# Save CNN model
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
cnn_checkpoint = await self._save_cnn_checkpoint()
|
||||
if cnn_checkpoint:
|
||||
logger.info(f"CNN checkpoint saved: {cnn_checkpoint}")
|
||||
|
||||
self.training_stats['models_saved'] += 1
|
||||
self.training_stats['last_checkpoint'] = datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in checkpoint loop: {e}")
|
||||
|
||||
async def _monitoring_loop(self):
|
||||
"""System monitoring and performance tracking loop"""
|
||||
logger.info("Starting monitoring loop...")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
await asyncio.sleep(300) # Monitor every 5 minutes
|
||||
|
||||
# Log system statistics
|
||||
uptime = datetime.now() - self.training_stats['start_time']
|
||||
|
||||
logger.info("=== CONTINUOUS TRAINING SYSTEM STATUS ===")
|
||||
logger.info(f"Uptime: {uptime}")
|
||||
logger.info(f"RL training cycles: {self.training_stats['rl_training_cycles']}")
|
||||
logger.info(f"CNN training cycles: {self.training_stats['cnn_training_cycles']}")
|
||||
logger.info(f"Perfect moves detected: {self.training_stats['perfect_moves_detected']}")
|
||||
logger.info(f"Total ticks processed: {self.training_stats['total_ticks_processed']}")
|
||||
logger.info(f"Models saved: {self.training_stats['models_saved']}")
|
||||
|
||||
# DataProvider statistics
|
||||
if hasattr(self.data_provider, 'get_subscriber_stats'):
|
||||
subscriber_stats = self.data_provider.get_subscriber_stats()
|
||||
logger.info(f"Active subscribers: {subscriber_stats.get('active_subscribers', 0)}")
|
||||
logger.info(f"Total ticks distributed: {subscriber_stats.get('distribution_stats', {}).get('total_ticks_distributed', 0)}")
|
||||
|
||||
# Orchestrator performance
|
||||
if hasattr(self.orchestrator, 'get_performance_metrics'):
|
||||
perf_metrics = self.orchestrator.get_performance_metrics()
|
||||
logger.info(f"Orchestrator performance: {perf_metrics}")
|
||||
|
||||
logger.info("==========================================")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in monitoring loop: {e}")
|
||||
|
||||
async def _run_dashboard(self):
|
||||
"""Run the dashboard in a separate thread"""
|
||||
try:
|
||||
logger.info("Starting live trading dashboard...")
|
||||
|
||||
def run_dashboard():
|
||||
self.dashboard = RealTimeScalpingDashboard(
|
||||
data_provider=self.data_provider,
|
||||
orchestrator=self.orchestrator
|
||||
)
|
||||
self.dashboard.run(host='127.0.0.1', port=8051, debug=False)
|
||||
|
||||
dashboard_thread = Thread(target=run_dashboard, daemon=True)
|
||||
dashboard_thread.start()
|
||||
|
||||
logger.info("Dashboard started at http://127.0.0.1:8051")
|
||||
|
||||
# Keep dashboard thread alive
|
||||
while self.running:
|
||||
await asyncio.sleep(10)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running dashboard: {e}")
|
||||
|
||||
def _prepare_rl_training_data(self) -> Dict[str, Any]:
|
||||
"""Prepare training data for RL agent"""
|
||||
try:
|
||||
# Get recent market data from DataProvider
|
||||
eth_data = self.data_provider.get_latest_candles('ETH/USDT', '1m', limit=1000)
|
||||
btc_data = self.data_provider.get_latest_candles('BTC/USDT', '1m', limit=1000)
|
||||
|
||||
if eth_data is not None and not eth_data.empty:
|
||||
return {
|
||||
'eth_data': eth_data,
|
||||
'btc_data': btc_data,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing RL training data: {e}")
|
||||
return None
|
||||
|
||||
def _detect_perfect_moves(self) -> list:
|
||||
"""Detect perfect trading moves for CNN training"""
|
||||
try:
|
||||
# Get recent tick data
|
||||
recent_ticks = self.data_provider.get_recent_ticks('ETHUSDT', count=500)
|
||||
|
||||
if not recent_ticks:
|
||||
return []
|
||||
|
||||
# Simple perfect move detection (can be enhanced)
|
||||
perfect_moves = []
|
||||
|
||||
for i in range(1, len(recent_ticks) - 1):
|
||||
prev_tick = recent_ticks[i-1]
|
||||
curr_tick = recent_ticks[i]
|
||||
next_tick = recent_ticks[i+1]
|
||||
|
||||
# Detect significant price movements
|
||||
price_change = (next_tick.price - curr_tick.price) / curr_tick.price
|
||||
|
||||
if abs(price_change) > 0.001: # 0.1% movement
|
||||
perfect_moves.append({
|
||||
'timestamp': curr_tick.timestamp,
|
||||
'price': curr_tick.price,
|
||||
'action': 'BUY' if price_change > 0 else 'SELL',
|
||||
'confidence': min(abs(price_change) * 100, 1.0)
|
||||
})
|
||||
|
||||
return perfect_moves[-10:] # Return last 10 perfect moves
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting perfect moves: {e}")
|
||||
return []
|
||||
|
||||
async def _train_rl_agent(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Train the RL agent with market data"""
|
||||
try:
|
||||
# Placeholder for RL training logic
|
||||
# This would integrate with the actual RL agent
|
||||
|
||||
logger.info("Training RL agent with market data...")
|
||||
|
||||
# Simulate training time
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return {
|
||||
'loss': 0.05,
|
||||
'reward': 0.75,
|
||||
'episodes': 100
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training RL agent: {e}")
|
||||
return None
|
||||
|
||||
async def _train_cnn_model(self, perfect_moves: list) -> Dict[str, Any]:
|
||||
"""Train the CNN model with perfect moves"""
|
||||
try:
|
||||
# Placeholder for CNN training logic
|
||||
# This would integrate with the actual CNN model
|
||||
|
||||
logger.info(f"Training CNN model with {len(perfect_moves)} perfect moves...")
|
||||
|
||||
# Simulate training time
|
||||
await asyncio.sleep(2)
|
||||
|
||||
return {
|
||||
'accuracy': 0.92,
|
||||
'loss': 0.08,
|
||||
'perfect_moves_processed': len(perfect_moves)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN model: {e}")
|
||||
return None
|
||||
|
||||
async def _save_rl_checkpoint(self) -> str:
|
||||
"""Save RL model checkpoint"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
checkpoint_path = f"models/rl/checkpoint_rl_{timestamp}.pt"
|
||||
|
||||
# Placeholder for actual model saving
|
||||
logger.info(f"Saving RL checkpoint to {checkpoint_path}")
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving RL checkpoint: {e}")
|
||||
return None
|
||||
|
||||
async def _save_cnn_checkpoint(self) -> str:
|
||||
"""Save CNN model checkpoint"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
checkpoint_path = f"models/cnn/checkpoint_cnn_{timestamp}.pt"
|
||||
|
||||
# Placeholder for actual model saving
|
||||
logger.info(f"Saving CNN checkpoint to {checkpoint_path}")
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving CNN checkpoint: {e}")
|
||||
return None
|
||||
|
||||
async def _wait_for_shutdown(self):
|
||||
"""Wait for shutdown signal"""
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, shutting down...")
|
||||
self.shutdown_event.set()
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Wait for shutdown event
|
||||
while not self.shutdown_event.is_set():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the continuous training system"""
|
||||
logger.info("Stopping Continuous Training System...")
|
||||
self.running = False
|
||||
|
||||
try:
|
||||
# Stop DataProvider streaming
|
||||
if self.data_provider:
|
||||
await self.data_provider.stop_real_time_streaming()
|
||||
|
||||
# Final checkpoint
|
||||
logger.info("Creating final checkpoints...")
|
||||
await self._save_rl_checkpoint()
|
||||
await self._save_cnn_checkpoint()
|
||||
|
||||
# Log final statistics
|
||||
uptime = datetime.now() - self.training_stats['start_time']
|
||||
logger.info("=== FINAL TRAINING STATISTICS ===")
|
||||
logger.info(f"Total uptime: {uptime}")
|
||||
logger.info(f"RL training cycles: {self.training_stats['rl_training_cycles']}")
|
||||
logger.info(f"CNN training cycles: {self.training_stats['cnn_training_cycles']}")
|
||||
logger.info(f"Perfect moves detected: {self.training_stats['perfect_moves_detected']}")
|
||||
logger.info(f"Total ticks processed: {self.training_stats['total_ticks_processed']}")
|
||||
logger.info(f"Models saved: {self.training_stats['models_saved']}")
|
||||
logger.info("=================================")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
|
||||
logger.info("Continuous Training System stopped")
|
||||
|
||||
async def main():
|
||||
"""Main entry point"""
|
||||
logger.info("Starting Continuous Full Training System (RL + CNN)")
|
||||
|
||||
# Create and start the training system
|
||||
training_system = ContinuousTrainingSystem()
|
||||
|
||||
try:
|
||||
await training_system.start(run_dashboard=True)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,269 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Crash-Safe Dashboard Runner
|
||||
|
||||
This runner is designed to prevent crashes by:
|
||||
1. Isolating imports with try/except blocks
|
||||
2. Minimal initialization
|
||||
3. Graceful error handling
|
||||
4. No complex training loops
|
||||
5. Safe component loading
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
# Fix environment issues before any imports
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||
os.environ['OMP_NUM_THREADS'] = '1' # Minimal threads
|
||||
os.environ['MPLBACKEND'] = 'Agg'
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Setup basic logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reduce noise from other loggers
|
||||
logging.getLogger('werkzeug').setLevel(logging.ERROR)
|
||||
logging.getLogger('dash').setLevel(logging.ERROR)
|
||||
logging.getLogger('matplotlib').setLevel(logging.ERROR)
|
||||
|
||||
class CrashSafeDashboard:
|
||||
"""Crash-safe dashboard with minimal dependencies"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize with safe error handling"""
|
||||
self.components = {}
|
||||
self.dashboard_app = None
|
||||
self.initialization_errors = []
|
||||
|
||||
logger.info("Initializing crash-safe dashboard...")
|
||||
|
||||
def safe_import(self, module_name, class_name=None):
|
||||
"""Safely import modules with error handling"""
|
||||
try:
|
||||
if class_name:
|
||||
module = __import__(module_name, fromlist=[class_name])
|
||||
return getattr(module, class_name)
|
||||
else:
|
||||
return __import__(module_name)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to import {module_name}.{class_name if class_name else ''}: {e}"
|
||||
logger.error(error_msg)
|
||||
self.initialization_errors.append(error_msg)
|
||||
return None
|
||||
|
||||
def initialize_core_components(self):
|
||||
"""Initialize core components safely"""
|
||||
logger.info("Initializing core components...")
|
||||
|
||||
# Try to import and initialize config
|
||||
try:
|
||||
from core.config import get_config, setup_logging
|
||||
setup_logging()
|
||||
self.components['config'] = get_config()
|
||||
logger.info("✓ Config loaded")
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Config failed: {e}")
|
||||
self.initialization_errors.append(f"Config: {e}")
|
||||
|
||||
# Try to initialize data provider
|
||||
try:
|
||||
from core.data_provider import DataProvider
|
||||
self.components['data_provider'] = DataProvider()
|
||||
logger.info("✓ Data provider initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Data provider failed: {e}")
|
||||
self.initialization_errors.append(f"Data provider: {e}")
|
||||
|
||||
# Try to initialize trading executor
|
||||
try:
|
||||
from core.trading_executor import TradingExecutor
|
||||
self.components['trading_executor'] = TradingExecutor()
|
||||
logger.info("✓ Trading executor initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Trading executor failed: {e}")
|
||||
self.initialization_errors.append(f"Trading executor: {e}")
|
||||
|
||||
# Try to initialize orchestrator (WITHOUT training to avoid crashes)
|
||||
try:
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
self.components['orchestrator'] = TradingOrchestrator(
|
||||
data_provider=self.components.get('data_provider'),
|
||||
enhanced_rl_training=False # DISABLED to prevent crashes
|
||||
)
|
||||
logger.info("✓ Orchestrator initialized (training disabled)")
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Orchestrator failed: {e}")
|
||||
self.initialization_errors.append(f"Orchestrator: {e}")
|
||||
|
||||
def create_minimal_dashboard(self):
|
||||
"""Create minimal dashboard without complex features"""
|
||||
try:
|
||||
import dash
|
||||
from dash import html, dcc
|
||||
|
||||
# Create minimal Dash app
|
||||
self.dashboard_app = dash.Dash(__name__)
|
||||
|
||||
# Create simple layout
|
||||
self.dashboard_app.layout = html.Div([
|
||||
html.H1("Trading Dashboard - Safe Mode", style={'textAlign': 'center'}),
|
||||
html.Hr(),
|
||||
|
||||
# Status section
|
||||
html.Div([
|
||||
html.H3("System Status"),
|
||||
html.Div(id="system-status", children=self._get_system_status()),
|
||||
], style={'margin': '20px'}),
|
||||
|
||||
# Error section
|
||||
html.Div([
|
||||
html.H3("Initialization Status"),
|
||||
html.Div(id="init-status", children=self._get_init_status()),
|
||||
], style={'margin': '20px'}),
|
||||
|
||||
# Simple refresh interval
|
||||
dcc.Interval(
|
||||
id='interval-component',
|
||||
interval=5000, # Update every 5 seconds
|
||||
n_intervals=0
|
||||
)
|
||||
])
|
||||
|
||||
# Add simple callback
|
||||
@self.dashboard_app.callback(
|
||||
[dash.dependencies.Output('system-status', 'children'),
|
||||
dash.dependencies.Output('init-status', 'children')],
|
||||
[dash.dependencies.Input('interval-component', 'n_intervals')]
|
||||
)
|
||||
def update_status(n):
|
||||
try:
|
||||
return self._get_system_status(), self._get_init_status()
|
||||
except Exception as e:
|
||||
logger.error(f"Callback error: {e}")
|
||||
return f"Callback error: {e}", "Error in callback"
|
||||
|
||||
logger.info("✓ Minimal dashboard created")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Dashboard creation failed: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def _get_system_status(self):
|
||||
"""Get system status for display"""
|
||||
try:
|
||||
status_items = []
|
||||
|
||||
# Check components
|
||||
for name, component in self.components.items():
|
||||
if component is not None:
|
||||
status_items.append(html.P(f"✓ {name.replace('_', ' ').title()}: OK",
|
||||
style={'color': 'green'}))
|
||||
else:
|
||||
status_items.append(html.P(f"✗ {name.replace('_', ' ').title()}: Failed",
|
||||
style={'color': 'red'}))
|
||||
|
||||
# Add timestamp
|
||||
status_items.append(html.P(f"Last update: {datetime.now().strftime('%H:%M:%S')}",
|
||||
style={'color': 'gray', 'fontSize': '12px'}))
|
||||
|
||||
return status_items
|
||||
|
||||
except Exception as e:
|
||||
return [html.P(f"Status error: {e}", style={'color': 'red'})]
|
||||
|
||||
def _get_init_status(self):
|
||||
"""Get initialization status for display"""
|
||||
try:
|
||||
if not self.initialization_errors:
|
||||
return [html.P("✓ All components initialized successfully", style={'color': 'green'})]
|
||||
|
||||
error_items = [html.P("⚠️ Some components failed to initialize:", style={'color': 'orange'})]
|
||||
|
||||
for error in self.initialization_errors:
|
||||
error_items.append(html.P(f"• {error}", style={'color': 'red', 'fontSize': '12px'}))
|
||||
|
||||
return error_items
|
||||
|
||||
except Exception as e:
|
||||
return [html.P(f"Init status error: {e}", style={'color': 'red'})]
|
||||
|
||||
def run(self, port=8051):
|
||||
"""Run the crash-safe dashboard"""
|
||||
try:
|
||||
logger.info("=" * 60)
|
||||
logger.info("CRASH-SAFE DASHBOARD")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Mode: Safe mode with minimal features")
|
||||
logger.info("Training: Completely disabled")
|
||||
logger.info("Focus: System stability and basic monitoring")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize components
|
||||
self.initialize_core_components()
|
||||
|
||||
# Create dashboard
|
||||
if not self.create_minimal_dashboard():
|
||||
logger.error("Failed to create dashboard")
|
||||
return False
|
||||
|
||||
# Report initialization status
|
||||
if self.initialization_errors:
|
||||
logger.warning(f"Dashboard starting with {len(self.initialization_errors)} component failures")
|
||||
for error in self.initialization_errors:
|
||||
logger.warning(f" - {error}")
|
||||
else:
|
||||
logger.info("All components initialized successfully")
|
||||
|
||||
# Start dashboard
|
||||
logger.info(f"Starting dashboard on http://127.0.0.1:{port}")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
self.dashboard_app.run_server(
|
||||
host='127.0.0.1',
|
||||
port=port,
|
||||
debug=False,
|
||||
use_reloader=False,
|
||||
threaded=True
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Dashboard stopped by user")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Dashboard failed: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main function with comprehensive error handling"""
|
||||
try:
|
||||
dashboard = CrashSafeDashboard()
|
||||
success = dashboard.run()
|
||||
|
||||
if success:
|
||||
logger.info("Dashboard completed successfully")
|
||||
else:
|
||||
logger.error("Dashboard failed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,525 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Enhanced RL Training Launcher with Real Data Integration
|
||||
|
||||
This script launches the comprehensive RL training system that uses:
|
||||
- Real-time tick data (300s window for momentum detection)
|
||||
- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d)
|
||||
- BTC reference data for correlation
|
||||
- CNN hidden features and predictions
|
||||
- Williams Market Structure pivot points
|
||||
- Market microstructure analysis
|
||||
|
||||
The RL model will receive ~13,400 features instead of the previous ~100 basic features.
|
||||
Training metrics are automatically logged to TensorBoard for visualization.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('enhanced_rl_training.log'),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import our enhanced components
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from training.enhanced_rl_trainer import EnhancedRLTrainer
|
||||
from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
|
||||
from training.williams_market_structure import WilliamsMarketStructure
|
||||
from training.cnn_rl_bridge import CNNRLBridge
|
||||
from utils.tensorboard_logger import TensorBoardLogger
|
||||
|
||||
class EnhancedRLTrainingSystem:
|
||||
"""Comprehensive RL training system with real data integration"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the enhanced RL training system"""
|
||||
self.config = get_config()
|
||||
self.running = False
|
||||
self.data_provider = None
|
||||
self.orchestrator = None
|
||||
self.rl_trainer = None
|
||||
|
||||
# Performance tracking
|
||||
self.training_stats = {
|
||||
'training_sessions': 0,
|
||||
'total_experiences': 0,
|
||||
'avg_state_size': 0,
|
||||
'data_quality_score': 0.0,
|
||||
'last_training_time': None
|
||||
}
|
||||
|
||||
# Initialize TensorBoard logger
|
||||
experiment_name = f"enhanced_rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self.tb_logger = TensorBoardLogger(
|
||||
log_dir="runs",
|
||||
experiment_name=experiment_name,
|
||||
enabled=True
|
||||
)
|
||||
|
||||
logger.info("Enhanced RL Training System initialized")
|
||||
logger.info(f"TensorBoard logging enabled for experiment: {experiment_name}")
|
||||
logger.info("Features:")
|
||||
logger.info("- Real-time tick data processing (300s window)")
|
||||
logger.info("- Multi-timeframe OHLCV analysis (1s, 1m, 1h, 1d)")
|
||||
logger.info("- BTC correlation analysis")
|
||||
logger.info("- CNN feature integration")
|
||||
logger.info("- Williams Market Structure pivot points")
|
||||
logger.info("- ~13,400 feature state vector (vs previous ~100)")
|
||||
|
||||
# async def initialize(self):
|
||||
# """Initialize all components"""
|
||||
# try:
|
||||
# logger.info("Initializing enhanced RL training components...")
|
||||
|
||||
# # Initialize data provider with real-time streaming
|
||||
# logger.info("Setting up data provider with real-time streaming...")
|
||||
# self.data_provider = DataProvider(
|
||||
# symbols=self.config.symbols,
|
||||
# timeframes=self.config.timeframes
|
||||
# )
|
||||
|
||||
# # Start real-time data streaming
|
||||
# await self.data_provider.start_real_time_streaming()
|
||||
# logger.info("Real-time data streaming started")
|
||||
|
||||
# # Wait for initial data collection
|
||||
# logger.info("Collecting initial market data...")
|
||||
# await asyncio.sleep(30) # Allow 30 seconds for data collection
|
||||
|
||||
# # Initialize enhanced orchestrator
|
||||
# logger.info("Initializing enhanced orchestrator...")
|
||||
# self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||
|
||||
# # Initialize enhanced RL trainer with comprehensive state building
|
||||
# logger.info("Initializing enhanced RL trainer...")
|
||||
# self.rl_trainer = EnhancedRLTrainer(
|
||||
# config=self.config,
|
||||
# orchestrator=self.orchestrator
|
||||
# )
|
||||
|
||||
# # Verify data availability
|
||||
# data_status = await self._verify_data_availability()
|
||||
# if not data_status['has_sufficient_data']:
|
||||
# logger.warning("Insufficient data detected. Continuing with limited training.")
|
||||
# logger.warning(f"Data status: {data_status}")
|
||||
# else:
|
||||
# logger.info("Sufficient data available for comprehensive RL training")
|
||||
# logger.info(f"Tick data: {data_status['tick_count']} ticks")
|
||||
# logger.info(f"OHLCV data: {data_status['ohlcv_bars']} bars")
|
||||
|
||||
# self.running = True
|
||||
# logger.info("Enhanced RL training system initialized successfully")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error during initialization: {e}")
|
||||
# raise
|
||||
|
||||
# async def _verify_data_availability(self) -> Dict[str, any]:
|
||||
# """Verify that we have sufficient data for training"""
|
||||
# try:
|
||||
# data_status = {
|
||||
# 'has_sufficient_data': False,
|
||||
# 'tick_count': 0,
|
||||
# 'ohlcv_bars': 0,
|
||||
# 'symbols_with_data': [],
|
||||
# 'missing_data': []
|
||||
# }
|
||||
|
||||
# for symbol in self.config.symbols:
|
||||
# # Check tick data
|
||||
# recent_ticks = self.data_provider.get_recent_ticks(symbol, count=100)
|
||||
# tick_count = len(recent_ticks)
|
||||
|
||||
# # Check OHLCV data
|
||||
# ohlcv_bars = 0
|
||||
# for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
# try:
|
||||
# df = self.data_provider.get_historical_data(
|
||||
# symbol=symbol,
|
||||
# timeframe=timeframe,
|
||||
# limit=50,
|
||||
# refresh=True
|
||||
# )
|
||||
# if df is not None and not df.empty:
|
||||
# ohlcv_bars += len(df)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error checking {timeframe} data for {symbol}: {e}")
|
||||
|
||||
# data_status['tick_count'] += tick_count
|
||||
# data_status['ohlcv_bars'] += ohlcv_bars
|
||||
|
||||
# if tick_count >= 50 and ohlcv_bars >= 100:
|
||||
# data_status['symbols_with_data'].append(symbol)
|
||||
# else:
|
||||
# data_status['missing_data'].append(f"{symbol}: {tick_count} ticks, {ohlcv_bars} bars")
|
||||
|
||||
# # Consider data sufficient if we have at least one symbol with good data
|
||||
# data_status['has_sufficient_data'] = len(data_status['symbols_with_data']) > 0
|
||||
|
||||
# return data_status
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error verifying data availability: {e}")
|
||||
# return {'has_sufficient_data': False, 'error': str(e)}
|
||||
|
||||
# async def run_training_loop(self):
|
||||
# """Run the main training loop with real data"""
|
||||
# logger.info("Starting enhanced RL training loop...")
|
||||
|
||||
# training_cycle = 0
|
||||
# last_state_size_log = time.time()
|
||||
|
||||
# try:
|
||||
# while self.running:
|
||||
# training_cycle += 1
|
||||
# cycle_start_time = time.time()
|
||||
|
||||
# logger.info(f"Training cycle {training_cycle} started")
|
||||
|
||||
# # Get comprehensive market states with real data
|
||||
# market_states = await self._get_comprehensive_market_states()
|
||||
|
||||
# if not market_states:
|
||||
# logger.warning("No market states available. Waiting for data...")
|
||||
# await asyncio.sleep(60)
|
||||
# continue
|
||||
|
||||
# # Train RL agents with comprehensive states
|
||||
# training_results = await self._train_rl_agents(market_states)
|
||||
|
||||
# # Update performance tracking
|
||||
# self._update_training_stats(training_results, market_states)
|
||||
|
||||
# # Log training progress
|
||||
# cycle_duration = time.time() - cycle_start_time
|
||||
# logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
|
||||
|
||||
# # Log state size periodically
|
||||
# if time.time() - last_state_size_log > 300: # Every 5 minutes
|
||||
# self._log_state_size_info(market_states)
|
||||
# last_state_size_log = time.time()
|
||||
|
||||
# # Save models periodically
|
||||
# if training_cycle % 10 == 0:
|
||||
# await self._save_training_progress()
|
||||
|
||||
# # Wait before next training cycle
|
||||
# await asyncio.sleep(300) # Train every 5 minutes
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in training loop: {e}")
|
||||
# raise
|
||||
|
||||
# async def _get_comprehensive_market_states(self) -> Dict[str, any]:
|
||||
# """Get comprehensive market states with all required data"""
|
||||
# try:
|
||||
# # Get market states from orchestrator
|
||||
# universal_stream = self.orchestrator.universal_adapter.get_universal_stream()
|
||||
# market_states = await self.orchestrator._get_all_market_states_universal(universal_stream)
|
||||
|
||||
# # Verify data quality
|
||||
# quality_score = self._calculate_data_quality(market_states)
|
||||
# self.training_stats['data_quality_score'] = quality_score
|
||||
|
||||
# if quality_score < 0.5:
|
||||
# logger.warning(f"Low data quality detected: {quality_score:.2f}")
|
||||
|
||||
# return market_states
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error getting comprehensive market states: {e}")
|
||||
# return {}
|
||||
|
||||
# def _calculate_data_quality(self, market_states: Dict[str, any]) -> float:
|
||||
# """Calculate data quality score based on available data"""
|
||||
# try:
|
||||
# if not market_states:
|
||||
# return 0.0
|
||||
|
||||
# total_score = 0.0
|
||||
# total_symbols = len(market_states)
|
||||
|
||||
# for symbol, state in market_states.items():
|
||||
# symbol_score = 0.0
|
||||
|
||||
# # Score based on tick data availability
|
||||
# if hasattr(state, 'raw_ticks') and state.raw_ticks:
|
||||
# tick_score = min(len(state.raw_ticks) / 100, 1.0) # Max score for 100+ ticks
|
||||
# symbol_score += tick_score * 0.3
|
||||
|
||||
# # Score based on OHLCV data availability
|
||||
# if hasattr(state, 'ohlcv_data') and state.ohlcv_data:
|
||||
# ohlcv_score = len(state.ohlcv_data) / 4.0 # Max score for all 4 timeframes
|
||||
# symbol_score += min(ohlcv_score, 1.0) * 0.4
|
||||
|
||||
# # Score based on CNN features
|
||||
# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
|
||||
# symbol_score += 0.15
|
||||
|
||||
# # Score based on pivot points
|
||||
# if hasattr(state, 'pivot_points') and state.pivot_points:
|
||||
# symbol_score += 0.15
|
||||
|
||||
# total_score += symbol_score
|
||||
|
||||
# return total_score / total_symbols if total_symbols > 0 else 0.0
|
||||
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error calculating data quality: {e}")
|
||||
# return 0.5 # Default to medium quality
|
||||
|
||||
async def _train_rl_agents(self, market_states: Dict[str, any]) -> Dict[str, any]:
|
||||
"""Train RL agents with comprehensive market states"""
|
||||
try:
|
||||
training_results = {
|
||||
'symbols_trained': [],
|
||||
'total_experiences': 0,
|
||||
'avg_state_size': 0,
|
||||
'training_errors': [],
|
||||
'losses': {},
|
||||
'rewards': {}
|
||||
}
|
||||
|
||||
for symbol, market_state in market_states.items():
|
||||
try:
|
||||
# Convert market state to comprehensive RL state
|
||||
rl_state = self.rl_trainer._market_state_to_rl_state(market_state)
|
||||
|
||||
if rl_state is not None and len(rl_state) > 0:
|
||||
# Record state size
|
||||
state_size = len(rl_state)
|
||||
training_results['avg_state_size'] += state_size
|
||||
|
||||
# Log state size to TensorBoard
|
||||
self.tb_logger.log_scalar(
|
||||
f'State/{symbol}/Size',
|
||||
state_size,
|
||||
self.training_stats['training_sessions']
|
||||
)
|
||||
|
||||
# Simulate trading action for experience generation
|
||||
# In real implementation, this would be actual trading decisions
|
||||
action = self._simulate_trading_action(symbol, rl_state)
|
||||
|
||||
# Generate reward based on market outcome
|
||||
reward = self._calculate_training_reward(symbol, market_state, action)
|
||||
|
||||
# Store reward for TensorBoard logging
|
||||
training_results['rewards'][symbol] = reward
|
||||
|
||||
# Log action and reward to TensorBoard
|
||||
self.tb_logger.log_scalars(f'Actions/{symbol}', {
|
||||
'action': action,
|
||||
'reward': reward
|
||||
}, self.training_stats['training_sessions'])
|
||||
|
||||
# Add experience to RL agent
|
||||
agent = self.rl_trainer.agents.get(symbol)
|
||||
if agent:
|
||||
# Create next state (would be actual next market state in real scenario)
|
||||
next_state = rl_state # Simplified for now
|
||||
|
||||
agent.remember(
|
||||
state=rl_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=False
|
||||
)
|
||||
|
||||
# Train agent if enough experiences
|
||||
if len(agent.replay_buffer) >= agent.batch_size:
|
||||
loss = agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"Agent {symbol} training loss: {loss:.4f}")
|
||||
|
||||
# Store loss for TensorBoard logging
|
||||
training_results['losses'][symbol] = loss
|
||||
|
||||
# Log loss to TensorBoard
|
||||
self.tb_logger.log_scalar(
|
||||
f'Training/{symbol}/Loss',
|
||||
loss,
|
||||
self.training_stats['training_sessions']
|
||||
)
|
||||
|
||||
training_results['symbols_trained'].append(symbol)
|
||||
training_results['total_experiences'] += 1
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error training {symbol}: {e}"
|
||||
logger.warning(error_msg)
|
||||
training_results['training_errors'].append(error_msg)
|
||||
|
||||
# Calculate average state size
|
||||
if len(training_results['symbols_trained']) > 0:
|
||||
training_results['avg_state_size'] /= len(training_results['symbols_trained'])
|
||||
|
||||
# Log overall training metrics to TensorBoard
|
||||
self.tb_logger.log_scalars('Training/Overall', {
|
||||
'symbols_trained': len(training_results['symbols_trained']),
|
||||
'experiences': training_results['total_experiences'],
|
||||
'avg_state_size': training_results['avg_state_size'],
|
||||
'errors': len(training_results['training_errors'])
|
||||
}, self.training_stats['training_sessions'])
|
||||
|
||||
return training_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training RL agents: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
# def _simulate_trading_action(self, symbol: str, rl_state) -> int:
|
||||
# """Simulate trading action for training (would be real decision in production)"""
|
||||
# # Simple simulation based on state features
|
||||
# if len(rl_state) > 100:
|
||||
# # Use momentum features to decide action
|
||||
# momentum_features = rl_state[:100] # First 100 features assumed to be momentum
|
||||
# avg_momentum = sum(momentum_features) / len(momentum_features)
|
||||
|
||||
# if avg_momentum > 0.6:
|
||||
# return 1 # BUY
|
||||
# elif avg_momentum < 0.4:
|
||||
# return 2 # SELL
|
||||
# else:
|
||||
# return 0 # HOLD
|
||||
# else:
|
||||
# return 0 # HOLD as default
|
||||
|
||||
# def _calculate_training_reward(self, symbol: str, market_state, action: int) -> float:
|
||||
# """Calculate training reward based on market state and action"""
|
||||
# try:
|
||||
# # Simple reward calculation based on market conditions
|
||||
# base_reward = 0.0
|
||||
|
||||
# # Reward based on volatility alignment
|
||||
# if hasattr(market_state, 'volatility'):
|
||||
# if action == 0 and market_state.volatility > 0.02: # HOLD in high volatility
|
||||
# base_reward += 0.1
|
||||
# elif action != 0 and market_state.volatility < 0.01: # Trade in low volatility
|
||||
# base_reward += 0.1
|
||||
|
||||
# # Reward based on trend alignment
|
||||
# if hasattr(market_state, 'trend_strength'):
|
||||
# if action == 1 and market_state.trend_strength > 0.6: # BUY in uptrend
|
||||
# base_reward += 0.2
|
||||
# elif action == 2 and market_state.trend_strength < 0.4: # SELL in downtrend
|
||||
# base_reward += 0.2
|
||||
|
||||
# return base_reward
|
||||
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error calculating reward for {symbol}: {e}")
|
||||
# return 0.0
|
||||
|
||||
# def _update_training_stats(self, training_results: Dict[str, any], market_states: Dict[str, any]):
|
||||
# """Update training statistics"""
|
||||
# self.training_stats['training_sessions'] += 1
|
||||
# self.training_stats['total_experiences'] += training_results.get('total_experiences', 0)
|
||||
# self.training_stats['avg_state_size'] = training_results.get('avg_state_size', 0)
|
||||
# self.training_stats['last_training_time'] = datetime.now()
|
||||
|
||||
# # Log statistics periodically
|
||||
# if self.training_stats['training_sessions'] % 10 == 0:
|
||||
# logger.info("Training Statistics:")
|
||||
# logger.info(f" Sessions: {self.training_stats['training_sessions']}")
|
||||
# logger.info(f" Total Experiences: {self.training_stats['total_experiences']}")
|
||||
# logger.info(f" Avg State Size: {self.training_stats['avg_state_size']:.0f}")
|
||||
# logger.info(f" Data Quality: {self.training_stats['data_quality_score']:.2f}")
|
||||
|
||||
# def _log_state_size_info(self, market_states: Dict[str, any]):
|
||||
# """Log information about state sizes for debugging"""
|
||||
# for symbol, state in market_states.items():
|
||||
# info = []
|
||||
|
||||
# if hasattr(state, 'raw_ticks'):
|
||||
# info.append(f"ticks: {len(state.raw_ticks)}")
|
||||
|
||||
# if hasattr(state, 'ohlcv_data'):
|
||||
# total_bars = sum(len(bars) for bars in state.ohlcv_data.values())
|
||||
# info.append(f"OHLCV bars: {total_bars}")
|
||||
|
||||
# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
|
||||
# info.append("CNN features: available")
|
||||
|
||||
# if hasattr(state, 'pivot_points') and state.pivot_points:
|
||||
# info.append("pivot points: available")
|
||||
|
||||
# logger.info(f"{symbol} state data: {', '.join(info)}")
|
||||
|
||||
# async def _save_training_progress(self):
|
||||
# """Save training progress and models"""
|
||||
# try:
|
||||
# if self.rl_trainer:
|
||||
# self.rl_trainer._save_all_models()
|
||||
# logger.info("Training progress saved")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error saving training progress: {e}")
|
||||
|
||||
# async def shutdown(self):
|
||||
# """Graceful shutdown"""
|
||||
# logger.info("Shutting down enhanced RL training system...")
|
||||
# self.running = False
|
||||
|
||||
# # Save final state
|
||||
# await self._save_training_progress()
|
||||
|
||||
# # Stop data provider
|
||||
# if self.data_provider:
|
||||
# await self.data_provider.stop_real_time_streaming()
|
||||
|
||||
# logger.info("Enhanced RL training system shutdown complete")
|
||||
|
||||
# async def main():
|
||||
# """Main function to run enhanced RL training"""
|
||||
# system = None
|
||||
|
||||
# def signal_handler(signum, frame):
|
||||
# logger.info("Received shutdown signal")
|
||||
# if system:
|
||||
# asyncio.create_task(system.shutdown())
|
||||
|
||||
# # Set up signal handlers
|
||||
# signal.signal(signal.SIGINT, signal_handler)
|
||||
# signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# try:
|
||||
# # Create and initialize the training system
|
||||
# system = EnhancedRLTrainingSystem()
|
||||
# await system.initialize()
|
||||
|
||||
# logger.info("Enhanced RL Training System is now running...")
|
||||
# logger.info("The RL model now receives ~13,400 features instead of ~100!")
|
||||
# logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# # Run the training loop
|
||||
# await system.run_training_loop()
|
||||
|
||||
# except KeyboardInterrupt:
|
||||
# logger.info("Training interrupted by user")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in main training loop: {e}")
|
||||
# raise
|
||||
# finally:
|
||||
# if system:
|
||||
# await system.shutdown()
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# asyncio.run(main())
|
||||
@@ -1,95 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run Dashboard with Enhanced Training System Enabled
|
||||
|
||||
This script starts the trading dashboard with the enhanced real-time
|
||||
training system automatically enabled and running.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
from core.trading_executor import TradingExecutor
|
||||
from web.clean_dashboard import create_clean_dashboard
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def main():
|
||||
"""Start dashboard with enhanced training enabled"""
|
||||
try:
|
||||
logger.info("=" * 70)
|
||||
logger.info("STARTING DASHBOARD WITH ENHANCED TRAINING SYSTEM")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# 1. Initialize components with enhanced training
|
||||
logger.info("1. Initializing components...")
|
||||
data_provider = DataProvider()
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# 2. Create orchestrator with enhanced training ENABLED
|
||||
logger.info("2. Creating orchestrator with enhanced training...")
|
||||
orchestrator = TradingOrchestrator(
|
||||
data_provider=data_provider,
|
||||
enhanced_rl_training=True # 🔥 THIS ENABLES ENHANCED TRAINING
|
||||
)
|
||||
|
||||
# 3. Verify enhanced training is available
|
||||
logger.info("3. Verifying enhanced training system...")
|
||||
if orchestrator.enhanced_training_system:
|
||||
logger.info("✅ Enhanced training system available")
|
||||
logger.info(f" - Training enabled: {orchestrator.training_enabled}")
|
||||
|
||||
# 4. Start enhanced training
|
||||
logger.info("4. Starting enhanced training system...")
|
||||
start_result = orchestrator.start_enhanced_training()
|
||||
if start_result:
|
||||
logger.info("✅ Enhanced training started successfully")
|
||||
else:
|
||||
logger.warning("⚠️ Enhanced training start failed")
|
||||
else:
|
||||
logger.warning("⚠️ Enhanced training system not available")
|
||||
|
||||
# 5. Create dashboard
|
||||
logger.info("5. Creating dashboard...")
|
||||
dashboard = create_clean_dashboard(
|
||||
data_provider=data_provider,
|
||||
orchestrator=orchestrator,
|
||||
trading_executor=trading_executor
|
||||
)
|
||||
|
||||
# 6. Connect training system to dashboard
|
||||
logger.info("6. Connecting training system to dashboard...")
|
||||
orchestrator.set_training_dashboard(dashboard)
|
||||
|
||||
# 7. Start dashboard
|
||||
logger.info("7. Starting dashboard...")
|
||||
logger.info("🎉 Dashboard with enhanced training is now running!")
|
||||
logger.info(" - Enhanced training: ENABLED")
|
||||
logger.info(" - Real-time learning: ACTIVE")
|
||||
logger.info(" - Dashboard URL: http://127.0.0.1:8051")
|
||||
|
||||
# Keep running
|
||||
await asyncio.sleep(3600) # Run for 1 hour
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Dashboard stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting dashboard: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,510 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integrated Real-time RL COB Trading System with Dashboard
|
||||
|
||||
This script starts both:
|
||||
1. RealtimeRLCOBTrader - 1B parameter RL model with real-time training
|
||||
2. COB Dashboard - Real-time visualization with RL predictions
|
||||
|
||||
The RL predictions are integrated into the dashboard for live visualization.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
from aiohttp import web
|
||||
|
||||
# Local imports
|
||||
from core.realtime_rl_cob_trader import RealtimeRLCOBTrader, PredictionResult
|
||||
from core.trading_executor import TradingExecutor
|
||||
from core.config import load_config
|
||||
from web.cob_realtime_dashboard import COBDashboardServer
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('logs/integrated_rl_cob_system.log'),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class IntegratedRLCOBSystem:
|
||||
"""
|
||||
Integrated Real-time RL COB Trading System with Dashboard
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: str = "config.yaml"):
|
||||
"""Initialize integrated system with configuration"""
|
||||
self.config = load_config(config_path)
|
||||
self.trader = None
|
||||
self.dashboard = None
|
||||
self.trading_executor = None
|
||||
self.running = False
|
||||
|
||||
# RL prediction storage for dashboard
|
||||
self.rl_predictions: Dict[str, list] = {}
|
||||
self.prediction_history: Dict[str, list] = {}
|
||||
|
||||
# Setup signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
|
||||
logger.info("IntegratedRLCOBSystem initialized")
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
"""Handle shutdown signals"""
|
||||
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the integrated RL COB trading system with dashboard"""
|
||||
try:
|
||||
logger.info("=" * 60)
|
||||
logger.info("INTEGRATED RL COB SYSTEM STARTING")
|
||||
logger.info("Real-time RL Trading + Dashboard")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize trading executor
|
||||
await self._initialize_trading_executor()
|
||||
|
||||
# Initialize RL trader with prediction callback
|
||||
await self._initialize_rl_trader()
|
||||
|
||||
# Initialize dashboard with RL integration
|
||||
await self._initialize_dashboard()
|
||||
|
||||
# Start the integrated system
|
||||
await self._start_integrated_system()
|
||||
|
||||
# Run main loop
|
||||
await self._run_main_loop()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error in integrated system: {e}")
|
||||
raise
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
async def _initialize_trading_executor(self):
|
||||
"""Initialize the trading executor"""
|
||||
logger.info("Initializing Trading Executor...")
|
||||
|
||||
# Get trading configuration
|
||||
trading_config = self.config.get('trading', {})
|
||||
mexc_config = self.config.get('mexc', {})
|
||||
|
||||
# Determine if we should run in simulation mode
|
||||
simulation_mode = mexc_config.get('simulation_mode', True)
|
||||
|
||||
if simulation_mode:
|
||||
logger.info("Running in SIMULATION mode - no real trades will be executed")
|
||||
else:
|
||||
logger.warning("Running in LIVE TRADING mode - real money at risk!")
|
||||
|
||||
# Add safety confirmation for live trading
|
||||
confirmation = input("Type 'CONFIRM_LIVE_TRADING' to proceed with live trading: ")
|
||||
if confirmation != 'CONFIRM_LIVE_TRADING':
|
||||
logger.info("Live trading not confirmed, switching to simulation mode")
|
||||
simulation_mode = True
|
||||
|
||||
# Initialize trading executor with config path
|
||||
self.trading_executor = TradingExecutor("config.yaml")
|
||||
|
||||
logger.info(f"Trading Executor initialized in {'SIMULATION' if simulation_mode else 'LIVE'} mode")
|
||||
|
||||
async def _initialize_rl_trader(self):
|
||||
"""Initialize the RL trader with prediction callbacks"""
|
||||
logger.info("Initializing Real-time RL COB Trader...")
|
||||
|
||||
# Get RL configuration
|
||||
rl_config = self.config.get('realtime_rl', {})
|
||||
|
||||
# Trading symbols
|
||||
symbols = rl_config.get('symbols', ['BTC/USDT', 'ETH/USDT'])
|
||||
|
||||
# Initialize prediction storage
|
||||
for symbol in symbols:
|
||||
self.rl_predictions[symbol] = []
|
||||
self.prediction_history[symbol] = []
|
||||
|
||||
# RL parameters
|
||||
inference_interval_ms = rl_config.get('inference_interval_ms', 200)
|
||||
min_confidence_threshold = rl_config.get('min_confidence_threshold', 0.7)
|
||||
required_confident_predictions = rl_config.get('required_confident_predictions', 3)
|
||||
model_checkpoint_dir = rl_config.get('model_checkpoint_dir', 'models/realtime_rl_cob')
|
||||
|
||||
# Initialize RL trader
|
||||
self.trader = RealtimeRLCOBTrader(
|
||||
symbols=symbols,
|
||||
trading_executor=self.trading_executor,
|
||||
model_checkpoint_dir=model_checkpoint_dir,
|
||||
inference_interval_ms=inference_interval_ms,
|
||||
min_confidence_threshold=min_confidence_threshold,
|
||||
required_confident_predictions=required_confident_predictions
|
||||
)
|
||||
|
||||
# Monkey-patch the trader to capture predictions
|
||||
original_add_signal = self.trader._add_signal
|
||||
def enhanced_add_signal(symbol: str, prediction: PredictionResult):
|
||||
# Call original method
|
||||
original_add_signal(symbol, prediction)
|
||||
# Capture prediction for dashboard
|
||||
self._on_rl_prediction(symbol, prediction)
|
||||
|
||||
self.trader._add_signal = enhanced_add_signal
|
||||
|
||||
logger.info(f"RL Trader initialized for symbols: {symbols}")
|
||||
logger.info(f"Inference interval: {inference_interval_ms}ms")
|
||||
logger.info(f"Confidence threshold: {min_confidence_threshold}")
|
||||
logger.info(f"Required predictions: {required_confident_predictions}")
|
||||
|
||||
def _on_rl_prediction(self, symbol: str, prediction: PredictionResult):
|
||||
"""Handle RL predictions for dashboard integration"""
|
||||
try:
|
||||
# Convert prediction to dashboard format
|
||||
prediction_data = {
|
||||
'timestamp': prediction.timestamp.isoformat(),
|
||||
'direction': prediction.predicted_direction, # 0=DOWN, 1=SIDEWAYS, 2=UP
|
||||
'confidence': prediction.confidence,
|
||||
'predicted_change': prediction.predicted_change,
|
||||
'direction_text': ['DOWN', 'SIDEWAYS', 'UP'][prediction.predicted_direction],
|
||||
'color': ['red', 'gray', 'green'][prediction.predicted_direction]
|
||||
}
|
||||
|
||||
# Add to current predictions (for live display)
|
||||
self.rl_predictions[symbol].append(prediction_data)
|
||||
if len(self.rl_predictions[symbol]) > 100: # Keep last 100
|
||||
self.rl_predictions[symbol] = self.rl_predictions[symbol][-100:]
|
||||
|
||||
# Add to history (for chart overlay)
|
||||
self.prediction_history[symbol].append(prediction_data)
|
||||
if len(self.prediction_history[symbol]) > 1000: # Keep last 1000
|
||||
self.prediction_history[symbol] = self.prediction_history[symbol][-1000:]
|
||||
|
||||
logger.debug(f"Captured RL prediction for {symbol}: {prediction.predicted_direction} "
|
||||
f"(confidence: {prediction.confidence:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing RL prediction: {e}")
|
||||
|
||||
async def _initialize_dashboard(self):
|
||||
"""Initialize the COB dashboard with RL integration"""
|
||||
logger.info("Initializing COB Dashboard with RL Integration...")
|
||||
|
||||
# Get dashboard configuration
|
||||
dashboard_config = self.config.get('dashboard', {})
|
||||
host = dashboard_config.get('host', 'localhost')
|
||||
port = dashboard_config.get('port', 8053)
|
||||
|
||||
# Create enhanced dashboard server
|
||||
self.dashboard = EnhancedCOBDashboardServer(
|
||||
host=host,
|
||||
port=port,
|
||||
rl_system=self # Pass reference to get predictions
|
||||
)
|
||||
|
||||
logger.info(f"COB Dashboard initialized at http://{host}:{port}")
|
||||
|
||||
async def _start_integrated_system(self):
|
||||
"""Start the complete integrated system"""
|
||||
logger.info("Starting Integrated RL COB System...")
|
||||
|
||||
# Start RL trader first (this initializes COB integration)
|
||||
await self.trader.start()
|
||||
logger.info("RL Trader started")
|
||||
|
||||
# Start dashboard (uses same COB integration)
|
||||
await self.dashboard.start()
|
||||
logger.info("COB Dashboard started")
|
||||
|
||||
self.running = True
|
||||
|
||||
logger.info("INTEGRATED SYSTEM FULLY OPERATIONAL!")
|
||||
logger.info("1B parameter RL model: ACTIVE")
|
||||
logger.info("Real-time COB data: STREAMING")
|
||||
logger.info("Signal accumulation: ACTIVE")
|
||||
logger.info("Live predictions: VISIBLE IN DASHBOARD")
|
||||
logger.info("Continuous training: ACTIVE")
|
||||
logger.info(f"Dashboard URL: http://{self.dashboard.host}:{self.dashboard.port}")
|
||||
|
||||
async def _run_main_loop(self):
|
||||
"""Main monitoring and statistics loop"""
|
||||
logger.info("Starting integrated system monitoring...")
|
||||
|
||||
last_stats_time = datetime.now()
|
||||
stats_interval = 60 # Print stats every 60 seconds
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Sleep for a bit
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Print periodic statistics
|
||||
current_time = datetime.now()
|
||||
if (current_time - last_stats_time).total_seconds() >= stats_interval:
|
||||
await self._print_integrated_stats()
|
||||
last_stats_time = current_time
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main loop: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
logger.info("Integrated system monitoring stopped")
|
||||
|
||||
async def _print_integrated_stats(self):
|
||||
"""Print comprehensive integrated system statistics"""
|
||||
try:
|
||||
logger.info("=" * 80)
|
||||
logger.info("INTEGRATED RL COB SYSTEM STATISTICS")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# RL Trader Statistics
|
||||
if self.trader:
|
||||
rl_stats = self.trader.get_performance_stats()
|
||||
logger.info("\n🤖 RL TRADER PERFORMANCE:")
|
||||
|
||||
for symbol in self.trader.symbols:
|
||||
training_stats = rl_stats.get('training_stats', {}).get(symbol, {})
|
||||
inference_stats = rl_stats.get('inference_stats', {}).get(symbol, {})
|
||||
signal_stats = rl_stats.get('signal_stats', {}).get(symbol, {})
|
||||
|
||||
logger.info(f"\n 📈 {symbol}:")
|
||||
logger.info(f" Predictions: {training_stats.get('total_predictions', 0)}")
|
||||
logger.info(f" Success Rate: {signal_stats.get('success_rate', 0):.1%}")
|
||||
logger.info(f" Avg Inference: {inference_stats.get('average_inference_time_ms', 0):.1f}ms")
|
||||
logger.info(f" Current Signals: {signal_stats.get('current_signals', 0)}")
|
||||
|
||||
# RL prediction stats for dashboard
|
||||
recent_predictions = len(self.rl_predictions.get(symbol, []))
|
||||
total_predictions = len(self.prediction_history.get(symbol, []))
|
||||
logger.info(f" Dashboard Predictions: {recent_predictions} recent, {total_predictions} total")
|
||||
|
||||
# Dashboard Statistics
|
||||
if self.dashboard:
|
||||
logger.info(f"\nDASHBOARD STATISTICS:")
|
||||
logger.info(f" Active Connections: {len(self.dashboard.websocket_connections)}")
|
||||
logger.info(f" Server Status: {'RUNNING' if self.dashboard.site else 'STOPPED'}")
|
||||
logger.info(f" URL: http://{self.dashboard.host}:{self.dashboard.port}")
|
||||
|
||||
# Trading Executor Statistics
|
||||
if self.trading_executor:
|
||||
positions = self.trading_executor.get_positions()
|
||||
trade_history = self.trading_executor.get_trade_history()
|
||||
|
||||
logger.info(f"\n💰 TRADING STATISTICS:")
|
||||
logger.info(f" Active Positions: {len(positions)}")
|
||||
logger.info(f" Total Trades: {len(trade_history)}")
|
||||
|
||||
if trade_history:
|
||||
total_pnl = sum(trade.pnl for trade in trade_history)
|
||||
profitable_trades = sum(1 for trade in trade_history if trade.pnl > 0)
|
||||
win_rate = (profitable_trades / len(trade_history)) * 100
|
||||
|
||||
logger.info(f" Total P&L: ${total_pnl:.2f}")
|
||||
logger.info(f" Win Rate: {win_rate:.1f}%")
|
||||
|
||||
logger.info("=" * 80)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error printing integrated stats: {e}")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the integrated system gracefully"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
logger.info("Stopping Integrated RL COB System...")
|
||||
|
||||
self.running = False
|
||||
|
||||
# Stop dashboard
|
||||
if self.dashboard:
|
||||
await self.dashboard.stop()
|
||||
logger.info("Dashboard stopped")
|
||||
|
||||
# Stop RL trader
|
||||
if self.trader:
|
||||
await self.trader.stop()
|
||||
logger.info("RL Trader stopped")
|
||||
|
||||
logger.info("🏁 Integrated system stopped successfully")
|
||||
|
||||
def get_rl_predictions(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get RL predictions for dashboard display"""
|
||||
return {
|
||||
'recent_predictions': self.rl_predictions.get(symbol, []),
|
||||
'prediction_history': self.prediction_history.get(symbol, []),
|
||||
'total_predictions': len(self.prediction_history.get(symbol, [])),
|
||||
'recent_count': len(self.rl_predictions.get(symbol, []))
|
||||
}
|
||||
|
||||
class EnhancedCOBDashboardServer(COBDashboardServer):
|
||||
"""Enhanced COB Dashboard with RL prediction integration"""
|
||||
|
||||
def __init__(self, host: str = 'localhost', port: int = 8053, rl_system: IntegratedRLCOBSystem = None):
|
||||
super().__init__(host, port)
|
||||
self.rl_system = rl_system
|
||||
|
||||
# Add RL prediction routes
|
||||
self._setup_rl_routes()
|
||||
|
||||
logger.info("Enhanced COB Dashboard with RL predictions initialized")
|
||||
|
||||
async def serve_dashboard(self, request):
|
||||
"""Serve the enhanced dashboard HTML with RL predictions"""
|
||||
try:
|
||||
# Read the enhanced dashboard HTML
|
||||
dashboard_path = os.path.join(os.path.dirname(__file__), 'enhanced_cob_dashboard.html')
|
||||
|
||||
if os.path.exists(dashboard_path):
|
||||
with open(dashboard_path, 'r', encoding='utf-8') as f:
|
||||
html_content = f.read()
|
||||
return web.Response(text=html_content, content_type='text/html')
|
||||
else:
|
||||
# Fallback to basic dashboard
|
||||
logger.warning("Enhanced dashboard HTML not found, using basic dashboard")
|
||||
return await super().serve_dashboard(request)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error serving enhanced dashboard: {e}")
|
||||
return web.Response(text="Dashboard error", status=500)
|
||||
|
||||
def _setup_rl_routes(self):
|
||||
"""Setup additional routes for RL predictions"""
|
||||
self.app.router.add_get('/api/rl-predictions/{symbol}', self.get_rl_predictions)
|
||||
self.app.router.add_get('/api/rl-status', self.get_rl_status)
|
||||
|
||||
async def get_rl_predictions(self, request):
|
||||
"""Get RL predictions for a symbol"""
|
||||
try:
|
||||
symbol = request.match_info['symbol']
|
||||
symbol = symbol.replace('%2F', '/')
|
||||
|
||||
if symbol not in self.symbols:
|
||||
return web.json_response({
|
||||
'error': f'Symbol {symbol} not supported'
|
||||
}, status=400)
|
||||
|
||||
if not self.rl_system:
|
||||
return web.json_response({
|
||||
'error': 'RL system not available'
|
||||
}, status=503)
|
||||
|
||||
predictions = self.rl_system.get_rl_predictions(symbol)
|
||||
|
||||
return web.json_response({
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'predictions': predictions
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting RL predictions: {e}")
|
||||
return web.json_response({
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_rl_status(self, request):
|
||||
"""Get RL system status"""
|
||||
try:
|
||||
if not self.rl_system or not self.rl_system.trader:
|
||||
return web.json_response({
|
||||
'status': 'inactive',
|
||||
'error': 'RL system not available'
|
||||
})
|
||||
|
||||
rl_stats = self.rl_system.trader.get_performance_stats()
|
||||
|
||||
status = {
|
||||
'status': 'active',
|
||||
'symbols': self.rl_system.trader.symbols,
|
||||
'model_info': rl_stats.get('model_info', {}),
|
||||
'inference_interval_ms': self.rl_system.trader.inference_interval_ms,
|
||||
'confidence_threshold': self.rl_system.trader.min_confidence_threshold,
|
||||
'required_predictions': self.rl_system.trader.required_confident_predictions,
|
||||
'device': str(self.rl_system.trader.device),
|
||||
'running': self.rl_system.trader.running
|
||||
}
|
||||
|
||||
return web.json_response(status)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting RL status: {e}")
|
||||
return web.json_response({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def _broadcast_cob_update(self, symbol: str, data: Dict):
|
||||
"""Enhanced COB update broadcast with RL predictions"""
|
||||
try:
|
||||
# Get RL predictions if available
|
||||
rl_data = {}
|
||||
if self.rl_system:
|
||||
rl_predictions = self.rl_system.get_rl_predictions(symbol)
|
||||
rl_data = {
|
||||
'rl_predictions': rl_predictions.get('recent_predictions', [])[-10:], # Last 10
|
||||
'prediction_count': rl_predictions.get('total_predictions', 0)
|
||||
}
|
||||
|
||||
# Enhanced data with RL predictions
|
||||
enhanced_data = {
|
||||
**data,
|
||||
'rl_data': rl_data
|
||||
}
|
||||
|
||||
# Broadcast to all WebSocket connections
|
||||
message = json.dumps({
|
||||
'type': 'cob_update',
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'data': enhanced_data
|
||||
}, default=str)
|
||||
|
||||
# Send to all connected clients
|
||||
disconnected = []
|
||||
for ws in self.websocket_connections:
|
||||
try:
|
||||
await ws.send_str(message)
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket send failed: {e}")
|
||||
disconnected.append(ws)
|
||||
|
||||
# Remove disconnected clients
|
||||
for ws in disconnected:
|
||||
self.websocket_connections.discard(ws)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting enhanced COB update: {e}")
|
||||
|
||||
async def main():
|
||||
"""Main entry point for integrated RL COB system"""
|
||||
# Create logs directory
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
# Initialize and start integrated system
|
||||
system = IntegratedRLCOBSystem()
|
||||
|
||||
try:
|
||||
await system.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt")
|
||||
except Exception as e:
|
||||
logger.error(f"System error: {e}")
|
||||
raise
|
||||
finally:
|
||||
await system.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,52 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
One-Click MEXC Browser Launcher
|
||||
|
||||
Simply run this script to start capturing MEXC futures trading requests.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add project root to path
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
def main():
|
||||
"""Launch MEXC browser automation"""
|
||||
print("🚀 MEXC Futures Request Interceptor")
|
||||
print("=" * 50)
|
||||
print("This will automatically:")
|
||||
print("✅ Install ChromeDriver")
|
||||
print("✅ Open MEXC futures page")
|
||||
print("✅ Capture all API requests")
|
||||
print("✅ Extract session cookies")
|
||||
print("✅ Save data to JSON files")
|
||||
print("\nRequirements will be installed automatically if missing.")
|
||||
|
||||
try:
|
||||
# First try to run the auto browser directly
|
||||
from core.mexc_webclient.auto_browser import main as run_auto_browser
|
||||
run_auto_browser()
|
||||
|
||||
except ImportError as e:
|
||||
print(f"\n⚠️ Import error: {e}")
|
||||
print("Installing requirements first...")
|
||||
|
||||
# Try to install requirements and run setup
|
||||
try:
|
||||
from setup_mexc_browser import main as setup_main
|
||||
setup_main()
|
||||
except ImportError:
|
||||
print("❌ Could not find setup script")
|
||||
print("Please run: pip install selenium webdriver-manager")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
print("\nTroubleshooting:")
|
||||
print("1. Make sure you have Chrome browser installed")
|
||||
print("2. Check your internet connection")
|
||||
print("3. Try running: pip install selenium webdriver-manager")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,451 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Optimized COB System - Eliminates Redundant Implementations
|
||||
|
||||
This optimized script runs both the COB dashboard and 1B RL trading system
|
||||
in a single process with shared data sources to eliminate redundancies:
|
||||
|
||||
BEFORE (Redundant):
|
||||
- Dashboard: Own COBIntegration instance
|
||||
- RL Trader: Own COBIntegration instance
|
||||
- Training: Own COBIntegration instance
|
||||
= 3x WebSocket connections, 3x order book processing
|
||||
|
||||
AFTER (Optimized):
|
||||
- Shared COBIntegration instance
|
||||
- Single WebSocket connection per exchange
|
||||
- Shared order book processing and caching
|
||||
= 1x connections, 1x processing, shared memory
|
||||
|
||||
Resource savings: ~60% memory, ~70% network bandwidth
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from aiohttp import web
|
||||
import threading
|
||||
|
||||
# Local imports
|
||||
from core.cob_integration import COBIntegration
|
||||
from core.data_provider import DataProvider
|
||||
from core.trading_executor import TradingExecutor
|
||||
from core.config import load_config
|
||||
from web.cob_realtime_dashboard import COBDashboardServer
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('logs/optimized_cob_system.log'),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OptimizedCOBSystem:
|
||||
"""
|
||||
Optimized COB System - Single COB instance shared across all components
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: str = "config.yaml"):
|
||||
"""Initialize optimized system with shared resources"""
|
||||
self.config = load_config(config_path)
|
||||
self.running = False
|
||||
|
||||
# Shared components (eliminate redundancy)
|
||||
self.data_provider = DataProvider()
|
||||
self.shared_cob_integration: Optional[COBIntegration] = None
|
||||
self.trading_executor: Optional[TradingExecutor] = None
|
||||
|
||||
# Dashboard using shared COB
|
||||
self.dashboard_server: Optional[COBDashboardServer] = None
|
||||
|
||||
# Performance tracking
|
||||
self.performance_stats = {
|
||||
'start_time': None,
|
||||
'cob_updates_processed': 0,
|
||||
'dashboard_connections': 0,
|
||||
'memory_saved_mb': 0
|
||||
}
|
||||
|
||||
# Setup signal handlers
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
|
||||
logger.info("OptimizedCOBSystem initialized - Eliminating redundant implementations")
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
"""Handle shutdown signals"""
|
||||
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the optimized COB system"""
|
||||
try:
|
||||
logger.info("=" * 70)
|
||||
logger.info("🚀 OPTIMIZED COB SYSTEM STARTING")
|
||||
logger.info("=" * 70)
|
||||
logger.info("Eliminating redundant COB implementations...")
|
||||
logger.info("Single shared COB integration for all components")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Initialize shared components
|
||||
await self._initialize_shared_components()
|
||||
|
||||
# Initialize dashboard with shared COB
|
||||
await self._initialize_optimized_dashboard()
|
||||
|
||||
# Start the integrated system
|
||||
await self._start_optimized_system()
|
||||
|
||||
# Run main monitoring loop
|
||||
await self._run_optimized_loop()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error in optimized system: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
async def _initialize_shared_components(self):
|
||||
"""Initialize shared components (eliminates redundancy)"""
|
||||
logger.info("1. Initializing shared COB integration...")
|
||||
|
||||
# Single COB integration instance for entire system
|
||||
self.shared_cob_integration = COBIntegration(
|
||||
data_provider=self.data_provider,
|
||||
symbols=['BTC/USDT', 'ETH/USDT']
|
||||
)
|
||||
|
||||
# Start the shared COB integration
|
||||
await self.shared_cob_integration.start()
|
||||
|
||||
logger.info("2. Initializing trading executor...")
|
||||
|
||||
# Trading executor configuration
|
||||
trading_config = self.config.get('trading', {})
|
||||
mexc_config = self.config.get('mexc', {})
|
||||
simulation_mode = mexc_config.get('simulation_mode', True)
|
||||
|
||||
self.trading_executor = TradingExecutor()
|
||||
|
||||
logger.info("✅ Shared components initialized")
|
||||
logger.info(f" Single COB integration: {len(self.shared_cob_integration.symbols)} symbols")
|
||||
logger.info(f" Trading mode: {'SIMULATION' if simulation_mode else 'LIVE'}")
|
||||
|
||||
async def _initialize_optimized_dashboard(self):
|
||||
"""Initialize dashboard that uses shared COB (no redundant instance)"""
|
||||
logger.info("3. Initializing optimized dashboard...")
|
||||
|
||||
# Create dashboard and replace its COB with our shared one
|
||||
self.dashboard_server = COBDashboardServer(host='localhost', port=8053)
|
||||
|
||||
# Replace the dashboard's COB integration with our shared one
|
||||
self.dashboard_server.cob_integration = self.shared_cob_integration
|
||||
|
||||
logger.info("✅ Optimized dashboard initialized with shared COB")
|
||||
|
||||
async def _start_optimized_system(self):
|
||||
"""Start the optimized system with shared resources"""
|
||||
logger.info("4. Starting optimized system...")
|
||||
|
||||
self.running = True
|
||||
self.performance_stats['start_time'] = datetime.now()
|
||||
|
||||
# Start dashboard server with shared COB
|
||||
await self.dashboard_server.start()
|
||||
|
||||
# Estimate memory savings
|
||||
# Start RL trader
|
||||
await self.rl_trader.start()
|
||||
|
||||
# Estimate memory savings
|
||||
estimated_savings = self._calculate_memory_savings()
|
||||
self.performance_stats['memory_saved_mb'] = estimated_savings
|
||||
|
||||
logger.info("🚀 Optimized COB System started successfully!")
|
||||
logger.info(f"💾 Estimated memory savings: {estimated_savings:.0f} MB")
|
||||
logger.info(f"🌐 Dashboard: http://localhost:8053")
|
||||
logger.info(f"🤖 RL Training: Active with 1B parameters")
|
||||
logger.info(f"📊 Shared COB: Single integration for all components")
|
||||
logger.info("🔄 System Status: OPTIMIZED - No redundant implementations")
|
||||
|
||||
def _calculate_memory_savings(self) -> float:
|
||||
"""Calculate estimated memory savings from eliminating redundancy"""
|
||||
# Estimates based on typical COB memory usage
|
||||
cob_integration_memory_mb = 512 # Order books, caches, connections
|
||||
websocket_connection_memory_mb = 64 # Per exchange connection
|
||||
|
||||
# Before: 3 separate COB integrations (dashboard + RL trader + training)
|
||||
before_memory = 3 * cob_integration_memory_mb + 3 * websocket_connection_memory_mb
|
||||
|
||||
# After: 1 shared COB integration
|
||||
after_memory = 1 * cob_integration_memory_mb + 1 * websocket_connection_memory_mb
|
||||
|
||||
savings = before_memory - after_memory
|
||||
return savings
|
||||
|
||||
async def _run_optimized_loop(self):
|
||||
"""Main optimized monitoring loop"""
|
||||
logger.info("Starting optimized monitoring loop...")
|
||||
|
||||
last_stats_time = datetime.now()
|
||||
stats_interval = 60 # Print stats every 60 seconds
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Sleep for a bit
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Update performance stats
|
||||
self._update_performance_stats()
|
||||
|
||||
# Print periodic statistics
|
||||
current_time = datetime.now()
|
||||
if (current_time - last_stats_time).total_seconds() >= stats_interval:
|
||||
await self._print_optimized_stats()
|
||||
last_stats_time = current_time
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in optimized loop: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
logger.info("Optimized monitoring loop stopped")
|
||||
|
||||
def _update_performance_stats(self):
|
||||
"""Update performance statistics"""
|
||||
try:
|
||||
# Get stats from shared COB integration
|
||||
if self.shared_cob_integration:
|
||||
cob_stats = self.shared_cob_integration.get_statistics()
|
||||
self.performance_stats['cob_updates_processed'] = cob_stats.get('total_signals', {}).get('BTC/USDT', 0)
|
||||
|
||||
# Get stats from dashboard
|
||||
if self.dashboard_server:
|
||||
dashboard_stats = self.dashboard_server.get_stats()
|
||||
self.performance_stats['dashboard_connections'] = dashboard_stats.get('active_connections', 0)
|
||||
|
||||
# Get stats from RL trader
|
||||
if self.rl_trader:
|
||||
rl_stats = self.rl_trader.get_stats()
|
||||
self.performance_stats['rl_predictions'] = rl_stats.get('total_predictions', 0)
|
||||
|
||||
# Get stats from trading executor
|
||||
if self.trading_executor:
|
||||
trade_history = self.trading_executor.get_trade_history()
|
||||
self.performance_stats['trades_executed'] = len(trade_history)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error updating performance stats: {e}")
|
||||
|
||||
async def _print_optimized_stats(self):
|
||||
"""Print comprehensive optimized system statistics"""
|
||||
try:
|
||||
stats = self.performance_stats
|
||||
uptime = (datetime.now() - stats['start_time']).total_seconds() if stats['start_time'] else 0
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("🚀 OPTIMIZED COB SYSTEM PERFORMANCE STATISTICS")
|
||||
logger.info("=" * 80)
|
||||
|
||||
logger.info("📊 Resource Optimization:")
|
||||
logger.info(f" Memory Saved: {stats['memory_saved_mb']:.0f} MB")
|
||||
logger.info(f" Uptime: {uptime:.0f} seconds")
|
||||
logger.info(f" COB Updates: {stats['cob_updates_processed']}")
|
||||
|
||||
logger.info("\n🌐 Dashboard Statistics:")
|
||||
logger.info(f" Active Connections: {stats['dashboard_connections']}")
|
||||
logger.info(f" Server Status: {'RUNNING' if self.dashboard_server else 'STOPPED'}")
|
||||
|
||||
logger.info("\n🤖 RL Trading Statistics:")
|
||||
logger.info(f" Total Predictions: {stats['rl_predictions']}")
|
||||
logger.info(f" Trades Executed: {stats['trades_executed']}")
|
||||
logger.info(f" Trainer Status: {'ACTIVE' if self.rl_trader else 'STOPPED'}")
|
||||
|
||||
# Shared COB statistics
|
||||
if self.shared_cob_integration:
|
||||
cob_stats = self.shared_cob_integration.get_statistics()
|
||||
logger.info("\n📈 Shared COB Integration:")
|
||||
logger.info(f" Active Exchanges: {', '.join(cob_stats.get('active_exchanges', []))}")
|
||||
logger.info(f" Streaming: {cob_stats.get('is_streaming', False)}")
|
||||
logger.info(f" CNN Callbacks: {cob_stats.get('cnn_callbacks', 0)}")
|
||||
logger.info(f" DQN Callbacks: {cob_stats.get('dqn_callbacks', 0)}")
|
||||
logger.info(f" Dashboard Callbacks: {cob_stats.get('dashboard_callbacks', 0)}")
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("✅ OPTIMIZATION STATUS: Redundancy eliminated, shared resources active")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error printing optimized stats: {e}")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the optimized system gracefully"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
logger.info("Stopping Optimized COB System...")
|
||||
|
||||
self.running = False
|
||||
|
||||
# Stop RL trader
|
||||
if self.rl_trader:
|
||||
await self.rl_trader.stop()
|
||||
logger.info("✅ RL Trader stopped")
|
||||
|
||||
# Stop dashboard
|
||||
if self.dashboard_server:
|
||||
await self.dashboard_server.stop()
|
||||
logger.info("✅ Dashboard stopped")
|
||||
|
||||
# Stop shared COB integration (last, as others depend on it)
|
||||
if self.shared_cob_integration:
|
||||
await self.shared_cob_integration.stop()
|
||||
logger.info("✅ Shared COB integration stopped")
|
||||
|
||||
# Print final optimization report
|
||||
await self._print_final_optimization_report()
|
||||
|
||||
logger.info("Optimized COB System stopped successfully")
|
||||
|
||||
async def _print_final_optimization_report(self):
|
||||
"""Print final optimization report"""
|
||||
stats = self.performance_stats
|
||||
uptime = (datetime.now() - stats['start_time']).total_seconds() if stats['start_time'] else 0
|
||||
|
||||
logger.info("\n📊 FINAL OPTIMIZATION REPORT:")
|
||||
logger.info(f" Total Runtime: {uptime:.0f} seconds")
|
||||
logger.info(f" Memory Saved: {stats['memory_saved_mb']:.0f} MB")
|
||||
logger.info(f" COB Updates Processed: {stats['cob_updates_processed']}")
|
||||
logger.info(f" RL Predictions Made: {stats['rl_predictions']}")
|
||||
logger.info(f" Trades Executed: {stats['trades_executed']}")
|
||||
logger.info(" ✅ Redundant implementations eliminated")
|
||||
logger.info(" ✅ Shared COB integration successful")
|
||||
|
||||
|
||||
# Simplified components that use shared COB (no redundant integrations)
|
||||
|
||||
class EnhancedCOBDashboard(COBDashboardServer):
|
||||
"""Enhanced dashboard that uses shared COB integration"""
|
||||
|
||||
def __init__(self, host: str = 'localhost', port: int = 8053,
|
||||
shared_cob: COBIntegration = None, performance_tracker: Dict = None):
|
||||
# Initialize parent without creating new COB integration
|
||||
self.shared_cob = shared_cob
|
||||
self.performance_tracker = performance_tracker or {}
|
||||
super().__init__(host, port)
|
||||
|
||||
# Use shared COB instead of creating new one
|
||||
self.cob_integration = shared_cob
|
||||
logger.info("Enhanced dashboard using shared COB integration (no redundancy)")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get dashboard statistics"""
|
||||
return {
|
||||
'active_connections': len(self.websocket_connections),
|
||||
'using_shared_cob': self.shared_cob is not None,
|
||||
'server_running': self.runner is not None
|
||||
}
|
||||
|
||||
class OptimizedRLTrader:
|
||||
"""Optimized RL trader that uses shared COB integration"""
|
||||
|
||||
def __init__(self, symbols: List[str], shared_cob: COBIntegration,
|
||||
trading_executor: TradingExecutor, performance_tracker: Dict = None):
|
||||
self.symbols = symbols
|
||||
self.shared_cob = shared_cob
|
||||
self.trading_executor = trading_executor
|
||||
self.performance_tracker = performance_tracker or {}
|
||||
self.running = False
|
||||
|
||||
# Subscribe to shared COB updates instead of creating new integration
|
||||
self.subscription_id = None
|
||||
self.prediction_count = 0
|
||||
|
||||
logger.info("Optimized RL trader using shared COB integration (no redundancy)")
|
||||
|
||||
async def start(self):
|
||||
"""Start RL trader with shared COB"""
|
||||
self.running = True
|
||||
|
||||
# Subscribe to shared COB updates
|
||||
self.subscription_id = self.shared_cob.add_dqn_callback(self._on_cob_update)
|
||||
|
||||
# Start prediction loop
|
||||
asyncio.create_task(self._prediction_loop())
|
||||
|
||||
logger.info("Optimized RL trader started with shared COB subscription")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop RL trader"""
|
||||
self.running = False
|
||||
logger.info("Optimized RL trader stopped")
|
||||
|
||||
async def _on_cob_update(self, symbol: str, data: Dict):
|
||||
"""Handle COB updates from shared integration"""
|
||||
try:
|
||||
# Process RL prediction using shared data
|
||||
self.prediction_count += 1
|
||||
|
||||
# Simple prediction logic (placeholder)
|
||||
confidence = 0.75 # Example confidence
|
||||
|
||||
if self.prediction_count % 100 == 0:
|
||||
logger.info(f"RL Prediction #{self.prediction_count} for {symbol} (confidence: {confidence:.2f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RL update: {e}")
|
||||
|
||||
async def _prediction_loop(self):
|
||||
"""Main prediction loop"""
|
||||
while self.running:
|
||||
try:
|
||||
# RL model inference would go here
|
||||
await asyncio.sleep(0.2) # 200ms inference interval
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction loop: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get RL trader statistics"""
|
||||
return {
|
||||
'total_predictions': self.prediction_count,
|
||||
'using_shared_cob': self.shared_cob is not None,
|
||||
'subscription_active': self.subscription_id is not None
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for optimized COB system"""
|
||||
try:
|
||||
# Create logs directory
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
# Initialize and start optimized system
|
||||
system = OptimizedCOBSystem()
|
||||
await system.start()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt, shutting down...")
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set event loop policy for Windows compatibility
|
||||
if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -1,324 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Real-time RL COB Trader Launcher
|
||||
|
||||
Launch script for the real-time reinforcement learning trader that:
|
||||
1. Uses COB data for training a 1B parameter model
|
||||
2. Performs inference every 200ms
|
||||
3. Accumulates confident signals for trade execution
|
||||
4. Trains continuously in real-time based on outcomes
|
||||
|
||||
This script provides a complete trading system integration.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
# Local imports
|
||||
from core.realtime_rl_cob_trader import RealtimeRLCOBTrader
|
||||
from core.trading_executor import TradingExecutor
|
||||
from core.config import load_config
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('logs/realtime_rl_cob_trader.log'),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RealtimeRLCOBTraderLauncher:
|
||||
"""
|
||||
Launcher for Real-time RL COB Trader system
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: str = "config.yaml"):
|
||||
"""Initialize launcher with configuration"""
|
||||
self.config_path = config_path
|
||||
self.config = load_config(config_path)
|
||||
self.trader: Optional[RealtimeRLCOBTrader] = None
|
||||
self.trading_executor: Optional[TradingExecutor] = None
|
||||
self.running = False
|
||||
|
||||
# Setup signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
|
||||
logger.info("RealtimeRLCOBTraderLauncher initialized")
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
"""Handle shutdown signals"""
|
||||
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the real-time RL COB trading system"""
|
||||
try:
|
||||
logger.info("=" * 60)
|
||||
logger.info("REAL-TIME RL COB TRADER SYSTEM STARTING")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize trading executor
|
||||
await self._initialize_trading_executor()
|
||||
|
||||
# Initialize RL trader
|
||||
await self._initialize_rl_trader()
|
||||
|
||||
# Start the trading system
|
||||
await self._start_trading_system()
|
||||
|
||||
# Run main loop
|
||||
await self._run_main_loop()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error in trader launcher: {e}")
|
||||
raise
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
async def _initialize_trading_executor(self):
|
||||
"""Initialize the trading executor"""
|
||||
logger.info("Initializing Trading Executor...")
|
||||
|
||||
# Get trading configuration
|
||||
trading_config = self.config.get('trading', {})
|
||||
mexc_config = self.config.get('mexc', {})
|
||||
|
||||
# Determine if we should run in simulation mode
|
||||
simulation_mode = mexc_config.get('simulation_mode', True)
|
||||
|
||||
if simulation_mode:
|
||||
logger.info("Running in SIMULATION mode - no real trades will be executed")
|
||||
else:
|
||||
logger.warning("Running in LIVE TRADING mode - real money at risk!")
|
||||
|
||||
# Add safety confirmation for live trading
|
||||
confirmation = input("Type 'CONFIRM_LIVE_TRADING' to proceed with live trading: ")
|
||||
if confirmation != 'CONFIRM_LIVE_TRADING':
|
||||
logger.info("Live trading not confirmed, switching to simulation mode")
|
||||
simulation_mode = True
|
||||
|
||||
# Initialize trading executor
|
||||
self.trading_executor = TradingExecutor(self.config_path)
|
||||
|
||||
logger.info(f"Trading Executor initialized in {'SIMULATION' if simulation_mode else 'LIVE'} mode")
|
||||
|
||||
async def _initialize_rl_trader(self):
|
||||
"""Initialize the RL trader"""
|
||||
logger.info("Initializing Real-time RL COB Trader...")
|
||||
|
||||
# Get RL configuration
|
||||
rl_config = self.config.get('realtime_rl', {})
|
||||
|
||||
# Trading symbols
|
||||
symbols = rl_config.get('symbols', ['BTC/USDT', 'ETH/USDT'])
|
||||
|
||||
# RL parameters
|
||||
inference_interval_ms = rl_config.get('inference_interval_ms', 200)
|
||||
min_confidence_threshold = rl_config.get('min_confidence_threshold', 0.7)
|
||||
required_confident_predictions = rl_config.get('required_confident_predictions', 3)
|
||||
model_checkpoint_dir = rl_config.get('model_checkpoint_dir', 'models/realtime_rl_cob')
|
||||
|
||||
# Initialize RL trader
|
||||
if self.trading_executor is None:
|
||||
raise RuntimeError("Trading executor not initialized")
|
||||
|
||||
self.trader = RealtimeRLCOBTrader(
|
||||
symbols=symbols,
|
||||
trading_executor=self.trading_executor,
|
||||
model_checkpoint_dir=model_checkpoint_dir,
|
||||
inference_interval_ms=inference_interval_ms,
|
||||
min_confidence_threshold=min_confidence_threshold,
|
||||
required_confident_predictions=required_confident_predictions
|
||||
)
|
||||
|
||||
logger.info(f"RL Trader initialized for symbols: {symbols}")
|
||||
logger.info(f"Inference interval: {inference_interval_ms}ms")
|
||||
logger.info(f"Confidence threshold: {min_confidence_threshold}")
|
||||
logger.info(f"Required predictions: {required_confident_predictions}")
|
||||
|
||||
async def _start_trading_system(self):
|
||||
"""Start the complete trading system"""
|
||||
logger.info("Starting Real-time RL COB Trading System...")
|
||||
|
||||
# Start RL trader (this will start COB integration internally)
|
||||
if self.trader is None:
|
||||
raise RuntimeError("RL trader not initialized")
|
||||
await self.trader.start()
|
||||
|
||||
self.running = True
|
||||
|
||||
logger.info("✅ Real-time RL COB Trading System started successfully!")
|
||||
logger.info("🔥 1B parameter model training and inference active")
|
||||
logger.info("📊 COB data streaming and processing")
|
||||
logger.info("🎯 Signal accumulation and trade execution ready")
|
||||
logger.info("⚡ Real-time training on prediction outcomes")
|
||||
|
||||
async def _run_main_loop(self):
|
||||
"""Main monitoring and statistics loop"""
|
||||
logger.info("Starting main monitoring loop...")
|
||||
|
||||
last_stats_time = datetime.now()
|
||||
stats_interval = 60 # Print stats every 60 seconds
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Sleep for a bit
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Print periodic statistics
|
||||
current_time = datetime.now()
|
||||
if (current_time - last_stats_time).total_seconds() >= stats_interval:
|
||||
await self._print_performance_stats()
|
||||
last_stats_time = current_time
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main loop: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
logger.info("Main monitoring loop stopped")
|
||||
|
||||
async def _print_performance_stats(self):
|
||||
"""Print comprehensive performance statistics"""
|
||||
try:
|
||||
if not self.trader:
|
||||
return
|
||||
|
||||
stats = self.trader.get_performance_stats()
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("🔥 REAL-TIME RL COB TRADER PERFORMANCE STATISTICS")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Model information
|
||||
logger.info("📊 Model Information:")
|
||||
for symbol, model_info in stats.get('model_info', {}).items():
|
||||
total_params = model_info.get('total_parameters', 0)
|
||||
logger.info(f" {symbol}: {total_params:,} parameters ({total_params/1e9:.2f}B)")
|
||||
|
||||
# Training statistics
|
||||
logger.info("\n🧠 Training Statistics:")
|
||||
for symbol, training_stats in stats.get('training_stats', {}).items():
|
||||
total_preds = training_stats.get('total_predictions', 0)
|
||||
successful_preds = training_stats.get('successful_predictions', 0)
|
||||
success_rate = (successful_preds / max(1, total_preds)) * 100
|
||||
avg_loss = training_stats.get('average_loss', 0.0)
|
||||
training_steps = training_stats.get('total_training_steps', 0)
|
||||
last_training = training_stats.get('last_training_time')
|
||||
|
||||
logger.info(f" {symbol}:")
|
||||
logger.info(f" Predictions: {total_preds} (Success: {success_rate:.1f}%)")
|
||||
logger.info(f" Training Steps: {training_steps}")
|
||||
logger.info(f" Average Loss: {avg_loss:.6f}")
|
||||
if last_training:
|
||||
logger.info(f" Last Training: {last_training}")
|
||||
|
||||
# Inference statistics
|
||||
logger.info("\n⚡ Inference Statistics:")
|
||||
for symbol, inference_stats in stats.get('inference_stats', {}).items():
|
||||
total_inferences = inference_stats.get('total_inferences', 0)
|
||||
avg_time = inference_stats.get('average_inference_time_ms', 0.0)
|
||||
last_inference = inference_stats.get('last_inference_time')
|
||||
|
||||
logger.info(f" {symbol}:")
|
||||
logger.info(f" Total Inferences: {total_inferences}")
|
||||
logger.info(f" Average Time: {avg_time:.1f}ms")
|
||||
if last_inference:
|
||||
logger.info(f" Last Inference: {last_inference}")
|
||||
|
||||
# Signal statistics
|
||||
logger.info("\n🎯 Signal Accumulation:")
|
||||
for symbol, signal_stats in stats.get('signal_stats', {}).items():
|
||||
current_signals = signal_stats.get('current_signals', 0)
|
||||
confidence_sum = signal_stats.get('confidence_sum', 0.0)
|
||||
success_rate = signal_stats.get('success_rate', 0.0) * 100
|
||||
|
||||
logger.info(f" {symbol}:")
|
||||
logger.info(f" Current Signals: {current_signals}")
|
||||
logger.info(f" Confidence Sum: {confidence_sum:.2f}")
|
||||
logger.info(f" Historical Success Rate: {success_rate:.1f}%")
|
||||
|
||||
# Trading executor statistics
|
||||
if self.trading_executor:
|
||||
positions = self.trading_executor.get_positions()
|
||||
trade_history = self.trading_executor.get_trade_history()
|
||||
|
||||
logger.info("\n💰 Trading Statistics:")
|
||||
logger.info(f" Active Positions: {len(positions)}")
|
||||
logger.info(f" Total Trades: {len(trade_history)}")
|
||||
|
||||
if trade_history:
|
||||
# Calculate P&L statistics
|
||||
total_pnl = sum(trade.pnl for trade in trade_history)
|
||||
profitable_trades = sum(1 for trade in trade_history if trade.pnl > 0)
|
||||
win_rate = (profitable_trades / len(trade_history)) * 100
|
||||
|
||||
logger.info(f" Total P&L: ${total_pnl:.2f}")
|
||||
logger.info(f" Win Rate: {win_rate:.1f}%")
|
||||
|
||||
# Show active positions
|
||||
if positions:
|
||||
logger.info("\n📍 Active Positions:")
|
||||
for symbol, position in positions.items():
|
||||
logger.info(f" {symbol}: {position.side} {position.quantity:.6f} @ ${position.entry_price:.2f}")
|
||||
|
||||
logger.info("=" * 80)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error printing performance stats: {e}")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the trading system gracefully"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
logger.info("Stopping Real-time RL COB Trading System...")
|
||||
|
||||
self.running = False
|
||||
|
||||
# Stop RL trader
|
||||
if self.trader:
|
||||
await self.trader.stop()
|
||||
logger.info("✅ RL Trader stopped")
|
||||
|
||||
# Print final statistics
|
||||
if self.trader:
|
||||
logger.info("\n📊 Final Performance Summary:")
|
||||
await self._print_performance_stats()
|
||||
|
||||
logger.info("Real-time RL COB Trading System stopped successfully")
|
||||
|
||||
async def main():
|
||||
"""Main entry point"""
|
||||
try:
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
# Initialize and start launcher
|
||||
launcher = RealtimeRLCOBTraderLauncher()
|
||||
await launcher.start()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt, shutting down...")
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set event loop policy for Windows compatibility
|
||||
if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -1,218 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple Dashboard Runner - Fixed version for testing
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
# Fix OpenMP library conflicts
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
|
||||
# Fix matplotlib backend
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def create_simple_dashboard():
|
||||
"""Create a simple working dashboard"""
|
||||
try:
|
||||
import dash
|
||||
from dash import html, dcc, Input, Output
|
||||
import plotly.graph_objs as go
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Create Dash app
|
||||
app = dash.Dash(__name__)
|
||||
|
||||
# Simple layout
|
||||
app.layout = html.Div([
|
||||
html.H1("Trading System Dashboard", style={'textAlign': 'center', 'color': '#2c3e50'}),
|
||||
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H3("System Status", style={'color': '#27ae60'}),
|
||||
html.P(id='system-status', children="System: RUNNING", style={'fontSize': '18px'}),
|
||||
html.P(id='current-time', children=f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"),
|
||||
], style={'width': '48%', 'display': 'inline-block', 'padding': '20px'}),
|
||||
|
||||
html.Div([
|
||||
html.H3("Trading Stats", style={'color': '#3498db'}),
|
||||
html.P("Total Trades: 0"),
|
||||
html.P("Success Rate: 0%"),
|
||||
html.P("Current PnL: $0.00"),
|
||||
], style={'width': '48%', 'display': 'inline-block', 'padding': '20px'}),
|
||||
]),
|
||||
|
||||
html.Div([
|
||||
dcc.Graph(id='price-chart'),
|
||||
], style={'padding': '20px'}),
|
||||
|
||||
html.Div([
|
||||
dcc.Graph(id='performance-chart'),
|
||||
], style={'padding': '20px'}),
|
||||
|
||||
# Auto-refresh component
|
||||
dcc.Interval(
|
||||
id='interval-component',
|
||||
interval=5000, # Update every 5 seconds
|
||||
n_intervals=0
|
||||
)
|
||||
])
|
||||
|
||||
# Callback for updating time
|
||||
@app.callback(
|
||||
Output('current-time', 'children'),
|
||||
Input('interval-component', 'n_intervals')
|
||||
)
|
||||
def update_time(n):
|
||||
return f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
# Callback for price chart
|
||||
@app.callback(
|
||||
Output('price-chart', 'figure'),
|
||||
Input('interval-component', 'n_intervals')
|
||||
)
|
||||
def update_price_chart(n):
|
||||
# Generate sample data
|
||||
dates = pd.date_range(start=datetime.now() - timedelta(hours=24),
|
||||
end=datetime.now(), freq='1H')
|
||||
prices = 3000 + np.cumsum(np.random.randn(len(dates)) * 10)
|
||||
|
||||
fig = go.Figure()
|
||||
fig.add_trace(go.Scatter(
|
||||
x=dates,
|
||||
y=prices,
|
||||
mode='lines',
|
||||
name='ETH/USDT',
|
||||
line=dict(color='#3498db', width=2)
|
||||
))
|
||||
|
||||
fig.update_layout(
|
||||
title='ETH/USDT Price Chart (24H)',
|
||||
xaxis_title='Time',
|
||||
yaxis_title='Price (USD)',
|
||||
template='plotly_white',
|
||||
height=400
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
# Callback for performance chart
|
||||
@app.callback(
|
||||
Output('performance-chart', 'figure'),
|
||||
Input('interval-component', 'n_intervals')
|
||||
)
|
||||
def update_performance_chart(n):
|
||||
# Generate sample performance data
|
||||
dates = pd.date_range(start=datetime.now() - timedelta(days=7),
|
||||
end=datetime.now(), freq='1D')
|
||||
performance = np.cumsum(np.random.randn(len(dates)) * 0.02) * 100
|
||||
|
||||
fig = go.Figure()
|
||||
fig.add_trace(go.Scatter(
|
||||
x=dates,
|
||||
y=performance,
|
||||
mode='lines+markers',
|
||||
name='Portfolio Performance',
|
||||
line=dict(color='#27ae60', width=3),
|
||||
marker=dict(size=6)
|
||||
))
|
||||
|
||||
fig.update_layout(
|
||||
title='Portfolio Performance (7 Days)',
|
||||
xaxis_title='Date',
|
||||
yaxis_title='Performance (%)',
|
||||
template='plotly_white',
|
||||
height=400
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
return app
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating dashboard: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def test_data_provider():
|
||||
"""Test data provider in background"""
|
||||
try:
|
||||
from core.data_provider import DataProvider
|
||||
from core.api_rate_limiter import get_rate_limiter
|
||||
|
||||
logger.info("Testing data provider...")
|
||||
|
||||
# Create data provider
|
||||
data_provider = DataProvider(
|
||||
symbols=['ETH/USDT'],
|
||||
timeframes=['1m', '5m']
|
||||
)
|
||||
|
||||
# Test getting data
|
||||
df = data_provider.get_historical_data('ETH/USDT', '1m', limit=10)
|
||||
if df is not None and len(df) > 0:
|
||||
logger.info(f"✓ Data provider working: {len(df)} candles retrieved")
|
||||
else:
|
||||
logger.warning("⚠ Data provider returned no data (rate limiting)")
|
||||
|
||||
# Test rate limiter status
|
||||
rate_limiter = get_rate_limiter()
|
||||
status = rate_limiter.get_all_endpoint_status()
|
||||
logger.info(f"Rate limiter status: {status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data provider test error: {e}")
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("SIMPLE DASHBOARD RUNNER - TESTING SYSTEM")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Test data provider in background
|
||||
data_thread = threading.Thread(target=test_data_provider, daemon=True)
|
||||
data_thread.start()
|
||||
|
||||
# Create and run dashboard
|
||||
app = create_simple_dashboard()
|
||||
if app is None:
|
||||
logger.error("Failed to create dashboard")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Starting dashboard server...")
|
||||
logger.info("Dashboard URL: http://127.0.0.1:8050")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run the dashboard
|
||||
app.run(debug=False, host='127.0.0.1', port=8050, use_reloader=False)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Dashboard stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Dashboard error: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,275 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Stable Dashboard Runner - Prioritizes System Stability
|
||||
|
||||
This runner focuses on:
|
||||
1. System stability and reliability
|
||||
2. Core trading functionality
|
||||
3. Minimal resource usage
|
||||
4. Robust error handling
|
||||
5. Graceful degradation
|
||||
|
||||
Deferred features (until stability is achieved):
|
||||
- TensorBoard integration
|
||||
- Complex training loops
|
||||
- Advanced visualizations
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
import signal
|
||||
from pathlib import Path
|
||||
|
||||
# Fix environment issues before imports
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||
os.environ['OMP_NUM_THREADS'] = '2' # Reduced from 4 for stability
|
||||
|
||||
# Fix matplotlib backend
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging, get_config
|
||||
from system_stability_audit import SystemStabilityAuditor
|
||||
|
||||
# Setup logging with reduced verbosity for stability
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reduce logging noise from other modules
|
||||
logging.getLogger('werkzeug').setLevel(logging.ERROR)
|
||||
logging.getLogger('dash').setLevel(logging.ERROR)
|
||||
logging.getLogger('matplotlib').setLevel(logging.ERROR)
|
||||
|
||||
class StableDashboardRunner:
|
||||
"""
|
||||
Stable dashboard runner with focus on reliability
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize stable dashboard runner"""
|
||||
self.config = get_config()
|
||||
self.running = False
|
||||
self.dashboard = None
|
||||
self.stability_auditor = None
|
||||
|
||||
# Core components
|
||||
self.data_provider = None
|
||||
self.orchestrator = None
|
||||
self.trading_executor = None
|
||||
|
||||
# Stability monitoring
|
||||
self.last_health_check = time.time()
|
||||
self.health_check_interval = 30 # Check every 30 seconds
|
||||
|
||||
logger.info("Stable Dashboard Runner initialized")
|
||||
|
||||
def initialize_components(self):
|
||||
"""Initialize core components with error handling"""
|
||||
try:
|
||||
logger.info("Initializing core components...")
|
||||
|
||||
# Initialize data provider
|
||||
from core.data_provider import DataProvider
|
||||
self.data_provider = DataProvider()
|
||||
logger.info("✓ Data provider initialized")
|
||||
|
||||
# Initialize trading executor
|
||||
from core.trading_executor import TradingExecutor
|
||||
self.trading_executor = TradingExecutor()
|
||||
logger.info("✓ Trading executor initialized")
|
||||
|
||||
# Initialize orchestrator with minimal features for stability
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
self.orchestrator = TradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
enhanced_rl_training=False # Disabled for stability
|
||||
)
|
||||
logger.info("✓ Orchestrator initialized (training disabled for stability)")
|
||||
|
||||
# Initialize dashboard
|
||||
from web.clean_dashboard import CleanTradingDashboard
|
||||
self.dashboard = CleanTradingDashboard(
|
||||
data_provider=self.data_provider,
|
||||
orchestrator=self.orchestrator,
|
||||
trading_executor=self.trading_executor
|
||||
)
|
||||
logger.info("✓ Dashboard initialized")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing components: {e}")
|
||||
return False
|
||||
|
||||
def start_stability_monitoring(self):
|
||||
"""Start system stability monitoring"""
|
||||
try:
|
||||
self.stability_auditor = SystemStabilityAuditor()
|
||||
self.stability_auditor.start_monitoring()
|
||||
logger.info("✓ Stability monitoring started")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting stability monitoring: {e}")
|
||||
|
||||
def health_check(self):
|
||||
"""Perform system health check"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
if current_time - self.last_health_check < self.health_check_interval:
|
||||
return
|
||||
|
||||
self.last_health_check = current_time
|
||||
|
||||
# Check stability score
|
||||
if self.stability_auditor:
|
||||
report = self.stability_auditor.get_stability_report()
|
||||
stability_score = report.get('stability_score', 0)
|
||||
|
||||
if stability_score < 50:
|
||||
logger.warning(f"Low stability score: {stability_score:.1f}/100")
|
||||
# Attempt to fix issues
|
||||
self.stability_auditor.fix_common_issues()
|
||||
elif stability_score < 80:
|
||||
logger.info(f"Moderate stability: {stability_score:.1f}/100")
|
||||
else:
|
||||
logger.debug(f"Good stability: {stability_score:.1f}/100")
|
||||
|
||||
# Check component health
|
||||
if self.dashboard and hasattr(self.dashboard, 'app'):
|
||||
logger.debug("✓ Dashboard responsive")
|
||||
|
||||
if self.data_provider:
|
||||
logger.debug("✓ Data provider active")
|
||||
|
||||
if self.orchestrator:
|
||||
logger.debug("✓ Orchestrator active")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in health check: {e}")
|
||||
|
||||
def run(self):
|
||||
"""Run the stable dashboard"""
|
||||
try:
|
||||
logger.info("=" * 60)
|
||||
logger.info("STABLE TRADING DASHBOARD")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Priority: System Stability & Core Functionality")
|
||||
logger.info("Training: Disabled (will be enabled after stability)")
|
||||
logger.info("TensorBoard: Deferred (documented in design)")
|
||||
logger.info("Focus: Dashboard, Data, Basic Trading")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize components
|
||||
if not self.initialize_components():
|
||||
logger.error("Failed to initialize components")
|
||||
return False
|
||||
|
||||
# Start stability monitoring
|
||||
self.start_stability_monitoring()
|
||||
|
||||
# Start health check thread
|
||||
health_thread = threading.Thread(target=self._health_check_loop, daemon=True)
|
||||
health_thread.start()
|
||||
|
||||
# Get dashboard port
|
||||
port = int(os.environ.get('DASHBOARD_PORT', '8051'))
|
||||
|
||||
logger.info(f"Starting dashboard on http://127.0.0.1:{port}")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
self.running = True
|
||||
|
||||
# Start dashboard (this blocks)
|
||||
if self.dashboard and hasattr(self.dashboard, 'app'):
|
||||
self.dashboard.app.run_server(
|
||||
host='127.0.0.1',
|
||||
port=port,
|
||||
debug=False,
|
||||
use_reloader=False, # Disable reloader for stability
|
||||
threaded=True
|
||||
)
|
||||
else:
|
||||
logger.error("Dashboard not properly initialized")
|
||||
return False
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Dashboard stopped by user")
|
||||
self.shutdown()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error running dashboard: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def _health_check_loop(self):
|
||||
"""Health check loop running in background"""
|
||||
while self.running:
|
||||
try:
|
||||
self.health_check()
|
||||
time.sleep(self.health_check_interval)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in health check loop: {e}")
|
||||
time.sleep(60) # Wait longer on error
|
||||
|
||||
def shutdown(self):
|
||||
"""Graceful shutdown"""
|
||||
try:
|
||||
logger.info("Shutting down stable dashboard...")
|
||||
self.running = False
|
||||
|
||||
# Stop stability monitoring
|
||||
if self.stability_auditor:
|
||||
self.stability_auditor.stop_monitoring()
|
||||
logger.info("✓ Stability monitoring stopped")
|
||||
|
||||
# Stop components
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'stop'):
|
||||
self.orchestrator.stop()
|
||||
logger.info("✓ Orchestrator stopped")
|
||||
|
||||
if self.data_provider and hasattr(self.data_provider, 'stop'):
|
||||
self.data_provider.stop()
|
||||
logger.info("✓ Data provider stopped")
|
||||
|
||||
logger.info("Stable dashboard shutdown complete")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle shutdown signals"""
|
||||
logger.info("Received shutdown signal")
|
||||
sys.exit(0)
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
# Setup signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
runner = StableDashboardRunner()
|
||||
success = runner.run()
|
||||
|
||||
if success:
|
||||
logger.info("Dashboard completed successfully")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("Dashboard failed")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,64 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run Templated Trading Dashboard
|
||||
Demonstrates the new MVC template-based architecture
|
||||
"""
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from web.templated_dashboard import create_templated_dashboard
|
||||
from web.dashboard_model import create_sample_dashboard_data
|
||||
from web.template_renderer import DashboardTemplateRenderer
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the templated dashboard"""
|
||||
try:
|
||||
logger.info("=== TEMPLATED DASHBOARD DEMO ===")
|
||||
|
||||
# Test the template system first
|
||||
logger.info("Testing template system...")
|
||||
|
||||
# Create sample data
|
||||
sample_data = create_sample_dashboard_data()
|
||||
logger.info(f"Created sample data with {len(sample_data.metrics)} metrics")
|
||||
|
||||
# Test template renderer
|
||||
renderer = DashboardTemplateRenderer()
|
||||
logger.info("Template renderer initialized")
|
||||
|
||||
# Create templated dashboard
|
||||
logger.info("Creating templated dashboard...")
|
||||
dashboard = create_templated_dashboard()
|
||||
|
||||
logger.info("Dashboard created successfully!")
|
||||
logger.info("Template-based MVC architecture features:")
|
||||
logger.info(" ✓ HTML templates separated from Python code")
|
||||
logger.info(" ✓ Data models for structured data")
|
||||
logger.info(" ✓ Template renderer for clean separation")
|
||||
logger.info(" ✓ Easy to modify HTML without touching Python")
|
||||
logger.info(" ✓ Reusable components and templates")
|
||||
|
||||
# Run the dashboard
|
||||
logger.info("Starting templated dashboard server...")
|
||||
dashboard.run_server(host='127.0.0.1', port=8052, debug=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running templated dashboard: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,155 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TensorBoard Launch Script
|
||||
|
||||
Starts TensorBoard server for monitoring training progress.
|
||||
Visualizes training metrics, rewards, state information, and model performance.
|
||||
|
||||
This script can be run standalone or integrated with the dashboard.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import webbrowser
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def start_tensorboard(logdir="runs", port=6006, open_browser=True):
|
||||
"""
|
||||
Start TensorBoard server programmatically
|
||||
|
||||
Args:
|
||||
logdir: Directory containing TensorBoard logs
|
||||
port: Port to run TensorBoard on
|
||||
open_browser: Whether to open browser automatically
|
||||
|
||||
Returns:
|
||||
subprocess.Popen: TensorBoard process
|
||||
"""
|
||||
# Set log directory
|
||||
runs_dir = Path(logdir)
|
||||
if not runs_dir.exists():
|
||||
logger.warning(f"No '{logdir}' directory found. Creating it.")
|
||||
runs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if there are any log directories
|
||||
log_dirs = list(runs_dir.glob("*"))
|
||||
if not log_dirs:
|
||||
logger.warning(f"No training logs found in '{logdir}' directory.")
|
||||
else:
|
||||
logger.info(f"Found {len(log_dirs)} training sessions")
|
||||
|
||||
# List available sessions
|
||||
logger.info("Available training sessions:")
|
||||
for i, log_dir in enumerate(sorted(log_dirs), 1):
|
||||
logger.info(f" {i}. {log_dir.name}")
|
||||
|
||||
try:
|
||||
logger.info(f"Starting TensorBoard on port {port}...")
|
||||
|
||||
# Try to open browser automatically if requested
|
||||
if open_browser:
|
||||
try:
|
||||
webbrowser.open(f"http://localhost:{port}")
|
||||
logger.info("Browser opened automatically")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not open browser automatically: {e}")
|
||||
|
||||
# Start TensorBoard process with enhanced options
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tensorboard.main",
|
||||
"--logdir", str(runs_dir),
|
||||
"--port", str(port),
|
||||
"--samples_per_plugin", "images=100,audio=100,text=100",
|
||||
"--reload_interval", "5", # Reload data every 5 seconds
|
||||
"--reload_multifile", "true" # Better handling of multiple log files
|
||||
]
|
||||
|
||||
logger.info("TensorBoard is running with enhanced training visualization!")
|
||||
logger.info(f"View training metrics at: http://localhost:{port}")
|
||||
logger.info("Available dashboards:")
|
||||
logger.info(" - SCALARS: Training metrics, rewards, and losses")
|
||||
logger.info(" - HISTOGRAMS: Feature distributions and model weights")
|
||||
logger.info(" - TIME SERIES: Training progress over time")
|
||||
|
||||
# Start TensorBoard process
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True
|
||||
)
|
||||
|
||||
# Return process for management
|
||||
return process
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error("TensorBoard not found. Install with: pip install tensorboard")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting TensorBoard: {e}")
|
||||
return None
|
||||
|
||||
def main():
|
||||
"""Launch TensorBoard with enhanced visualization options"""
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description="Launch TensorBoard for training visualization")
|
||||
parser.add_argument("--port", type=int, default=6006, help="Port to run TensorBoard on")
|
||||
parser.add_argument("--logdir", type=str, default="runs", help="Directory containing TensorBoard logs")
|
||||
parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically")
|
||||
parser.add_argument("--dashboard-integration", action="store_true", help="Run in dashboard integration mode")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Start TensorBoard
|
||||
process = start_tensorboard(
|
||||
logdir=args.logdir,
|
||||
port=args.port,
|
||||
open_browser=not args.no_browser
|
||||
)
|
||||
|
||||
if process is None:
|
||||
return 1
|
||||
|
||||
# If running in dashboard integration mode, return immediately
|
||||
if args.dashboard_integration:
|
||||
return 0
|
||||
|
||||
# Otherwise, wait for process to complete
|
||||
try:
|
||||
print("\n" + "="*70)
|
||||
print("🔥 TensorBoard is running with enhanced training visualization!")
|
||||
print(f"📈 View training metrics at: http://localhost:{args.port}")
|
||||
print("⏹️ Press Ctrl+C to stop TensorBoard")
|
||||
print("="*70 + "\n")
|
||||
|
||||
# Wait for process to complete or user interrupt
|
||||
process.wait()
|
||||
return 0
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n🛑 TensorBoard stopped")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
182
run_tests.py
182
run_tests.py
@@ -1,182 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unified Test Runner for Trading System
|
||||
|
||||
This script provides a unified interface to run all tests in the system:
|
||||
- Essential functionality tests
|
||||
- Model persistence tests
|
||||
- Training integration tests
|
||||
- Indicators and signals tests
|
||||
- Remaining individual test files
|
||||
|
||||
Usage:
|
||||
python run_tests.py # Run all tests
|
||||
python run_tests.py essential # Run essential tests only
|
||||
python run_tests.py persistence # Run model persistence tests only
|
||||
python run_tests.py training # Run training integration tests only
|
||||
python run_tests.py indicators # Run indicators and signals tests only
|
||||
python run_tests.py individual # Run individual test files only
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from safe_logging import setup_safe_logging
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def run_test_module(module_path, test_type="all"):
|
||||
"""Run a specific test module"""
|
||||
try:
|
||||
cmd = [sys.executable, str(module_path)]
|
||||
if test_type != "all":
|
||||
cmd.append(test_type)
|
||||
|
||||
logger.info(f"Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, cwd=project_root)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f"✅ {module_path.name} passed")
|
||||
if result.stdout:
|
||||
logger.info(result.stdout)
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ {module_path.name} failed")
|
||||
if result.stderr:
|
||||
logger.error(result.stderr)
|
||||
if result.stdout:
|
||||
logger.error(result.stdout)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error running {module_path}: {e}")
|
||||
return False
|
||||
|
||||
def run_essential_tests():
|
||||
"""Run essential functionality tests"""
|
||||
logger.info("=== Running Essential Tests ===")
|
||||
return run_test_module(project_root / "tests" / "test_essential.py")
|
||||
|
||||
def run_persistence_tests():
|
||||
"""Run model persistence tests"""
|
||||
logger.info("=== Running Model Persistence Tests ===")
|
||||
return run_test_module(project_root / "tests" / "test_model_persistence.py")
|
||||
|
||||
def run_training_tests():
|
||||
"""Run training integration tests"""
|
||||
logger.info("=== Running Training Integration Tests ===")
|
||||
return run_test_module(project_root / "tests" / "test_training_integration.py")
|
||||
|
||||
def run_indicators_tests():
|
||||
"""Run indicators and signals tests"""
|
||||
logger.info("=== Running Indicators and Signals Tests ===")
|
||||
return run_test_module(project_root / "tests" / "test_indicators_and_signals.py")
|
||||
|
||||
def run_individual_tests():
|
||||
"""Run remaining individual test files"""
|
||||
logger.info("=== Running Individual Test Files ===")
|
||||
|
||||
individual_tests = [
|
||||
"test_positions.py",
|
||||
"test_tick_cache.py",
|
||||
"test_timestamps.py"
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_file in individual_tests:
|
||||
test_path = project_root / test_file
|
||||
if test_path.exists():
|
||||
logger.info(f"Running {test_file}...")
|
||||
result = run_test_module(test_path)
|
||||
results.append(result)
|
||||
else:
|
||||
logger.warning(f"Test file not found: {test_file}")
|
||||
results.append(False)
|
||||
|
||||
return all(results)
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all test suites"""
|
||||
logger.info("🧪 Running All Trading System Tests")
|
||||
logger.info("=" * 60)
|
||||
|
||||
test_suites = [
|
||||
("Essential Tests", run_essential_tests),
|
||||
("Model Persistence Tests", run_persistence_tests),
|
||||
("Training Integration Tests", run_training_tests),
|
||||
("Indicators and Signals Tests", run_indicators_tests),
|
||||
("Individual Tests", run_individual_tests),
|
||||
]
|
||||
|
||||
results = []
|
||||
for suite_name, suite_func in test_suites:
|
||||
logger.info(f"\n📋 {suite_name}")
|
||||
logger.info("-" * 40)
|
||||
try:
|
||||
result = suite_func()
|
||||
results.append((suite_name, result))
|
||||
except Exception as e:
|
||||
logger.error(f"❌ {suite_name} crashed: {e}")
|
||||
results.append((suite_name, False))
|
||||
|
||||
# Print summary
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("📊 TEST RESULTS SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
|
||||
passed = 0
|
||||
for suite_name, result in results:
|
||||
status = "✅ PASS" if result else "❌ FAIL"
|
||||
logger.info(f"{status}: {suite_name}")
|
||||
if result:
|
||||
passed += 1
|
||||
|
||||
logger.info(f"\nPassed: {passed}/{len(results)} test suites")
|
||||
|
||||
if passed == len(results):
|
||||
logger.info("🎉 All tests passed! Trading system is working correctly.")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"⚠️ {len(results) - passed} test suite(s) failed. Please check the issues above.")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main test runner"""
|
||||
setup_safe_logging()
|
||||
|
||||
# Parse command line arguments
|
||||
if len(sys.argv) > 1:
|
||||
test_type = sys.argv[1].lower()
|
||||
|
||||
if test_type == "essential":
|
||||
success = run_essential_tests()
|
||||
elif test_type == "persistence":
|
||||
success = run_persistence_tests()
|
||||
elif test_type == "training":
|
||||
success = run_training_tests()
|
||||
elif test_type == "indicators":
|
||||
success = run_indicators_tests()
|
||||
elif test_type == "individual":
|
||||
success = run_individual_tests()
|
||||
elif test_type in ["help", "-h", "--help"]:
|
||||
print(__doc__)
|
||||
return 0
|
||||
else:
|
||||
logger.error(f"Unknown test type: {test_type}")
|
||||
print(__doc__)
|
||||
return 1
|
||||
else:
|
||||
success = run_all_tests()
|
||||
|
||||
return 0 if success else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,197 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Kill Stale Processes Script
|
||||
|
||||
Safely terminates stale Python processes related to the trading dashboard
|
||||
with proper error handling and graceful termination.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import signal
|
||||
from pathlib import Path
|
||||
import threading
|
||||
|
||||
# Global timeout flag
|
||||
timeout_reached = False
|
||||
|
||||
def timeout_handler():
|
||||
"""Handler for overall script timeout"""
|
||||
global timeout_reached
|
||||
timeout_reached = True
|
||||
print("\n⚠️ WARNING: Script timeout reached (10s) - forcing exit")
|
||||
os._exit(0) # Force exit
|
||||
|
||||
def kill_stale_processes():
|
||||
"""Kill stale trading dashboard processes safely"""
|
||||
global timeout_reached
|
||||
|
||||
# Set up overall timeout (10 seconds)
|
||||
timer = threading.Timer(10.0, timeout_handler)
|
||||
timer.daemon = True
|
||||
timer.start()
|
||||
|
||||
try:
|
||||
import psutil
|
||||
except ImportError:
|
||||
print("psutil not available - using fallback method")
|
||||
return kill_stale_fallback()
|
||||
|
||||
current_pid = os.getpid()
|
||||
killed_processes = []
|
||||
failed_processes = []
|
||||
|
||||
# Keywords to identify trading dashboard processes
|
||||
target_keywords = [
|
||||
'dashboard', 'scalping', 'trading', 'tensorboard',
|
||||
'run_clean', 'run_main', 'gogo2', 'mexc'
|
||||
]
|
||||
|
||||
try:
|
||||
print("Scanning for stale processes...")
|
||||
|
||||
# Get all Python processes with timeout
|
||||
python_processes = []
|
||||
scan_start = time.time()
|
||||
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
||||
if timeout_reached or (time.time() - scan_start) > 3.0: # 3s max for scanning
|
||||
print("Process scanning timeout - proceeding with found processes")
|
||||
break
|
||||
|
||||
try:
|
||||
if proc.info['pid'] == current_pid:
|
||||
continue
|
||||
|
||||
name = proc.info['name'].lower()
|
||||
if 'python' in name or 'tensorboard' in name:
|
||||
cmdline_str = ' '.join(proc.info['cmdline']) if proc.info['cmdline'] else ''
|
||||
|
||||
# Check if this is a target process
|
||||
if any(keyword in cmdline_str.lower() for keyword in target_keywords):
|
||||
python_processes.append({
|
||||
'proc': proc,
|
||||
'pid': proc.info['pid'],
|
||||
'name': proc.info['name'],
|
||||
'cmdline': cmdline_str
|
||||
})
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
continue
|
||||
|
||||
if not python_processes:
|
||||
print("No stale processes found")
|
||||
timer.cancel() # Cancel the timeout
|
||||
return True
|
||||
|
||||
print(f"Found {len(python_processes)} target processes to terminate:")
|
||||
for p in python_processes[:5]: # Show max 5 to save time
|
||||
print(f" - PID {p['pid']}: {p['name']} - {p['cmdline'][:80]}...")
|
||||
if len(python_processes) > 5:
|
||||
print(f" ... and {len(python_processes) - 5} more")
|
||||
|
||||
# Graceful termination first (with reduced wait time)
|
||||
print("\nAttempting graceful termination...")
|
||||
termination_start = time.time()
|
||||
|
||||
for p in python_processes:
|
||||
if timeout_reached or (time.time() - termination_start) > 2.0:
|
||||
print("Termination timeout - moving to force kill")
|
||||
break
|
||||
|
||||
try:
|
||||
proc = p['proc']
|
||||
if proc.is_running():
|
||||
proc.terminate()
|
||||
print(f" Sent SIGTERM to PID {p['pid']}")
|
||||
except Exception as e:
|
||||
failed_processes.append(f"Failed to terminate PID {p['pid']}: {e}")
|
||||
|
||||
# Wait for graceful shutdown (reduced from 2.0 to 1.0)
|
||||
time.sleep(1.0)
|
||||
|
||||
# Force kill remaining processes
|
||||
print("\nChecking for remaining processes...")
|
||||
kill_start = time.time()
|
||||
|
||||
for p in python_processes:
|
||||
if timeout_reached or (time.time() - kill_start) > 2.0:
|
||||
print("Force kill timeout - exiting")
|
||||
break
|
||||
|
||||
try:
|
||||
proc = p['proc']
|
||||
if proc.is_running():
|
||||
print(f" Force killing PID {p['pid']} ({p['name']})")
|
||||
proc.kill()
|
||||
killed_processes.append(f"Force killed PID {p['pid']} ({p['name']})")
|
||||
else:
|
||||
killed_processes.append(f"Gracefully terminated PID {p['pid']} ({p['name']})")
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
killed_processes.append(f"Process PID {p['pid']} already terminated")
|
||||
except Exception as e:
|
||||
failed_processes.append(f"Failed to kill PID {p['pid']}: {e}")
|
||||
|
||||
# Results (quick summary)
|
||||
print(f"\n=== Quick Results ===")
|
||||
print(f"✓ Cleaned up {len(killed_processes)} processes")
|
||||
if failed_processes:
|
||||
print(f"✗ Failed: {len(failed_processes)} processes")
|
||||
|
||||
timer.cancel() # Cancel the timeout if we finished early
|
||||
return len(failed_processes) == 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during process cleanup: {e}")
|
||||
timer.cancel()
|
||||
return False
|
||||
|
||||
def kill_stale_fallback():
|
||||
"""Fallback method using basic OS commands"""
|
||||
print("Using fallback process killing method...")
|
||||
|
||||
try:
|
||||
if os.name == 'nt': # Windows
|
||||
import subprocess
|
||||
# Kill Python processes with dashboard keywords (with timeout)
|
||||
result = subprocess.run([
|
||||
'taskkill', '/f', '/im', 'python.exe'
|
||||
], capture_output=True, text=True, timeout=5.0)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("Windows: Killed all Python processes")
|
||||
else:
|
||||
print("Windows: No Python processes to kill or access denied")
|
||||
|
||||
else: # Unix/Linux
|
||||
import subprocess
|
||||
# More targeted approach for Unix (with timeouts)
|
||||
subprocess.run(['pkill', '-f', 'dashboard'], capture_output=True, timeout=2.0)
|
||||
subprocess.run(['pkill', '-f', 'scalping'], capture_output=True, timeout=2.0)
|
||||
subprocess.run(['pkill', '-f', 'tensorboard'], capture_output=True, timeout=2.0)
|
||||
print("Unix: Killed dashboard-related processes")
|
||||
|
||||
return True
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print("Fallback method timed out")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Fallback method failed: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 50)
|
||||
print("STALE PROCESS CLEANUP (10s timeout)")
|
||||
print("=" * 50)
|
||||
|
||||
start_time = time.time()
|
||||
success = kill_stale_processes()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
exit_code = 0 if success else 1
|
||||
|
||||
print(f"Completed in {elapsed:.1f}s")
|
||||
print("=" * 50)
|
||||
sys.exit(exit_code)
|
||||
@@ -1,40 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Restart Dashboard with Forced Learning Enabled
|
||||
Simple script to start dashboard with all learning features enabled
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
"""Start dashboard with forced learning"""
|
||||
logger.info("🚀 Starting Dashboard with FORCED LEARNING ENABLED")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Timestamp: {datetime.now()}")
|
||||
logger.info("Fixes Applied:")
|
||||
logger.info("✅ Enhanced RL: FORCED ENABLED")
|
||||
logger.info("✅ CNN Training: FORCED ENABLED")
|
||||
logger.info("✅ Williams Pivots: CNN INTEGRATED")
|
||||
logger.info("✅ Learning Pipeline: ACTIVE")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# Import and run main
|
||||
from main_clean import run_web_dashboard
|
||||
logger.info("Starting web dashboard...")
|
||||
run_web_dashboard()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Dashboard stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Dashboard failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,188 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Overnight Training Restart Script
|
||||
Keeps main.py running continuously, restarting it if it crashes.
|
||||
Designed for overnight training sessions with unstable code.
|
||||
|
||||
Usage:
|
||||
python restart_main_overnight.py
|
||||
|
||||
Press Ctrl+C to stop the restart loop.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import signal
|
||||
import os
|
||||
|
||||
# Setup logging for the restart script
|
||||
def setup_restart_logging():
|
||||
"""Setup logging for restart events"""
|
||||
log_dir = Path("logs")
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create restart log file with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
log_file = log_dir / f"restart_main_{timestamp}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_file, encoding='utf-8'),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Restart script logging to: {log_file}")
|
||||
return logger
|
||||
|
||||
def kill_existing_processes(logger):
|
||||
"""Kill any existing main.py processes to avoid conflicts"""
|
||||
try:
|
||||
if os.name == 'nt': # Windows
|
||||
# Kill any existing Python processes running main.py
|
||||
subprocess.run(['taskkill', '/f', '/im', 'python.exe'],
|
||||
capture_output=True, check=False)
|
||||
subprocess.run(['taskkill', '/f', '/im', 'pythonw.exe'],
|
||||
capture_output=True, check=False)
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not kill existing processes: {e}")
|
||||
|
||||
def run_main_with_restart(logger):
|
||||
"""Main restart loop"""
|
||||
restart_count = 0
|
||||
consecutive_fast_exits = 0
|
||||
start_time = datetime.now()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("OVERNIGHT TRAINING RESTART SCRIPT STARTED")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Press Ctrl+C to stop the restart loop")
|
||||
logger.info("Main script: main.py")
|
||||
logger.info("Restart delay on crash: 10 seconds")
|
||||
logger.info("Fast exit protection: Enabled")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Kill any existing processes
|
||||
kill_existing_processes(logger)
|
||||
|
||||
while True:
|
||||
try:
|
||||
restart_count += 1
|
||||
run_start_time = datetime.now()
|
||||
|
||||
logger.info(f"[RESTART #{restart_count}] Starting main.py at {run_start_time.strftime('%H:%M:%S')}")
|
||||
|
||||
# Start main.py as subprocess
|
||||
process = subprocess.Popen([
|
||||
sys.executable, "main.py"
|
||||
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
universal_newlines=True, bufsize=1)
|
||||
|
||||
logger.info(f"[PROCESS] main.py started with PID: {process.pid}")
|
||||
|
||||
# Stream output from main.py
|
||||
try:
|
||||
if process.stdout:
|
||||
while True:
|
||||
output = process.stdout.readline()
|
||||
if output == '' and process.poll() is not None:
|
||||
break
|
||||
if output:
|
||||
# Forward output from main.py (remove extra newlines)
|
||||
print(f"[MAIN] {output.rstrip()}")
|
||||
else:
|
||||
# If no stdout, just wait for process to complete
|
||||
process.wait()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("[INTERRUPT] Ctrl+C received, stopping main.py...")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("[FORCE KILL] Process didn't terminate, force killing...")
|
||||
process.kill()
|
||||
raise
|
||||
|
||||
# Process has exited
|
||||
exit_code = process.poll()
|
||||
run_end_time = datetime.now()
|
||||
run_duration = (run_end_time - run_start_time).total_seconds()
|
||||
|
||||
logger.info(f"[EXIT] main.py exited with code {exit_code}")
|
||||
logger.info(f"[DURATION] Process ran for {run_duration:.1f} seconds")
|
||||
|
||||
# Check for fast exits (potential configuration issues)
|
||||
if run_duration < 30: # Less than 30 seconds
|
||||
consecutive_fast_exits += 1
|
||||
logger.warning(f"[FAST EXIT] Process exited quickly ({consecutive_fast_exits} consecutive)")
|
||||
|
||||
if consecutive_fast_exits >= 5:
|
||||
logger.error("[ABORT] Too many consecutive fast exits (5+)")
|
||||
logger.error("This indicates a configuration or startup problem")
|
||||
logger.error("Please check the main.py script manually")
|
||||
break
|
||||
|
||||
# Longer delay for fast exits
|
||||
delay = min(60, 10 * consecutive_fast_exits)
|
||||
logger.info(f"[DELAY] Waiting {delay} seconds before restart due to fast exit...")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
consecutive_fast_exits = 0 # Reset counter
|
||||
logger.info("[DELAY] Waiting 10 seconds before restart...")
|
||||
time.sleep(10)
|
||||
|
||||
# Log session statistics every 10 restarts
|
||||
if restart_count % 10 == 0:
|
||||
total_duration = (datetime.now() - start_time).total_seconds()
|
||||
logger.info(f"[STATS] Session: {restart_count} restarts in {total_duration/3600:.1f} hours")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("[SHUTDOWN] Restart loop interrupted by user")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[ERROR] Unexpected error in restart loop: {e}")
|
||||
logger.error("Continuing restart loop after 30 second delay...")
|
||||
time.sleep(30)
|
||||
|
||||
total_duration = (datetime.now() - start_time).total_seconds()
|
||||
logger.info("=" * 60)
|
||||
logger.info("OVERNIGHT TRAINING SESSION COMPLETE")
|
||||
logger.info(f"Total restarts: {restart_count}")
|
||||
logger.info(f"Total session time: {total_duration/3600:.1f} hours")
|
||||
logger.info("=" * 60)
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
# Setup signal handlers for clean shutdown
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"[SIGNAL] Received signal {signum}, shutting down...")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
if hasattr(signal, 'SIGTERM'):
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Setup logging
|
||||
global logger
|
||||
logger = setup_restart_logging()
|
||||
|
||||
try:
|
||||
run_main_with_restart(logger)
|
||||
except Exception as e:
|
||||
logger.error(f"[FATAL] Fatal error in restart script: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,88 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MEXC Browser Setup & Runner
|
||||
|
||||
This script automatically installs dependencies and runs the MEXC browser automation.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import importlib
|
||||
|
||||
def check_and_install_requirements():
|
||||
"""Check and install required packages"""
|
||||
required_packages = [
|
||||
'selenium',
|
||||
'webdriver-manager',
|
||||
'requests'
|
||||
]
|
||||
|
||||
print("🔍 Checking required packages...")
|
||||
|
||||
missing_packages = []
|
||||
for package in required_packages:
|
||||
try:
|
||||
importlib.import_module(package.replace('-', '_'))
|
||||
print(f"✅ {package} - already installed")
|
||||
except ImportError:
|
||||
missing_packages.append(package)
|
||||
print(f"❌ {package} - missing")
|
||||
|
||||
if missing_packages:
|
||||
print(f"\n📦 Installing missing packages: {', '.join(missing_packages)}")
|
||||
|
||||
for package in missing_packages:
|
||||
try:
|
||||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
|
||||
print(f"✅ Successfully installed {package}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Failed to install {package}: {e}")
|
||||
return False
|
||||
|
||||
print("✅ All requirements satisfied!")
|
||||
return True
|
||||
|
||||
def run_browser_automation():
|
||||
"""Run the MEXC browser automation"""
|
||||
try:
|
||||
# Import and run the auto browser
|
||||
from core.mexc_webclient.auto_browser import main as auto_browser_main
|
||||
auto_browser_main()
|
||||
except ImportError:
|
||||
print("❌ Could not import auto browser module")
|
||||
print("Make sure core/mexc_webclient/auto_browser.py exists")
|
||||
except Exception as e:
|
||||
print(f"❌ Error running browser automation: {e}")
|
||||
|
||||
def main():
|
||||
"""Main setup and run function"""
|
||||
print("🚀 MEXC Browser Automation Setup")
|
||||
print("=" * 40)
|
||||
|
||||
# Check Python version
|
||||
if sys.version_info < (3, 7):
|
||||
print("❌ Python 3.7+ required")
|
||||
return
|
||||
|
||||
print(f"✅ Python {sys.version.split()[0]} detected")
|
||||
|
||||
# Install requirements
|
||||
if not check_and_install_requirements():
|
||||
print("❌ Failed to install requirements")
|
||||
return
|
||||
|
||||
print("\n🌐 Starting browser automation...")
|
||||
print("This will:")
|
||||
print("• Download ChromeDriver automatically")
|
||||
print("• Open MEXC futures page")
|
||||
print("• Capture all trading requests")
|
||||
print("• Extract session cookies")
|
||||
|
||||
input("\nPress Enter to continue...")
|
||||
|
||||
# Run the automation
|
||||
run_browser_automation()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,160 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Helper script to start monitoring services for RL training
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import requests
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Available ports to try for TensorBoard
|
||||
TENSORBOARD_PORTS = [6006, 6007, 6008, 6009, 6010, 6011, 6012]
|
||||
|
||||
def check_port(port, service_name):
|
||||
"""Check if a service is running on the specified port"""
|
||||
try:
|
||||
response = requests.get(f"http://localhost:{port}", timeout=3)
|
||||
print(f"✅ {service_name} is running on port {port}")
|
||||
return True
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
def is_port_in_use(port):
|
||||
"""Check if a port is already in use"""
|
||||
import socket
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.bind(('localhost', port))
|
||||
return False
|
||||
except OSError:
|
||||
return True
|
||||
|
||||
def find_available_port(ports_list, service_name):
|
||||
"""Find an available port from the list"""
|
||||
for port in ports_list:
|
||||
if not is_port_in_use(port):
|
||||
print(f"🔍 Found available port {port} for {service_name}")
|
||||
return port
|
||||
else:
|
||||
print(f"⚠️ Port {port} is already in use")
|
||||
return None
|
||||
|
||||
def save_port_config(tensorboard_port):
|
||||
"""Save the port configuration to a file"""
|
||||
config = {
|
||||
"tensorboard_port": tensorboard_port,
|
||||
"web_dashboard_port": 8051
|
||||
}
|
||||
with open("monitoring_ports.json", "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
print(f"💾 Port configuration saved to monitoring_ports.json")
|
||||
|
||||
def start_tensorboard():
|
||||
"""Start TensorBoard in background on an available port"""
|
||||
try:
|
||||
# First check if TensorBoard is already running on any of our ports
|
||||
for port in TENSORBOARD_PORTS:
|
||||
if check_port(port, "TensorBoard"):
|
||||
print(f"✅ TensorBoard already running on port {port}")
|
||||
save_port_config(port)
|
||||
return port
|
||||
|
||||
# Find an available port
|
||||
port = find_available_port(TENSORBOARD_PORTS, "TensorBoard")
|
||||
if port is None:
|
||||
print(f"❌ No available ports found in range {TENSORBOARD_PORTS}")
|
||||
return None
|
||||
|
||||
print(f"🚀 Starting TensorBoard on port {port}...")
|
||||
|
||||
# Create runs directory if it doesn't exist
|
||||
Path("runs").mkdir(exist_ok=True)
|
||||
|
||||
# Start TensorBoard
|
||||
if os.name == 'nt': # Windows
|
||||
subprocess.Popen([
|
||||
sys.executable, "-m", "tensorboard",
|
||||
"--logdir=runs", f"--port={port}", "--reload_interval=1"
|
||||
], creationflags=subprocess.CREATE_NEW_CONSOLE)
|
||||
else: # Linux/Mac
|
||||
subprocess.Popen([
|
||||
sys.executable, "-m", "tensorboard",
|
||||
"--logdir=runs", f"--port={port}", "--reload_interval=1"
|
||||
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
|
||||
# Wait for TensorBoard to start
|
||||
print(f"⏳ Waiting for TensorBoard to start on port {port}...")
|
||||
for i in range(15):
|
||||
time.sleep(2)
|
||||
if check_port(port, "TensorBoard"):
|
||||
save_port_config(port)
|
||||
return port
|
||||
|
||||
print(f"⚠️ TensorBoard failed to start on port {port} within 30 seconds")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error starting TensorBoard: {e}")
|
||||
return None
|
||||
|
||||
def check_web_dashboard_port():
|
||||
"""Check if web dashboard port is available"""
|
||||
port = 8051
|
||||
if is_port_in_use(port):
|
||||
print(f"⚠️ Web dashboard port {port} is in use")
|
||||
# Try alternative ports
|
||||
for alt_port in [8052, 8053, 8054, 8055]:
|
||||
if not is_port_in_use(alt_port):
|
||||
print(f"🔍 Alternative port {alt_port} available for web dashboard")
|
||||
return alt_port
|
||||
print("❌ No alternative ports found for web dashboard")
|
||||
return port
|
||||
else:
|
||||
print(f"✅ Web dashboard port {port} is available")
|
||||
return port
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
print("=" * 60)
|
||||
print("🎯 RL TRAINING MONITORING SETUP")
|
||||
print("=" * 60)
|
||||
|
||||
# Check web dashboard port
|
||||
web_port = check_web_dashboard_port()
|
||||
|
||||
# Start TensorBoard
|
||||
tensorboard_port = start_tensorboard()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 MONITORING STATUS")
|
||||
print("=" * 60)
|
||||
|
||||
if tensorboard_port:
|
||||
print(f"✅ TensorBoard: http://localhost:{tensorboard_port}")
|
||||
# Update port config
|
||||
save_port_config(tensorboard_port)
|
||||
else:
|
||||
print("❌ TensorBoard: Failed to start")
|
||||
print(" Manual start: python -m tensorboard --logdir=runs --port=6007")
|
||||
|
||||
if web_port:
|
||||
print(f"✅ Web Dashboard: Ready on port {web_port}")
|
||||
|
||||
print(f"\n🎯 Ready to start RL training!")
|
||||
if tensorboard_port and web_port != 8051:
|
||||
print(f"Run: python train_realtime_with_tensorboard.py --episodes 10 --web-port {web_port}")
|
||||
else:
|
||||
print("Run: python train_realtime_with_tensorboard.py --episodes 10")
|
||||
|
||||
print(f"\n📋 Available URLs:")
|
||||
if tensorboard_port:
|
||||
print(f" 📊 TensorBoard: http://localhost:{tensorboard_port}")
|
||||
if web_port:
|
||||
print(f" 🌐 Web Dashboard: http://localhost:{web_port} (starts with training)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,179 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Start Overnight Training Session
|
||||
|
||||
This script starts a comprehensive overnight training session that:
|
||||
1. Ensures CNN and COB RL training processes are implemented and running
|
||||
2. Executes training passes on each signal when predictions change
|
||||
3. Calculates PnL and records trades in SIM mode
|
||||
4. Tracks model performance statistics
|
||||
5. Converts signals to actual trades for performance tracking
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(f'overnight_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
"""Start the overnight training session"""
|
||||
try:
|
||||
logger.info("🌙 STARTING OVERNIGHT TRAINING SESSION")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Import required components
|
||||
from core.config import get_config, setup_logging
|
||||
from core.data_provider import DataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
from web.clean_dashboard import CleanTradingDashboard
|
||||
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
|
||||
# Initialize components
|
||||
logger.info("Initializing components...")
|
||||
|
||||
# Create data provider
|
||||
data_provider = DataProvider()
|
||||
logger.info("✅ Data Provider initialized")
|
||||
|
||||
# Create trading executor in simulation mode
|
||||
trading_executor = TradingExecutor()
|
||||
trading_executor.simulation_mode = True # Ensure we're in simulation mode
|
||||
logger.info("✅ Trading Executor initialized (SIMULATION MODE)")
|
||||
|
||||
# Create orchestrator with enhanced training
|
||||
orchestrator = TradingOrchestrator(
|
||||
data_provider=data_provider,
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
logger.info("✅ Trading Orchestrator initialized")
|
||||
|
||||
# Connect trading executor to orchestrator
|
||||
if hasattr(orchestrator, 'set_trading_executor'):
|
||||
orchestrator.set_trading_executor(trading_executor)
|
||||
logger.info("✅ Trading Executor connected to Orchestrator")
|
||||
|
||||
# Create dashboard (this initializes the overnight training coordinator)
|
||||
dashboard = CleanTradingDashboard(
|
||||
data_provider=data_provider,
|
||||
orchestrator=orchestrator,
|
||||
trading_executor=trading_executor
|
||||
)
|
||||
logger.info("✅ Dashboard initialized with Overnight Training Coordinator")
|
||||
|
||||
# Start the overnight training session
|
||||
logger.info("Starting overnight training session...")
|
||||
success = dashboard.start_overnight_training()
|
||||
|
||||
if success:
|
||||
logger.info("🌙 OVERNIGHT TRAINING SESSION STARTED SUCCESSFULLY")
|
||||
logger.info("=" * 80)
|
||||
logger.info("Training Features Active:")
|
||||
logger.info("✅ CNN training on signal changes")
|
||||
logger.info("✅ COB RL training on market microstructure")
|
||||
logger.info("✅ DQN training on trading decisions")
|
||||
logger.info("✅ Trade execution and recording (SIMULATION)")
|
||||
logger.info("✅ Performance tracking and statistics")
|
||||
logger.info("✅ Model checkpointing every 50 trades")
|
||||
logger.info("✅ Signal-to-trade conversion with PnL calculation")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Monitor training progress
|
||||
logger.info("Monitoring training progress...")
|
||||
logger.info("Press Ctrl+C to stop the training session")
|
||||
|
||||
# Keep the session running and periodically report progress
|
||||
start_time = datetime.now()
|
||||
last_report_time = start_time
|
||||
|
||||
while True:
|
||||
try:
|
||||
time.sleep(60) # Check every minute
|
||||
|
||||
current_time = datetime.now()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
# Get performance summary every 10 minutes
|
||||
if (current_time - last_report_time).total_seconds() >= 600: # 10 minutes
|
||||
performance = dashboard.get_training_performance_summary()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"🌙 TRAINING PROGRESS REPORT - {elapsed_time}")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Total Signals: {performance.get('total_signals', 0)}")
|
||||
logger.info(f"Total Trades: {performance.get('total_trades', 0)}")
|
||||
logger.info(f"Successful Trades: {performance.get('successful_trades', 0)}")
|
||||
logger.info(f"Success Rate: {performance.get('success_rate', 0):.1%}")
|
||||
logger.info(f"Total P&L: ${performance.get('total_pnl', 0):.2f}")
|
||||
logger.info(f"Models Trained: {', '.join(performance.get('models_trained', []))}")
|
||||
logger.info(f"Training Status: {'ACTIVE' if performance.get('is_running', False) else 'INACTIVE'}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
last_report_time = current_time
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n🛑 Training session interrupted by user")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error during training monitoring: {e}")
|
||||
time.sleep(30) # Wait 30 seconds before retrying
|
||||
|
||||
# Stop the training session
|
||||
logger.info("Stopping overnight training session...")
|
||||
dashboard.stop_overnight_training()
|
||||
|
||||
# Final report
|
||||
final_performance = dashboard.get_training_performance_summary()
|
||||
total_time = datetime.now() - start_time
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("🌅 OVERNIGHT TRAINING SESSION COMPLETED")
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"Total Duration: {total_time}")
|
||||
logger.info(f"Final Statistics:")
|
||||
logger.info(f" Total Signals: {final_performance.get('total_signals', 0)}")
|
||||
logger.info(f" Total Trades: {final_performance.get('total_trades', 0)}")
|
||||
logger.info(f" Successful Trades: {final_performance.get('successful_trades', 0)}")
|
||||
logger.info(f" Success Rate: {final_performance.get('success_rate', 0):.1%}")
|
||||
logger.info(f" Total P&L: ${final_performance.get('total_pnl', 0):.2f}")
|
||||
logger.info(f" Models Trained: {', '.join(final_performance.get('models_trained', []))}")
|
||||
logger.info("=" * 80)
|
||||
|
||||
else:
|
||||
logger.error("❌ Failed to start overnight training session")
|
||||
return 1
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n🛑 Training session interrupted by user")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in overnight training session: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = main()
|
||||
sys.exit(exit_code)
|
||||
@@ -1,426 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
System Stability Audit and Monitoring
|
||||
|
||||
This script performs a comprehensive audit of the trading system to identify
|
||||
and fix stability issues, memory leaks, and performance bottlenecks.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import psutil
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
import gc
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
import traceback
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging, get_config
|
||||
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SystemStabilityAuditor:
|
||||
"""
|
||||
Comprehensive system stability auditor and monitor
|
||||
|
||||
Monitors:
|
||||
- Memory usage and leaks
|
||||
- CPU usage and performance
|
||||
- Thread health and deadlocks
|
||||
- Model performance and stability
|
||||
- Dashboard responsiveness
|
||||
- Data provider health
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the stability auditor"""
|
||||
self.config = get_config()
|
||||
self.monitoring_active = False
|
||||
self.monitoring_thread = None
|
||||
|
||||
# Performance baselines
|
||||
self.baseline_memory = psutil.virtual_memory().used
|
||||
self.baseline_cpu = psutil.cpu_percent()
|
||||
|
||||
# Monitoring data
|
||||
self.memory_history = []
|
||||
self.cpu_history = []
|
||||
self.thread_history = []
|
||||
self.error_history = []
|
||||
|
||||
# Stability metrics
|
||||
self.stability_score = 100.0
|
||||
self.critical_issues = []
|
||||
self.warnings = []
|
||||
|
||||
logger.info("System Stability Auditor initialized")
|
||||
|
||||
def start_monitoring(self):
|
||||
"""Start continuous system monitoring"""
|
||||
if self.monitoring_active:
|
||||
logger.warning("Monitoring already active")
|
||||
return
|
||||
|
||||
self.monitoring_active = True
|
||||
self.monitoring_thread = threading.Thread(target=self._monitoring_loop, daemon=True)
|
||||
self.monitoring_thread.start()
|
||||
|
||||
logger.info("System stability monitoring started")
|
||||
|
||||
def stop_monitoring(self):
|
||||
"""Stop system monitoring"""
|
||||
self.monitoring_active = False
|
||||
if self.monitoring_thread:
|
||||
self.monitoring_thread.join(timeout=5)
|
||||
|
||||
logger.info("System stability monitoring stopped")
|
||||
|
||||
def _monitoring_loop(self):
|
||||
"""Main monitoring loop"""
|
||||
while self.monitoring_active:
|
||||
try:
|
||||
# Collect system metrics
|
||||
self._collect_system_metrics()
|
||||
|
||||
# Check for memory leaks
|
||||
self._check_memory_leaks()
|
||||
|
||||
# Check CPU usage
|
||||
self._check_cpu_usage()
|
||||
|
||||
# Check thread health
|
||||
self._check_thread_health()
|
||||
|
||||
# Check for deadlocks
|
||||
self._check_for_deadlocks()
|
||||
|
||||
# Update stability score
|
||||
self._update_stability_score()
|
||||
|
||||
# Log status every 60 seconds
|
||||
if len(self.memory_history) % 12 == 0: # Every 12 * 5s = 60s
|
||||
self._log_stability_status()
|
||||
|
||||
time.sleep(5) # Check every 5 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in monitoring loop: {e}")
|
||||
self.error_history.append({
|
||||
'timestamp': datetime.now(),
|
||||
'error': str(e),
|
||||
'traceback': traceback.format_exc()
|
||||
})
|
||||
time.sleep(10) # Wait longer on error
|
||||
|
||||
def _collect_system_metrics(self):
|
||||
"""Collect system performance metrics"""
|
||||
try:
|
||||
# Memory metrics
|
||||
memory = psutil.virtual_memory()
|
||||
memory_data = {
|
||||
'timestamp': datetime.now(),
|
||||
'used_gb': memory.used / (1024**3),
|
||||
'available_gb': memory.available / (1024**3),
|
||||
'percent': memory.percent
|
||||
}
|
||||
self.memory_history.append(memory_data)
|
||||
|
||||
# Keep only last 720 entries (1 hour at 5s intervals)
|
||||
if len(self.memory_history) > 720:
|
||||
self.memory_history = self.memory_history[-720:]
|
||||
|
||||
# CPU metrics
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
cpu_data = {
|
||||
'timestamp': datetime.now(),
|
||||
'percent': cpu_percent,
|
||||
'cores': psutil.cpu_count()
|
||||
}
|
||||
self.cpu_history.append(cpu_data)
|
||||
|
||||
# Keep only last 720 entries
|
||||
if len(self.cpu_history) > 720:
|
||||
self.cpu_history = self.cpu_history[-720:]
|
||||
|
||||
# Thread metrics
|
||||
thread_count = threading.active_count()
|
||||
thread_data = {
|
||||
'timestamp': datetime.now(),
|
||||
'count': thread_count,
|
||||
'threads': [t.name for t in threading.enumerate()]
|
||||
}
|
||||
self.thread_history.append(thread_data)
|
||||
|
||||
# Keep only last 720 entries
|
||||
if len(self.thread_history) > 720:
|
||||
self.thread_history = self.thread_history[-720:]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting system metrics: {e}")
|
||||
|
||||
def _check_memory_leaks(self):
|
||||
"""Check for memory leaks"""
|
||||
try:
|
||||
if len(self.memory_history) < 10:
|
||||
return
|
||||
|
||||
# Check if memory usage is consistently increasing
|
||||
recent_memory = [m['used_gb'] for m in self.memory_history[-10:]]
|
||||
memory_trend = sum(recent_memory[-5:]) / 5 - sum(recent_memory[:5]) / 5
|
||||
|
||||
# If memory increased by more than 100MB in last 10 checks
|
||||
if memory_trend > 0.1:
|
||||
warning = f"Potential memory leak detected: +{memory_trend:.2f}GB in last 50s"
|
||||
if warning not in self.warnings:
|
||||
self.warnings.append(warning)
|
||||
logger.warning(warning)
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
logger.info("Forced garbage collection to free memory")
|
||||
|
||||
# Check for excessive memory usage
|
||||
current_memory = self.memory_history[-1]['percent']
|
||||
if current_memory > 85:
|
||||
critical = f"High memory usage: {current_memory:.1f}%"
|
||||
if critical not in self.critical_issues:
|
||||
self.critical_issues.append(critical)
|
||||
logger.error(critical)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking memory leaks: {e}")
|
||||
|
||||
def _check_cpu_usage(self):
|
||||
"""Check CPU usage patterns"""
|
||||
try:
|
||||
if len(self.cpu_history) < 10:
|
||||
return
|
||||
|
||||
# Check for sustained high CPU usage
|
||||
recent_cpu = [c['percent'] for c in self.cpu_history[-10:]]
|
||||
avg_cpu = sum(recent_cpu) / len(recent_cpu)
|
||||
|
||||
if avg_cpu > 90:
|
||||
critical = f"Sustained high CPU usage: {avg_cpu:.1f}%"
|
||||
if critical not in self.critical_issues:
|
||||
self.critical_issues.append(critical)
|
||||
logger.error(critical)
|
||||
elif avg_cpu > 75:
|
||||
warning = f"High CPU usage: {avg_cpu:.1f}%"
|
||||
if warning not in self.warnings:
|
||||
self.warnings.append(warning)
|
||||
logger.warning(warning)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking CPU usage: {e}")
|
||||
|
||||
def _check_thread_health(self):
|
||||
"""Check thread health and detect issues"""
|
||||
try:
|
||||
if len(self.thread_history) < 5:
|
||||
return
|
||||
|
||||
current_threads = self.thread_history[-1]['count']
|
||||
|
||||
# Check for thread explosion
|
||||
if current_threads > 50:
|
||||
critical = f"Thread explosion detected: {current_threads} active threads"
|
||||
if critical not in self.critical_issues:
|
||||
self.critical_issues.append(critical)
|
||||
logger.error(critical)
|
||||
|
||||
# Log thread names for debugging
|
||||
thread_names = self.thread_history[-1]['threads']
|
||||
logger.error(f"Active threads: {thread_names}")
|
||||
|
||||
# Check for thread leaks (gradually increasing thread count)
|
||||
if len(self.thread_history) >= 10:
|
||||
thread_counts = [t['count'] for t in self.thread_history[-10:]]
|
||||
thread_trend = sum(thread_counts[-5:]) / 5 - sum(thread_counts[:5]) / 5
|
||||
|
||||
if thread_trend > 2: # More than 2 threads increase on average
|
||||
warning = f"Potential thread leak: +{thread_trend:.1f} threads in last 50s"
|
||||
if warning not in self.warnings:
|
||||
self.warnings.append(warning)
|
||||
logger.warning(warning)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking thread health: {e}")
|
||||
|
||||
def _check_for_deadlocks(self):
|
||||
"""Check for potential deadlocks"""
|
||||
try:
|
||||
# Simple deadlock detection based on thread states
|
||||
all_threads = threading.enumerate()
|
||||
blocked_threads = []
|
||||
|
||||
for thread in all_threads:
|
||||
if hasattr(thread, '_is_stopped') and not thread._is_stopped:
|
||||
# Thread is running but might be blocked
|
||||
# This is a simplified check - real deadlock detection is complex
|
||||
pass
|
||||
|
||||
# For now, just check if we have threads that haven't been active
|
||||
# More sophisticated deadlock detection would require thread state analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for deadlocks: {e}")
|
||||
|
||||
def _update_stability_score(self):
|
||||
"""Update overall system stability score"""
|
||||
try:
|
||||
score = 100.0
|
||||
|
||||
# Deduct points for critical issues
|
||||
score -= len(self.critical_issues) * 20
|
||||
|
||||
# Deduct points for warnings
|
||||
score -= len(self.warnings) * 5
|
||||
|
||||
# Deduct points for recent errors
|
||||
recent_errors = [e for e in self.error_history
|
||||
if e['timestamp'] > datetime.now() - timedelta(minutes=10)]
|
||||
score -= len(recent_errors) * 10
|
||||
|
||||
# Deduct points for high resource usage
|
||||
if self.memory_history:
|
||||
current_memory = self.memory_history[-1]['percent']
|
||||
if current_memory > 80:
|
||||
score -= (current_memory - 80) * 2
|
||||
|
||||
if self.cpu_history:
|
||||
current_cpu = self.cpu_history[-1]['percent']
|
||||
if current_cpu > 80:
|
||||
score -= (current_cpu - 80) * 1
|
||||
|
||||
self.stability_score = max(0, score)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating stability score: {e}")
|
||||
|
||||
def _log_stability_status(self):
|
||||
"""Log current stability status"""
|
||||
try:
|
||||
logger.info("=" * 50)
|
||||
logger.info("SYSTEM STABILITY STATUS")
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"Stability Score: {self.stability_score:.1f}/100")
|
||||
|
||||
if self.memory_history:
|
||||
mem = self.memory_history[-1]
|
||||
logger.info(f"Memory: {mem['used_gb']:.1f}GB used ({mem['percent']:.1f}%)")
|
||||
|
||||
if self.cpu_history:
|
||||
cpu = self.cpu_history[-1]
|
||||
logger.info(f"CPU: {cpu['percent']:.1f}%")
|
||||
|
||||
if self.thread_history:
|
||||
threads = self.thread_history[-1]
|
||||
logger.info(f"Threads: {threads['count']} active")
|
||||
|
||||
if self.critical_issues:
|
||||
logger.error(f"Critical Issues ({len(self.critical_issues)}):")
|
||||
for issue in self.critical_issues[-5:]: # Show last 5
|
||||
logger.error(f" - {issue}")
|
||||
|
||||
if self.warnings:
|
||||
logger.warning(f"Warnings ({len(self.warnings)}):")
|
||||
for warning in self.warnings[-5:]: # Show last 5
|
||||
logger.warning(f" - {warning}")
|
||||
|
||||
logger.info("=" * 50)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging stability status: {e}")
|
||||
|
||||
def get_stability_report(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive stability report"""
|
||||
try:
|
||||
return {
|
||||
'stability_score': self.stability_score,
|
||||
'critical_issues': self.critical_issues,
|
||||
'warnings': self.warnings,
|
||||
'memory_usage': self.memory_history[-1] if self.memory_history else None,
|
||||
'cpu_usage': self.cpu_history[-1] if self.cpu_history else None,
|
||||
'thread_count': self.thread_history[-1]['count'] if self.thread_history else 0,
|
||||
'recent_errors': len([e for e in self.error_history
|
||||
if e['timestamp'] > datetime.now() - timedelta(minutes=10)]),
|
||||
'monitoring_active': self.monitoring_active
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating stability report: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def fix_common_issues(self):
|
||||
"""Attempt to fix common stability issues"""
|
||||
try:
|
||||
logger.info("Attempting to fix common stability issues...")
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
logger.info("✓ Forced garbage collection")
|
||||
|
||||
# Clear old history to free memory
|
||||
if len(self.memory_history) > 360: # Keep only 30 minutes
|
||||
self.memory_history = self.memory_history[-360:]
|
||||
if len(self.cpu_history) > 360:
|
||||
self.cpu_history = self.cpu_history[-360:]
|
||||
if len(self.thread_history) > 360:
|
||||
self.thread_history = self.thread_history[-360:]
|
||||
|
||||
logger.info("✓ Cleared old monitoring history")
|
||||
|
||||
# Clear old errors
|
||||
cutoff_time = datetime.now() - timedelta(hours=1)
|
||||
self.error_history = [e for e in self.error_history if e['timestamp'] > cutoff_time]
|
||||
logger.info("✓ Cleared old error history")
|
||||
|
||||
# Reset warnings and critical issues that might be stale
|
||||
self.warnings = []
|
||||
self.critical_issues = []
|
||||
logger.info("✓ Reset stale warnings and critical issues")
|
||||
|
||||
logger.info("Common stability fixes applied")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fixing common issues: {e}")
|
||||
|
||||
def main():
|
||||
"""Main function for standalone execution"""
|
||||
try:
|
||||
logger.info("Starting System Stability Audit")
|
||||
|
||||
auditor = SystemStabilityAuditor()
|
||||
auditor.start_monitoring()
|
||||
|
||||
# Run for 5 minutes then generate report
|
||||
time.sleep(300)
|
||||
|
||||
report = auditor.get_stability_report()
|
||||
logger.info("FINAL STABILITY REPORT:")
|
||||
logger.info(f"Stability Score: {report['stability_score']:.1f}/100")
|
||||
logger.info(f"Critical Issues: {len(report['critical_issues'])}")
|
||||
logger.info(f"Warnings: {len(report['warnings'])}")
|
||||
|
||||
# Attempt fixes if needed
|
||||
if report['stability_score'] < 80:
|
||||
auditor.fix_common_issues()
|
||||
|
||||
auditor.stop_monitoring()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Audit interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stability audit: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,191 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Build Base Data Performance
|
||||
|
||||
This script tests the performance of build_base_data_input to ensure it's instantaneous.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.config import get_config
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_build_base_data_performance():
|
||||
"""Test the performance of build_base_data_input"""
|
||||
|
||||
logger.info("=== Testing Build Base Data Performance ===")
|
||||
|
||||
try:
|
||||
# Initialize orchestrator
|
||||
config = get_config()
|
||||
orchestrator = TradingOrchestrator(
|
||||
symbol="ETH/USDT",
|
||||
config=config
|
||||
)
|
||||
|
||||
# Start the orchestrator to initialize data
|
||||
orchestrator.start()
|
||||
logger.info("✅ Orchestrator started")
|
||||
|
||||
# Wait a bit for data to be populated
|
||||
time.sleep(2)
|
||||
|
||||
# Test performance of build_base_data_input
|
||||
symbol = "ETH/USDT"
|
||||
num_tests = 10
|
||||
total_time = 0
|
||||
|
||||
logger.info(f"Running {num_tests} performance tests...")
|
||||
|
||||
for i in range(num_tests):
|
||||
start_time = time.time()
|
||||
|
||||
base_data = orchestrator.build_base_data_input(symbol)
|
||||
|
||||
end_time = time.time()
|
||||
duration = (end_time - start_time) * 1000 # Convert to milliseconds
|
||||
total_time += duration
|
||||
|
||||
if base_data:
|
||||
logger.info(f"Test {i+1}: {duration:.2f}ms - ✅ Success")
|
||||
else:
|
||||
logger.warning(f"Test {i+1}: {duration:.2f}ms - ❌ Failed (no data)")
|
||||
|
||||
avg_time = total_time / num_tests
|
||||
|
||||
logger.info(f"=== Performance Results ===")
|
||||
logger.info(f"Average time: {avg_time:.2f}ms")
|
||||
logger.info(f"Total time: {total_time:.2f}ms")
|
||||
|
||||
# Performance thresholds
|
||||
if avg_time < 10: # Less than 10ms is excellent
|
||||
logger.info("🎉 EXCELLENT: Build time is under 10ms")
|
||||
elif avg_time < 50: # Less than 50ms is good
|
||||
logger.info("✅ GOOD: Build time is under 50ms")
|
||||
elif avg_time < 100: # Less than 100ms is acceptable
|
||||
logger.info("⚠️ ACCEPTABLE: Build time is under 100ms")
|
||||
else:
|
||||
logger.error("❌ SLOW: Build time is over 100ms - needs optimization")
|
||||
|
||||
# Test with multiple symbols
|
||||
logger.info("Testing with multiple symbols...")
|
||||
symbols = ["ETH/USDT", "BTC/USDT"]
|
||||
|
||||
for symbol in symbols:
|
||||
start_time = time.time()
|
||||
base_data = orchestrator.build_base_data_input(symbol)
|
||||
end_time = time.time()
|
||||
duration = (end_time - start_time) * 1000
|
||||
|
||||
logger.info(f"{symbol}: {duration:.2f}ms")
|
||||
|
||||
# Stop orchestrator
|
||||
orchestrator.stop()
|
||||
logger.info("✅ Orchestrator stopped")
|
||||
|
||||
return avg_time < 100 # Return True if performance is acceptable
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Performance test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_cache_effectiveness():
|
||||
"""Test that caching is working effectively"""
|
||||
|
||||
logger.info("=== Testing Cache Effectiveness ===")
|
||||
|
||||
try:
|
||||
# Initialize orchestrator
|
||||
config = get_config()
|
||||
orchestrator = TradingOrchestrator(
|
||||
symbol="ETH/USDT",
|
||||
config=config
|
||||
)
|
||||
|
||||
orchestrator.start()
|
||||
time.sleep(2) # Let data populate
|
||||
|
||||
symbol = "ETH/USDT"
|
||||
|
||||
# First call (should build cache)
|
||||
start_time = time.time()
|
||||
base_data1 = orchestrator.build_base_data_input(symbol)
|
||||
first_call_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Second call (should use cache)
|
||||
start_time = time.time()
|
||||
base_data2 = orchestrator.build_base_data_input(symbol)
|
||||
second_call_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Third call (should still use cache)
|
||||
start_time = time.time()
|
||||
base_data3 = orchestrator.build_base_data_input(symbol)
|
||||
third_call_time = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(f"First call (build cache): {first_call_time:.2f}ms")
|
||||
logger.info(f"Second call (use cache): {second_call_time:.2f}ms")
|
||||
logger.info(f"Third call (use cache): {third_call_time:.2f}ms")
|
||||
|
||||
# Cache should make subsequent calls faster
|
||||
if second_call_time < first_call_time * 0.5:
|
||||
logger.info("✅ Cache is working effectively")
|
||||
cache_effective = True
|
||||
else:
|
||||
logger.warning("⚠️ Cache may not be working as expected")
|
||||
cache_effective = False
|
||||
|
||||
# Verify data consistency
|
||||
if base_data1 and base_data2 and base_data3:
|
||||
# Check that we get consistent data structure
|
||||
if (len(base_data1.ohlcv_1s) == len(base_data2.ohlcv_1s) == len(base_data3.ohlcv_1s)):
|
||||
logger.info("✅ Data consistency maintained")
|
||||
else:
|
||||
logger.warning("⚠️ Data consistency issues detected")
|
||||
|
||||
orchestrator.stop()
|
||||
|
||||
return cache_effective
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Cache effectiveness test failed: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all performance tests"""
|
||||
|
||||
logger.info("Starting Build Base Data Performance Tests")
|
||||
|
||||
# Test 1: Basic performance
|
||||
test1_passed = test_build_base_data_performance()
|
||||
|
||||
# Test 2: Cache effectiveness
|
||||
test2_passed = test_cache_effectiveness()
|
||||
|
||||
# Summary
|
||||
logger.info("=== Test Summary ===")
|
||||
logger.info(f"Performance Test: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
|
||||
logger.info(f"Cache Effectiveness: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
|
||||
|
||||
if test1_passed and test2_passed:
|
||||
logger.info("🎉 All tests passed! build_base_data_input is optimized.")
|
||||
logger.info("The system now:")
|
||||
logger.info(" - Builds BaseDataInput in under 100ms")
|
||||
logger.info(" - Uses effective caching for repeated calls")
|
||||
logger.info(" - Maintains data consistency")
|
||||
else:
|
||||
logger.error("❌ Some tests failed. Performance optimization needed.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,348 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for Bybit ETH futures position opening/closing
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# Load environment variables from .env file
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
# If dotenv is not available, try to load .env manually
|
||||
if os.path.exists('.env'):
|
||||
with open('.env', 'r') as f:
|
||||
for line in f:
|
||||
if line.strip() and not line.startswith('#'):
|
||||
key, value = line.strip().split('=', 1)
|
||||
os.environ[key] = value
|
||||
|
||||
from core.exchanges.bybit_interface import BybitInterface
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BybitEthFuturesTest:
|
||||
"""Test class for Bybit ETH futures trading"""
|
||||
|
||||
def __init__(self, test_mode=True):
|
||||
self.test_mode = test_mode
|
||||
self.bybit = BybitInterface(test_mode=test_mode)
|
||||
self.test_symbol = 'ETHUSDT'
|
||||
self.test_quantity = 0.01 # Small test amount
|
||||
|
||||
def run_tests(self):
|
||||
"""Run all tests"""
|
||||
print("=" * 60)
|
||||
print("BYBIT ETH FUTURES POSITION TESTING")
|
||||
print("=" * 60)
|
||||
print(f"Test mode: {'TESTNET' if self.test_mode else 'LIVE'}")
|
||||
print(f"Symbol: {self.test_symbol}")
|
||||
print(f"Test quantity: {self.test_quantity} ETH")
|
||||
print("=" * 60)
|
||||
|
||||
# Test 1: Connection
|
||||
if not self.test_connection():
|
||||
print("❌ Connection failed - stopping tests")
|
||||
return False
|
||||
|
||||
# Test 2: Check balance
|
||||
if not self.test_balance():
|
||||
print("❌ Balance check failed - stopping tests")
|
||||
return False
|
||||
|
||||
# Test 3: Check current positions
|
||||
self.test_current_positions()
|
||||
|
||||
# Test 4: Get ticker
|
||||
if not self.test_ticker():
|
||||
print("❌ Ticker test failed - stopping tests")
|
||||
return False
|
||||
|
||||
# Test 5: Open a long position
|
||||
long_order = self.test_open_long_position()
|
||||
if not long_order:
|
||||
print("❌ Open long position failed")
|
||||
return False
|
||||
|
||||
# Test 6: Check position after opening
|
||||
time.sleep(2) # Wait for position to be reflected
|
||||
if not self.test_position_after_open():
|
||||
print("❌ Position check after opening failed")
|
||||
return False
|
||||
|
||||
# Test 7: Close the position
|
||||
if not self.test_close_position():
|
||||
print("❌ Close position failed")
|
||||
return False
|
||||
|
||||
# Test 8: Check position after closing
|
||||
time.sleep(2) # Wait for position to be reflected
|
||||
self.test_position_after_close()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ ALL TESTS COMPLETED SUCCESSFULLY")
|
||||
print("=" * 60)
|
||||
return True
|
||||
|
||||
def test_connection(self):
|
||||
"""Test connection to Bybit"""
|
||||
print("\n📡 Testing connection to Bybit...")
|
||||
|
||||
# First test simple connectivity without auth
|
||||
print("Testing basic API connectivity...")
|
||||
try:
|
||||
from core.exchanges.bybit_rest_client import BybitRestClient
|
||||
client = BybitRestClient(
|
||||
api_key="dummy",
|
||||
api_secret="dummy",
|
||||
testnet=True
|
||||
)
|
||||
|
||||
# Test public endpoint (server time)
|
||||
server_time = client.get_server_time()
|
||||
print(f"✅ Public API working - Server time: {server_time.get('result', {}).get('timeSecond')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Public API failed: {e}")
|
||||
return False
|
||||
|
||||
# Now test with actual credentials
|
||||
print("Testing with API credentials...")
|
||||
try:
|
||||
connected = self.bybit.connect()
|
||||
if connected:
|
||||
print("✅ Successfully connected to Bybit with credentials")
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to connect to Bybit with credentials")
|
||||
print("This might be due to:")
|
||||
print("- Invalid API credentials")
|
||||
print("- Credentials not enabled for testnet")
|
||||
print("- Missing required permissions")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Connection error: {e}")
|
||||
return False
|
||||
|
||||
def test_balance(self):
|
||||
"""Test getting account balance"""
|
||||
print("\n💰 Testing account balance...")
|
||||
|
||||
try:
|
||||
# Get USDT balance (for margin)
|
||||
usdt_balance = self.bybit.get_balance('USDT')
|
||||
print(f"USDT Balance: {usdt_balance}")
|
||||
|
||||
# Get all balances
|
||||
all_balances = self.bybit.get_all_balances()
|
||||
print("All balances:")
|
||||
for asset, balance in all_balances.items():
|
||||
if balance['total'] > 0:
|
||||
print(f" {asset}: Free={balance['free']}, Locked={balance['locked']}, Total={balance['total']}")
|
||||
|
||||
if usdt_balance > 10: # Need at least $10 for testing
|
||||
print("✅ Sufficient balance for testing")
|
||||
return True
|
||||
else:
|
||||
print("❌ Insufficient USDT balance for testing (need at least $10)")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Balance check error: {e}")
|
||||
return False
|
||||
|
||||
def test_current_positions(self):
|
||||
"""Test getting current positions"""
|
||||
print("\n📊 Checking current positions...")
|
||||
|
||||
try:
|
||||
positions = self.bybit.get_positions()
|
||||
if positions:
|
||||
print(f"Found {len(positions)} open positions:")
|
||||
for pos in positions:
|
||||
print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}")
|
||||
print(f" PnL: ${pos['unrealized_pnl']:.2f} ({pos['percentage']:.2f}%)")
|
||||
else:
|
||||
print("No open positions found")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Position check error: {e}")
|
||||
|
||||
def test_ticker(self):
|
||||
"""Test getting ticker information"""
|
||||
print(f"\n📈 Testing ticker for {self.test_symbol}...")
|
||||
|
||||
try:
|
||||
ticker = self.bybit.get_ticker(self.test_symbol)
|
||||
if ticker:
|
||||
print(f"✅ Ticker data received:")
|
||||
print(f" Last Price: ${ticker['last_price']:.2f}")
|
||||
print(f" Bid: ${ticker['bid_price']:.2f}")
|
||||
print(f" Ask: ${ticker['ask_price']:.2f}")
|
||||
print(f" 24h Volume: {ticker['volume_24h']:.2f}")
|
||||
print(f" 24h Change: {ticker['change_24h']:.4f}%")
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to get ticker data")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ticker error: {e}")
|
||||
return False
|
||||
|
||||
def test_open_long_position(self):
|
||||
"""Test opening a long position"""
|
||||
print(f"\n🚀 Opening long position for {self.test_quantity} {self.test_symbol}...")
|
||||
|
||||
try:
|
||||
# Place market buy order
|
||||
order = self.bybit.place_order(
|
||||
symbol=self.test_symbol,
|
||||
side='buy',
|
||||
order_type='market',
|
||||
quantity=self.test_quantity
|
||||
)
|
||||
|
||||
if 'error' in order:
|
||||
print(f"❌ Order failed: {order['error']}")
|
||||
return None
|
||||
|
||||
print("✅ Long position opened successfully:")
|
||||
print(f" Order ID: {order['order_id']}")
|
||||
print(f" Symbol: {order['symbol']}")
|
||||
print(f" Side: {order['side']}")
|
||||
print(f" Quantity: {order['quantity']}")
|
||||
print(f" Status: {order['status']}")
|
||||
|
||||
return order
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Open position error: {e}")
|
||||
return None
|
||||
|
||||
def test_position_after_open(self):
|
||||
"""Test checking position after opening"""
|
||||
print(f"\n📊 Checking position after opening...")
|
||||
|
||||
try:
|
||||
positions = self.bybit.get_positions(self.test_symbol)
|
||||
if positions:
|
||||
position = positions[0]
|
||||
print("✅ Position found:")
|
||||
print(f" Symbol: {position['symbol']}")
|
||||
print(f" Side: {position['side']}")
|
||||
print(f" Size: {position['size']}")
|
||||
print(f" Entry Price: ${position['entry_price']:.2f}")
|
||||
print(f" Mark Price: ${position['mark_price']:.2f}")
|
||||
print(f" Unrealized PnL: ${position['unrealized_pnl']:.2f}")
|
||||
print(f" Percentage: {position['percentage']:.2f}%")
|
||||
print(f" Leverage: {position['leverage']}x")
|
||||
return True
|
||||
else:
|
||||
print("❌ No position found after opening")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Position check error: {e}")
|
||||
return False
|
||||
|
||||
def test_close_position(self):
|
||||
"""Test closing the position"""
|
||||
print(f"\n🔄 Closing position for {self.test_symbol}...")
|
||||
|
||||
try:
|
||||
# Close the position
|
||||
close_order = self.bybit.close_position(self.test_symbol)
|
||||
|
||||
if 'error' in close_order:
|
||||
print(f"❌ Close order failed: {close_order['error']}")
|
||||
return False
|
||||
|
||||
print("✅ Position closed successfully:")
|
||||
print(f" Order ID: {close_order['order_id']}")
|
||||
print(f" Symbol: {close_order['symbol']}")
|
||||
print(f" Side: {close_order['side']}")
|
||||
print(f" Quantity: {close_order['quantity']}")
|
||||
print(f" Status: {close_order['status']}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Close position error: {e}")
|
||||
return False
|
||||
|
||||
def test_position_after_close(self):
|
||||
"""Test checking position after closing"""
|
||||
print(f"\n📊 Checking position after closing...")
|
||||
|
||||
try:
|
||||
positions = self.bybit.get_positions(self.test_symbol)
|
||||
if positions:
|
||||
position = positions[0]
|
||||
print("⚠️ Position still exists (may be partially closed):")
|
||||
print(f" Symbol: {position['symbol']}")
|
||||
print(f" Side: {position['side']}")
|
||||
print(f" Size: {position['size']}")
|
||||
print(f" Entry Price: ${position['entry_price']:.2f}")
|
||||
print(f" Unrealized PnL: ${position['unrealized_pnl']:.2f}")
|
||||
else:
|
||||
print("✅ Position successfully closed - no open positions")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Position check error: {e}")
|
||||
|
||||
def test_order_history(self):
|
||||
"""Test getting order history"""
|
||||
print(f"\n📋 Checking recent orders...")
|
||||
|
||||
try:
|
||||
# Get open orders
|
||||
open_orders = self.bybit.get_open_orders(self.test_symbol)
|
||||
print(f"Open orders: {len(open_orders)}")
|
||||
for order in open_orders:
|
||||
print(f" {order['order_id']}: {order['side']} {order['quantity']} @ ${order['price']:.2f} - {order['status']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Order history error: {e}")
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
print("Starting Bybit ETH Futures Test...")
|
||||
|
||||
# Check if API credentials are set
|
||||
api_key = os.getenv('BYBIT_API_KEY')
|
||||
api_secret = os.getenv('BYBIT_API_SECRET')
|
||||
|
||||
if not api_key or not api_secret:
|
||||
print("❌ Please set BYBIT_API_KEY and BYBIT_API_SECRET environment variables")
|
||||
return False
|
||||
|
||||
# Create test instance
|
||||
test = BybitEthFuturesTest(test_mode=True) # Always use testnet for safety
|
||||
|
||||
# Run tests
|
||||
success = test.run_tests()
|
||||
|
||||
if success:
|
||||
print("\n🎉 All tests passed!")
|
||||
else:
|
||||
print("\n💥 Some tests failed!")
|
||||
|
||||
return success
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -1,304 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fixed Bybit ETH futures trading test with proper minimum order size handling
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# Load environment variables
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
if os.path.exists('.env'):
|
||||
with open('.env', 'r') as f:
|
||||
for line in f:
|
||||
if line.strip() and not line.startswith('#'):
|
||||
key, value = line.strip().split('=', 1)
|
||||
os.environ[key] = value
|
||||
|
||||
from core.exchanges.bybit_interface import BybitInterface
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_instrument_info(bybit: BybitInterface, symbol: str) -> dict:
|
||||
"""Get instrument information including minimum order size"""
|
||||
try:
|
||||
instruments = bybit.get_instruments("linear")
|
||||
for instrument in instruments:
|
||||
if instrument.get('symbol') == symbol:
|
||||
return instrument
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting instrument info: {e}")
|
||||
return {}
|
||||
|
||||
def test_eth_futures_trading():
|
||||
"""Test ETH futures trading with proper minimum order size"""
|
||||
print("🚀 Starting Fixed Bybit ETH Futures Live Trading Test...")
|
||||
print("=" * 60)
|
||||
print("BYBIT ETH FUTURES LIVE TRADING TEST (FIXED)")
|
||||
print("=" * 60)
|
||||
print("⚠️ This uses LIVE environment with real money!")
|
||||
print("⚠️ Will check minimum order size first")
|
||||
print("=" * 60)
|
||||
|
||||
# Check if API credentials are set
|
||||
api_key = os.getenv('BYBIT_API_KEY')
|
||||
api_secret = os.getenv('BYBIT_API_SECRET')
|
||||
|
||||
if not api_key or not api_secret:
|
||||
print("❌ API credentials not found in environment")
|
||||
return False
|
||||
|
||||
# Create Bybit interface with live environment
|
||||
bybit = BybitInterface(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
test_mode=False # Use live environment
|
||||
)
|
||||
|
||||
symbol = 'ETHUSDT'
|
||||
|
||||
# Test 1: Connection
|
||||
print(f"\n📡 Testing connection to Bybit live environment...")
|
||||
try:
|
||||
if not bybit.connect():
|
||||
print("❌ Failed to connect to Bybit")
|
||||
return False
|
||||
print("✅ Successfully connected to Bybit live environment")
|
||||
except Exception as e:
|
||||
print(f"❌ Connection error: {e}")
|
||||
return False
|
||||
|
||||
# Test 2: Get instrument information to check minimum order size
|
||||
print(f"\n📋 Getting instrument information for {symbol}...")
|
||||
try:
|
||||
instrument_info = get_instrument_info(bybit, symbol)
|
||||
if not instrument_info:
|
||||
print(f"❌ Failed to get instrument info for {symbol}")
|
||||
return False
|
||||
|
||||
print("✅ Instrument information retrieved:")
|
||||
print(f" Symbol: {instrument_info.get('symbol')}")
|
||||
print(f" Status: {instrument_info.get('status')}")
|
||||
print(f" Base Coin: {instrument_info.get('baseCoin')}")
|
||||
print(f" Quote Coin: {instrument_info.get('quoteCoin')}")
|
||||
|
||||
# Extract minimum order size
|
||||
lot_size_filter = instrument_info.get('lotSizeFilter', {})
|
||||
min_order_qty = float(lot_size_filter.get('minOrderQty', 0.01))
|
||||
max_order_qty = float(lot_size_filter.get('maxOrderQty', 10000))
|
||||
qty_step = float(lot_size_filter.get('qtyStep', 0.01))
|
||||
|
||||
print(f" Minimum Order Qty: {min_order_qty}")
|
||||
print(f" Maximum Order Qty: {max_order_qty}")
|
||||
print(f" Quantity Step: {qty_step}")
|
||||
|
||||
# Use minimum order size for testing
|
||||
test_quantity = min_order_qty
|
||||
print(f" Using test quantity: {test_quantity} ETH")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Instrument info error: {e}")
|
||||
return False
|
||||
|
||||
# Test 3: Get account balance
|
||||
print(f"\n💰 Checking account balance...")
|
||||
try:
|
||||
usdt_balance = bybit.get_balance('USDT')
|
||||
print(f"USDT Balance: ${usdt_balance:.2f}")
|
||||
|
||||
# Calculate required balance (with some buffer)
|
||||
current_price_data = bybit.get_ticker(symbol)
|
||||
if not current_price_data:
|
||||
print("❌ Failed to get current ETH price")
|
||||
return False
|
||||
|
||||
current_price = current_price_data['last_price']
|
||||
required_balance = current_price * test_quantity * 1.1 # 10% buffer
|
||||
|
||||
print(f"Current ETH price: ${current_price:.2f}")
|
||||
print(f"Required balance: ${required_balance:.2f}")
|
||||
|
||||
if usdt_balance < required_balance:
|
||||
print(f"❌ Insufficient USDT balance for testing (need at least ${required_balance:.2f})")
|
||||
return False
|
||||
|
||||
print("✅ Sufficient balance for testing")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Balance check error: {e}")
|
||||
return False
|
||||
|
||||
# Test 4: Check existing positions
|
||||
print(f"\n📊 Checking existing positions...")
|
||||
try:
|
||||
positions = bybit.get_positions(symbol)
|
||||
if positions:
|
||||
print(f"Found {len(positions)} existing positions:")
|
||||
for pos in positions:
|
||||
print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}")
|
||||
print(f" PnL: ${pos['unrealized_pnl']:.2f}")
|
||||
else:
|
||||
print("No existing positions found")
|
||||
except Exception as e:
|
||||
print(f"❌ Position check error: {e}")
|
||||
return False
|
||||
|
||||
# Test 5: Ask user confirmation before trading
|
||||
print(f"\n⚠️ TRADING CONFIRMATION")
|
||||
print(f" Symbol: {symbol}")
|
||||
print(f" Quantity: {test_quantity} ETH")
|
||||
print(f" Estimated cost: ${current_price * test_quantity:.2f}")
|
||||
print(f" Environment: LIVE (real money)")
|
||||
print(f" Minimum order size confirmed: {min_order_qty}")
|
||||
|
||||
response = input("\nDo you want to proceed with the live trading test? (y/N): ").lower()
|
||||
if response != 'y' and response != 'yes':
|
||||
print("❌ Trading test cancelled by user")
|
||||
return False
|
||||
|
||||
# Test 6: Open a small long position
|
||||
print(f"\n🚀 Opening small long position...")
|
||||
try:
|
||||
order = bybit.place_order(
|
||||
symbol=symbol,
|
||||
side='buy',
|
||||
order_type='market',
|
||||
quantity=test_quantity
|
||||
)
|
||||
|
||||
if 'error' in order:
|
||||
print(f"❌ Order failed: {order['error']}")
|
||||
return False
|
||||
|
||||
print("✅ Long position opened successfully:")
|
||||
print(f" Order ID: {order['order_id']}")
|
||||
print(f" Symbol: {order['symbol']}")
|
||||
print(f" Side: {order['side']}")
|
||||
print(f" Quantity: {order['quantity']}")
|
||||
print(f" Status: {order['status']}")
|
||||
|
||||
order_id = order['order_id']
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Order placement error: {e}")
|
||||
return False
|
||||
|
||||
# Test 7: Wait a moment and check position
|
||||
print(f"\n⏳ Waiting 5 seconds for position to be reflected...")
|
||||
time.sleep(5)
|
||||
|
||||
try:
|
||||
positions = bybit.get_positions(symbol)
|
||||
if positions:
|
||||
position = positions[0]
|
||||
print("✅ Position confirmed:")
|
||||
print(f" Symbol: {position['symbol']}")
|
||||
print(f" Side: {position['side']}")
|
||||
print(f" Size: {position['size']}")
|
||||
print(f" Entry Price: ${position['entry_price']:.2f}")
|
||||
print(f" Current PnL: ${position['unrealized_pnl']:.2f}")
|
||||
print(f" Leverage: {position['leverage']}x")
|
||||
else:
|
||||
print("⚠️ No position found (may already be closed)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Position check error: {e}")
|
||||
|
||||
# Test 8: Close the position
|
||||
print(f"\n🔄 Closing the position...")
|
||||
try:
|
||||
close_order = bybit.close_position(symbol)
|
||||
|
||||
if 'error' in close_order:
|
||||
print(f"❌ Close order failed: {close_order['error']}")
|
||||
# Don't return False here, as the position might still exist
|
||||
print("⚠️ You may need to manually close the position")
|
||||
else:
|
||||
print("✅ Position closed successfully:")
|
||||
print(f" Order ID: {close_order['order_id']}")
|
||||
print(f" Symbol: {close_order['symbol']}")
|
||||
print(f" Side: {close_order['side']}")
|
||||
print(f" Quantity: {close_order['quantity']}")
|
||||
print(f" Status: {close_order['status']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Close position error: {e}")
|
||||
print("⚠️ You may need to manually close the position")
|
||||
|
||||
# Test 9: Final position check
|
||||
print(f"\n📊 Final position check...")
|
||||
time.sleep(3)
|
||||
|
||||
try:
|
||||
positions = bybit.get_positions(symbol)
|
||||
if positions:
|
||||
position = positions[0]
|
||||
print("⚠️ Position still exists:")
|
||||
print(f" Size: {position['size']}")
|
||||
print(f" PnL: ${position['unrealized_pnl']:.2f}")
|
||||
print("💡 You may want to manually close this position")
|
||||
else:
|
||||
print("✅ No open positions - trading test completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Final position check error: {e}")
|
||||
|
||||
# Test 10: Final balance check
|
||||
print(f"\n💰 Final balance check...")
|
||||
try:
|
||||
final_balance = bybit.get_balance('USDT')
|
||||
print(f"Final USDT Balance: ${final_balance:.2f}")
|
||||
|
||||
balance_change = final_balance - usdt_balance
|
||||
if balance_change > 0:
|
||||
print(f"💰 Profit: +${balance_change:.2f}")
|
||||
elif balance_change < 0:
|
||||
print(f"📉 Loss: ${balance_change:.2f}")
|
||||
else:
|
||||
print(f"🔄 No change: ${balance_change:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Final balance check error: {e}")
|
||||
|
||||
return True
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
print("🚀 Starting Fixed Bybit ETH Futures Live Trading Test...")
|
||||
|
||||
success = test_eth_futures_trading()
|
||||
|
||||
if success:
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ BYBIT ETH FUTURES TRADING TEST COMPLETED")
|
||||
print("=" * 60)
|
||||
print("🎯 Your Bybit integration is fully functional!")
|
||||
print("🔄 Position opening and closing works correctly")
|
||||
print("💰 Account balance integration works")
|
||||
print("📊 All trading functions are operational")
|
||||
print("📏 Minimum order size handling works")
|
||||
print("=" * 60)
|
||||
else:
|
||||
print("\n💥 Trading test failed!")
|
||||
print("🔍 Check the error messages above for details")
|
||||
|
||||
return success
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -1,249 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Bybit ETH futures trading with live environment
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# Load environment variables
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
if os.path.exists('.env'):
|
||||
with open('.env', 'r') as f:
|
||||
for line in f:
|
||||
if line.strip() and not line.startswith('#'):
|
||||
key, value = line.strip().split('=', 1)
|
||||
os.environ[key] = value
|
||||
|
||||
from core.exchanges.bybit_interface import BybitInterface
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_eth_futures_trading():
|
||||
"""Test ETH futures trading with live environment"""
|
||||
print("=" * 60)
|
||||
print("BYBIT ETH FUTURES LIVE TRADING TEST")
|
||||
print("=" * 60)
|
||||
print("⚠️ This uses LIVE environment with real money!")
|
||||
print("⚠️ Test amount: 0.001 ETH (very small)")
|
||||
print("=" * 60)
|
||||
|
||||
# Check if API credentials are set
|
||||
api_key = os.getenv('BYBIT_API_KEY')
|
||||
api_secret = os.getenv('BYBIT_API_SECRET')
|
||||
|
||||
if not api_key or not api_secret:
|
||||
print("❌ API credentials not found in environment")
|
||||
return False
|
||||
|
||||
# Create Bybit interface with live environment
|
||||
bybit = BybitInterface(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
test_mode=False # Use live environment
|
||||
)
|
||||
|
||||
symbol = 'ETHUSDT'
|
||||
test_quantity = 0.01 # Minimum order size for ETH futures
|
||||
|
||||
# Test 1: Connection
|
||||
print(f"\n📡 Testing connection to Bybit live environment...")
|
||||
try:
|
||||
if not bybit.connect():
|
||||
print("❌ Failed to connect to Bybit")
|
||||
return False
|
||||
print("✅ Successfully connected to Bybit live environment")
|
||||
except Exception as e:
|
||||
print(f"❌ Connection error: {e}")
|
||||
return False
|
||||
|
||||
# Test 2: Get account balance
|
||||
print(f"\n💰 Checking account balance...")
|
||||
try:
|
||||
usdt_balance = bybit.get_balance('USDT')
|
||||
print(f"USDT Balance: ${usdt_balance:.2f}")
|
||||
|
||||
if usdt_balance < 5:
|
||||
print("❌ Insufficient USDT balance for testing (need at least $5)")
|
||||
return False
|
||||
|
||||
print("✅ Sufficient balance for testing")
|
||||
except Exception as e:
|
||||
print(f"❌ Balance check error: {e}")
|
||||
return False
|
||||
|
||||
# Test 3: Get current ETH price
|
||||
print(f"\n📈 Getting current ETH price...")
|
||||
try:
|
||||
ticker = bybit.get_ticker(symbol)
|
||||
if not ticker:
|
||||
print("❌ Failed to get ticker")
|
||||
return False
|
||||
|
||||
current_price = ticker['last_price']
|
||||
print(f"Current ETH price: ${current_price:.2f}")
|
||||
print(f"Test order value: ${current_price * test_quantity:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ticker error: {e}")
|
||||
return False
|
||||
|
||||
# Test 4: Check existing positions
|
||||
print(f"\n📊 Checking existing positions...")
|
||||
try:
|
||||
positions = bybit.get_positions(symbol)
|
||||
if positions:
|
||||
print(f"Found {len(positions)} existing positions:")
|
||||
for pos in positions:
|
||||
print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}")
|
||||
print(f" PnL: ${pos['unrealized_pnl']:.2f}")
|
||||
else:
|
||||
print("No existing positions found")
|
||||
except Exception as e:
|
||||
print(f"❌ Position check error: {e}")
|
||||
return False
|
||||
|
||||
# Test 5: Ask user confirmation before trading
|
||||
print(f"\n⚠️ TRADING CONFIRMATION")
|
||||
print(f" Symbol: {symbol}")
|
||||
print(f" Quantity: {test_quantity} ETH")
|
||||
print(f" Estimated cost: ${current_price * test_quantity:.2f}")
|
||||
print(f" Environment: LIVE (real money)")
|
||||
|
||||
response = input("\nDo you want to proceed with the live trading test? (y/N): ").lower()
|
||||
if response != 'y' and response != 'yes':
|
||||
print("❌ Trading test cancelled by user")
|
||||
return False
|
||||
|
||||
# Test 6: Open a small long position
|
||||
print(f"\n🚀 Opening small long position...")
|
||||
try:
|
||||
order = bybit.place_order(
|
||||
symbol=symbol,
|
||||
side='buy',
|
||||
order_type='market',
|
||||
quantity=test_quantity
|
||||
)
|
||||
|
||||
if 'error' in order:
|
||||
print(f"❌ Order failed: {order['error']}")
|
||||
return False
|
||||
|
||||
print("✅ Long position opened successfully:")
|
||||
print(f" Order ID: {order['order_id']}")
|
||||
print(f" Symbol: {order['symbol']}")
|
||||
print(f" Side: {order['side']}")
|
||||
print(f" Quantity: {order['quantity']}")
|
||||
print(f" Status: {order['status']}")
|
||||
|
||||
order_id = order['order_id']
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Order placement error: {e}")
|
||||
return False
|
||||
|
||||
# Test 7: Wait a moment and check position
|
||||
print(f"\n⏳ Waiting 3 seconds for position to be reflected...")
|
||||
time.sleep(3)
|
||||
|
||||
try:
|
||||
positions = bybit.get_positions(symbol)
|
||||
if positions:
|
||||
position = positions[0]
|
||||
print("✅ Position confirmed:")
|
||||
print(f" Symbol: {position['symbol']}")
|
||||
print(f" Side: {position['side']}")
|
||||
print(f" Size: {position['size']}")
|
||||
print(f" Entry Price: ${position['entry_price']:.2f}")
|
||||
print(f" Current PnL: ${position['unrealized_pnl']:.2f}")
|
||||
print(f" Leverage: {position['leverage']}x")
|
||||
else:
|
||||
print("⚠️ No position found (may already be closed)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Position check error: {e}")
|
||||
|
||||
# Test 8: Close the position
|
||||
print(f"\n🔄 Closing the position...")
|
||||
try:
|
||||
close_order = bybit.close_position(symbol)
|
||||
|
||||
if 'error' in close_order:
|
||||
print(f"❌ Close order failed: {close_order['error']}")
|
||||
return False
|
||||
|
||||
print("✅ Position closed successfully:")
|
||||
print(f" Order ID: {close_order['order_id']}")
|
||||
print(f" Symbol: {close_order['symbol']}")
|
||||
print(f" Side: {close_order['side']}")
|
||||
print(f" Quantity: {close_order['quantity']}")
|
||||
print(f" Status: {close_order['status']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Close position error: {e}")
|
||||
return False
|
||||
|
||||
# Test 9: Final position check
|
||||
print(f"\n📊 Final position check...")
|
||||
time.sleep(2)
|
||||
|
||||
try:
|
||||
positions = bybit.get_positions(symbol)
|
||||
if positions:
|
||||
position = positions[0]
|
||||
print("⚠️ Position still exists:")
|
||||
print(f" Size: {position['size']}")
|
||||
print(f" PnL: ${position['unrealized_pnl']:.2f}")
|
||||
else:
|
||||
print("✅ No open positions - trading test completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Final position check error: {e}")
|
||||
|
||||
# Test 10: Final balance check
|
||||
print(f"\n💰 Final balance check...")
|
||||
try:
|
||||
final_balance = bybit.get_balance('USDT')
|
||||
print(f"Final USDT Balance: ${final_balance:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Final balance check error: {e}")
|
||||
|
||||
return True
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
print("🚀 Starting Bybit ETH Futures Live Trading Test...")
|
||||
|
||||
success = test_eth_futures_trading()
|
||||
|
||||
if success:
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ BYBIT ETH FUTURES TRADING TEST COMPLETED")
|
||||
print("=" * 60)
|
||||
print("🎯 Your Bybit integration is fully functional!")
|
||||
print("🔄 Position opening and closing works correctly")
|
||||
print("💰 Account balance integration works")
|
||||
print("📊 All trading functions are operational")
|
||||
print("=" * 60)
|
||||
else:
|
||||
print("\n💥 Trading test failed!")
|
||||
|
||||
return success
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -1,220 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Bybit public API functionality (no authentication required)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.exchanges.bybit_rest_client import BybitRestClient
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_public_api():
|
||||
"""Test public API endpoints"""
|
||||
print("=" * 60)
|
||||
print("BYBIT PUBLIC API TEST")
|
||||
print("=" * 60)
|
||||
|
||||
# Test both testnet and live for public endpoints
|
||||
for testnet in [True, False]:
|
||||
env_name = "TESTNET" if testnet else "LIVE"
|
||||
print(f"\n🔄 Testing {env_name} environment...")
|
||||
|
||||
client = BybitRestClient(
|
||||
api_key="dummy",
|
||||
api_secret="dummy",
|
||||
testnet=testnet
|
||||
)
|
||||
|
||||
# Test 1: Server time
|
||||
try:
|
||||
server_time = client.get_server_time()
|
||||
time_second = server_time.get('result', {}).get('timeSecond')
|
||||
print(f"✅ Server time: {time_second}")
|
||||
except Exception as e:
|
||||
print(f"❌ Server time failed: {e}")
|
||||
continue
|
||||
|
||||
# Test 2: Get ticker for ETHUSDT
|
||||
try:
|
||||
ticker = client.get_ticker('ETHUSDT', 'linear')
|
||||
ticker_data = ticker.get('result', {}).get('list', [])
|
||||
if ticker_data:
|
||||
data = ticker_data[0]
|
||||
print(f"✅ ETH/USDT ticker:")
|
||||
print(f" Last Price: ${float(data.get('lastPrice', 0)):.2f}")
|
||||
print(f" 24h Volume: {float(data.get('volume24h', 0)):.2f}")
|
||||
print(f" 24h Change: {float(data.get('price24hPcnt', 0)) * 100:.2f}%")
|
||||
else:
|
||||
print("❌ No ticker data received")
|
||||
except Exception as e:
|
||||
print(f"❌ Ticker failed: {e}")
|
||||
|
||||
# Test 3: Get instruments info
|
||||
try:
|
||||
instruments = client.get_instruments_info('linear')
|
||||
instruments_list = instruments.get('result', {}).get('list', [])
|
||||
eth_instruments = [i for i in instruments_list if 'ETH' in i.get('symbol', '')]
|
||||
print(f"✅ Found {len(eth_instruments)} ETH instruments")
|
||||
for instr in eth_instruments[:3]: # Show first 3
|
||||
print(f" {instr.get('symbol')} - Status: {instr.get('status')}")
|
||||
except Exception as e:
|
||||
print(f"❌ Instruments failed: {e}")
|
||||
|
||||
# Test 4: Get orderbook
|
||||
try:
|
||||
orderbook = client.get_orderbook('ETHUSDT', 'linear', 5)
|
||||
ob_data = orderbook.get('result', {})
|
||||
bids = ob_data.get('b', [])
|
||||
asks = ob_data.get('a', [])
|
||||
|
||||
if bids and asks:
|
||||
print(f"✅ Orderbook (top 3):")
|
||||
print(f" Best bid: ${float(bids[0][0]):.2f} (qty: {float(bids[0][1]):.4f})")
|
||||
print(f" Best ask: ${float(asks[0][0]):.2f} (qty: {float(asks[0][1]):.4f})")
|
||||
spread = float(asks[0][0]) - float(bids[0][0])
|
||||
print(f" Spread: ${spread:.2f}")
|
||||
else:
|
||||
print("❌ No orderbook data received")
|
||||
except Exception as e:
|
||||
print(f"❌ Orderbook failed: {e}")
|
||||
|
||||
print(f"📊 {env_name} environment test completed")
|
||||
|
||||
def test_live_authentication():
|
||||
"""Test live authentication (if user wants to test with live credentials)"""
|
||||
print("\n" + "=" * 60)
|
||||
print("BYBIT LIVE AUTHENTICATION TEST")
|
||||
print("=" * 60)
|
||||
print("⚠️ This will test with LIVE credentials (not testnet)")
|
||||
|
||||
# Load environment variables
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
# If dotenv is not available, try to load .env manually
|
||||
if os.path.exists('.env'):
|
||||
with open('.env', 'r') as f:
|
||||
for line in f:
|
||||
if line.strip() and not line.startswith('#'):
|
||||
key, value = line.strip().split('=', 1)
|
||||
os.environ[key] = value
|
||||
|
||||
api_key = os.getenv('BYBIT_API_KEY')
|
||||
api_secret = os.getenv('BYBIT_API_SECRET')
|
||||
|
||||
if not api_key or not api_secret:
|
||||
print("❌ No API credentials found in environment")
|
||||
return
|
||||
|
||||
print(f"🔑 Using API key: {api_key[:8]}...")
|
||||
|
||||
# Test with live environment (testnet=False)
|
||||
client = BybitRestClient(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
testnet=False # Use live environment
|
||||
)
|
||||
|
||||
# Test connectivity
|
||||
try:
|
||||
if client.test_connectivity():
|
||||
print("✅ Basic connectivity OK")
|
||||
else:
|
||||
print("❌ Basic connectivity failed")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"❌ Connectivity error: {e}")
|
||||
return
|
||||
|
||||
# Test authentication
|
||||
try:
|
||||
if client.test_authentication():
|
||||
print("✅ Authentication successful!")
|
||||
|
||||
# Get account info
|
||||
account_info = client.get_account_info()
|
||||
accounts = account_info.get('result', {}).get('list', [])
|
||||
|
||||
if accounts:
|
||||
print("📊 Account information:")
|
||||
for account in accounts:
|
||||
account_type = account.get('accountType', 'Unknown')
|
||||
print(f" Account Type: {account_type}")
|
||||
|
||||
coins = account.get('coin', [])
|
||||
usdt_balance = None
|
||||
for coin in coins:
|
||||
if coin.get('coin') == 'USDT':
|
||||
usdt_balance = float(coin.get('walletBalance', 0))
|
||||
break
|
||||
|
||||
if usdt_balance:
|
||||
print(f" USDT Balance: ${usdt_balance:.2f}")
|
||||
|
||||
# Show positions if any
|
||||
try:
|
||||
positions = client.get_positions('linear')
|
||||
pos_list = positions.get('result', {}).get('list', [])
|
||||
active_positions = [p for p in pos_list if float(p.get('size', 0)) != 0]
|
||||
|
||||
if active_positions:
|
||||
print(f" Active Positions: {len(active_positions)}")
|
||||
for pos in active_positions:
|
||||
symbol = pos.get('symbol')
|
||||
side = pos.get('side')
|
||||
size = float(pos.get('size', 0))
|
||||
pnl = float(pos.get('unrealisedPnl', 0))
|
||||
print(f" {symbol}: {side} {size} (PnL: ${pnl:.2f})")
|
||||
else:
|
||||
print(" No active positions")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Could not get positions: {e}")
|
||||
|
||||
return True
|
||||
else:
|
||||
print("❌ Authentication failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Authentication error: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
print("🚀 Starting Bybit API Tests...")
|
||||
|
||||
# Test public API
|
||||
test_public_api()
|
||||
|
||||
# Ask user if they want to test live authentication
|
||||
print("\n" + "=" * 60)
|
||||
response = input("Do you want to test live authentication? (y/N): ").lower()
|
||||
|
||||
if response == 'y' or response == 'yes':
|
||||
success = test_live_authentication()
|
||||
if success:
|
||||
print("\n✅ Live authentication test passed!")
|
||||
print("🎯 Your Bybit integration is working!")
|
||||
else:
|
||||
print("\n❌ Live authentication test failed")
|
||||
else:
|
||||
print("\n📋 Skipping live authentication test")
|
||||
|
||||
print("\n🎉 Public API tests completed successfully!")
|
||||
print("📈 Bybit integration is functional for market data")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,137 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Cache Fix
|
||||
|
||||
Creates a corrupted Parquet file to test the fix mechanism
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from utils.cache_manager import get_cache_manager
|
||||
|
||||
def create_test_data():
|
||||
"""Create test cache files including a corrupted one"""
|
||||
print("Creating test cache files...")
|
||||
|
||||
# Ensure cache directory exists
|
||||
cache_dir = Path("data/cache")
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a valid Parquet file
|
||||
valid_data = pd.DataFrame({
|
||||
'timestamp': pd.date_range('2025-01-01', periods=100, freq='1min'),
|
||||
'open': [100.0 + i for i in range(100)],
|
||||
'high': [101.0 + i for i in range(100)],
|
||||
'low': [99.0 + i for i in range(100)],
|
||||
'close': [100.5 + i for i in range(100)],
|
||||
'volume': [1000 + i*10 for i in range(100)]
|
||||
})
|
||||
|
||||
valid_file = cache_dir / "ETHUSDT_1m.parquet"
|
||||
valid_data.to_parquet(valid_file, index=False)
|
||||
print(f"Created valid file: {valid_file}")
|
||||
|
||||
# Create a corrupted Parquet file by writing invalid data
|
||||
corrupted_file = cache_dir / "BTCUSDT_1m.parquet"
|
||||
with open(corrupted_file, 'wb') as f:
|
||||
f.write(b"This is not a valid Parquet file - corrupted data")
|
||||
print(f"Created corrupted file: {corrupted_file}")
|
||||
|
||||
# Create an empty file
|
||||
empty_file = cache_dir / "SOLUSDT_1m.parquet"
|
||||
empty_file.touch()
|
||||
print(f"Created empty file: {empty_file}")
|
||||
|
||||
def test_cache_manager():
|
||||
"""Test the cache manager's ability to detect and fix issues"""
|
||||
print("\n=== Testing Cache Manager ===")
|
||||
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
# Scan health
|
||||
print("1. Scanning cache health...")
|
||||
health_summary = cache_manager.get_cache_summary()
|
||||
|
||||
print(f"Total files: {health_summary['total_files']}")
|
||||
print(f"Valid files: {health_summary['valid_files']}")
|
||||
print(f"Corrupted files: {health_summary['corrupted_files']}")
|
||||
print(f"Health percentage: {health_summary['health_percentage']:.1f}%")
|
||||
|
||||
# Show corrupted files
|
||||
for cache_dir, report in health_summary['directories'].items():
|
||||
if report['corrupted_files'] > 0:
|
||||
print(f"\nCorrupted files in {cache_dir}:")
|
||||
for corrupted in report['corrupted_files_list']:
|
||||
print(f" - {corrupted['file']}: {corrupted['error']}")
|
||||
|
||||
# Test cleanup
|
||||
print("\n2. Testing cleanup...")
|
||||
deleted_files = cache_manager.cleanup_corrupted_files(dry_run=False)
|
||||
|
||||
deleted_count = 0
|
||||
for cache_dir, files in deleted_files.items():
|
||||
for file_info in files:
|
||||
if "DELETED:" in file_info:
|
||||
deleted_count += 1
|
||||
print(f" {file_info}")
|
||||
|
||||
print(f"Deleted {deleted_count} corrupted files")
|
||||
|
||||
# Verify cleanup
|
||||
print("\n3. Verifying cleanup...")
|
||||
health_summary_after = cache_manager.get_cache_summary()
|
||||
print(f"Corrupted files after cleanup: {health_summary_after['corrupted_files']}")
|
||||
|
||||
def test_data_provider_integration():
|
||||
"""Test that the data provider can handle corrupted cache gracefully"""
|
||||
print("\n=== Testing Data Provider Integration ===")
|
||||
|
||||
# Create another corrupted file
|
||||
cache_dir = Path("data/cache")
|
||||
corrupted_file = cache_dir / "ETHUSDT_5m.parquet"
|
||||
with open(corrupted_file, 'wb') as f:
|
||||
f.write(b"PAR1\x00\x00corrupted thrift data that will cause deserialization error")
|
||||
print(f"Created corrupted file with thrift error: {corrupted_file}")
|
||||
|
||||
# Try to import and use data provider
|
||||
try:
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Create data provider (should auto-fix corrupted cache)
|
||||
data_provider = DataProvider()
|
||||
print("Data provider created successfully - auto-fix worked!")
|
||||
|
||||
# Try to load data (should handle corruption gracefully)
|
||||
try:
|
||||
data = data_provider._load_from_cache("ETH/USDT", "5m")
|
||||
if data is None:
|
||||
print("Cache loading returned None (expected for corrupted file)")
|
||||
else:
|
||||
print(f"Loaded {len(data)} rows from cache")
|
||||
except Exception as e:
|
||||
print(f"Cache loading failed: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Data provider test failed: {e}")
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("=== Cache Fix Test Suite ===")
|
||||
|
||||
# Clean up any existing test files
|
||||
cache_dir = Path("data/cache")
|
||||
if cache_dir.exists():
|
||||
for file in cache_dir.glob("*.parquet"):
|
||||
file.unlink()
|
||||
|
||||
# Run tests
|
||||
create_test_data()
|
||||
test_cache_manager()
|
||||
test_data_provider_integration()
|
||||
|
||||
print("\n=== Test Complete ===")
|
||||
print("The cache fix system is working correctly!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,175 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test CNN Integration
|
||||
|
||||
This script tests if the CNN adapter is working properly and identifies issues.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_cnn_adapter():
|
||||
"""Test CNN adapter initialization and basic functionality"""
|
||||
try:
|
||||
logger.info("Testing CNN adapter initialization...")
|
||||
|
||||
# Test 1: Import CNN adapter
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
logger.info("✅ CNN adapter import successful")
|
||||
|
||||
# Test 2: Initialize CNN adapter
|
||||
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
logger.info("✅ CNN adapter initialization successful")
|
||||
|
||||
# Test 3: Check adapter attributes
|
||||
logger.info(f"CNN adapter model: {cnn_adapter.model}")
|
||||
logger.info(f"CNN adapter device: {cnn_adapter.device}")
|
||||
logger.info(f"CNN adapter model_name: {cnn_adapter.model_name}")
|
||||
|
||||
# Test 4: Check metrics tracking
|
||||
logger.info(f"Inference count: {cnn_adapter.inference_count}")
|
||||
logger.info(f"Training count: {cnn_adapter.training_count}")
|
||||
logger.info(f"Training data length: {len(cnn_adapter.training_data)}")
|
||||
|
||||
# Test 5: Test simple training sample addition
|
||||
cnn_adapter.add_training_sample("ETH/USDT", "BUY", 0.1)
|
||||
logger.info(f"✅ Training sample added, new length: {len(cnn_adapter.training_data)}")
|
||||
|
||||
# Test 6: Test training if we have enough samples
|
||||
if len(cnn_adapter.training_data) >= 2:
|
||||
# Add another sample to have minimum for training
|
||||
cnn_adapter.add_training_sample("ETH/USDT", "SELL", -0.05)
|
||||
|
||||
# Try training
|
||||
training_result = cnn_adapter.train(epochs=1)
|
||||
logger.info(f"✅ Training successful: {training_result}")
|
||||
|
||||
# Check if metrics were updated
|
||||
logger.info(f"Last training time: {cnn_adapter.last_training_time}")
|
||||
logger.info(f"Last training loss: {cnn_adapter.last_training_loss}")
|
||||
logger.info(f"Training count: {cnn_adapter.training_count}")
|
||||
else:
|
||||
logger.info("⚠️ Not enough training samples for training test")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ CNN adapter test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_base_data_input():
|
||||
"""Test BaseDataInput creation"""
|
||||
try:
|
||||
logger.info("Testing BaseDataInput creation...")
|
||||
|
||||
# Test 1: Import BaseDataInput
|
||||
from core.data_models import BaseDataInput, OHLCVBar, COBData
|
||||
logger.info("✅ BaseDataInput import successful")
|
||||
|
||||
# Test 2: Create sample OHLCV bars
|
||||
sample_bars = []
|
||||
for i in range(10): # Create 10 sample bars
|
||||
bar = OHLCVBar(
|
||||
symbol="ETH/USDT",
|
||||
timestamp=datetime.now(),
|
||||
open=3500.0 + i,
|
||||
high=3510.0 + i,
|
||||
low=3490.0 + i,
|
||||
close=3505.0 + i,
|
||||
volume=1000.0,
|
||||
timeframe="1s"
|
||||
)
|
||||
sample_bars.append(bar)
|
||||
|
||||
logger.info(f"✅ Created {len(sample_bars)} sample OHLCV bars")
|
||||
|
||||
# Test 3: Create BaseDataInput
|
||||
base_data = BaseDataInput(
|
||||
symbol="ETH/USDT",
|
||||
timestamp=datetime.now(),
|
||||
ohlcv_1s=sample_bars,
|
||||
ohlcv_1m=sample_bars,
|
||||
ohlcv_1h=sample_bars,
|
||||
ohlcv_1d=sample_bars,
|
||||
btc_ohlcv_1s=sample_bars
|
||||
)
|
||||
|
||||
logger.info("✅ BaseDataInput created successfully")
|
||||
|
||||
# Test 4: Validate BaseDataInput
|
||||
is_valid = base_data.validate()
|
||||
logger.info(f"BaseDataInput validation: {is_valid}")
|
||||
|
||||
# Test 5: Get feature vector
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
logger.info(f"✅ Feature vector created, shape: {feature_vector.shape}")
|
||||
|
||||
return base_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ BaseDataInput test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def test_cnn_prediction():
|
||||
"""Test CNN prediction with BaseDataInput"""
|
||||
try:
|
||||
logger.info("Testing CNN prediction...")
|
||||
|
||||
# Get CNN adapter and base data
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
|
||||
base_data = test_base_data_input()
|
||||
if not base_data:
|
||||
logger.error("❌ Cannot test prediction without valid BaseDataInput")
|
||||
return False
|
||||
|
||||
# Test prediction
|
||||
model_output = cnn_adapter.predict(base_data)
|
||||
logger.info(f"✅ Prediction successful: {model_output.predictions['action']} ({model_output.confidence:.3f})")
|
||||
|
||||
# Check if metrics were updated
|
||||
logger.info(f"Inference count after prediction: {cnn_adapter.inference_count}")
|
||||
logger.info(f"Last inference time: {cnn_adapter.last_inference_time}")
|
||||
logger.info(f"Last prediction output: {cnn_adapter.last_prediction_output}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ CNN prediction test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
logger.info("🧪 Starting CNN Integration Tests")
|
||||
|
||||
# Test 1: CNN Adapter
|
||||
if not test_cnn_adapter():
|
||||
logger.error("❌ CNN adapter test failed - stopping")
|
||||
return False
|
||||
|
||||
# Test 2: CNN Prediction
|
||||
if not test_cnn_prediction():
|
||||
logger.error("❌ CNN prediction test failed - stopping")
|
||||
return False
|
||||
|
||||
logger.info("✅ All CNN integration tests passed!")
|
||||
logger.info("🎯 The CNN adapter should now work properly in the dashboard")
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -1,22 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test COB Dashboard with Enhanced WebSocket
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from web.cob_realtime_dashboard import COBDashboardServer
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
async def main():
|
||||
"""Test the COB dashboard"""
|
||||
dashboard = COBDashboardServer(host='localhost', port=8053)
|
||||
await dashboard.start()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,118 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for COB data quality and imbalance indicators
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_cob_data_quality():
|
||||
"""Test COB data quality and imbalance indicators"""
|
||||
logger.info("Testing COB data quality and imbalance indicators...")
|
||||
|
||||
# Initialize data provider
|
||||
dp = DataProvider()
|
||||
|
||||
# Wait for initial data load and COB connection
|
||||
logger.info("Waiting for initial data load and COB connection...")
|
||||
time.sleep(20)
|
||||
|
||||
# Test 1: Check cached data summary
|
||||
logger.info("\n=== Test 1: Cached Data Summary ===")
|
||||
cache_summary = dp.get_cached_data_summary()
|
||||
for symbol in cache_summary['cached_data']:
|
||||
logger.info(f"\n{symbol}:")
|
||||
for timeframe, info in cache_summary['cached_data'][symbol].items():
|
||||
if 'candle_count' in info and info['candle_count'] > 0:
|
||||
logger.info(f" {timeframe}: {info['candle_count']} candles, latest: ${info['latest_price']}")
|
||||
else:
|
||||
logger.info(f" {timeframe}: {info.get('status', 'no data')}")
|
||||
|
||||
# Test 2: Check COB data quality
|
||||
logger.info("\n=== Test 2: COB Data Quality ===")
|
||||
cob_quality = dp.get_cob_data_quality()
|
||||
|
||||
for symbol in cob_quality['symbols']:
|
||||
logger.info(f"\n{symbol} COB Data:")
|
||||
|
||||
# Raw ticks
|
||||
raw_info = cob_quality['raw_ticks'].get(symbol, {})
|
||||
logger.info(f" Raw ticks: {raw_info.get('count', 0)} ticks")
|
||||
if raw_info.get('age_seconds') is not None:
|
||||
logger.info(f" Raw data age: {raw_info['age_seconds']:.1f} seconds")
|
||||
|
||||
# Aggregated 1s data
|
||||
agg_info = cob_quality['aggregated_1s'].get(symbol, {})
|
||||
logger.info(f" Aggregated 1s: {agg_info.get('count', 0)} records")
|
||||
if agg_info.get('age_seconds') is not None:
|
||||
logger.info(f" Aggregated data age: {agg_info['age_seconds']:.1f} seconds")
|
||||
|
||||
# Imbalance indicators
|
||||
imbalance_info = cob_quality['imbalance_indicators'].get(symbol, {})
|
||||
if imbalance_info:
|
||||
logger.info(f" Imbalance 1s: {imbalance_info.get('imbalance_1s', 0):.4f}")
|
||||
logger.info(f" Imbalance 5s: {imbalance_info.get('imbalance_5s', 0):.4f}")
|
||||
logger.info(f" Imbalance 15s: {imbalance_info.get('imbalance_15s', 0):.4f}")
|
||||
logger.info(f" Imbalance 60s: {imbalance_info.get('imbalance_60s', 0):.4f}")
|
||||
logger.info(f" Total volume: {imbalance_info.get('total_volume', 0):.2f}")
|
||||
logger.info(f" Price buckets: {imbalance_info.get('bucket_count', 0)}")
|
||||
|
||||
# Data freshness
|
||||
freshness = cob_quality['data_freshness'].get(symbol, 'unknown')
|
||||
logger.info(f" Data freshness: {freshness}")
|
||||
|
||||
# Test 3: Get recent COB aggregated data
|
||||
logger.info("\n=== Test 3: Recent COB Aggregated Data ===")
|
||||
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
||||
recent_cob = dp.get_cob_1s_aggregated(symbol, count=5)
|
||||
logger.info(f"\n{symbol} - Last 5 aggregated records:")
|
||||
|
||||
for i, record in enumerate(recent_cob[-5:]):
|
||||
timestamp = record.get('timestamp', 0)
|
||||
imbalance_1s = record.get('imbalance_1s', 0)
|
||||
imbalance_5s = record.get('imbalance_5s', 0)
|
||||
total_volume = record.get('total_volume', 0)
|
||||
bucket_count = len(record.get('bid_buckets', {})) + len(record.get('ask_buckets', {}))
|
||||
|
||||
logger.info(f" [{i+1}] Time: {timestamp}, Imb1s: {imbalance_1s:.4f}, "
|
||||
f"Imb5s: {imbalance_5s:.4f}, Vol: {total_volume:.2f}, Buckets: {bucket_count}")
|
||||
|
||||
# Test 4: Monitor real-time updates
|
||||
logger.info("\n=== Test 4: Real-time Updates (30 seconds) ===")
|
||||
logger.info("Monitoring COB data updates...")
|
||||
|
||||
initial_quality = dp.get_cob_data_quality()
|
||||
time.sleep(30)
|
||||
updated_quality = dp.get_cob_data_quality()
|
||||
|
||||
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
||||
initial_count = initial_quality['raw_ticks'].get(symbol, {}).get('count', 0)
|
||||
updated_count = updated_quality['raw_ticks'].get(symbol, {}).get('count', 0)
|
||||
new_ticks = updated_count - initial_count
|
||||
|
||||
initial_agg = initial_quality['aggregated_1s'].get(symbol, {}).get('count', 0)
|
||||
updated_agg = updated_quality['aggregated_1s'].get(symbol, {}).get('count', 0)
|
||||
new_agg = updated_agg - initial_agg
|
||||
|
||||
logger.info(f"{symbol}: +{new_ticks} raw ticks, +{new_agg} aggregated records")
|
||||
|
||||
# Show latest imbalances
|
||||
latest_imbalances = updated_quality['imbalance_indicators'].get(symbol, {})
|
||||
if latest_imbalances:
|
||||
logger.info(f" Latest imbalances: 1s={latest_imbalances.get('imbalance_1s', 0):.4f}, "
|
||||
f"5s={latest_imbalances.get('imbalance_5s', 0):.4f}, "
|
||||
f"15s={latest_imbalances.get('imbalance_15s', 0):.4f}, "
|
||||
f"60s={latest_imbalances.get('imbalance_60s', 0):.4f}")
|
||||
|
||||
# Clean shutdown
|
||||
logger.info("\n=== Shutting Down ===")
|
||||
dp.stop_automatic_data_maintenance()
|
||||
logger.info("COB data quality test completed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cob_data_quality()
|
||||
@@ -1,131 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test COB WebSocket Only Integration
|
||||
|
||||
This script tests that COB integration works with Enhanced WebSocket only,
|
||||
without falling back to REST API calls.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
from core.cob_integration import COBIntegration
|
||||
|
||||
async def test_cob_websocket_only():
|
||||
"""Test COB integration with WebSocket only"""
|
||||
print("=== Testing COB WebSocket Only Integration ===")
|
||||
|
||||
# Initialize COB integration
|
||||
print("1. Initializing COB integration...")
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
cob_integration = COBIntegration(symbols=symbols)
|
||||
|
||||
# Track updates
|
||||
update_count = 0
|
||||
last_update_time = None
|
||||
|
||||
def dashboard_callback(symbol: str, data: Dict):
|
||||
nonlocal update_count, last_update_time
|
||||
update_count += 1
|
||||
last_update_time = datetime.now()
|
||||
|
||||
if update_count <= 5: # Show first 5 updates
|
||||
data_type = data.get('type', 'unknown')
|
||||
if data_type == 'cob_update':
|
||||
stats = data.get('data', {}).get('stats', {})
|
||||
mid_price = stats.get('mid_price', 0)
|
||||
spread_bps = stats.get('spread_bps', 0)
|
||||
source = stats.get('source', 'unknown')
|
||||
print(f" Update #{update_count}: {symbol} - Price: ${mid_price:.2f}, Spread: {spread_bps:.1f}bps, Source: {source}")
|
||||
elif data_type == 'websocket_status':
|
||||
status_data = data.get('data', {})
|
||||
status = status_data.get('status', 'unknown')
|
||||
print(f" Status #{update_count}: {symbol} - WebSocket: {status}")
|
||||
|
||||
# Add dashboard callback
|
||||
cob_integration.add_dashboard_callback(dashboard_callback)
|
||||
|
||||
# Start COB integration
|
||||
print("2. Starting COB integration...")
|
||||
try:
|
||||
# Start in background
|
||||
start_task = asyncio.create_task(cob_integration.start())
|
||||
|
||||
# Wait for initialization
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# Check if COB provider is disabled
|
||||
print("3. Checking COB provider status:")
|
||||
if cob_integration.cob_provider is None:
|
||||
print(" ✅ COB provider is disabled (using Enhanced WebSocket only)")
|
||||
else:
|
||||
print(" ❌ COB provider is still active (may cause REST API fallback)")
|
||||
|
||||
# Check Enhanced WebSocket status
|
||||
print("4. Checking Enhanced WebSocket status:")
|
||||
if cob_integration.enhanced_websocket:
|
||||
print(" ✅ Enhanced WebSocket is initialized")
|
||||
|
||||
# Check WebSocket status for each symbol
|
||||
websocket_status = cob_integration.get_websocket_status()
|
||||
for symbol, status in websocket_status.items():
|
||||
print(f" {symbol}: {status}")
|
||||
else:
|
||||
print(" ❌ Enhanced WebSocket is not initialized")
|
||||
|
||||
# Monitor updates for a few seconds
|
||||
print("5. Monitoring COB updates...")
|
||||
initial_count = update_count
|
||||
monitor_start = time.time()
|
||||
|
||||
# Wait for updates
|
||||
await asyncio.sleep(5)
|
||||
|
||||
monitor_duration = time.time() - monitor_start
|
||||
updates_received = update_count - initial_count
|
||||
update_rate = updates_received / monitor_duration
|
||||
|
||||
print(f" Received {updates_received} updates in {monitor_duration:.1f}s")
|
||||
print(f" Update rate: {update_rate:.1f} updates/second")
|
||||
|
||||
if update_rate >= 8: # Should be around 10 updates/second
|
||||
print(" ✅ Update rate is excellent (8+ updates/second)")
|
||||
elif update_rate >= 5:
|
||||
print(" ✅ Update rate is good (5+ updates/second)")
|
||||
elif update_rate >= 1:
|
||||
print(" ⚠️ Update rate is low (1+ updates/second)")
|
||||
else:
|
||||
print(" ❌ Update rate is too low (<1 update/second)")
|
||||
|
||||
# Check data quality
|
||||
print("6. Data quality check:")
|
||||
if last_update_time:
|
||||
time_since_last = (datetime.now() - last_update_time).total_seconds()
|
||||
if time_since_last < 1:
|
||||
print(f" ✅ Recent data (last update {time_since_last:.1f}s ago)")
|
||||
else:
|
||||
print(f" ⚠️ Stale data (last update {time_since_last:.1f}s ago)")
|
||||
else:
|
||||
print(" ❌ No updates received")
|
||||
|
||||
# Stop the integration
|
||||
print("7. Stopping COB integration...")
|
||||
await cob_integration.stop()
|
||||
|
||||
# Cancel the start task
|
||||
start_task.cancel()
|
||||
try:
|
||||
await start_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Error during COB integration test: {e}")
|
||||
|
||||
print(f"\n✅ COB WebSocket only test completed!")
|
||||
print(f"Total updates received: {update_count}")
|
||||
print("Enhanced WebSocket is now the sole data source (no REST API fallback)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_cob_websocket_only())
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user