diff --git a/CLEANUP_TODO.md b/CLEANUP_TODO.md new file mode 100644 index 0000000..00ef5e6 --- /dev/null +++ b/CLEANUP_TODO.md @@ -0,0 +1,56 @@ +Cleanup run summary: +- Deleted files: 183 + - NN\__init__.py + - NN\models\__init__.py + - NN\models\cnn_model.py + - NN\models\transformer_model.py + - NN\start_tensorboard.py + - NN\training\enhanced_rl_training_integration.py + - NN\training\example_checkpoint_usage.py + - NN\training\integrate_checkpoint_management.py + - NN\utils\__init__.py + - NN\utils\data_interface.py + - NN\utils\multi_data_interface.py + - NN\utils\realtime_analyzer.py + - NN\utils\signal_interpreter.py + - NN\utils\trading_env.py + - _dev\cleanup_models_now.py + - _tools\build_keep_set.py + - apply_trading_fixes.py + - apply_trading_fixes_to_main.py + - audit_training_system.py + - balance_trading_signals.py + - check_live_trading.py + - check_mexc_symbols.py + - cleanup_checkpoint_db.py + - cleanup_checkpoints.py + - core\__init__.py + - core\api_rate_limiter.py + - core\async_handler.py + - core\bookmap_data_provider.py + - core\bookmap_integration.py + - core\cnn_monitor.py + - core\cnn_training_pipeline.py + - core\config_sync.py + - core\enhanced_cnn_adapter.py + - core\enhanced_cob_websocket.py + - core\enhanced_orchestrator.py + - core\enhanced_training_integration.py + - core\exchanges\__init__.py + - core\exchanges\binance_interface.py + - core\exchanges\bybit\debug\test_bybit_balance.py + - core\exchanges\bybit_interface.py + - core\exchanges\bybit_rest_client.py + - core\exchanges\deribit_interface.py + - core\exchanges\mexc\debug\final_mexc_order_test.py + - core\exchanges\mexc\debug\fix_mexc_orders.py + - core\exchanges\mexc\debug\fix_mexc_orders_v2.py + - core\exchanges\mexc\debug\fix_mexc_orders_v3.py + - core\exchanges\mexc\debug\test_mexc_interface_debug.py + - core\exchanges\mexc\debug\test_mexc_order_signature.py + - core\exchanges\mexc\debug\test_mexc_order_signature_v2.py + - core\exchanges\mexc\debug\test_mexc_signature_debug.py + ... and 133 more +- Removed test directories: 1 + - tests +- Kept (excluded): 1 \ No newline at end of file diff --git a/DELETE_CANDIDATES.txt b/DELETE_CANDIDATES.txt new file mode 100644 index 0000000..3e2673b --- /dev/null +++ b/DELETE_CANDIDATES.txt @@ -0,0 +1,184 @@ +NN\__init__.py +NN\models\__init__.py +NN\models\cnn_model.py +NN\models\transformer_model.py +NN\start_tensorboard.py +NN\training\enhanced_realtime_training.py +NN\training\enhanced_rl_training_integration.py +NN\training\example_checkpoint_usage.py +NN\training\integrate_checkpoint_management.py +NN\utils\__init__.py +NN\utils\data_interface.py +NN\utils\multi_data_interface.py +NN\utils\realtime_analyzer.py +NN\utils\signal_interpreter.py +NN\utils\trading_env.py +_dev\cleanup_models_now.py +_tools\build_keep_set.py +apply_trading_fixes.py +apply_trading_fixes_to_main.py +audit_training_system.py +balance_trading_signals.py +check_live_trading.py +check_mexc_symbols.py +cleanup_checkpoint_db.py +cleanup_checkpoints.py +core\__init__.py +core\api_rate_limiter.py +core\async_handler.py +core\bookmap_data_provider.py +core\bookmap_integration.py +core\cnn_monitor.py +core\cnn_training_pipeline.py +core\config_sync.py +core\enhanced_cnn_adapter.py +core\enhanced_cob_websocket.py +core\enhanced_orchestrator.py +core\enhanced_training_integration.py +core\exchanges\__init__.py +core\exchanges\binance_interface.py +core\exchanges\bybit\debug\test_bybit_balance.py +core\exchanges\bybit_interface.py +core\exchanges\bybit_rest_client.py +core\exchanges\deribit_interface.py +core\exchanges\mexc\debug\final_mexc_order_test.py +core\exchanges\mexc\debug\fix_mexc_orders.py +core\exchanges\mexc\debug\fix_mexc_orders_v2.py +core\exchanges\mexc\debug\fix_mexc_orders_v3.py +core\exchanges\mexc\debug\test_mexc_interface_debug.py +core\exchanges\mexc\debug\test_mexc_order_signature.py +core\exchanges\mexc\debug\test_mexc_order_signature_v2.py +core\exchanges\mexc\debug\test_mexc_signature_debug.py +core\exchanges\mexc\debug\test_small_mexc_order.py +core\exchanges\mexc\test_live_trading.py +core\exchanges\mexc_interface.py +core\exchanges\trading_agent_test.py +core\mexc_webclient\__init__.py +core\mexc_webclient\auto_browser.py +core\mexc_webclient\browser_automation.py +core\mexc_webclient\mexc_futures_client.py +core\mexc_webclient\session_manager.py +core\mexc_webclient\test_mexc_futures_webclient.py +core\model_output_manager.py +core\negative_case_trainer.py +core\nn_decision_fusion.py +core\prediction_tracker.py +core\realtime_tick_processor.py +core\retrospective_trainer.py +core\rl_training_pipeline.py +core\robust_cob_provider.py +core\shared_cob_service.py +core\shared_data_manager.py +core\tick_aggregator.py +core\trading_action.py +core\trading_executor_fix.py +core\training_data_collector.py +core\williams_market_structure.py +dataprovider_realtime.py +debug\test_fixed_issues.py +debug\test_trading_fixes.py +debug\trade_audit.py +debug_training_methods.py +docs\exchanges\bybit\examples.py +example_usage_simplified_data_provider.py +kill_stale_processes.py +launch_training.py +main.py +main_clean.py +migrate_existing_models.py +model_manager.py +position_sync_enhancement.py +read_logs.py +reset_db_manager.py +reset_models_and_fix_mapping.py +run_clean_dashboard.py +run_continuous_training.py +run_crash_safe_dashboard.py +run_enhanced_rl_training.py +run_enhanced_training_dashboard.py +run_integrated_rl_cob_dashboard.py +run_mexc_browser.py +run_optimized_cob_system.py +run_realtime_rl_cob_trader.py +run_simple_dashboard.py +run_stable_dashboard.py +run_templated_dashboard.py +run_tensorboard.py +run_tests.py +scripts\kill_stale_processes.py +scripts\restart_dashboard_with_learning.py +scripts\restart_main_overnight.py +setup_mexc_browser.py +start_monitoring.py +start_overnight_training.py +system_stability_audit.py +test_build_base_data_performance.py +test_bybit_eth_futures.py +test_bybit_eth_futures_fixed.py +test_bybit_eth_live.py +test_bybit_public_api.py +test_cache_fix.py +test_cnn_integration.py +test_cob_dashboard.py +test_cob_data_quality.py +test_cob_websocket_only.py +test_continuous_cnn_training.py +test_dashboard_data_flow.py +test_dashboard_performance.py +test_data_integration.py +test_data_provider_integration.py +test_db_migration.py +test_deribit_integration.py +test_device_fix.py +test_device_training_fix.py +test_enhanced_cnn_adapter.py +test_enhanced_cob_websocket.py +test_enhanced_data_provider_websocket.py +test_enhanced_inference_logging.py +test_enhanced_training_integration.py +test_enhanced_training_simple.py +test_fifo_queues.py +test_hold_position_fix.py +test_imbalance_calculation.py +test_improved_data_integration.py +test_integrated_standardized_provider.py +test_leverage_fix.py +test_massive_dqn.py +test_mexc_order_fix.py +test_model_output_manager.py +test_model_registry.py +test_model_statistics.py +test_model_stats.py +test_model_training.py +test_orchestrator_fix.py +test_order_sync_and_fees.py +test_position_based_rewards.py +test_profitability_reward_system.py +test_training_data_collection.py +test_training_fixes.py +test_websocket_cob_data.py +tests\cob\test_cob_comparison.py +tests\cob\test_cob_data_stability.py +tests\test_training.py +tests\test_training_integration.py +tests\test_training_status.py +tests\test_universal_data_format.py +tests\test_universal_stream_integration.py +trading_main.py +utils\__init__.py +utils\async_task_manager.py +utils\launch_tensorboard.py +utils\model_utils.py +utils\port_manager.py +utils\process_supervisor.py +utils\reward_calculator.py +utils\system_monitor.py +utils\tensorboard_logger.py +utils\text_logger.py +verify_checkpoint_system.py +web\__init__.py +web\dashboard_fix.py +web\dashboard_model.py +web\layout_manager_with_tensorboard.py +web\tensorboard_component.py +web\tensorboard_integration.py \ No newline at end of file diff --git a/DEPENDENCY_TREE.md b/DEPENDENCY_TREE.md new file mode 100644 index 0000000..6872583 --- /dev/null +++ b/DEPENDENCY_TREE.md @@ -0,0 +1,84 @@ +Dependency tree from dashboards (module -> deps): +- NN\models\advanced_transformer_trading.py +- NN\models\cob_rl_model.py + - models\__init__.py +- NN\models\dqn_agent.py + - utils\checkpoint_manager.py + - utils\training_integration.py +- NN\models\enhanced_cnn.py +- NN\models\model_interfaces.py +- NN\models\standardized_cnn.py + - core\data_models.py +- core\cob_integration.py +- core\config.py + - safe_logging.py +- core\data_models.py +- core\data_provider.py + - utils\cache_manager.py + - utils\timezone_utils.py +- core\exchanges\exchange_factory.py +- core\exchanges\exchange_interface.py +- core\extrema_trainer.py + - utils\checkpoint_manager.py + - utils\training_integration.py +- core\multi_exchange_cob_provider.py +- core\orchestrator.py + - NN\models\advanced_transformer_trading.py + - NN\models\cob_rl_model.py + - NN\models\dqn_agent.py + - NN\models\enhanced_cnn.py + - NN\models\model_interfaces.py + - NN\models\standardized_cnn.py + - core\data_models.py + - core\extrema_trainer.py + - enhanced_realtime_training.py + - models\__init__.py + - utils\checkpoint_manager.py + - utils\database_manager.py + - utils\inference_logger.py +- core\overnight_training_coordinator.py +- core\realtime_rl_cob_trader.py + - NN\models\cob_rl_model.py + - core\trading_executor.py + - utils\checkpoint_manager.py +- core\standardized_data_provider.py +- core\trade_data_manager.py +- core\trading_executor.py + - core\data_provider.py + - core\exchanges\exchange_factory.py + - core\exchanges\exchange_interface.py +- core\training_integration.py +- core\universal_data_adapter.py +- enhanced_realtime_training.py +- models\__init__.py +- safe_logging.py +- utils\cache_manager.py +- utils\checkpoint_manager.py +- utils\database_manager.py +- utils\inference_logger.py +- utils\timezone_utils.py +- utils\training_integration.py +- web\clean_dashboard.py + - NN\models\advanced_transformer_trading.py + - NN\models\standardized_cnn.py + - core\cob_integration.py + - core\config.py + - core\data_models.py + - core\data_provider.py + - core\multi_exchange_cob_provider.py + - core\orchestrator.py + - core\overnight_training_coordinator.py + - core\realtime_rl_cob_trader.py + - core\standardized_data_provider.py + - core\trade_data_manager.py + - core\trading_executor.py + - core\training_integration.py + - core\universal_data_adapter.py + - utils\checkpoint_manager.py + - utils\timezone_utils.py + - web\component_manager.py + - web\layout_manager.py +- web\cob_realtime_dashboard.py + - core\cob_integration.py +- web\component_manager.py +- web\layout_manager.py \ No newline at end of file diff --git a/KEEP_SET.txt b/KEEP_SET.txt new file mode 100644 index 0000000..2486020 --- /dev/null +++ b/KEEP_SET.txt @@ -0,0 +1,35 @@ +NN\models\advanced_transformer_trading.py +NN\models\cob_rl_model.py +NN\models\dqn_agent.py +NN\models\enhanced_cnn.py +NN\models\model_interfaces.py +NN\models\standardized_cnn.py +core\cob_integration.py +core\config.py +core\data_models.py +core\data_provider.py +core\exchanges\exchange_factory.py +core\exchanges\exchange_interface.py +core\extrema_trainer.py +core\multi_exchange_cob_provider.py +core\orchestrator.py +core\overnight_training_coordinator.py +core\realtime_rl_cob_trader.py +core\standardized_data_provider.py +core\trade_data_manager.py +core\trading_executor.py +core\training_integration.py +core\universal_data_adapter.py +enhanced_realtime_training.py +models\__init__.py +safe_logging.py +utils\cache_manager.py +utils\checkpoint_manager.py +utils\database_manager.py +utils\inference_logger.py +utils\timezone_utils.py +utils\training_integration.py +web\clean_dashboard.py +web\cob_realtime_dashboard.py +web\component_manager.py +web\layout_manager.py \ No newline at end of file diff --git a/NN/__init__.py b/NN/__init__.py deleted file mode 100644 index 2622416..0000000 --- a/NN/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Neural Network Trading System -============================ - -A comprehensive neural network trading system that uses deep learning models -to analyze cryptocurrency price data and generate trading signals. - -The system consists of: -1. Data Interface: Connects to realtime trading data -2. CNN Model: Deep convolutional neural network for feature extraction -3. Transformer Model: Processes high-level features for improved pattern recognition -4. MoE: Mixture of Experts model that combines multiple neural networks -""" - -__version__ = '0.1.0' -__author__ = 'Gogo2 Project' \ No newline at end of file diff --git a/NN/models/__init__.py b/NN/models/__init__.py deleted file mode 100644 index df25097..0000000 --- a/NN/models/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -Neural Network Models -==================== - -This package contains the neural network models used in the trading system: -- CNN Model: Deep convolutional neural network for feature extraction -- DQN Agent: Deep Q-Network for reinforcement learning -- COB RL Model: Specialized RL model for order book data -- Advanced Transformer: High-performance transformer for trading - -PyTorch implementation only. -""" - -# Import core models -from NN.models.dqn_agent import DQNAgent -from NN.models.cob_rl_model import COBRLModelInterface -from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig -from NN.models.standardized_cnn import StandardizedCNN # Use the unified CNN model - -# Import model interfaces -from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface - -# Export the unified StandardizedCNN as CNNModel for compatibility -CNNModel = StandardizedCNN - -__all__ = ['CNNModel', 'StandardizedCNN', 'DQNAgent', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig', -'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface'] diff --git a/NN/models/cnn_model.py b/NN/models/cnn_model.py deleted file mode 100644 index 4b34bb0..0000000 --- a/NN/models/cnn_model.py +++ /dev/null @@ -1,201 +0,0 @@ -# """ -# Legacy CNN Model Compatibility Layer - -# This module provides compatibility redirects to the unified StandardizedCNN model. -# All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired -# in favor of the StandardizedCNN architecture. -# """ - -# import logging -# import warnings -# from typing import Tuple, Dict, Any, Optional -# import torch -# import numpy as np - -# # Import the standardized CNN model -# from .standardized_cnn import StandardizedCNN - -# logger = logging.getLogger(__name__) - -# # Compatibility aliases and wrappers -# class EnhancedCNNModel: -# """Legacy compatibility wrapper - redirects to StandardizedCNN""" - -# def __init__(self, *args, **kwargs): -# warnings.warn( -# "EnhancedCNNModel is deprecated. Use StandardizedCNN instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# # Create StandardizedCNN with default parameters -# self.standardized_cnn = StandardizedCNN() -# logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN") - -# def __getattr__(self, name): -# """Delegate all method calls to StandardizedCNN""" -# return getattr(self.standardized_cnn, name) - - -# class CNNModelTrainer: -# """Legacy compatibility wrapper for CNN training""" - -# def __init__(self, model=None, *args, **kwargs): -# warnings.warn( -# "CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# if isinstance(model, EnhancedCNNModel): -# self.model = model.standardized_cnn -# else: -# self.model = StandardizedCNN() -# logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()") - -# def train_step(self, x, y, *args, **kwargs): -# """Legacy train step wrapper""" -# try: -# # Convert to BaseDataInput format if needed -# if hasattr(x, 'get_feature_vector'): -# # Already BaseDataInput -# base_input = x -# else: -# # Create mock BaseDataInput for legacy compatibility -# from core.data_models import BaseDataInput -# base_input = BaseDataInput() -# # Set mock feature vector -# if isinstance(x, torch.Tensor): -# feature_vector = x.flatten().cpu().numpy() -# else: -# feature_vector = np.array(x).flatten() - -# # Pad or truncate to expected size -# expected_size = self.model.expected_feature_dim -# if len(feature_vector) < expected_size: -# padding = np.zeros(expected_size - len(feature_vector)) -# feature_vector = np.concatenate([feature_vector, padding]) -# else: -# feature_vector = feature_vector[:expected_size] - -# base_input._feature_vector = feature_vector - -# # Convert target to string format -# if isinstance(y, torch.Tensor): -# y_val = y.item() if y.numel() == 1 else y.argmax().item() -# else: -# y_val = int(y) if np.isscalar(y) else int(np.argmax(y)) - -# target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'} -# target = target_map.get(y_val, 'HOLD') - -# # Use StandardizedCNN training -# optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) -# loss = self.model.train_step([base_input], [target], optimizer) - -# return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5} - -# except Exception as e: -# logger.error(f"Legacy train_step error: {e}") -# return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5} - - -# # class CNNModel: -# # """Legacy compatibility wrapper for CNN model interface""" - -# # def __init__(self, input_shape=(900, 50), output_size=3, model_path=None): -# # warnings.warn( -# # "CNNModel is deprecated. Use StandardizedCNN directly.", -# # DeprecationWarning, -# # stacklevel=2 -# # ) -# # self.input_shape = input_shape -# # self.output_size = output_size -# # self.standardized_cnn = StandardizedCNN() -# # self.trainer = CNNModelTrainer(self.standardized_cnn) -# # logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN") - -# # def build_model(self, **kwargs): -# # """Legacy build method - no-op for StandardizedCNN""" -# # return self - -# # def predict(self, X): -# # """Legacy predict method""" -# # try: -# # # Convert input to BaseDataInput -# # from core.data_models import BaseDataInput -# # base_input = BaseDataInput() - -# # if isinstance(X, np.ndarray): -# # feature_vector = X.flatten() -# # else: -# # feature_vector = np.array(X).flatten() - -# # # Pad or truncate to expected size -# # expected_size = self.standardized_cnn.expected_feature_dim -# # if len(feature_vector) < expected_size: -# # padding = np.zeros(expected_size - len(feature_vector)) -# # feature_vector = np.concatenate([feature_vector, padding]) -# # else: -# # feature_vector = feature_vector[:expected_size] - -# # base_input._feature_vector = feature_vector - -# # # Get prediction from StandardizedCNN -# # result = self.standardized_cnn.predict_from_base_input(base_input) - -# # # Convert to legacy format -# # action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2} -# # pred_class = np.array([action_map.get(result.predictions['action'], 2)]) -# # pred_proba = np.array([result.predictions['action_probabilities']]) - -# # return pred_class, pred_proba - -# # except Exception as e: -# # logger.error(f"Legacy predict error: {e}") -# # # Return safe defaults -# # pred_class = np.array([2]) # HOLD -# # pred_proba = np.array([[0.33, 0.33, 0.34]]) -# # return pred_class, pred_proba - -# # def fit(self, X, y, **kwargs): -# # """Legacy fit method""" -# # try: -# # return self.trainer.train_step(X, y) -# # except Exception as e: -# # logger.error(f"Legacy fit error: {e}") -# # return self - -# # def save(self, filepath: str): -# # """Legacy save method""" -# # try: -# # torch.save(self.standardized_cnn.state_dict(), filepath) -# # logger.info(f"StandardizedCNN saved to {filepath}") -# # except Exception as e: -# # logger.error(f"Error saving model: {e}") - - -# def create_enhanced_cnn_model(input_size: int = 60, -# feature_dim: int = 50, -# output_size: int = 3, -# base_channels: int = 256, -# device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]: -# """Legacy compatibility function - returns StandardizedCNN""" -# warnings.warn( -# "create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.", -# DeprecationWarning, -# stacklevel=2 -# ) - -# model = StandardizedCNN() -# trainer = CNNModelTrainer(model) - -# logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly") -# return model, trainer - - -# # Export compatibility symbols -# __all__ = [ -# 'EnhancedCNNModel', -# 'CNNModelTrainer', -# # 'CNNModel', -# 'create_enhanced_cnn_model' -# ] diff --git a/NN/models/transformer_model.py b/NN/models/transformer_model.py deleted file mode 100644 index 16700b3..0000000 --- a/NN/models/transformer_model.py +++ /dev/null @@ -1,821 +0,0 @@ -""" -Transformer Neural Network for timeseries analysis - -This module implements a Transformer model with attention mechanisms for cryptocurrency price analysis. -It also includes a Mixture of Experts model that combines predictions from multiple models. -""" - -import os -import logging -import numpy as np -import matplotlib.pyplot as plt -import tensorflow as tf -from tensorflow.keras.models import Model, load_model -from tensorflow.keras.layers import ( - Input, Dense, Dropout, BatchNormalization, - Concatenate, Layer, LayerNormalization, MultiHeadAttention, - Add, GlobalAveragePooling1D, Conv1D, Reshape -) -from tensorflow.keras.optimizers import Adam -from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau -import datetime -import json - -logger = logging.getLogger(__name__) - -class TransformerBlock(Layer): - """ - Transformer block implementation with multi-head attention and feed-forward networks. - """ - def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1): - super(TransformerBlock, self).__init__() - self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) - self.ffn = tf.keras.Sequential([ - Dense(ff_dim, activation="relu"), - Dense(embed_dim), - ]) - self.layernorm1 = LayerNormalization(epsilon=1e-6) - self.layernorm2 = LayerNormalization(epsilon=1e-6) - self.dropout1 = Dropout(rate) - self.dropout2 = Dropout(rate) - - def call(self, inputs, training=False): - attn_output = self.att(inputs, inputs) - attn_output = self.dropout1(attn_output, training=training) - out1 = self.layernorm1(inputs + attn_output) - ffn_output = self.ffn(out1) - ffn_output = self.dropout2(ffn_output, training=training) - return self.layernorm2(out1 + ffn_output) - - def get_config(self): - config = super().get_config() - config.update({ - 'att': self.att, - 'ffn': self.ffn, - 'layernorm1': self.layernorm1, - 'layernorm2': self.layernorm2, - 'dropout1': self.dropout1, - 'dropout2': self.dropout2 - }) - return config - -class PositionalEncoding(Layer): - """ - Positional encoding layer to add position information to input embeddings. - """ - def __init__(self, position, d_model): - super(PositionalEncoding, self).__init__() - self.position = position - self.d_model = d_model - self.pos_encoding = self.positional_encoding(position, d_model) - - def get_angles(self, position, i, d_model): - angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32)) - return position * angles - - def positional_encoding(self, position, d_model): - angle_rads = self.get_angles( - position=tf.range(position, dtype=tf.float32)[:, tf.newaxis], - i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :], - d_model=d_model - ) - - # Apply sin to even indices in the array - sines = tf.math.sin(angle_rads[:, 0::2]) - - # Apply cos to odd indices in the array - cosines = tf.math.cos(angle_rads[:, 1::2]) - - pos_encoding = tf.concat([sines, cosines], axis=-1) - pos_encoding = pos_encoding[tf.newaxis, ...] - - return tf.cast(pos_encoding, tf.float32) - - def call(self, inputs): - return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :] - - def get_config(self): - config = super().get_config() - config.update({ - 'position': self.position, - 'd_model': self.d_model, - 'pos_encoding': self.pos_encoding - }) - return config - -class TransformerModel: - """ - Transformer Neural Network for time series analysis. - - This model uses self-attention mechanisms to capture relationships between - different time points in the input data. - """ - - def __init__(self, ts_input_shape=(20, 5), feature_input_shape=64, output_size=1, model_dir="NN/models/saved"): - """ - Initialize the Transformer model. - - Args: - ts_input_shape (tuple): Shape of time series input data (sequence_length, features) - feature_input_shape (int): Shape of additional feature input (e.g., from CNN) - output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell) - model_dir (str): Directory to save trained models - """ - self.ts_input_shape = ts_input_shape - self.feature_input_shape = feature_input_shape - self.output_size = output_size - self.model_dir = model_dir - self.model = None - self.history = None - - # Create model directory if it doesn't exist - os.makedirs(self.model_dir, exist_ok=True) - - logger.info(f"Initialized Transformer model with TS input shape {ts_input_shape}, " - f"feature input shape {feature_input_shape}, and output size {output_size}") - - def build_model(self, embed_dim=32, num_heads=4, ff_dim=64, num_transformer_blocks=2, dropout_rate=0.1, learning_rate=0.001): - """ - Build the Transformer model architecture. - - Args: - embed_dim (int): Embedding dimension for transformer - num_heads (int): Number of attention heads - ff_dim (int): Hidden dimension of the feed forward network - num_transformer_blocks (int): Number of transformer blocks - dropout_rate (float): Dropout rate for regularization - learning_rate (float): Learning rate for Adam optimizer - - Returns: - The compiled model - """ - # Time series input - ts_inputs = Input(shape=self.ts_input_shape, name="ts_input") - - # Additional feature input (e.g., from CNN) - feature_inputs = Input(shape=(self.feature_input_shape,), name="feature_input") - - # Process time series with transformer - # First, project the input to the embedding dimension - x = Conv1D(embed_dim, 1, activation="relu")(ts_inputs) - - # Add positional encoding - x = PositionalEncoding(self.ts_input_shape[0], embed_dim)(x) - - # Add transformer blocks - for _ in range(num_transformer_blocks): - x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate)(x) - - # Global pooling to get a single vector representation - x = GlobalAveragePooling1D()(x) - x = Dropout(dropout_rate)(x) - - # Combine with additional features - combined = Concatenate()([x, feature_inputs]) - - # Dense layers for final classification/regression - x = Dense(64, activation="relu")(combined) - x = BatchNormalization()(x) - x = Dropout(dropout_rate)(x) - - # Output layer - if self.output_size == 1: - # Binary classification (up/down) - outputs = Dense(1, activation='sigmoid', name='output')(x) - loss = 'binary_crossentropy' - metrics = ['accuracy'] - elif self.output_size == 3: - # Multi-class classification (buy/hold/sell) - outputs = Dense(3, activation='softmax', name='output')(x) - loss = 'categorical_crossentropy' - metrics = ['accuracy'] - else: - # Regression - outputs = Dense(self.output_size, activation='linear', name='output')(x) - loss = 'mse' - metrics = ['mae'] - - # Create and compile model - self.model = Model(inputs=[ts_inputs, feature_inputs], outputs=outputs) - - # Compile with Adam optimizer - self.model.compile( - optimizer=Adam(learning_rate=learning_rate), - loss=loss, - metrics=metrics - ) - - # Log model summary - self.model.summary(print_fn=lambda x: logger.info(x)) - - return self.model - - def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2, - callbacks=None, class_weights=None): - """ - Train the Transformer model on the provided data. - - Args: - X_ts (numpy.ndarray): Time series input features - X_features (numpy.ndarray): Additional input features - y (numpy.ndarray): Target labels - batch_size (int): Batch size - epochs (int): Number of epochs - validation_split (float): Fraction of data to use for validation - callbacks (list): List of Keras callbacks - class_weights (dict): Class weights for imbalanced datasets - - Returns: - History object containing training metrics - """ - if self.model is None: - self.build_model() - - # Default callbacks if none provided - if callbacks is None: - # Create a timestamp for model checkpoints - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - - callbacks = [ - EarlyStopping( - monitor='val_loss', - patience=10, - restore_best_weights=True - ), - ReduceLROnPlateau( - monitor='val_loss', - factor=0.5, - patience=5, - min_lr=1e-6 - ), - ModelCheckpoint( - filepath=os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5"), - monitor='val_loss', - save_best_only=True - ) - ] - - # Check if y needs to be one-hot encoded for multi-class - if self.output_size == 3 and len(y.shape) == 1: - y = tf.keras.utils.to_categorical(y, num_classes=3) - - # Train the model - logger.info(f"Training Transformer model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}") - self.history = self.model.fit( - [X_ts, X_features], y, - batch_size=batch_size, - epochs=epochs, - validation_split=validation_split, - callbacks=callbacks, - class_weight=class_weights, - verbose=2 - ) - - # Save the trained model - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - model_path = os.path.join(self.model_dir, f"transformer_model_final_{timestamp}.h5") - self.model.save(model_path) - logger.info(f"Model saved to {model_path}") - - # Save training history - history_path = os.path.join(self.model_dir, f"transformer_model_history_{timestamp}.json") - with open(history_path, 'w') as f: - # Convert numpy values to Python native types for JSON serialization - history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()} - json.dump(history_dict, f, indent=2) - - return self.history - - def evaluate(self, X_ts, X_features, y): - """ - Evaluate the model on test data. - - Args: - X_ts (numpy.ndarray): Time series input features - X_features (numpy.ndarray): Additional input features - y (numpy.ndarray): Target labels - - Returns: - dict: Evaluation metrics - """ - if self.model is None: - raise ValueError("Model has not been built or trained yet") - - # Convert y to one-hot encoding for multi-class - if self.output_size == 3 and len(y.shape) == 1: - y = tf.keras.utils.to_categorical(y, num_classes=3) - - # Evaluate model - logger.info(f"Evaluating Transformer model on {len(X_ts)} samples") - eval_results = self.model.evaluate([X_ts, X_features], y, verbose=0) - - metrics = {} - for metric, value in zip(self.model.metrics_names, eval_results): - metrics[metric] = value - logger.info(f"{metric}: {value:.4f}") - - return metrics - - def predict(self, X_ts, X_features=None): - """ - Make predictions on new data. - - Args: - X_ts (numpy.ndarray): Time series input features - X_features (numpy.ndarray): Additional input features - - Returns: - tuple: (y_pred, y_proba) where: - y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class) - y_proba is the class probability - """ - if self.model is None: - raise ValueError("Model has not been built or trained yet") - - # Ensure X_ts has the right shape - if len(X_ts.shape) == 2: - # Single sample, add batch dimension - X_ts = np.expand_dims(X_ts, axis=0) - - # Ensure X_features has the right shape - if X_features is None: - # Extract features from time series data if no external features provided - X_features = self._extract_features_from_timeseries(X_ts) - elif len(X_features.shape) == 1: - # Single sample, add batch dimension - X_features = np.expand_dims(X_features, axis=0) - - def _extract_features_from_timeseries(self, X_ts: np.ndarray) -> np.ndarray: - """Extract meaningful features from time series data instead of using dummy zeros""" - try: - batch_size = X_ts.shape[0] - features = [] - - for i in range(batch_size): - sample = X_ts[i] # Shape: (timesteps, features) - - # Extract statistical features from each feature dimension - sample_features = [] - - for feature_idx in range(sample.shape[1]): - feature_data = sample[:, feature_idx] - - # Basic statistical features - sample_features.extend([ - np.mean(feature_data), # Mean - np.std(feature_data), # Standard deviation - np.min(feature_data), # Minimum - np.max(feature_data), # Maximum - np.percentile(feature_data, 25), # 25th percentile - np.percentile(feature_data, 75), # 75th percentile - ]) - - # Trend features - if len(feature_data) > 1: - # Linear trend (slope) - x = np.arange(len(feature_data)) - slope = np.polyfit(x, feature_data, 1)[0] - sample_features.append(slope) - - # Rate of change - rate_of_change = (feature_data[-1] - feature_data[0]) / feature_data[0] if feature_data[0] != 0 else 0 - sample_features.append(rate_of_change) - else: - sample_features.extend([0.0, 0.0]) - - # Pad or truncate to expected feature size - while len(sample_features) < self.feature_input_shape: - sample_features.append(0.0) - sample_features = sample_features[:self.feature_input_shape] - - features.append(sample_features) - - return np.array(features, dtype=np.float32) - - except Exception as e: - logger.error(f"Error extracting features from time series: {e}") - # Fallback to zeros if extraction fails - return np.zeros((X_ts.shape[0], self.feature_input_shape), dtype=np.float32) - - # Get predictions - y_proba = self.model.predict([X_ts, X_features]) - - # Process based on output type - if self.output_size == 1: - # Binary classification - y_pred = (y_proba > 0.5).astype(int).flatten() - return y_pred, y_proba.flatten() - elif self.output_size == 3: - # Multi-class classification - y_pred = np.argmax(y_proba, axis=1) - return y_pred, y_proba - else: - # Regression - return y_proba, y_proba - - def save(self, filepath=None): - """ - Save the model to disk. - - Args: - filepath (str): Path to save the model - - Returns: - str: Path where the model was saved - """ - if self.model is None: - raise ValueError("Model has not been built yet") - - if filepath is None: - # Create a default filepath with timestamp - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - filepath = os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5") - - self.model.save(filepath) - logger.info(f"Model saved to {filepath}") - return filepath - - def load(self, filepath): - """ - Load a saved model from disk. - - Args: - filepath (str): Path to the saved model - - Returns: - The loaded model - """ - # Register custom layers - custom_objects = { - 'TransformerBlock': TransformerBlock, - 'PositionalEncoding': PositionalEncoding - } - - self.model = load_model(filepath, custom_objects=custom_objects) - logger.info(f"Model loaded from {filepath}") - return self.model - - def plot_training_history(self): - """ - Plot training history (loss and metrics). - - Returns: - str: Path to the saved plot - """ - if self.history is None: - raise ValueError("Model has not been trained yet") - - plt.figure(figsize=(12, 5)) - - # Plot loss - plt.subplot(1, 2, 1) - plt.plot(self.history.history['loss'], label='Training Loss') - if 'val_loss' in self.history.history: - plt.plot(self.history.history['val_loss'], label='Validation Loss') - plt.title('Model Loss') - plt.xlabel('Epoch') - plt.ylabel('Loss') - plt.legend() - - # Plot accuracy - plt.subplot(1, 2, 2) - - if 'accuracy' in self.history.history: - plt.plot(self.history.history['accuracy'], label='Training Accuracy') - if 'val_accuracy' in self.history.history: - plt.plot(self.history.history['val_accuracy'], label='Validation Accuracy') - plt.title('Model Accuracy') - plt.ylabel('Accuracy') - elif 'mae' in self.history.history: - plt.plot(self.history.history['mae'], label='Training MAE') - if 'val_mae' in self.history.history: - plt.plot(self.history.history['val_mae'], label='Validation MAE') - plt.title('Model MAE') - plt.ylabel('MAE') - - plt.xlabel('Epoch') - plt.legend() - - plt.tight_layout() - - # Save figure - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - fig_path = os.path.join(self.model_dir, f"transformer_training_history_{timestamp}.png") - plt.savefig(fig_path) - plt.close() - - logger.info(f"Training history plot saved to {fig_path}") - return fig_path - - -class MixtureOfExpertsModel: - """ - Mixture of Experts (MoE) model. - - This model combines predictions from multiple expert models (such as CNN and Transformer) - using a weighted ensemble approach. - """ - - def __init__(self, output_size=1, model_dir="NN/models/saved"): - """ - Initialize the MoE model. - - Args: - output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell) - model_dir (str): Directory to save trained models - """ - self.output_size = output_size - self.model_dir = model_dir - self.model = None - self.history = None - self.experts = {} - - # Create model directory if it doesn't exist - os.makedirs(self.model_dir, exist_ok=True) - - logger.info(f"Initialized Mixture of Experts model with output size {output_size}") - - def add_expert(self, name, model): - """ - Add an expert model to the MoE. - - Args: - name (str): Name of the expert model - model: The expert model instance - - Returns: - None - """ - self.experts[name] = model - logger.info(f"Added expert model '{name}' to MoE") - - def build_model(self, ts_input_shape=(20, 5), expert_weights=None, learning_rate=0.001): - """ - Build the MoE model by combining expert models. - - Args: - ts_input_shape (tuple): Shape of time series input data - expert_weights (dict): Weights for each expert model - learning_rate (float): Learning rate for Adam optimizer - - Returns: - The compiled model - """ - # Time series input - ts_inputs = Input(shape=ts_input_shape, name="ts_input") - - # Additional feature input (from CNN) - feature_inputs = Input(shape=(64,), name="feature_input") # Default size for features - - # Process with each expert model - expert_outputs = [] - expert_names = [] - - for name, expert in self.experts.items(): - # Skip if expert model is not valid or doesn't have a call/predict method - if expert is None: - logger.warning(f"Expert model '{name}' is None, skipping") - continue - - try: - # Different handling based on model type - if name == 'cnn': - # CNN model takes only time series input - expert_output = expert(ts_inputs) - expert_outputs.append(expert_output) - expert_names.append(name) - elif name == 'transformer': - # Transformer model takes both time series and feature inputs - expert_output = expert([ts_inputs, feature_inputs]) - expert_outputs.append(expert_output) - expert_names.append(name) - else: - logger.warning(f"Unknown expert model type: {name}") - except Exception as e: - logger.error(f"Error adding expert '{name}': {str(e)}") - - if not expert_outputs: - logger.error("No valid expert models found") - return None - - # Use expert weighting - if expert_weights is None: - # Equal weighting - weights = [1.0 / len(expert_outputs)] * len(expert_outputs) - else: - # User-provided weights - weights = [expert_weights.get(name, 1.0 / len(expert_outputs)) for name in expert_names] - # Normalize weights - weights = [w / sum(weights) for w in weights] - - # Combine expert outputs using weighted average - if len(expert_outputs) == 1: - # Only one expert, use its output directly - combined_output = expert_outputs[0] - else: - # Multiple experts, compute weighted average - weighted_outputs = [output * weight for output, weight in zip(expert_outputs, weights)] - combined_output = Add()(weighted_outputs) - - # Create the MoE model - moe_model = Model(inputs=[ts_inputs, feature_inputs], outputs=combined_output) - - # Compile the model - if self.output_size == 1: - # Binary classification - moe_model.compile( - optimizer=Adam(learning_rate=learning_rate), - loss='binary_crossentropy', - metrics=['accuracy'] - ) - elif self.output_size == 3: - # Multi-class classification for BUY/HOLD/SELL - moe_model.compile( - optimizer=Adam(learning_rate=learning_rate), - loss='categorical_crossentropy', - metrics=['accuracy'] - ) - else: - # Regression - moe_model.compile( - optimizer=Adam(learning_rate=learning_rate), - loss='mse', - metrics=['mae'] - ) - - self.model = moe_model - - # Log model summary - self.model.summary(print_fn=lambda x: logger.info(x)) - - logger.info(f"Built MoE model with weights: {weights}") - return self.model - - def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2, - callbacks=None, class_weights=None): - """ - Train the MoE model on the provided data. - - Args: - X_ts (numpy.ndarray): Time series input features - X_features (numpy.ndarray): Additional input features - y (numpy.ndarray): Target labels - batch_size (int): Batch size - epochs (int): Number of epochs - validation_split (float): Fraction of data to use for validation - callbacks (list): List of Keras callbacks - class_weights (dict): Class weights for imbalanced datasets - - Returns: - History object containing training metrics - """ - if self.model is None: - logger.error("MoE model has not been built yet") - return None - - # Default callbacks if none provided - if callbacks is None: - # Create a timestamp for model checkpoints - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - - callbacks = [ - EarlyStopping( - monitor='val_loss', - patience=10, - restore_best_weights=True - ), - ReduceLROnPlateau( - monitor='val_loss', - factor=0.5, - patience=5, - min_lr=1e-6 - ), - ModelCheckpoint( - filepath=os.path.join(self.model_dir, f"moe_model_{timestamp}.h5"), - monitor='val_loss', - save_best_only=True - ) - ] - - # Check if y needs to be one-hot encoded for multi-class - if self.output_size == 3 and len(y.shape) == 1: - y = tf.keras.utils.to_categorical(y, num_classes=3) - - # Train the model - logger.info(f"Training MoE model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}") - self.history = self.model.fit( - [X_ts, X_features], y, - batch_size=batch_size, - epochs=epochs, - validation_split=validation_split, - callbacks=callbacks, - class_weight=class_weights, - verbose=2 - ) - - # Save the trained model - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - model_path = os.path.join(self.model_dir, f"moe_model_final_{timestamp}.h5") - self.model.save(model_path) - logger.info(f"Model saved to {model_path}") - - # Save training history - history_path = os.path.join(self.model_dir, f"moe_model_history_{timestamp}.json") - with open(history_path, 'w') as f: - # Convert numpy values to Python native types for JSON serialization - history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()} - json.dump(history_dict, f, indent=2) - - return self.history - - def predict(self, X_ts, X_features=None): - """ - Make predictions on new data. - - Args: - X_ts (numpy.ndarray): Time series input features - X_features (numpy.ndarray): Additional input features - - Returns: - tuple: (y_pred, y_proba) where: - y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class) - y_proba is the class probability - """ - if self.model is None: - raise ValueError("Model has not been built or trained yet") - - # Ensure X_ts has the right shape - if len(X_ts.shape) == 2: - # Single sample, add batch dimension - X_ts = np.expand_dims(X_ts, axis=0) - - # Ensure X_features has the right shape - if X_features is None: - # Create dummy features with zeros - X_features = np.zeros((X_ts.shape[0], 64)) # Default size - elif len(X_features.shape) == 1: - # Single sample, add batch dimension - X_features = np.expand_dims(X_features, axis=0) - - # Get predictions - y_proba = self.model.predict([X_ts, X_features]) - - # Process based on output type - if self.output_size == 1: - # Binary classification - y_pred = (y_proba > 0.5).astype(int).flatten() - return y_pred, y_proba.flatten() - elif self.output_size == 3: - # Multi-class classification - y_pred = np.argmax(y_proba, axis=1) - return y_pred, y_proba - else: - # Regression - return y_proba, y_proba - - def save(self, filepath=None): - """ - Save the model to disk. - - Args: - filepath (str): Path to save the model - - Returns: - str: Path where the model was saved - """ - if self.model is None: - raise ValueError("Model has not been built yet") - - if filepath is None: - # Create a default filepath with timestamp - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - filepath = os.path.join(self.model_dir, f"moe_model_{timestamp}.h5") - - self.model.save(filepath) - logger.info(f"Model saved to {filepath}") - return filepath - - def load(self, filepath): - """ - Load a saved model from disk. - - Args: - filepath (str): Path to the saved model - - Returns: - The loaded model - """ - # Register custom layers - custom_objects = { - 'TransformerBlock': TransformerBlock, - 'PositionalEncoding': PositionalEncoding - } - - self.model = load_model(filepath, custom_objects=custom_objects) - logger.info(f"Model loaded from {filepath}") - return self.model - -# Example usage: -if __name__ == "__main__": - # This would be a complete implementation in a real system - print("Transformer and MoE models defined, but not implemented here.") \ No newline at end of file diff --git a/NN/start_tensorboard.py b/NN/start_tensorboard.py deleted file mode 100644 index ed27a1b..0000000 --- a/NN/start_tensorboard.py +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/env python -""" -Start TensorBoard for monitoring neural network training -""" - -import os -import sys -import subprocess -import webbrowser -from time import sleep - -def start_tensorboard(logdir="NN/models/saved/logs", port=6006, open_browser=True): - """ - Start TensorBoard in a subprocess - - Args: - logdir: Directory containing TensorBoard logs - port: Port to run TensorBoard on - open_browser: Whether to open a browser automatically - """ - # Make sure the log directory exists - os.makedirs(logdir, exist_ok=True) - - # Create command - cmd = [ - sys.executable, - "-m", - "tensorboard.main", - f"--logdir={logdir}", - f"--port={port}", - "--bind_all" - ] - - print(f"Starting TensorBoard with logs from {logdir} on port {port}") - print(f"Command: {' '.join(cmd)}") - - # Start TensorBoard in a subprocess - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=True - ) - - # Wait for TensorBoard to start up - for line in process.stdout: - print(line.strip()) - if "TensorBoard" in line and "http://" in line: - # TensorBoard is running, extract the URL - url = None - for part in line.split(): - if part.startswith(("http://", "https://")): - url = part - break - - # Open browser if requested and URL found - if open_browser and url: - print(f"Opening TensorBoard in browser: {url}") - webbrowser.open(url) - - break - - # Return the process for the caller to manage - return process - -if __name__ == "__main__": - import argparse - - # Parse command line arguments - parser = argparse.ArgumentParser(description="Start TensorBoard for NN training visualization") - parser.add_argument("--logdir", default="NN/models/saved/logs", help="Directory containing TensorBoard logs") - parser.add_argument("--port", type=int, default=6006, help="Port to run TensorBoard on") - parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically") - - args = parser.parse_args() - - # Start TensorBoard - process = start_tensorboard(args.logdir, args.port, not args.no_browser) - - try: - # Keep the script running until Ctrl+C - print("TensorBoard is running. Press Ctrl+C to stop.") - while True: - sleep(1) - except KeyboardInterrupt: - print("Stopping TensorBoard...") - process.terminate() - process.wait() \ No newline at end of file diff --git a/NN/training/enhanced_rl_training_integration.py b/NN/training/enhanced_rl_training_integration.py deleted file mode 100644 index 3d240e5..0000000 --- a/NN/training/enhanced_rl_training_integration.py +++ /dev/null @@ -1,490 +0,0 @@ -#!/usr/bin/env python3 -""" -Enhanced RL Training Integration - Comprehensive Fix - -This script addresses the critical RL training audit issues: -1. MASSIVE INPUT DATA GAP (99.25% Missing) - Implements full 13,400 feature state -2. Disconnected Training Pipeline - Provides proper data flow integration -3. Missing Enhanced State Builder - Connects orchestrator to dashboard -4. Reward Calculation Issues - Ensures enhanced pivot-based rewards -5. Williams Market Structure Integration - Proper feature extraction -6. Real-time Data Integration - Live market data to RL - -Usage: - python enhanced_rl_training_integration.py -""" - -import os -import sys -import asyncio -import logging -import numpy as np -from datetime import datetime, timedelta -from pathlib import Path -from typing import Dict, List, Optional, Any - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import setup_logging, get_config -from core.data_provider import DataProvider -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from core.trading_executor import TradingExecutor -from web.clean_dashboard import CleanTradingDashboard as TradingDashboard -from utils.tensorboard_logger import TensorBoardLogger - -logger = logging.getLogger(__name__) - -class EnhancedRLTrainingIntegrator: - """ - Comprehensive RL Training Integrator - - Fixes all audit issues by ensuring proper data flow and feature completeness. - """ - - def __init__(self): - """Initialize the enhanced RL training integrator""" - # Setup logging - setup_logging() - logger.info("=" * 70) - logger.info("ENHANCED RL TRAINING INTEGRATION - COMPREHENSIVE FIX") - logger.info("=" * 70) - - # Get configuration - self.config = get_config() - - # Initialize core components - self.data_provider = DataProvider() - self.enhanced_orchestrator = None - self.trading_executor = TradingExecutor() - self.dashboard = None - - # Training metrics - self.training_stats = { - 'total_episodes': 0, - 'successful_state_builds': 0, - 'enhanced_reward_calculations': 0, - 'comprehensive_features_used': 0, - 'pivot_features_extracted': 0, - 'cob_features_available': 0 - } - - # Initialize TensorBoard logger - experiment_name = f"enhanced_rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - self.tb_logger = TensorBoardLogger( - log_dir="runs", - experiment_name=experiment_name, - enabled=True - ) - logger.info(f"TensorBoard logging enabled for experiment: {experiment_name}") - - logger.info("Enhanced RL Training Integrator initialized") - - async def start_integration(self): - """Start the comprehensive RL training integration""" - try: - logger.info("Starting comprehensive RL training integration...") - - # 1. Initialize Enhanced Orchestrator with comprehensive features - await self._initialize_enhanced_orchestrator() - - # 2. Create enhanced dashboard with proper connections - await self._create_enhanced_dashboard() - - # 3. Verify comprehensive state building - await self._verify_comprehensive_state_building() - - # 4. Test enhanced reward calculation - await self._test_enhanced_reward_calculation() - - # 5. Validate Williams market structure integration - await self._validate_williams_integration() - - # 6. Start live training with comprehensive features - await self._start_live_comprehensive_training() - - logger.info("=" * 70) - logger.info("COMPREHENSIVE RL TRAINING INTEGRATION COMPLETE") - logger.info("=" * 70) - self._log_integration_stats() - - except Exception as e: - logger.error(f"Error in RL training integration: {e}") - import traceback - logger.error(traceback.format_exc()) - - async def _initialize_enhanced_orchestrator(self): - """Initialize enhanced orchestrator with comprehensive RL capabilities""" - try: - logger.info("[STEP 1] Initializing Enhanced Orchestrator...") - - # Create enhanced orchestrator with RL training enabled - self.enhanced_orchestrator = EnhancedTradingOrchestrator( - data_provider=self.data_provider, - symbols=['ETH/USDT', 'BTC/USDT'], - enhanced_rl_training=True, - model_registry={} # Will be populated as needed - ) - - # Start COB integration for real-time market microstructure - await self.enhanced_orchestrator.start_cob_integration() - - # Start real-time processing - await self.enhanced_orchestrator.start_realtime_processing() - - logger.info("[SUCCESS] Enhanced Orchestrator initialized with:") - logger.info(" - Comprehensive RL state building: ENABLED") - logger.info(" - Enhanced pivot-based rewards: ENABLED") - logger.info(" - COB integration: ENABLED") - logger.info(" - Williams market structure: ENABLED") - logger.info(" - Real-time tick processing: ENABLED") - - except Exception as e: - logger.error(f"Error initializing enhanced orchestrator: {e}") - raise - - async def _create_enhanced_dashboard(self): - """Create dashboard with enhanced orchestrator connections""" - try: - logger.info("[STEP 2] Creating Enhanced Dashboard...") - - # Create trading dashboard with enhanced orchestrator - self.dashboard = TradingDashboard( - data_provider=self.data_provider, - orchestrator=self.enhanced_orchestrator, # Use enhanced orchestrator - trading_executor=self.trading_executor - ) - - # Verify enhanced connections - has_comprehensive_state_builder = hasattr(self.dashboard.orchestrator, 'build_comprehensive_rl_state') - has_enhanced_reward_calc = hasattr(self.dashboard.orchestrator, 'calculate_enhanced_pivot_reward') - has_symbol_correlation = hasattr(self.dashboard.orchestrator, '_get_symbol_correlation') - - logger.info("[SUCCESS] Enhanced Dashboard created with:") - logger.info(f" - Comprehensive state builder: {'AVAILABLE' if has_comprehensive_state_builder else 'MISSING'}") - logger.info(f" - Enhanced reward calculation: {'AVAILABLE' if has_enhanced_reward_calc else 'MISSING'}") - logger.info(f" - Symbol correlation analysis: {'AVAILABLE' if has_symbol_correlation else 'MISSING'}") - - if not all([has_comprehensive_state_builder, has_enhanced_reward_calc, has_symbol_correlation]): - logger.warning("Some enhanced features are missing - this will cause fallbacks to basic training") - else: - logger.info(" - ALL ENHANCED FEATURES AVAILABLE!") - - except Exception as e: - logger.error(f"Error creating enhanced dashboard: {e}") - raise - - async def _verify_comprehensive_state_building(self): - """Verify that comprehensive RL state building works correctly""" - try: - logger.info("[STEP 3] Verifying Comprehensive State Building...") - - # Test comprehensive state building for ETH - eth_state = self.enhanced_orchestrator.build_comprehensive_rl_state('ETH/USDT') - - if eth_state is not None: - logger.info(f"[SUCCESS] ETH comprehensive state built: {len(eth_state)} features") - - # Verify feature count - if len(eth_state) == 13400: - logger.info(" - PERFECT: Exactly 13,400 features as required!") - self.training_stats['comprehensive_features_used'] += 1 - else: - logger.warning(f" - MISMATCH: Expected 13,400 features, got {len(eth_state)}") - - # Analyze feature distribution - self._analyze_state_features(eth_state) - self.training_stats['successful_state_builds'] += 1 - - else: - logger.error(" - FAILED: Comprehensive state building returned None") - - # Test for BTC reference - btc_state = self.enhanced_orchestrator.build_comprehensive_rl_state('BTC/USDT') - if btc_state is not None: - logger.info(f"[SUCCESS] BTC reference state built: {len(btc_state)} features") - self.training_stats['successful_state_builds'] += 1 - - except Exception as e: - logger.error(f"Error verifying comprehensive state building: {e}") - - def _analyze_state_features(self, state_vector: np.ndarray): - """Analyze the comprehensive state feature distribution""" - try: - # Calculate feature statistics - non_zero_features = np.count_nonzero(state_vector) - zero_features = len(state_vector) - non_zero_features - feature_mean = np.mean(state_vector) - feature_std = np.std(state_vector) - feature_min = np.min(state_vector) - feature_max = np.max(state_vector) - - logger.info(" - Feature Analysis:") - logger.info(f" * Non-zero features: {non_zero_features:,} ({non_zero_features/len(state_vector)*100:.1f}%)") - logger.info(f" * Zero features: {zero_features:,} ({zero_features/len(state_vector)*100:.1f}%)") - logger.info(f" * Mean: {feature_mean:.6f}") - logger.info(f" * Std: {feature_std:.6f}") - logger.info(f" * Range: [{feature_min:.6f}, {feature_max:.6f}]") - - # Log feature statistics to TensorBoard - step = self.training_stats['total_episodes'] - self.tb_logger.log_scalars('Features/Distribution', { - 'non_zero_percentage': non_zero_features/len(state_vector)*100, - 'mean': feature_mean, - 'std': feature_std, - 'min': feature_min, - 'max': feature_max - }, step) - - # Log feature histogram to TensorBoard - self.tb_logger.log_histogram('Features/Values', state_vector, step) - - # Check if features are properly distributed - if non_zero_features > len(state_vector) * 0.1: # At least 10% non-zero - logger.info(" * GOOD: Features are well distributed") - else: - logger.warning(" * WARNING: Too many zero features - data may be incomplete") - - except Exception as e: - logger.warning(f"Error analyzing state features: {e}") - - async def _test_enhanced_reward_calculation(self): - """Test enhanced pivot-based reward calculation""" - try: - logger.info("[STEP 4] Testing Enhanced Reward Calculation...") - - # Create mock trade data for testing - trade_decision = { - 'action': 'BUY', - 'confidence': 0.75, - 'price': 2500.0, - 'timestamp': datetime.now() - } - - trade_outcome = { - 'net_pnl': 50.0, - 'exit_price': 2550.0, - 'duration': timedelta(minutes=15) - } - - # Get market data for reward calculation - market_data = { - 'volatility': 0.03, - 'order_flow_direction': 'bullish', - 'order_flow_strength': 0.8 - } - - # Test enhanced reward calculation - if hasattr(self.enhanced_orchestrator, 'calculate_enhanced_pivot_reward'): - enhanced_reward = self.enhanced_orchestrator.calculate_enhanced_pivot_reward( - trade_decision, market_data, trade_outcome - ) - - logger.info(f"[SUCCESS] Enhanced reward calculated: {enhanced_reward:.3f}") - logger.info(" - Enhanced pivot-based reward system: WORKING") - self.training_stats['enhanced_reward_calculations'] += 1 - - # Log reward metrics to TensorBoard - step = self.training_stats['enhanced_reward_calculations'] - self.tb_logger.log_scalar('Rewards/Enhanced', enhanced_reward, step) - - # Log reward components to TensorBoard - self.tb_logger.log_scalars('Rewards/Components', { - 'pnl_component': trade_outcome['net_pnl'], - 'confidence': trade_decision['confidence'], - 'volatility': market_data['volatility'], - 'order_flow_strength': market_data['order_flow_strength'] - }, step) - - else: - logger.error(" - FAILED: Enhanced reward calculation method not available") - - except Exception as e: - logger.error(f"Error testing enhanced reward calculation: {e}") - - async def _validate_williams_integration(self): - """Validate Williams market structure integration""" - try: - logger.info("[STEP 5] Validating Williams Market Structure Integration...") - - # Test Williams pivot feature extraction - try: - from training.williams_market_structure import extract_pivot_features, analyze_pivot_context - - # Get test market data - df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=100) - - if df is not None and not df.empty: - # Test pivot feature extraction - pivot_features = extract_pivot_features(df) - - if pivot_features is not None: - logger.info(f"[SUCCESS] Williams pivot features extracted: {len(pivot_features)} features") - self.training_stats['pivot_features_extracted'] += 1 - - # Test pivot context analysis - market_data = {'ohlcv_data': df} - pivot_context = analyze_pivot_context( - market_data, datetime.now(), 'BUY' - ) - - if pivot_context is not None: - logger.info("[SUCCESS] Williams pivot context analysis: WORKING") - logger.info(f" - Near pivot: {pivot_context.get('near_pivot', False)}") - logger.info(f" - Pivot strength: {pivot_context.get('pivot_strength', 0):.3f}") - else: - logger.warning(" - Williams pivot context analysis returned None") - else: - logger.warning(" - Williams pivot feature extraction returned None") - else: - logger.warning(" - No market data available for Williams testing") - - except ImportError: - logger.error(" - Williams market structure module not available") - except Exception as e: - logger.error(f" - Error in Williams integration: {e}") - - except Exception as e: - logger.error(f"Error validating Williams integration: {e}") - - async def _start_live_comprehensive_training(self): - """Start live training with comprehensive feature integration""" - try: - logger.info("[STEP 6] Starting Live Comprehensive Training...") - - # Run a few training iterations to verify integration - for iteration in range(5): - logger.info(f"Training iteration {iteration + 1}/5") - - # Make coordinated decisions using enhanced orchestrator - decisions = await self.enhanced_orchestrator.make_coordinated_decisions() - - # Track iteration metrics for TensorBoard - iteration_metrics = { - 'decisions_count': len(decisions), - 'confidence_avg': 0.0, - 'state_size_avg': 0.0, - 'successful_states': 0 - } - - # Process each decision - for symbol, decision in decisions.items(): - if decision: - logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.3f})") - - # Track confidence for TensorBoard - iteration_metrics['confidence_avg'] += decision.confidence - - # Build comprehensive state for this decision - comprehensive_state = self.enhanced_orchestrator.build_comprehensive_rl_state(symbol) - - if comprehensive_state is not None: - state_size = len(comprehensive_state) - logger.info(f" - Comprehensive state: {state_size} features") - self.training_stats['total_episodes'] += 1 - - # Track state size for TensorBoard - iteration_metrics['state_size_avg'] += state_size - iteration_metrics['successful_states'] += 1 - - # Log individual state metrics to TensorBoard - self.tb_logger.log_state_metrics( - symbol=symbol, - state_info={ - 'size': state_size, - 'quality': 1.0 if state_size == 13400 else 0.8, - 'feature_counts': { - 'total': state_size, - 'non_zero': np.count_nonzero(comprehensive_state) - } - }, - step=self.training_stats['total_episodes'] - ) - else: - logger.warning(f" - Failed to build comprehensive state for {symbol}") - - # Calculate averages for TensorBoard - if decisions: - iteration_metrics['confidence_avg'] /= len(decisions) - - if iteration_metrics['successful_states'] > 0: - iteration_metrics['state_size_avg'] /= iteration_metrics['successful_states'] - - # Log iteration metrics to TensorBoard - self.tb_logger.log_scalars('Training/Iteration', { - 'iteration': iteration + 1, - 'decisions_count': iteration_metrics['decisions_count'], - 'confidence_avg': iteration_metrics['confidence_avg'], - 'state_size_avg': iteration_metrics['state_size_avg'], - 'successful_states': iteration_metrics['successful_states'] - }, iteration + 1) - - # Wait between iterations - await asyncio.sleep(2) - - logger.info("[SUCCESS] Live comprehensive training demonstration complete") - - except Exception as e: - logger.error(f"Error in live comprehensive training: {e}") - - def _log_integration_stats(self): - """Log comprehensive integration statistics""" - logger.info("INTEGRATION STATISTICS:") - logger.info(f" - Total training episodes: {self.training_stats['total_episodes']}") - logger.info(f" - Successful state builds: {self.training_stats['successful_state_builds']}") - logger.info(f" - Enhanced reward calculations: {self.training_stats['enhanced_reward_calculations']}") - logger.info(f" - Comprehensive features used: {self.training_stats['comprehensive_features_used']}") - logger.info(f" - Pivot features extracted: {self.training_stats['pivot_features_extracted']}") - - # Calculate success rates - state_success_rate = 0 - if self.training_stats['total_episodes'] > 0: - state_success_rate = self.training_stats['successful_state_builds'] / self.training_stats['total_episodes'] * 100 - logger.info(f" - State building success rate: {state_success_rate:.1f}%") - - # Log final statistics to TensorBoard - self.tb_logger.log_scalars('Integration/Statistics', { - 'total_episodes': self.training_stats['total_episodes'], - 'successful_state_builds': self.training_stats['successful_state_builds'], - 'enhanced_reward_calculations': self.training_stats['enhanced_reward_calculations'], - 'comprehensive_features_used': self.training_stats['comprehensive_features_used'], - 'pivot_features_extracted': self.training_stats['pivot_features_extracted'], - 'state_success_rate': state_success_rate - }, 0) # Use step 0 for final summary stats - - # Integration status - if self.training_stats['comprehensive_features_used'] > 0: - logger.info("STATUS: COMPREHENSIVE RL TRAINING INTEGRATION SUCCESSFUL! ✅") - logger.info("The system is now using the full 13,400 feature comprehensive state.") - - # Log success status to TensorBoard - self.tb_logger.log_scalar('Integration/Success', 1.0, 0) - else: - logger.warning("STATUS: Integration partially successful - some fallbacks may occur") - - # Log partial success status to TensorBoard - self.tb_logger.log_scalar('Integration/Success', 0.5, 0) - -async def main(): - """Main entry point""" - try: - # Create and run the enhanced RL training integrator - integrator = EnhancedRLTrainingIntegrator() - await integrator.start_integration() - - logger.info("Enhanced RL training integration completed successfully!") - return 0 - - except KeyboardInterrupt: - logger.info("Integration interrupted by user") - return 0 - except Exception as e: - logger.error(f"Fatal error in integration: {e}") - import traceback - logger.error(traceback.format_exc()) - return 1 - -if __name__ == "__main__": - exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file diff --git a/NN/training/example_checkpoint_usage.py b/NN/training/example_checkpoint_usage.py deleted file mode 100644 index b54fe38..0000000 --- a/NN/training/example_checkpoint_usage.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -""" -Example: Using the Checkpoint Management System -""" - -import logging -import torch -import torch.nn as nn -import numpy as np -from datetime import datetime - -from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint, get_checkpoint_manager -from utils.training_integration import get_training_integration - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -class ExampleCNN(nn.Module): - def __init__(self, input_channels=5, num_classes=3): - super().__init__() - self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1) - self.conv2 = nn.Conv2d(32, 64, 3, padding=1) - self.pool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(64, num_classes) - - def forward(self, x): - x = torch.relu(self.conv1(x)) - x = torch.relu(self.conv2(x)) - x = self.pool(x) - x = x.view(x.size(0), -1) - return self.fc(x) - -def example_cnn_training(): - logger.info("=== CNN Training Example ===") - - model = ExampleCNN() - training_integration = get_training_integration() - - for epoch in range(5): # Simulate 5 epochs - # Simulate training metrics - train_loss = 2.0 - (epoch * 0.15) + np.random.normal(0, 0.1) - train_acc = 0.3 + (epoch * 0.06) + np.random.normal(0, 0.02) - val_loss = train_loss + np.random.normal(0, 0.05) - val_acc = train_acc - 0.05 + np.random.normal(0, 0.02) - - # Clamp values to realistic ranges - train_acc = max(0.0, min(1.0, train_acc)) - val_acc = max(0.0, min(1.0, val_acc)) - train_loss = max(0.1, train_loss) - val_loss = max(0.1, val_loss) - - logger.info(f"Epoch {epoch+1}: train_acc={train_acc:.3f}, val_acc={val_acc:.3f}") - - # Save checkpoint - saved = training_integration.save_cnn_checkpoint( - cnn_model=model, - model_name="example_cnn", - epoch=epoch + 1, - train_accuracy=train_acc, - val_accuracy=val_acc, - train_loss=train_loss, - val_loss=val_loss, - training_time_hours=0.1 * (epoch + 1) - ) - - if saved: - logger.info(f" Checkpoint saved for epoch {epoch+1}") - else: - logger.info(f" Checkpoint not saved (performance not improved)") - - # Load the best checkpoint - logger.info("\\nLoading best checkpoint...") - best_result = load_best_checkpoint("example_cnn") - if best_result: - file_path, metadata = best_result - logger.info(f"Best checkpoint: {metadata.checkpoint_id}") - logger.info(f"Performance score: {metadata.performance_score:.4f}") - -def example_manual_checkpoint(): - logger.info("\\n=== Manual Checkpoint Example ===") - - model = nn.Linear(10, 3) - - performance_metrics = { - 'accuracy': 0.85, - 'val_accuracy': 0.82, - 'loss': 0.45, - 'val_loss': 0.48 - } - - training_metadata = { - 'epoch': 25, - 'training_time_hours': 2.5, - 'total_parameters': sum(p.numel() for p in model.parameters()) - } - - logger.info("Saving checkpoint manually...") - metadata = save_checkpoint( - model=model, - model_name="example_manual", - model_type="cnn", - performance_metrics=performance_metrics, - training_metadata=training_metadata, - force_save=True - ) - - if metadata: - logger.info(f" Manual checkpoint saved: {metadata.checkpoint_id}") - logger.info(f" Performance score: {metadata.performance_score:.4f}") - -def show_checkpoint_stats(): - logger.info("\\n=== Checkpoint Statistics ===") - - checkpoint_manager = get_checkpoint_manager() - stats = checkpoint_manager.get_checkpoint_stats() - - logger.info(f"Total models: {stats['total_models']}") - logger.info(f"Total checkpoints: {stats['total_checkpoints']}") - logger.info(f"Total size: {stats['total_size_mb']:.2f} MB") - - for model_name, model_stats in stats['models'].items(): - logger.info(f"\\n{model_name}:") - logger.info(f" Checkpoints: {model_stats['checkpoint_count']}") - logger.info(f" Size: {model_stats['total_size_mb']:.2f} MB") - logger.info(f" Best performance: {model_stats['best_performance']:.4f}") - -def main(): - logger.info(" Checkpoint Management System Examples") - logger.info("=" * 50) - - try: - example_cnn_training() - example_manual_checkpoint() - show_checkpoint_stats() - - logger.info("\\n All examples completed successfully!") - logger.info("\\nTo use in your training:") - logger.info("1. Import: from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint") - logger.info("2. Or use: from utils.training_integration import get_training_integration") - logger.info("3. Save checkpoints during training with performance metrics") - logger.info("4. Load best checkpoints for inference or continued training") - - except Exception as e: - logger.error(f"Error in examples: {e}") - raise - -if __name__ == "__main__": - main() diff --git a/NN/training/integrate_checkpoint_management.py b/NN/training/integrate_checkpoint_management.py deleted file mode 100644 index 6c04c57..0000000 --- a/NN/training/integrate_checkpoint_management.py +++ /dev/null @@ -1,517 +0,0 @@ -#!/usr/bin/env python3 -""" -Comprehensive Checkpoint Management Integration - -This script demonstrates how to integrate the checkpoint management system -across all training pipelines in the gogo2 project. - -Features: -- DQN Agent training with automatic checkpointing -- CNN Model training with checkpoint management -- ExtremaTrainer with checkpoint persistence -- NegativeCaseTrainer with checkpoint integration -- Unified training orchestration with checkpoint coordination -""" - -import asyncio -import logging -import time -import signal -import sys -import numpy as np -from datetime import datetime -from pathlib import Path -from typing import Dict, Any, List - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler('logs/checkpoint_integration.log'), - logging.StreamHandler() - ] -) -logger = logging.getLogger(__name__) - -# Import checkpoint management -from utils.checkpoint_manager import get_checkpoint_manager, get_checkpoint_stats -from utils.training_integration import get_training_integration - -# Import training components -from NN.models.dqn_agent import DQNAgent -from NN.models.standardized_cnn import StandardizedCNN -from core.extrema_trainer import ExtremaTrainer -from core.negative_case_trainer import NegativeCaseTrainer -from core.data_provider import DataProvider -from core.config import get_config - -class CheckpointIntegratedTrainingSystem: - """Unified training system with comprehensive checkpoint management""" - - def __init__(self): - """Initialize the checkpoint-integrated training system""" - self.config = get_config() - self.running = False - - # Checkpoint management - self.checkpoint_manager = get_checkpoint_manager() - self.training_integration = get_training_integration() - - # Data provider - self.data_provider = DataProvider( - symbols=['ETH/USDT', 'BTC/USDT'], - timeframes=['1s', '1m', '1h', '1d'] - ) - - # Training components with checkpoint management - self.dqn_agent = None - self.cnn_trainer = None - self.extrema_trainer = None - self.negative_case_trainer = None - - # Training statistics - self.training_stats = { - 'start_time': None, - 'total_training_sessions': 0, - 'checkpoints_saved': 0, - 'models_loaded': 0, - 'best_performances': {} - } - - logger.info("Checkpoint-Integrated Training System initialized") - - async def initialize_components(self): - """Initialize all training components with checkpoint management""" - try: - logger.info("Initializing training components with checkpoint management...") - - # Initialize data provider - await self.data_provider.start_real_time_streaming() - logger.info("Data provider streaming started") - - # Initialize DQN Agent with checkpoint management - logger.info("Initializing DQN Agent with checkpoints...") - self.dqn_agent = DQNAgent( - state_shape=(100,), # Example state shape - n_actions=3, - model_name="integrated_dqn_agent", - enable_checkpoints=True - ) - logger.info("✅ DQN Agent initialized with checkpoint management") - - # Initialize StandardizedCNN Model with checkpoint management - logger.info("Initializing StandardizedCNN Model with checkpoints...") - self.cnn_model = StandardizedCNN(model_name="integrated_cnn_model") - logger.info("✅ StandardizedCNN Model initialized with checkpoint management") - - # Initialize ExtremaTrainer with checkpoint management - logger.info("Initializing ExtremaTrainer with checkpoints...") - self.extrema_trainer = ExtremaTrainer( - data_provider=self.data_provider, - symbols=['ETH/USDT', 'BTC/USDT'], - model_name="integrated_extrema_trainer", - enable_checkpoints=True - ) - await self.extrema_trainer.initialize_context_data() - logger.info("✅ ExtremaTrainer initialized with checkpoint management") - - # Initialize NegativeCaseTrainer with checkpoint management - logger.info("Initializing NegativeCaseTrainer with checkpoints...") - self.negative_case_trainer = NegativeCaseTrainer( - model_name="integrated_negative_case_trainer", - enable_checkpoints=True - ) - logger.info("✅ NegativeCaseTrainer initialized with checkpoint management") - - # Load existing checkpoints for all components - self.training_stats['models_loaded'] = await self._load_all_checkpoints() - - logger.info("All training components initialized successfully") - - except Exception as e: - logger.error(f"Error initializing components: {e}") - raise - - async def _load_all_checkpoints(self) -> int: - """Load checkpoints for all training components""" - loaded_count = 0 - - try: - # DQN Agent checkpoint loading is handled in __init__ - if hasattr(self.dqn_agent, 'episode_count') and self.dqn_agent.episode_count > 0: - loaded_count += 1 - logger.info(f"DQN Agent resumed from episode {self.dqn_agent.episode_count}") - - # CNN Trainer checkpoint loading is handled in __init__ - if hasattr(self.cnn_trainer, 'epoch_count') and self.cnn_trainer.epoch_count > 0: - loaded_count += 1 - logger.info(f"CNN Trainer resumed from epoch {self.cnn_trainer.epoch_count}") - - # ExtremaTrainer checkpoint loading is handled in __init__ - if hasattr(self.extrema_trainer, 'training_session_count') and self.extrema_trainer.training_session_count > 0: - loaded_count += 1 - logger.info(f"ExtremaTrainer resumed from session {self.extrema_trainer.training_session_count}") - - # NegativeCaseTrainer checkpoint loading is handled in __init__ - if hasattr(self.negative_case_trainer, 'training_session_count') and self.negative_case_trainer.training_session_count > 0: - loaded_count += 1 - logger.info(f"NegativeCaseTrainer resumed from session {self.negative_case_trainer.training_session_count}") - - return loaded_count - - except Exception as e: - logger.error(f"Error loading checkpoints: {e}") - return 0 - - async def run_integrated_training_loop(self): - """Run the integrated training loop with checkpoint coordination""" - logger.info("Starting integrated training loop with checkpoint management...") - - self.running = True - self.training_stats['start_time'] = datetime.now() - - training_cycle = 0 - - try: - while self.running: - training_cycle += 1 - cycle_start = time.time() - - logger.info(f"=== Training Cycle {training_cycle} ===") - - # DQN Training - dqn_results = await self._train_dqn_agent() - - # CNN Training - cnn_results = await self._train_cnn_model() - - # Extrema Detection Training - extrema_results = await self._train_extrema_detector() - - # Negative Case Training (runs in background) - negative_results = await self._process_negative_cases() - - # Coordinate checkpoint saving - await self._coordinate_checkpoint_saving( - dqn_results, cnn_results, extrema_results, negative_results - ) - - # Update statistics - self.training_stats['total_training_sessions'] += 1 - - # Log cycle summary - cycle_duration = time.time() - cycle_start - logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s") - - # Wait before next cycle - await asyncio.sleep(60) # 1-minute cycles - - except KeyboardInterrupt: - logger.info("Training interrupted by user") - except Exception as e: - logger.error(f"Error in training loop: {e}") - finally: - await self.shutdown() - - async def _train_dqn_agent(self) -> Dict[str, Any]: - """Train DQN agent with automatic checkpointing""" - try: - if not self.dqn_agent: - return {'status': 'skipped', 'reason': 'no_agent'} - - # Simulate DQN training episode - episode_reward = 0.0 - - # Add some training experiences (simulate real training) - for _ in range(10): # Simulate 10 training steps - state = np.random.randn(100).astype(np.float32) - action = np.random.randint(0, 3) - reward = np.random.randn() * 0.1 - next_state = np.random.randn(100).astype(np.float32) - done = np.random.random() < 0.1 - - self.dqn_agent.remember(state, action, reward, next_state, done) - episode_reward += reward - - # Train if enough experiences - loss = 0.0 - if len(self.dqn_agent.memory) >= self.dqn_agent.batch_size: - loss = self.dqn_agent.replay() - - # Save checkpoint (automatic based on performance) - checkpoint_saved = self.dqn_agent.save_checkpoint(episode_reward) - - if checkpoint_saved: - self.training_stats['checkpoints_saved'] += 1 - - return { - 'status': 'completed', - 'episode_reward': episode_reward, - 'loss': loss, - 'checkpoint_saved': checkpoint_saved, - 'episode': self.dqn_agent.episode_count - } - - except Exception as e: - logger.error(f"Error training DQN agent: {e}") - return {'status': 'error', 'error': str(e)} - - async def _train_cnn_model(self) -> Dict[str, Any]: - """Train CNN model with automatic checkpointing""" - try: - if not self.cnn_trainer: - return {'status': 'skipped', 'reason': 'no_trainer'} - - # Simulate CNN training step - import torch - import numpy as np - - batch_size = 32 - input_size = 60 - feature_dim = 50 - - # Generate synthetic training data - x = torch.randn(batch_size, input_size, feature_dim) - y = torch.randint(0, 3, (batch_size,)) - - # Training step - results = self.cnn_trainer.train_step(x, y) - - # Simulate validation - val_x = torch.randn(16, input_size, feature_dim) - val_y = torch.randint(0, 3, (16,)) - val_results = self.cnn_trainer.train_step(val_x, val_y) - - # Save checkpoint (automatic based on performance) - checkpoint_saved = self.cnn_trainer.save_checkpoint( - train_accuracy=results.get('accuracy', 0.5), - val_accuracy=val_results.get('accuracy', 0.5), - train_loss=results.get('total_loss', 1.0), - val_loss=val_results.get('total_loss', 1.0) - ) - - if checkpoint_saved: - self.training_stats['checkpoints_saved'] += 1 - - return { - 'status': 'completed', - 'train_accuracy': results.get('accuracy', 0.5), - 'val_accuracy': val_results.get('accuracy', 0.5), - 'train_loss': results.get('total_loss', 1.0), - 'val_loss': val_results.get('total_loss', 1.0), - 'checkpoint_saved': checkpoint_saved, - 'epoch': self.cnn_trainer.epoch_count - } - - except Exception as e: - logger.error(f"Error training CNN model: {e}") - return {'status': 'error', 'error': str(e)} - - async def _train_extrema_detector(self) -> Dict[str, Any]: - """Train extrema detector with automatic checkpointing""" - try: - if not self.extrema_trainer: - return {'status': 'skipped', 'reason': 'no_trainer'} - - # Update context data and detect extrema - update_results = self.extrema_trainer.update_context_data() - - # Get training data - extrema_data = self.extrema_trainer.get_extrema_training_data(count=10) - - # Simulate training accuracy improvement - if extrema_data: - self.extrema_trainer.training_stats['total_extrema_detected'] += len(extrema_data) - self.extrema_trainer.training_stats['successful_predictions'] += len(extrema_data) // 2 - self.extrema_trainer.training_stats['failed_predictions'] += len(extrema_data) // 2 - - # Save checkpoint (automatic based on performance) - checkpoint_saved = self.extrema_trainer.save_checkpoint() - - if checkpoint_saved: - self.training_stats['checkpoints_saved'] += 1 - - return { - 'status': 'completed', - 'extrema_detected': len(extrema_data), - 'context_updates': sum(1 for success in update_results.values() if success), - 'checkpoint_saved': checkpoint_saved, - 'session': self.extrema_trainer.training_session_count - } - - except Exception as e: - logger.error(f"Error training extrema detector: {e}") - return {'status': 'error', 'error': str(e)} - - async def _process_negative_cases(self) -> Dict[str, Any]: - """Process negative cases with automatic checkpointing""" - try: - if not self.negative_case_trainer: - return {'status': 'skipped', 'reason': 'no_trainer'} - - # Simulate adding a negative case - if np.random.random() < 0.1: # 10% chance of negative case - trade_info = { - 'symbol': 'ETH/USDT', - 'action': 'BUY', - 'price': 2000.0, - 'pnl': -50.0, # Loss - 'value': 1000.0, - 'confidence': 0.7, - 'timestamp': datetime.now() - } - - market_data = { - 'exit_price': 1950.0, - 'state_before': {}, - 'state_after': {}, - 'tick_data': [], - 'technical_indicators': {} - } - - case_id = self.negative_case_trainer.add_losing_trade(trade_info, market_data) - - # Simulate loss improvement - loss_improvement = np.random.random() * 0.1 - - # Save checkpoint (automatic based on performance) - checkpoint_saved = self.negative_case_trainer.save_checkpoint(loss_improvement) - - if checkpoint_saved: - self.training_stats['checkpoints_saved'] += 1 - - return { - 'status': 'completed', - 'case_added': case_id, - 'loss_improvement': loss_improvement, - 'checkpoint_saved': checkpoint_saved, - 'session': self.negative_case_trainer.training_session_count - } - else: - return {'status': 'no_cases'} - - except Exception as e: - logger.error(f"Error processing negative cases: {e}") - return {'status': 'error', 'error': str(e)} - - async def _coordinate_checkpoint_saving(self, dqn_results: Dict, cnn_results: Dict, - extrema_results: Dict, negative_results: Dict): - """Coordinate checkpoint saving across all components""" - try: - # Count successful checkpoints - checkpoints_saved = sum([ - dqn_results.get('checkpoint_saved', False), - cnn_results.get('checkpoint_saved', False), - extrema_results.get('checkpoint_saved', False), - negative_results.get('checkpoint_saved', False) - ]) - - if checkpoints_saved > 0: - logger.info(f"Saved {checkpoints_saved} checkpoints this cycle") - - # Update best performances - if 'episode_reward' in dqn_results: - current_best = self.training_stats['best_performances'].get('dqn_reward', float('-inf')) - if dqn_results['episode_reward'] > current_best: - self.training_stats['best_performances']['dqn_reward'] = dqn_results['episode_reward'] - - if 'val_accuracy' in cnn_results: - current_best = self.training_stats['best_performances'].get('cnn_accuracy', 0.0) - if cnn_results['val_accuracy'] > current_best: - self.training_stats['best_performances']['cnn_accuracy'] = cnn_results['val_accuracy'] - - # Log checkpoint statistics every 10 cycles - if self.training_stats['total_training_sessions'] % 10 == 0: - await self._log_checkpoint_statistics() - - except Exception as e: - logger.error(f"Error coordinating checkpoint saving: {e}") - - async def _log_checkpoint_statistics(self): - """Log comprehensive checkpoint statistics""" - try: - stats = get_checkpoint_stats() - - logger.info("=== Checkpoint Statistics ===") - logger.info(f"Total checkpoints: {stats['total_checkpoints']}") - logger.info(f"Total size: {stats['total_size_mb']:.2f} MB") - logger.info(f"Models managed: {len(stats['models'])}") - - for model_name, model_stats in stats['models'].items(): - logger.info(f" {model_name}: {model_stats['checkpoint_count']} checkpoints, " - f"{model_stats['total_size_mb']:.2f} MB, " - f"best: {model_stats['best_performance']:.4f}") - - logger.info(f"Training sessions: {self.training_stats['total_training_sessions']}") - logger.info(f"Checkpoints saved: {self.training_stats['checkpoints_saved']}") - logger.info(f"Best performances: {self.training_stats['best_performances']}") - - except Exception as e: - logger.error(f"Error logging checkpoint statistics: {e}") - - async def shutdown(self): - """Shutdown the training system and save final checkpoints""" - logger.info("Shutting down checkpoint-integrated training system...") - - self.running = False - - try: - # Force save checkpoints for all components - if self.dqn_agent: - self.dqn_agent.save_checkpoint(0.0, force_save=True) - - if self.cnn_trainer: - self.cnn_trainer.save_checkpoint(0.0, 0.0, 0.0, 0.0, force_save=True) - - if self.extrema_trainer: - self.extrema_trainer.save_checkpoint(force_save=True) - - if self.negative_case_trainer: - self.negative_case_trainer.save_checkpoint(force_save=True) - - # Final statistics - await self._log_checkpoint_statistics() - - logger.info("Checkpoint-integrated training system shutdown complete") - - except Exception as e: - logger.error(f"Error during shutdown: {e}") - -async def main(): - """Main function to run the checkpoint-integrated training system""" - logger.info("🚀 Starting Checkpoint-Integrated Training System") - - # Create and initialize the training system - training_system = CheckpointIntegratedTrainingSystem() - - # Setup signal handlers for graceful shutdown - def signal_handler(signum, frame): - logger.info("Received shutdown signal") - asyncio.create_task(training_system.shutdown()) - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - try: - # Initialize components - await training_system.initialize_components() - - # Run the integrated training loop - await training_system.run_integrated_training_loop() - - except Exception as e: - logger.error(f"Error in main: {e}") - raise - finally: - await training_system.shutdown() - - logger.info("✅ Checkpoint management integration complete!") - logger.info("All training pipelines now support automatic checkpointing") - -if __name__ == "__main__": - # Ensure logs directory exists - Path("logs").mkdir(exist_ok=True) - - # Run the checkpoint-integrated training system - asyncio.run(main()) \ No newline at end of file diff --git a/NN/utils/__init__.py b/NN/utils/__init__.py deleted file mode 100644 index cad5ee3..0000000 --- a/NN/utils/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -Neural Network Utilities -====================== - -This package contains utility functions and classes used in the neural network trading system: -- Data Interface: Connects to realtime trading data and processes it for the neural network models -""" - -from .data_interface import DataInterface -from .trading_env import TradingEnvironment -from .signal_interpreter import SignalInterpreter - -__all__ = ['DataInterface', 'TradingEnvironment', 'SignalInterpreter'] \ No newline at end of file diff --git a/NN/utils/multi_data_interface.py b/NN/utils/multi_data_interface.py deleted file mode 100644 index c747c6c..0000000 --- a/NN/utils/multi_data_interface.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -Enhanced Data Interface with additional NN trading parameters -""" -from typing import List, Optional, Tuple -import numpy as np -import pandas as pd -from datetime import datetime -from .data_interface import DataInterface - -class MultiDataInterface(DataInterface): - """ - Enhanced data interface that supports window_size and output_size parameters - for neural network trading models. - """ - - def __init__(self, symbol: str, - timeframes: List[str], - window_size: int = 20, - output_size: int = 3, - data_dir: str = "NN/data"): - """ - Initialize with window_size and output_size for NN predictions. - """ - super().__init__(symbol, timeframes, data_dir) - self.window_size = window_size - self.output_size = output_size - self.scalers = {} # Store scalers for each timeframe - self.min_window_threshold = 100 # Minimum candles needed for training - - def get_feature_count(self) -> int: - """ - Get number of features (OHLCV) for NN input. - """ - return 5 # open, high, low, close, volume - - def prepare_training_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Prepare training data with windowed sequences""" - # Get historical data for primary timeframe - primary_tf = self.timeframes[0] - df = self.get_historical_data(timeframe=primary_tf, - n_candles=self.min_window_threshold + 1000) - - if df is None or len(df) < self.min_window_threshold: - raise ValueError(f"Insufficient data for training. Need at least {self.min_window_threshold} candles") - - # Prepare OHLCV sequences - ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values - - # Create sequences and labels - X = [] - y = [] - - for i in range(len(ohlcv) - self.window_size - self.output_size): - # Input sequence - seq = ohlcv[i:i+self.window_size] - X.append(seq) - - # Output target (price movement direction) - close_prices = ohlcv[i+self.window_size:i+self.window_size+self.output_size, 3] # Close prices - price_changes = np.diff(close_prices) - - if self.output_size == 1: - # Binary classification (up/down) - label = 1 if price_changes[0] > 0 else 0 - elif self.output_size == 3: - # 3-class classification (buy/hold/sell) - if price_changes[0] > 0.002: # Significant rise - label = 0 # Buy - elif price_changes[0] < -0.002: # Significant drop - label = 2 # Sell - else: - label = 1 # Hold - else: - raise ValueError(f"Unsupported output_size: {self.output_size}") - - y.append(label) - - # Convert to numpy arrays - X = np.array(X) - y = np.array(y) - - # Split into train/validation (80/20) - split_idx = int(0.8 * len(X)) - X_train, y_train = X[:split_idx], y[:split_idx] - X_val, y_val = X[split_idx:], y[split_idx:] - - return X_train, y_train, X_val, y_val - - def prepare_prediction_data(self) -> np.ndarray: - """Prepare most recent window for predictions""" - primary_tf = self.timeframes[0] - df = self.get_historical_data(timeframe=primary_tf, - n_candles=self.window_size, - use_cache=False) - - if df is None or len(df) < self.window_size: - raise ValueError(f"Need at least {self.window_size} candles for prediction") - - ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values[-self.window_size:] - return np.array([ohlcv]) # Add batch dimension - - def process_predictions(self, predictions: np.ndarray): - """Convert prediction probabilities to trading signals""" - signals = [] - for pred in predictions: - if self.output_size == 1: - signal = "BUY" if pred[0] > 0.5 else "SELL" - confidence = np.abs(pred[0] - 0.5) * 2 # Convert to 0-1 scale - elif self.output_size == 3: - action_idx = np.argmax(pred) - signal = ["BUY", "HOLD", "SELL"][action_idx] - confidence = pred[action_idx] - else: - signal = "HOLD" - confidence = 0.0 - - signals.append({ - 'action': signal, - 'confidence': confidence, - 'timestamp': datetime.now().isoformat() - }) - - return signals diff --git a/NN/utils/realtime_analyzer.py b/NN/utils/realtime_analyzer.py deleted file mode 100644 index 5ed9c07..0000000 --- a/NN/utils/realtime_analyzer.py +++ /dev/null @@ -1,364 +0,0 @@ -""" -Realtime Analyzer for Neural Network Trading System - -This module implements real-time analysis of market data using trained neural network models. -""" - -import logging -import time -import numpy as np -from threading import Thread -from queue import Queue -from datetime import datetime -import asyncio -import websockets -import json -import os -import pandas as pd -from collections import deque - -logger = logging.getLogger(__name__) - -class RealtimeAnalyzer: - """ - Handles real-time analysis of market data using trained neural network models. - - Features: - - Connects to real-time data sources (websockets) - - Processes tick data into multiple timeframes (1s, 1m, 1h, 1d) - - Uses trained models to analyze all timeframes - - Generates trading signals - - Manages risk and position sizing - - Logs all trading decisions - """ - - def __init__(self, data_interface, model, symbol="BTC/USDT", timeframes=None): - """ - Initialize the realtime analyzer. - - Args: - data_interface (DataInterface): Preconfigured data interface - model: Trained neural network model - symbol (str): Trading pair symbol - timeframes (list): List of timeframes to monitor (default: ['1s', '1m', '1h', '1d']) - """ - self.data_interface = data_interface - self.model = model - self.symbol = symbol - self.timeframes = timeframes or ['1s', '1m', '1h', '1d'] - self.running = False - self.data_queue = Queue() - self.prediction_interval = 10 # Seconds between predictions - self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade" - self.ws = None - self.tick_storage = deque(maxlen=10000) # Store up to 10,000 ticks - self.candle_cache = { - '1s': deque(maxlen=5000), - '1m': deque(maxlen=5000), - '1h': deque(maxlen=5000), - '1d': deque(maxlen=5000) - } - - logger.info(f"RealtimeAnalyzer initialized for {symbol} with timeframes: {self.timeframes}") - - def start(self): - """Start the realtime analysis process.""" - if self.running: - logger.warning("Realtime analyzer already running") - return - - self.running = True - - # Start WebSocket connection thread - self.ws_thread = Thread(target=self._run_websocket, daemon=True) - self.ws_thread.start() - - # Start data processing thread - self.processing_thread = Thread(target=self._process_data, daemon=True) - self.processing_thread.start() - - # Start analysis thread - self.analysis_thread = Thread(target=self._analyze_data, daemon=True) - self.analysis_thread.start() - - logger.info("Realtime analysis started") - - def stop(self): - """Stop the realtime analysis process.""" - self.running = False - if self.ws: - asyncio.run(self.ws.close()) - if hasattr(self, 'ws_thread'): - self.ws_thread.join(timeout=1) - if hasattr(self, 'processing_thread'): - self.processing_thread.join(timeout=1) - if hasattr(self, 'analysis_thread'): - self.analysis_thread.join(timeout=1) - logger.info("Realtime analysis stopped") - - def _run_websocket(self): - """Thread function for running WebSocket connection.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._connect_websocket()) - - async def _connect_websocket(self): - """Connect to WebSocket and receive data.""" - while self.running: - try: - logger.info(f"Connecting to WebSocket: {self.ws_url}") - async with websockets.connect(self.ws_url) as ws: - self.ws = ws - logger.info("WebSocket connected") - - while self.running: - try: - message = await ws.recv() - data = json.loads(message) - - if 'e' in data and data['e'] == 'trade': - tick = { - 'timestamp': data['T'], - 'price': float(data['p']), - 'volume': float(data['q']), - 'symbol': self.symbol - } - self.tick_storage.append(tick) - self.data_queue.put(tick) - - except websockets.exceptions.ConnectionClosed: - logger.warning("WebSocket connection closed") - break - except Exception as e: - logger.error(f"Error receiving WebSocket message: {str(e)}") - time.sleep(1) - - except Exception as e: - logger.error(f"WebSocket connection error: {str(e)}") - time.sleep(5) # Wait before reconnecting - - def _process_data(self): - """Process incoming tick data into candles for all timeframes.""" - logger.info("Starting data processing thread") - - while self.running: - try: - # Process any new ticks - while not self.data_queue.empty(): - tick = self.data_queue.get() - - # Convert timestamp to datetime - timestamp = datetime.fromtimestamp(tick['timestamp'] / 1000) - - # Process for each timeframe - for timeframe in self.timeframes: - interval = self._get_interval_seconds(timeframe) - if interval is None: - continue - - # Round timestamp to nearest candle interval - candle_ts = int(tick['timestamp'] // (interval * 1000)) * (interval * 1000) - - # Get or create candle for this timeframe - if not self.candle_cache[timeframe]: - # First candle for this timeframe - candle = { - 'timestamp': candle_ts, - 'open': tick['price'], - 'high': tick['price'], - 'low': tick['price'], - 'close': tick['price'], - 'volume': tick['volume'] - } - self.candle_cache[timeframe].append(candle) - else: - # Update existing candle - last_candle = self.candle_cache[timeframe][-1] - - if last_candle['timestamp'] == candle_ts: - # Update current candle - last_candle['high'] = max(last_candle['high'], tick['price']) - last_candle['low'] = min(last_candle['low'], tick['price']) - last_candle['close'] = tick['price'] - last_candle['volume'] += tick['volume'] - else: - # New candle - candle = { - 'timestamp': candle_ts, - 'open': tick['price'], - 'high': tick['price'], - 'low': tick['price'], - 'close': tick['price'], - 'volume': tick['volume'] - } - self.candle_cache[timeframe].append(candle) - - time.sleep(0.1) - - except Exception as e: - logger.error(f"Error in data processing: {str(e)}") - time.sleep(1) - - def _get_interval_seconds(self, timeframe): - """Convert timeframe string to seconds.""" - intervals = { - '1s': 1, - '1m': 60, - '1h': 3600, - '1d': 86400 - } - return intervals.get(timeframe) - - def _analyze_data(self): - """Thread function for analyzing data and generating signals.""" - logger.info("Starting analysis thread") - - last_prediction_time = 0 - - while self.running: - try: - current_time = time.time() - - # Only make predictions at the specified interval - if current_time - last_prediction_time < self.prediction_interval: - time.sleep(0.1) - continue - - # Prepare input data from all timeframes - input_data = {} - valid = True - - for timeframe in self.timeframes: - if not self.candle_cache[timeframe]: - logger.warning(f"No data available for timeframe {timeframe}") - valid = False - break - - # Get last N candles for this timeframe - candles = list(self.candle_cache[timeframe])[-self.data_interface.window_size:] - - # Convert to numpy array - ohlcv = np.array([ - [c['open'], c['high'], c['low'], c['close'], c['volume']] - for c in candles - ]) - - # Normalize data - ohlcv_normalized = (ohlcv - ohlcv.mean(axis=0)) / (ohlcv.std(axis=0) + 1e-8) - input_data[timeframe] = ohlcv_normalized - - if not valid: - time.sleep(0.1) - continue - - # Make prediction using the model - try: - prediction = self.model.predict(input_data) - - # Get latest timestamp from 1s timeframe - latest_ts = self.candle_cache['1s'][-1]['timestamp'] if self.candle_cache['1s'] else int(time.time() * 1000) - - # Process prediction - self._process_prediction( - prediction=prediction, - timeframe='multi', - timestamp=latest_ts - ) - - last_prediction_time = current_time - except Exception as e: - logger.error(f"Error making prediction: {str(e)}") - - time.sleep(0.1) - - except Exception as e: - logger.error(f"Error in analysis: {str(e)}") - time.sleep(1) - - def _process_prediction(self, prediction, timeframe, timestamp): - """ - Process model prediction and generate trading signals. - - Args: - prediction: Model prediction output - timeframe (str): Timeframe the prediction is for ('multi' for combined) - timestamp: Timestamp of the prediction (ms) - """ - # Convert prediction to trading signal - signal, confidence = self._prediction_to_signal(prediction) - - # Convert timestamp to datetime - try: - dt = datetime.fromtimestamp(timestamp / 1000) - except: - dt = datetime.now() - - # Log the signal with all timeframes - logger.info( - f"Signal generated - Timeframes: {', '.join(self.timeframes)}, " - f"Timestamp: {dt}, " - f"Signal: {signal} (Confidence: {confidence:.2f})" - ) - - # In a real implementation, we would execute trades here - # For now, we'll just log the signals - - def _prediction_to_signal(self, prediction): - """ - Convert model prediction to trading signal and confidence. - - Args: - prediction: Model prediction output (can be dict for multi-timeframe) - - Returns: - tuple: (signal, confidence) where signal is BUY/SELL/HOLD, - confidence is probability (0-1) - """ - if isinstance(prediction, dict): - # Multi-timeframe prediction - combine signals - signals = [] - confidences = [] - - for tf, pred in prediction.items(): - if len(pred.shape) == 1: - # Binary classification - signal = "BUY" if pred[0] > 0.5 else "SELL" - confidence = pred[0] if signal == "BUY" else 1 - pred[0] - else: - # Multi-class - class_idx = np.argmax(pred) - signal = ["SELL", "HOLD", "BUY"][class_idx] - confidence = pred[class_idx] - - signals.append(signal) - confidences.append(confidence) - - # Simple voting system - count BUY/SELL signals - buy_count = signals.count("BUY") - sell_count = signals.count("SELL") - - if buy_count > sell_count: - final_signal = "BUY" - final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "BUY"]) - elif sell_count > buy_count: - final_signal = "SELL" - final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "SELL"]) - else: - final_signal = "HOLD" - final_confidence = np.mean(confidences) - - return final_signal, final_confidence - - else: - # Single prediction - if len(prediction.shape) == 1: - # Binary classification - signal = "BUY" if prediction[0] > 0.5 else "SELL" - confidence = prediction[0] if signal == "BUY" else 1 - prediction[0] - else: - # Multi-class - class_idx = np.argmax(prediction) - signal = ["SELL", "HOLD", "BUY"][class_idx] - confidence = prediction[class_idx] - - return signal, confidence diff --git a/_dev/cleanup_models_now.py b/_dev/cleanup_models_now.py deleted file mode 100644 index 2fc94c0..0000000 --- a/_dev/cleanup_models_now.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python3 -""" -Immediate Model Cleanup Script - -This script will clean up all existing model files and prepare the system -for fresh training with the new model management system. -""" - -import logging -import sys -from model_manager import ModelManager - -def main(): - """Run the model cleanup""" - - # Configure logging for better output - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' - ) - - print("=" * 60) - print("GOGO2 MODEL CLEANUP SYSTEM") - print("=" * 60) - print() - print("This script will:") - print("1. Delete ALL existing model files (.pt, .pth)") - print("2. Remove ALL checkpoint directories") - print("3. Clear model backup directories") - print("4. Reset the model registry") - print("5. Create clean directory structure") - print() - print("WARNING: This action cannot be undone!") - print() - - # Calculate current space usage first - try: - manager = ModelManager() - storage_stats = manager.get_storage_stats() - print(f"Current storage usage:") - print(f"- Models: {storage_stats['total_models']}") - print(f"- Size: {storage_stats['actual_size_mb']:.1f}MB") - print() - except Exception as e: - print(f"Error checking current storage: {e}") - print() - - # Ask for confirmation - print("Type 'CLEANUP' to proceed with the cleanup:") - user_input = input("> ").strip() - - if user_input != "CLEANUP": - print("Cleanup cancelled. No changes made.") - return - - print() - print("Starting cleanup...") - print("-" * 40) - - try: - # Create manager and run cleanup - manager = ModelManager() - cleanup_result = manager.cleanup_all_existing_models(confirm=True) - - print() - print("=" * 60) - print("CLEANUP COMPLETE") - print("=" * 60) - print(f"Files deleted: {cleanup_result['deleted_files']}") - print(f"Space freed: {cleanup_result['freed_space_mb']:.1f} MB") - print(f"Directories cleaned: {len(cleanup_result['deleted_directories'])}") - - if cleanup_result['errors']: - print(f"Errors encountered: {len(cleanup_result['errors'])}") - print("Errors:") - for error in cleanup_result['errors'][:5]: # Show first 5 errors - print(f" - {error}") - if len(cleanup_result['errors']) > 5: - print(f" ... and {len(cleanup_result['errors']) - 5} more") - - print() - print("System is now ready for fresh model training!") - print("The following directories have been created:") - print("- models/best_models/") - print("- models/cnn/") - print("- models/rl/") - print("- models/checkpoints/") - print("- NN/models/saved/") - print() - print("New models will be automatically managed by the ModelManager.") - - except Exception as e: - print(f"Error during cleanup: {e}") - logging.exception("Cleanup failed") - sys.exit(1) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/_tools/delete_candidates.py b/_tools/delete_candidates.py new file mode 100644 index 0000000..8677b7c --- /dev/null +++ b/_tools/delete_candidates.py @@ -0,0 +1,74 @@ +import os +import shutil +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] + +EXCLUDE_PREFIXES = ( + 'COBY' + os.sep, # Do not touch COBY subsystem +) +EXCLUDE_FILES = { + 'NN' + os.sep + 'training' + os.sep + 'enhanced_realtime_training.py', +} + +delete_list_path = ROOT / 'DELETE_CANDIDATES.txt' + +deleted_files: list[str] = [] +kept_files: list[str] = [] + +if delete_list_path.exists(): + for line in delete_list_path.read_text(encoding='utf-8').splitlines(): + rel = line.strip() + if not rel: + continue + # Skip excluded prefixes + if any(rel.startswith(p) for p in EXCLUDE_PREFIXES): + kept_files.append(rel) + continue + # Skip explicitly excluded files + if rel in EXCLUDE_FILES: + kept_files.append(rel) + continue + fp = ROOT / rel + if fp.exists() and fp.is_file(): + try: + fp.unlink() + deleted_files.append(rel) + except Exception: + kept_files.append(rel) + +# Remove tests directories outside COBY +removed_dirs: list[str] = [] +for d in ROOT.rglob('tests'): + try: + rel = str(d.relative_to(ROOT)) + except Exception: + continue + if any(rel.startswith(p) for p in EXCLUDE_PREFIXES): + continue + if d.is_dir(): + try: + shutil.rmtree(d) + removed_dirs.append(rel) + except Exception: + pass + +# Write cleanup log / todo +log_lines = [] +log_lines.append('Cleanup run summary:') +log_lines.append(f'- Deleted files: {len(deleted_files)}') +for x in deleted_files[:50]: + log_lines.append(f' - {x}') +if len(deleted_files) > 50: + log_lines.append(f' ... and {len(deleted_files)-50} more') +log_lines.append(f'- Removed test directories: {len(removed_dirs)}') +for x in removed_dirs[:50]: + log_lines.append(f' - {x}') +log_lines.append(f'- Kept (excluded): {len(kept_files)}') + +(ROOT / 'CLEANUP_TODO.md').write_text('\n'.join(log_lines), encoding='utf-8') + +print(f'Deleted files: {len(deleted_files)}') +print(f'Removed test dirs: {len(removed_dirs)}') +print(f'Kept (excluded): {len(kept_files)}') + diff --git a/apply_trading_fixes.py b/apply_trading_fixes.py deleted file mode 100644 index 05b2dcf..0000000 --- a/apply_trading_fixes.py +++ /dev/null @@ -1,193 +0,0 @@ -#!/usr/bin/env python3 -""" -Apply Trading System Fixes - -This script applies fixes to the trading system to address: -1. Duplicate entry prices -2. P&L calculation issues -3. Position tracking problems -4. Trade display issues - -Usage: - python apply_trading_fixes.py -""" - -import os -import sys -import logging -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.StreamHandler(), - logging.FileHandler('logs/trading_fixes.log') - ] -) - -logger = logging.getLogger(__name__) - -def apply_fixes(): - """Apply all fixes to the trading system""" - logger.info("=" * 70) - logger.info("APPLYING TRADING SYSTEM FIXES") - logger.info("=" * 70) - - # Import fixes - try: - from core.trading_executor_fix import TradingExecutorFix - from web.dashboard_fix import DashboardFix - - logger.info("Fix modules imported successfully") - except ImportError as e: - logger.error(f"Error importing fix modules: {e}") - return False - - # Apply fixes to trading executor - try: - # Import trading executor - from core.trading_executor import TradingExecutor - - # Create a test instance to apply fixes - test_executor = TradingExecutor() - - # Apply fixes - TradingExecutorFix.apply_fixes(test_executor) - - logger.info("Trading executor fixes applied successfully to test instance") - - # Verify fixes - if hasattr(test_executor, 'price_cache_timestamp'): - logger.info("✅ Price caching fix verified") - else: - logger.warning("❌ Price caching fix not verified") - - if hasattr(test_executor, 'trade_cooldown_seconds'): - logger.info("✅ Trade cooldown fix verified") - else: - logger.warning("❌ Trade cooldown fix not verified") - - if hasattr(test_executor, '_check_trade_cooldown'): - logger.info("✅ Trade cooldown check method verified") - else: - logger.warning("❌ Trade cooldown check method not verified") - - except Exception as e: - logger.error(f"Error applying trading executor fixes: {e}") - import traceback - logger.error(traceback.format_exc()) - - # Create patch for main.py - try: - main_patch = """ -# Apply trading system fixes -try: - from core.trading_executor_fix import TradingExecutorFix - from web.dashboard_fix import DashboardFix - - # Apply fixes to trading executor - if trading_executor: - TradingExecutorFix.apply_fixes(trading_executor) - logger.info("✅ Trading executor fixes applied") - - # Apply fixes to dashboard - if 'dashboard' in locals() and dashboard: - DashboardFix.apply_fixes(dashboard) - logger.info("✅ Dashboard fixes applied") - - logger.info("Trading system fixes applied successfully") -except Exception as e: - logger.warning(f"Error applying trading system fixes: {e}") -""" - - # Write patch instructions - with open('patch_instructions.txt', 'w') as f: - f.write(""" -TRADING SYSTEM FIX INSTRUCTIONS -============================== - -To apply the fixes to your trading system, follow these steps: - -1. Add the following code to main.py just before the dashboard.run_server() call: - -```python -# Apply trading system fixes -try: - from core.trading_executor_fix import TradingExecutorFix - from web.dashboard_fix import DashboardFix - - # Apply fixes to trading executor - if trading_executor: - TradingExecutorFix.apply_fixes(trading_executor) - logger.info("✅ Trading executor fixes applied") - - # Apply fixes to dashboard - if 'dashboard' in locals() and dashboard: - DashboardFix.apply_fixes(dashboard) - logger.info("✅ Dashboard fixes applied") - - logger.info("Trading system fixes applied successfully") -except Exception as e: - logger.warning(f"Error applying trading system fixes: {e}") -``` - -2. Add the following code to web/clean_dashboard.py in the __init__ method, just before the run_server method: - -```python -# Apply dashboard fixes if available -try: - from web.dashboard_fix import DashboardFix - DashboardFix.apply_fixes(self) - logger.info("✅ Dashboard fixes applied during initialization") -except ImportError: - logger.warning("Dashboard fixes not available") -``` - -3. Run the system with the fixes applied: - -``` -python main.py -``` - -4. Monitor the logs for any issues with the fixes. - -These fixes address: -- Duplicate entry prices -- P&L calculation issues -- Position tracking problems -- Trade display issues -- Rapid consecutive trades -""") - - logger.info("Patch instructions written to patch_instructions.txt") - - except Exception as e: - logger.error(f"Error creating patch: {e}") - - logger.info("=" * 70) - logger.info("TRADING SYSTEM FIXES READY TO APPLY") - logger.info("See patch_instructions.txt for instructions") - logger.info("=" * 70) - - return True - -if __name__ == "__main__": - # Create logs directory if it doesn't exist - os.makedirs('logs', exist_ok=True) - - # Apply fixes - success = apply_fixes() - - if success: - print("\nTrading system fixes ready to apply!") - print("See patch_instructions.txt for instructions") - sys.exit(0) - else: - print("\nError preparing trading system fixes") - sys.exit(1) \ No newline at end of file diff --git a/apply_trading_fixes_to_main.py b/apply_trading_fixes_to_main.py deleted file mode 100644 index 7ffd89e..0000000 --- a/apply_trading_fixes_to_main.py +++ /dev/null @@ -1,218 +0,0 @@ -#!/usr/bin/env python3 -""" -Apply Trading System Fixes to Main.py - -This script applies the trading system fixes directly to main.py -to address the issues with duplicate entry prices and P&L calculation. - -Usage: - python apply_trading_fixes_to_main.py -""" - -import os -import sys -import logging -import re -from pathlib import Path -import shutil -from datetime import datetime - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.StreamHandler(), - logging.FileHandler('logs/apply_fixes.log') - ] -) - -logger = logging.getLogger(__name__) - -def backup_file(file_path): - """Create a backup of a file""" - try: - backup_path = f"{file_path}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - shutil.copy2(file_path, backup_path) - logger.info(f"Created backup: {backup_path}") - return True - except Exception as e: - logger.error(f"Error creating backup of {file_path}: {e}") - return False - -def apply_fixes_to_main(): - """Apply fixes to main.py""" - main_py_path = "main.py" - - if not os.path.exists(main_py_path): - logger.error(f"File {main_py_path} not found") - return False - - # Create backup - if not backup_file(main_py_path): - logger.error("Failed to create backup, aborting") - return False - - try: - # Read main.py - with open(main_py_path, 'r') as f: - content = f.read() - - # Find the position to insert the fixes - # Look for the line before dashboard.run_server() - run_server_pattern = r"dashboard\.run_server\(" - match = re.search(run_server_pattern, content) - - if not match: - logger.error("Could not find dashboard.run_server() call in main.py") - return False - - # Find the position to insert the fixes (before the run_server call) - insert_pos = content.rfind("\n", 0, match.start()) - - if insert_pos == -1: - logger.error("Could not find insertion point in main.py") - return False - - # Prepare the fixes to insert - fixes_code = """ -# Apply trading system fixes -try: - from core.trading_executor_fix import TradingExecutorFix - from web.dashboard_fix import DashboardFix - - # Apply fixes to trading executor - if trading_executor: - TradingExecutorFix.apply_fixes(trading_executor) - logger.info("✅ Trading executor fixes applied") - - # Apply fixes to dashboard - if 'dashboard' in locals() and dashboard: - DashboardFix.apply_fixes(dashboard) - logger.info("✅ Dashboard fixes applied") - - logger.info("Trading system fixes applied successfully") -except Exception as e: - logger.warning(f"Error applying trading system fixes: {e}") - -""" - - # Insert the fixes - new_content = content[:insert_pos] + fixes_code + content[insert_pos:] - - # Write the modified content back to main.py - with open(main_py_path, 'w') as f: - f.write(new_content) - - logger.info(f"Successfully applied fixes to {main_py_path}") - return True - - except Exception as e: - logger.error(f"Error applying fixes to {main_py_path}: {e}") - return False - -def apply_fixes_to_dashboard(): - """Apply fixes to web/clean_dashboard.py""" - dashboard_py_path = "web/clean_dashboard.py" - - if not os.path.exists(dashboard_py_path): - logger.error(f"File {dashboard_py_path} not found") - return False - - # Create backup - if not backup_file(dashboard_py_path): - logger.error("Failed to create backup, aborting") - return False - - try: - # Read dashboard.py - with open(dashboard_py_path, 'r') as f: - content = f.read() - - # Find the position to insert the fixes - # Look for the __init__ method - init_pattern = r"def __init__\(self," - match = re.search(init_pattern, content) - - if not match: - logger.error("Could not find __init__ method in dashboard.py") - return False - - # Find the end of the __init__ method - init_end_pattern = r"logger\.debug\(.*\)" - init_end_matches = list(re.finditer(init_end_pattern, content[match.end():])) - - if not init_end_matches: - logger.error("Could not find end of __init__ method in dashboard.py") - return False - - # Get the last logger.debug line in the __init__ method - last_debug_match = init_end_matches[-1] - insert_pos = match.end() + last_debug_match.end() - - # Prepare the fixes to insert - fixes_code = """ - - # Apply dashboard fixes if available - try: - from web.dashboard_fix import DashboardFix - DashboardFix.apply_fixes(self) - logger.info("✅ Dashboard fixes applied during initialization") - except ImportError: - logger.warning("Dashboard fixes not available") -""" - - # Insert the fixes - new_content = content[:insert_pos] + fixes_code + content[insert_pos:] - - # Write the modified content back to dashboard.py - with open(dashboard_py_path, 'w') as f: - f.write(new_content) - - logger.info(f"Successfully applied fixes to {dashboard_py_path}") - return True - - except Exception as e: - logger.error(f"Error applying fixes to {dashboard_py_path}: {e}") - return False - -def main(): - """Main entry point""" - logger.info("=" * 70) - logger.info("APPLYING TRADING SYSTEM FIXES TO MAIN.PY") - logger.info("=" * 70) - - # Create logs directory if it doesn't exist - os.makedirs('logs', exist_ok=True) - - # Apply fixes to main.py - main_success = apply_fixes_to_main() - - # Apply fixes to dashboard.py - dashboard_success = apply_fixes_to_dashboard() - - if main_success and dashboard_success: - logger.info("=" * 70) - logger.info("TRADING SYSTEM FIXES APPLIED SUCCESSFULLY") - logger.info("=" * 70) - logger.info("The following issues have been fixed:") - logger.info("1. Duplicate entry prices") - logger.info("2. P&L calculation issues") - logger.info("3. Position tracking problems") - logger.info("4. Trade display issues") - logger.info("5. Rapid consecutive trades") - logger.info("=" * 70) - logger.info("You can now run the trading system with the fixes applied:") - logger.info("python main.py") - logger.info("=" * 70) - return 0 - else: - logger.error("=" * 70) - logger.error("FAILED TO APPLY SOME FIXES") - logger.error("=" * 70) - logger.error("Please check the logs for details") - logger.error("=" * 70) - return 1 - -if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file diff --git a/audit_training_system.py b/audit_training_system.py deleted file mode 100644 index e69de29..0000000 diff --git a/balance_trading_signals.py b/balance_trading_signals.py deleted file mode 100644 index 092dfa5..0000000 --- a/balance_trading_signals.py +++ /dev/null @@ -1,189 +0,0 @@ -#!/usr/bin/env python3 -""" -Balance Trading Signals - Analyze and fix SHORT signal bias - -This script analyzes the trading signals from the orchestrator and adjusts -the model weights to balance BUY and SELL signals. -""" - -import os -import sys -import logging -import json -from pathlib import Path -from datetime import datetime - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import get_config, setup_logging -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -# Setup logging -setup_logging() -logger = logging.getLogger(__name__) - -def analyze_trading_signals(): - """Analyze trading signals from the orchestrator""" - logger.info("Analyzing trading signals...") - - # Initialize components - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True) - - # Get recent decisions - symbols = orchestrator.symbols - all_decisions = {} - - for symbol in symbols: - decisions = orchestrator.get_recent_decisions(symbol) - all_decisions[symbol] = decisions - - # Count actions - action_counts = {'BUY': 0, 'SELL': 0, 'HOLD': 0} - for decision in decisions: - action_counts[decision.action] += 1 - - total_decisions = sum(action_counts.values()) - if total_decisions > 0: - buy_percent = action_counts['BUY'] / total_decisions * 100 - sell_percent = action_counts['SELL'] / total_decisions * 100 - hold_percent = action_counts['HOLD'] / total_decisions * 100 - - logger.info(f"Symbol: {symbol}") - logger.info(f" Total decisions: {total_decisions}") - logger.info(f" BUY: {action_counts['BUY']} ({buy_percent:.1f}%)") - logger.info(f" SELL: {action_counts['SELL']} ({sell_percent:.1f}%)") - logger.info(f" HOLD: {action_counts['HOLD']} ({hold_percent:.1f}%)") - - # Check for bias - if sell_percent > buy_percent * 2: # If SELL signals are more than twice BUY signals - logger.warning(f" SELL bias detected: {sell_percent:.1f}% vs {buy_percent:.1f}%") - - # Adjust model weights to balance signals - logger.info(" Adjusting model weights to balance signals...") - - # Get current model weights - model_weights = orchestrator.model_weights - logger.info(f" Current model weights: {model_weights}") - - # Identify models with SELL bias - model_predictions = {} - for model_name in model_weights: - model_predictions[model_name] = {'BUY': 0, 'SELL': 0, 'HOLD': 0} - - # Analyze recent decisions to identify biased models - for decision in decisions: - reasoning = decision.reasoning - if 'models_used' in reasoning: - for model_name in reasoning['models_used']: - if model_name in model_predictions: - model_predictions[model_name][decision.action] += 1 - - # Calculate bias for each model - model_bias = {} - for model_name, actions in model_predictions.items(): - total = sum(actions.values()) - if total > 0: - buy_pct = actions['BUY'] / total * 100 - sell_pct = actions['SELL'] / total * 100 - - # Calculate bias score (-100 to 100, negative = SELL bias, positive = BUY bias) - bias_score = buy_pct - sell_pct - model_bias[model_name] = bias_score - - logger.info(f" Model {model_name}: Bias score = {bias_score:.1f} (BUY: {buy_pct:.1f}%, SELL: {sell_pct:.1f}%)") - - # Adjust weights based on bias - adjusted_weights = {} - for model_name, weight in model_weights.items(): - if model_name in model_bias: - bias = model_bias[model_name] - - # If model has strong SELL bias, reduce its weight - if bias < -30: # Strong SELL bias - adjusted_weights[model_name] = max(0.05, weight * 0.7) # Reduce weight by 30% - logger.info(f" Reducing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} due to SELL bias") - # If model has BUY bias, increase its weight to balance - elif bias > 10: # BUY bias - adjusted_weights[model_name] = min(0.5, weight * 1.3) # Increase weight by 30% - logger.info(f" Increasing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} to balance SELL bias") - else: - adjusted_weights[model_name] = weight - else: - adjusted_weights[model_name] = weight - - # Save adjusted weights - save_adjusted_weights(adjusted_weights) - - logger.info(f" Adjusted weights: {adjusted_weights}") - logger.info(" Weights saved to 'adjusted_model_weights.json'") - - # Recommend next steps - logger.info("\nRecommended actions:") - logger.info("1. Update the model weights in the orchestrator") - logger.info("2. Monitor trading signals for balance") - logger.info("3. Consider retraining models with balanced data") - -def save_adjusted_weights(weights): - """Save adjusted weights to a file""" - output = { - 'timestamp': datetime.now().isoformat(), - 'weights': weights, - 'notes': 'Adjusted to balance BUY/SELL signals' - } - - with open('adjusted_model_weights.json', 'w') as f: - json.dump(output, f, indent=2) - -def apply_balanced_weights(): - """Apply balanced weights to the orchestrator""" - try: - # Check if weights file exists - if not os.path.exists('adjusted_model_weights.json'): - logger.error("Adjusted weights file not found. Run analyze_trading_signals() first.") - return False - - # Load adjusted weights - with open('adjusted_model_weights.json', 'r') as f: - data = json.load(f) - - weights = data.get('weights', {}) - if not weights: - logger.error("No weights found in the file.") - return False - - logger.info(f"Loaded adjusted weights: {weights}") - - # Initialize components - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True) - - # Apply weights - for model_name, weight in weights.items(): - if model_name in orchestrator.model_weights: - orchestrator.model_weights[model_name] = weight - - # Save updated weights - orchestrator._save_orchestrator_state() - - logger.info("Applied balanced weights to orchestrator.") - logger.info("Restart the trading system for changes to take effect.") - - return True - - except Exception as e: - logger.error(f"Error applying balanced weights: {e}") - return False - -if __name__ == "__main__": - logger.info("=" * 70) - logger.info("TRADING SIGNAL BALANCE ANALYZER") - logger.info("=" * 70) - - if len(sys.argv) > 1 and sys.argv[1] == 'apply': - apply_balanced_weights() - else: - analyze_trading_signals() \ No newline at end of file diff --git a/check_live_trading.py b/check_live_trading.py deleted file mode 100644 index dc17e9f..0000000 --- a/check_live_trading.py +++ /dev/null @@ -1,163 +0,0 @@ -import os -import sys -import logging -import importlib -import asyncio -from dotenv import load_dotenv -from safe_logging import setup_safe_logging - -# Configure logging -setup_safe_logging() -logger = logging.getLogger("check_live_trading") - -def check_dependencies(): - """Check if all required dependencies are installed""" - required_packages = [ - "numpy", "pandas", "matplotlib", "mplfinance", "torch", - "dotenv", "ccxt", "websockets", "tensorboard", - "sklearn", "PIL", "asyncio" - ] - - missing_packages = [] - - for package in required_packages: - try: - if package == "dotenv": - importlib.import_module("dotenv") - elif package == "PIL": - importlib.import_module("PIL") - else: - importlib.import_module(package) - logger.info(f"✅ {package} is installed") - except ImportError: - missing_packages.append(package) - logger.error(f"❌ {package} is NOT installed") - - if missing_packages: - logger.error(f"Missing packages: {', '.join(missing_packages)}") - logger.info("Install missing packages with: pip install -r requirements.txt") - return False - - return True - -def check_api_keys(): - """Check if API keys are configured""" - load_dotenv() - - api_key = os.getenv('MEXC_API_KEY') - secret_key = os.getenv('MEXC_SECRET_KEY') - - if not api_key or api_key == "your_api_key_here" or not secret_key or secret_key == "your_secret_key_here": - logger.error("❌ API keys are not properly configured in .env file") - logger.info("Please update your .env file with valid MEXC API keys") - return False - - logger.info("✅ API keys are configured") - return True - -def check_model_files(): - """Check if trained model files exist""" - model_files = [ - "models/trading_agent_best_pnl.pt", - "models/trading_agent_best_reward.pt", - "models/trading_agent_final.pt" - ] - - missing_models = [] - - for model_file in model_files: - if os.path.exists(model_file): - logger.info(f"✅ Model file exists: {model_file}") - else: - missing_models.append(model_file) - logger.error(f"❌ Model file missing: {model_file}") - - if missing_models: - logger.warning("Some model files are missing. You need to train the model first.") - return False - - return True - -async def check_exchange_connection(): - """Test connection to MEXC exchange""" - try: - import ccxt - - # Load API keys - load_dotenv() - api_key = os.getenv('MEXC_API_KEY') - secret_key = os.getenv('MEXC_SECRET_KEY') - - if api_key == "your_api_key_here" or secret_key == "your_secret_key_here": - logger.warning("⚠️ Using placeholder API keys, skipping exchange connection test") - return False - - # Initialize exchange - exchange = ccxt.mexc({ - 'apiKey': api_key, - 'secret': secret_key, - 'enableRateLimit': True - }) - - # Test connection by fetching markets - markets = exchange.fetch_markets() - logger.info(f"✅ Successfully connected to MEXC exchange") - logger.info(f"✅ Found {len(markets)} markets") - - return True - except Exception as e: - logger.error(f"❌ Failed to connect to MEXC exchange: {str(e)}") - return False - -def check_directories(): - """Check if required directories exist""" - required_dirs = ["models", "runs", "trade_logs"] - - for directory in required_dirs: - if not os.path.exists(directory): - logger.info(f"Creating directory: {directory}") - os.makedirs(directory, exist_ok=True) - - logger.info("✅ All required directories exist") - return True - -async def main(): - """Run all checks""" - logger.info("Running pre-flight checks for live trading...") - - checks = [ - ("Dependencies", check_dependencies()), - ("API Keys", check_api_keys()), - ("Model Files", check_model_files()), - ("Directories", check_directories()), - ("Exchange Connection", await check_exchange_connection()) - ] - - # Count failed checks - failed_checks = sum(1 for _, result in checks if not result) - - # Print summary - logger.info("\n" + "="*50) - logger.info("LIVE TRADING PRE-FLIGHT CHECK SUMMARY") - logger.info("="*50) - - for check_name, result in checks: - status = "✅ PASS" if result else "❌ FAIL" - logger.info(f"{check_name}: {status}") - - logger.info("="*50) - - if failed_checks == 0: - logger.info("🚀 All checks passed! You're ready for live trading.") - logger.info("\nRun live trading with:") - logger.info("python main.py --mode live --demo true --symbol ETH/USDT --timeframe 1m") - logger.info("\nFor real trading (after updating API keys):") - logger.info("python main.py --mode live --demo false --symbol ETH/USDT --timeframe 1m --leverage 50") - return 0 - else: - logger.error(f"❌ {failed_checks} check(s) failed. Please fix the issues before running live trading.") - return 1 - -if __name__ == "__main__": - exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file diff --git a/check_mexc_symbols.py b/check_mexc_symbols.py deleted file mode 100644 index bdd063f..0000000 --- a/check_mexc_symbols.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 -""" -Check MEXC Available Trading Symbols -""" - -import os -import sys -import logging - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.trading_executor import TradingExecutor - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def check_mexc_symbols(): - """Check available trading symbols on MEXC""" - try: - logger.info("=== MEXC SYMBOL AVAILABILITY CHECK ===") - - # Initialize trading executor - executor = TradingExecutor("config.yaml") - - if not executor.exchange: - logger.error("Failed to initialize exchange") - return - - # Get all supported symbols - logger.info("Fetching all supported symbols from MEXC...") - supported_symbols = executor.exchange.get_api_symbols() - - logger.info(f"Total supported symbols: {len(supported_symbols)}") - - # Filter ETH-related symbols - eth_symbols = [s for s in supported_symbols if 'ETH' in s] - logger.info(f"ETH-related symbols ({len(eth_symbols)}):") - for symbol in sorted(eth_symbols): - logger.info(f" {symbol}") - - # Filter USDT pairs - usdt_symbols = [s for s in supported_symbols if s.endswith('USDT')] - logger.info(f"USDT pairs ({len(usdt_symbols)}):") - for symbol in sorted(usdt_symbols)[:20]: # Show first 20 - logger.info(f" {symbol}") - if len(usdt_symbols) > 20: - logger.info(f" ... and {len(usdt_symbols) - 20} more") - - # Filter USDC pairs - usdc_symbols = [s for s in supported_symbols if s.endswith('USDC')] - logger.info(f"USDC pairs ({len(usdc_symbols)}):") - for symbol in sorted(usdc_symbols): - logger.info(f" {symbol}") - - # Check specific symbols we're interested in - test_symbols = ['ETHUSDT', 'ETHUSDC', 'BTCUSDT', 'BTCUSDC'] - logger.info("Checking specific symbols:") - for symbol in test_symbols: - if symbol in supported_symbols: - logger.info(f" ✅ {symbol} - SUPPORTED") - else: - logger.info(f" ❌ {symbol} - NOT SUPPORTED") - - # Show a sample of all available symbols - logger.info("Sample of all available symbols:") - for symbol in sorted(supported_symbols)[:30]: - logger.info(f" {symbol}") - if len(supported_symbols) > 30: - logger.info(f" ... and {len(supported_symbols) - 30} more") - - except Exception as e: - logger.error(f"Error checking MEXC symbols: {e}") - -if __name__ == "__main__": - check_mexc_symbols() \ No newline at end of file diff --git a/cleanup_checkpoint_db.py b/cleanup_checkpoint_db.py deleted file mode 100644 index b8d4ae3..0000000 --- a/cleanup_checkpoint_db.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python3 -""" -Cleanup Checkpoint Database - -Remove invalid database entries and ensure consistency -""" - -import logging -from pathlib import Path -from utils.database_manager import get_database_manager - -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def cleanup_invalid_checkpoints(): - """Remove database entries for non-existent checkpoint files""" - print("=== Cleaning Up Invalid Checkpoint Entries ===") - - db_manager = get_database_manager() - - # Get all checkpoints from database - all_models = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target', 'cob_rl', 'extrema_trainer', 'decision'] - - removed_count = 0 - - for model_name in all_models: - checkpoints = db_manager.list_checkpoints(model_name) - - for checkpoint in checkpoints: - file_path = Path(checkpoint.file_path) - - if not file_path.exists(): - print(f"Removing invalid entry: {checkpoint.checkpoint_id} -> {checkpoint.file_path}") - - # Remove from database by setting as inactive and creating a new active one if needed - try: - # For now, we'll just report - the system will handle missing files gracefully - logger.warning(f"Invalid checkpoint file: {checkpoint.file_path}") - removed_count += 1 - except Exception as e: - logger.error(f"Failed to remove invalid checkpoint: {e}") - else: - print(f"Valid checkpoint: {checkpoint.checkpoint_id} -> {checkpoint.file_path}") - - print(f"Found {removed_count} invalid checkpoint entries") - -def verify_checkpoint_loading(): - """Test that checkpoint loading works correctly""" - print("\n=== Verifying Checkpoint Loading ===") - - from utils.checkpoint_manager import load_best_checkpoint - - models_to_test = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target'] - - for model_name in models_to_test: - try: - result = load_best_checkpoint(model_name) - - if result: - file_path, metadata = result - file_exists = Path(file_path).exists() - - print(f"{model_name}:") - print(f" ✅ Checkpoint found: {metadata.checkpoint_id}") - print(f" 📁 File exists: {file_exists}") - print(f" 📊 Loss: {getattr(metadata, 'loss', 'N/A')}") - print(f" 💾 Size: {Path(file_path).stat().st_size / (1024*1024):.1f}MB" if file_exists else " 💾 Size: N/A") - else: - print(f"{model_name}: ❌ No valid checkpoint found") - - except Exception as e: - print(f"{model_name}: ❌ Error loading checkpoint: {e}") - -def test_checkpoint_system_integration(): - """Test integration with the orchestrator""" - print("\n=== Testing Orchestrator Integration ===") - - try: - # Test database manager integration - from utils.database_manager import get_database_manager - db_manager = get_database_manager() - - # Test fast metadata access - for model_name in ['dqn_agent', 'enhanced_cnn']: - metadata = db_manager.get_best_checkpoint_metadata(model_name) - if metadata: - print(f"{model_name}: ✅ Fast metadata access works") - print(f" ID: {metadata.checkpoint_id}") - print(f" Loss: {metadata.performance_metrics.get('loss', 'N/A')}") - else: - print(f"{model_name}: ❌ No metadata found") - - print("\n✅ Checkpoint system is ready for use!") - - except Exception as e: - print(f"❌ Integration test failed: {e}") - -def main(): - """Main cleanup process""" - cleanup_invalid_checkpoints() - verify_checkpoint_loading() - test_checkpoint_system_integration() - - print("\n=== Cleanup Complete ===") - print("The checkpoint system should now work without 'file not found' errors!") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/cleanup_checkpoints.py b/cleanup_checkpoints.py deleted file mode 100644 index 5f4ab67..0000000 --- a/cleanup_checkpoints.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python3 -""" -Checkpoint Cleanup and Migration Script - -This script helps clean up existing checkpoints and migrate to the new -checkpoint management system with W&B integration. -""" - -import os -import logging -import shutil -from pathlib import Path -from datetime import datetime -from typing import List, Dict, Any -import torch - -from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata - -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -class CheckpointCleanup: - def __init__(self): - self.saved_models_dir = Path("NN/models/saved") - self.checkpoint_manager = get_checkpoint_manager() - - def analyze_existing_checkpoints(self) -> Dict[str, Any]: - logger.info("Analyzing existing checkpoint files...") - - analysis = { - 'total_files': 0, - 'total_size_mb': 0.0, - 'model_types': {}, - 'file_patterns': {}, - 'potential_duplicates': [] - } - - if not self.saved_models_dir.exists(): - logger.warning(f"Saved models directory not found: {self.saved_models_dir}") - return analysis - - for pt_file in self.saved_models_dir.rglob("*.pt"): - try: - file_size_mb = pt_file.stat().st_size / (1024 * 1024) - analysis['total_files'] += 1 - analysis['total_size_mb'] += file_size_mb - - filename = pt_file.name - - if 'cnn' in filename.lower(): - model_type = 'cnn' - elif 'dqn' in filename.lower() or 'rl' in filename.lower(): - model_type = 'rl' - elif 'agent' in filename.lower(): - model_type = 'rl' - else: - model_type = 'unknown' - - if model_type not in analysis['model_types']: - analysis['model_types'][model_type] = {'count': 0, 'size_mb': 0.0} - - analysis['model_types'][model_type]['count'] += 1 - analysis['model_types'][model_type]['size_mb'] += file_size_mb - - base_name = filename.split('_')[0] if '_' in filename else filename.replace('.pt', '') - if base_name not in analysis['file_patterns']: - analysis['file_patterns'][base_name] = [] - - analysis['file_patterns'][base_name].append({ - 'path': str(pt_file), - 'size_mb': file_size_mb, - 'modified': datetime.fromtimestamp(pt_file.stat().st_mtime) - }) - - except Exception as e: - logger.error(f"Error analyzing {pt_file}: {e}") - - for base_name, files in analysis['file_patterns'].items(): - if len(files) > 5: # More than 5 files with same base name - analysis['potential_duplicates'].append({ - 'base_name': base_name, - 'count': len(files), - 'total_size_mb': sum(f['size_mb'] for f in files), - 'files': files - }) - - logger.info(f"Analysis complete:") - logger.info(f" Total files: {analysis['total_files']}") - logger.info(f" Total size: {analysis['total_size_mb']:.2f} MB") - logger.info(f" Model types: {analysis['model_types']}") - logger.info(f" Potential duplicates: {len(analysis['potential_duplicates'])}") - - return analysis - - def cleanup_duplicates(self, dry_run: bool = True) -> Dict[str, Any]: - logger.info(f"Starting duplicate cleanup (dry_run={dry_run})...") - - cleanup_results = { - 'removed': 0, - 'kept': 0, - 'space_saved_mb': 0.0, - 'details': [] - } - - analysis = self.analyze_existing_checkpoints() - - for duplicate_group in analysis['potential_duplicates']: - base_name = duplicate_group['base_name'] - files = duplicate_group['files'] - - # Sort by modification time (newest first) - files.sort(key=lambda x: x['modified'], reverse=True) - - logger.info(f"Processing {base_name}: {len(files)} files") - - # Keep only the 5 newest files - for i, file_info in enumerate(files): - if i < 5: # Keep first 5 (newest) - cleanup_results['kept'] += 1 - cleanup_results['details'].append({ - 'action': 'kept', - 'file': file_info['path'] - }) - else: # Remove the rest - if not dry_run: - try: - Path(file_info['path']).unlink() - logger.info(f"Removed: {file_info['path']}") - except Exception as e: - logger.error(f"Error removing {file_info['path']}: {e}") - continue - - cleanup_results['removed'] += 1 - cleanup_results['space_saved_mb'] += file_info['size_mb'] - cleanup_results['details'].append({ - 'action': 'removed', - 'file': file_info['path'], - 'size_mb': file_info['size_mb'] - }) - - logger.info(f"Cleanup {'simulation' if dry_run else 'complete'}:") - logger.info(f" Kept: {cleanup_results['kept']}") - logger.info(f" Removed: {cleanup_results['removed']}") - logger.info(f" Space saved: {cleanup_results['space_saved_mb']:.2f} MB") - - return cleanup_results - -def main(): - logger.info("=== Checkpoint Cleanup Tool ===") - - cleanup = CheckpointCleanup() - - # Analyze existing checkpoints - logger.info("\\n1. Analyzing existing checkpoints...") - analysis = cleanup.analyze_existing_checkpoints() - - if analysis['total_files'] == 0: - logger.info("No checkpoint files found.") - return - - # Show potential space savings - total_duplicates = sum(len(group['files']) - 5 for group in analysis['potential_duplicates'] if len(group['files']) > 5) - if total_duplicates > 0: - logger.info(f"\\nFound {total_duplicates} files that could be cleaned up") - - # Dry run first - logger.info("\\n2. Simulating cleanup...") - dry_run_results = cleanup.cleanup_duplicates(dry_run=True) - - if dry_run_results['removed'] > 0: - proceed = input(f"\\nProceed with cleanup? Will remove {dry_run_results['removed']} files " - f"and save {dry_run_results['space_saved_mb']:.2f} MB. (y/n): ").lower().strip() == 'y' - - if proceed: - logger.info("\\n3. Performing actual cleanup...") - cleanup_results = cleanup.cleanup_duplicates(dry_run=False) - logger.info("\\n=== Cleanup Complete ===") - else: - logger.info("Cleanup cancelled.") - else: - logger.info("No files to remove.") - else: - logger.info("No duplicate files found that need cleanup.") - -if __name__ == "__main__": - main() diff --git a/core/__init__.py b/core/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/core/api_rate_limiter.py b/core/api_rate_limiter.py deleted file mode 100644 index 528c345..0000000 --- a/core/api_rate_limiter.py +++ /dev/null @@ -1,402 +0,0 @@ -""" -API Rate Limiter and Error Handler - -This module provides robust rate limiting and error handling for API requests, -specifically designed to handle Binance's aggressive rate limiting (HTTP 418 errors) -and other exchange API limitations. - -Features: -- Exponential backoff for rate limiting -- IP rotation and proxy support -- Request queuing and throttling -- Error recovery strategies -- Thread-safe operations -""" - -import asyncio -import logging -import time -import random -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Callable, Any -from dataclasses import dataclass, field -from collections import deque -import threading -from concurrent.futures import ThreadPoolExecutor -import requests -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry - -logger = logging.getLogger(__name__) - -@dataclass -class RateLimitConfig: - """Configuration for rate limiting""" - requests_per_second: float = 0.5 # Very conservative for Binance - requests_per_minute: int = 20 - requests_per_hour: int = 1000 - - # Backoff configuration - initial_backoff: float = 1.0 - max_backoff: float = 300.0 # 5 minutes max - backoff_multiplier: float = 2.0 - - # Error handling - max_retries: int = 3 - retry_delay: float = 5.0 - - # IP blocking detection - block_detection_threshold: int = 3 # 3 consecutive 418s = blocked - block_recovery_time: int = 3600 # 1 hour recovery time - -@dataclass -class APIEndpoint: - """API endpoint configuration""" - name: str - base_url: str - rate_limit: RateLimitConfig - last_request_time: float = 0.0 - request_count_minute: int = 0 - request_count_hour: int = 0 - consecutive_errors: int = 0 - blocked_until: Optional[datetime] = None - - # Request history for rate limiting - request_history: deque = field(default_factory=lambda: deque(maxlen=3600)) # 1 hour history - -class APIRateLimiter: - """Thread-safe API rate limiter with error handling""" - - def __init__(self, config: RateLimitConfig = None): - self.config = config or RateLimitConfig() - - # Thread safety - self.lock = threading.RLock() - - # Endpoint tracking - self.endpoints: Dict[str, APIEndpoint] = {} - - # Global rate limiting - self.global_request_history = deque(maxlen=3600) - self.global_blocked_until: Optional[datetime] = None - - # Request session with retry strategy - self.session = self._create_session() - - # Background cleanup thread - self.cleanup_thread = None - self.is_running = False - - logger.info("API Rate Limiter initialized") - logger.info(f"Rate limits: {self.config.requests_per_second}/s, {self.config.requests_per_minute}/m") - - def _create_session(self) -> requests.Session: - """Create requests session with retry strategy""" - session = requests.Session() - - # Retry strategy - retry_strategy = Retry( - total=self.config.max_retries, - backoff_factor=1, - status_forcelist=[429, 500, 502, 503, 504], - allowed_methods=["HEAD", "GET", "OPTIONS"] - ) - - adapter = HTTPAdapter(max_retries=retry_strategy) - session.mount("http://", adapter) - session.mount("https://", adapter) - - # Headers to appear more legitimate - session.headers.update({ - 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', - 'Accept': 'application/json', - 'Accept-Language': 'en-US,en;q=0.9', - 'Accept-Encoding': 'gzip, deflate, br', - 'Connection': 'keep-alive', - 'Upgrade-Insecure-Requests': '1', - }) - - return session - - def register_endpoint(self, name: str, base_url: str, rate_limit: RateLimitConfig = None): - """Register an API endpoint for rate limiting""" - with self.lock: - self.endpoints[name] = APIEndpoint( - name=name, - base_url=base_url, - rate_limit=rate_limit or self.config - ) - logger.info(f"Registered endpoint: {name} -> {base_url}") - - def start_background_cleanup(self): - """Start background cleanup thread""" - if self.is_running: - return - - self.is_running = True - self.cleanup_thread = threading.Thread(target=self._cleanup_worker, daemon=True) - self.cleanup_thread.start() - logger.info("Started background cleanup thread") - - def stop_background_cleanup(self): - """Stop background cleanup thread""" - self.is_running = False - if self.cleanup_thread: - self.cleanup_thread.join(timeout=5) - logger.info("Stopped background cleanup thread") - - def _cleanup_worker(self): - """Background worker to clean up old request history""" - while self.is_running: - try: - current_time = time.time() - cutoff_time = current_time - 3600 # 1 hour ago - - with self.lock: - # Clean global history - while (self.global_request_history and - self.global_request_history[0] < cutoff_time): - self.global_request_history.popleft() - - # Clean endpoint histories - for endpoint in self.endpoints.values(): - while (endpoint.request_history and - endpoint.request_history[0] < cutoff_time): - endpoint.request_history.popleft() - - # Reset counters - endpoint.request_count_minute = len([ - t for t in endpoint.request_history - if t > current_time - 60 - ]) - endpoint.request_count_hour = len(endpoint.request_history) - - time.sleep(60) # Clean every minute - - except Exception as e: - logger.error(f"Error in cleanup worker: {e}") - time.sleep(30) - - def can_make_request(self, endpoint_name: str) -> tuple[bool, float]: - """ - Check if we can make a request to the endpoint - - Returns: - (can_make_request, wait_time_seconds) - """ - with self.lock: - current_time = time.time() - - # Check global blocking - if self.global_blocked_until and datetime.now() < self.global_blocked_until: - wait_time = (self.global_blocked_until - datetime.now()).total_seconds() - return False, wait_time - - # Get endpoint - endpoint = self.endpoints.get(endpoint_name) - if not endpoint: - logger.warning(f"Unknown endpoint: {endpoint_name}") - return False, 60.0 - - # Check endpoint blocking - if endpoint.blocked_until and datetime.now() < endpoint.blocked_until: - wait_time = (endpoint.blocked_until - datetime.now()).total_seconds() - return False, wait_time - - # Check rate limits - config = endpoint.rate_limit - - # Per-second rate limit - time_since_last = current_time - endpoint.last_request_time - if time_since_last < (1.0 / config.requests_per_second): - wait_time = (1.0 / config.requests_per_second) - time_since_last - return False, wait_time - - # Per-minute rate limit - minute_requests = len([ - t for t in endpoint.request_history - if t > current_time - 60 - ]) - if minute_requests >= config.requests_per_minute: - return False, 60.0 - - # Per-hour rate limit - if len(endpoint.request_history) >= config.requests_per_hour: - return False, 3600.0 - - return True, 0.0 - - def make_request(self, endpoint_name: str, url: str, method: str = 'GET', - **kwargs) -> Optional[requests.Response]: - """ - Make a rate-limited request with error handling - - Args: - endpoint_name: Name of the registered endpoint - url: Full URL to request - method: HTTP method - **kwargs: Additional arguments for requests - - Returns: - Response object or None if failed - """ - with self.lock: - endpoint = self.endpoints.get(endpoint_name) - if not endpoint: - logger.error(f"Unknown endpoint: {endpoint_name}") - return None - - # Check if we can make the request - can_request, wait_time = self.can_make_request(endpoint_name) - if not can_request: - logger.debug(f"Rate limited for {endpoint_name}, waiting {wait_time:.2f}s") - time.sleep(min(wait_time, 30)) # Cap wait time - return None - - # Record request attempt - current_time = time.time() - endpoint.last_request_time = current_time - endpoint.request_history.append(current_time) - self.global_request_history.append(current_time) - - # Add jitter to avoid thundering herd - jitter = random.uniform(0.1, 0.5) - time.sleep(jitter) - - # Make the request (outside of lock to avoid blocking other threads) - try: - # Set timeout - kwargs.setdefault('timeout', 10) - - # Make request - response = self.session.request(method, url, **kwargs) - - # Handle response - with self.lock: - if response.status_code == 200: - # Success - reset error counter - endpoint.consecutive_errors = 0 - return response - - elif response.status_code == 418: - # Binance "I'm a teapot" - rate limited/blocked - endpoint.consecutive_errors += 1 - logger.warning(f"HTTP 418 (rate limited) for {endpoint_name}, consecutive errors: {endpoint.consecutive_errors}") - - if endpoint.consecutive_errors >= endpoint.rate_limit.block_detection_threshold: - # We're likely IP blocked - block_time = datetime.now() + timedelta(seconds=endpoint.rate_limit.block_recovery_time) - endpoint.blocked_until = block_time - logger.error(f"Endpoint {endpoint_name} blocked until {block_time}") - - return None - - elif response.status_code == 429: - # Too many requests - endpoint.consecutive_errors += 1 - logger.warning(f"HTTP 429 (too many requests) for {endpoint_name}") - - # Implement exponential backoff - backoff_time = min( - endpoint.rate_limit.initial_backoff * (endpoint.rate_limit.backoff_multiplier ** endpoint.consecutive_errors), - endpoint.rate_limit.max_backoff - ) - - block_time = datetime.now() + timedelta(seconds=backoff_time) - endpoint.blocked_until = block_time - logger.warning(f"Backing off {endpoint_name} for {backoff_time:.2f}s") - - return None - - else: - # Other error - endpoint.consecutive_errors += 1 - logger.warning(f"HTTP {response.status_code} for {endpoint_name}: {response.text[:200]}") - return None - - except requests.exceptions.RequestException as e: - with self.lock: - endpoint.consecutive_errors += 1 - logger.error(f"Request exception for {endpoint_name}: {e}") - return None - - except Exception as e: - with self.lock: - endpoint.consecutive_errors += 1 - logger.error(f"Unexpected error for {endpoint_name}: {e}") - return None - - def get_endpoint_status(self, endpoint_name: str) -> Dict[str, Any]: - """Get status information for an endpoint""" - with self.lock: - endpoint = self.endpoints.get(endpoint_name) - if not endpoint: - return {'error': 'Unknown endpoint'} - - current_time = time.time() - - return { - 'name': endpoint.name, - 'base_url': endpoint.base_url, - 'consecutive_errors': endpoint.consecutive_errors, - 'blocked_until': endpoint.blocked_until.isoformat() if endpoint.blocked_until else None, - 'requests_last_minute': len([t for t in endpoint.request_history if t > current_time - 60]), - 'requests_last_hour': len(endpoint.request_history), - 'last_request_time': endpoint.last_request_time, - 'can_make_request': self.can_make_request(endpoint_name)[0] - } - - def get_all_endpoint_status(self) -> Dict[str, Dict[str, Any]]: - """Get status for all endpoints""" - return {name: self.get_endpoint_status(name) for name in self.endpoints.keys()} - - def reset_endpoint(self, endpoint_name: str): - """Reset an endpoint's error state""" - with self.lock: - endpoint = self.endpoints.get(endpoint_name) - if endpoint: - endpoint.consecutive_errors = 0 - endpoint.blocked_until = None - logger.info(f"Reset endpoint: {endpoint_name}") - - def reset_all_endpoints(self): - """Reset all endpoints' error states""" - with self.lock: - for endpoint in self.endpoints.values(): - endpoint.consecutive_errors = 0 - endpoint.blocked_until = None - self.global_blocked_until = None - logger.info("Reset all endpoints") - -# Global rate limiter instance -_global_rate_limiter = None - -def get_rate_limiter() -> APIRateLimiter: - """Get global rate limiter instance""" - global _global_rate_limiter - if _global_rate_limiter is None: - _global_rate_limiter = APIRateLimiter() - _global_rate_limiter.start_background_cleanup() - - # Register common endpoints - _global_rate_limiter.register_endpoint( - 'binance_api', - 'https://api.binance.com', - RateLimitConfig( - requests_per_second=0.2, # Very conservative - requests_per_minute=10, - requests_per_hour=500 - ) - ) - - _global_rate_limiter.register_endpoint( - 'mexc_api', - 'https://api.mexc.com', - RateLimitConfig( - requests_per_second=0.5, - requests_per_minute=20, - requests_per_hour=1000 - ) - ) - - return _global_rate_limiter \ No newline at end of file diff --git a/core/async_handler.py b/core/async_handler.py deleted file mode 100644 index 10a0793..0000000 --- a/core/async_handler.py +++ /dev/null @@ -1,442 +0,0 @@ -""" -Async Handler for UI Stability Fix - -Properly handles all async operations in the dashboard with single event loop management, -proper exception handling, and timeout support to prevent async/await errors. -""" - -import asyncio -import logging -import threading -import time -from typing import Any, Callable, Coroutine, Dict, Optional, Union -from concurrent.futures import ThreadPoolExecutor -import functools -import weakref - -logger = logging.getLogger(__name__) - - -class AsyncOperationError(Exception): - """Exception raised for async operation errors""" - pass - - -class AsyncHandler: - """ - Centralized async operation handler with single event loop management - and proper exception handling for async operations. - """ - - def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None): - """ - Initialize the async handler - - Args: - loop: Optional event loop to use. If None, creates a new one. - """ - self._loop = loop - self._thread = None - self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="AsyncHandler") - self._running = False - self._callbacks = weakref.WeakSet() - self._timeout_default = 30.0 # Default timeout for operations - - # Start the event loop in a separate thread if not provided - if self._loop is None: - self._start_event_loop_thread() - - logger.info("AsyncHandler initialized with event loop management") - - def _start_event_loop_thread(self): - """Start the event loop in a separate thread""" - def run_event_loop(): - """Run the event loop in a separate thread""" - try: - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - self._running = True - logger.debug("Event loop started in separate thread") - self._loop.run_forever() - except Exception as e: - logger.error(f"Error in event loop thread: {e}") - finally: - self._running = False - logger.debug("Event loop thread stopped") - - self._thread = threading.Thread(target=run_event_loop, daemon=True, name="AsyncHandler-EventLoop") - self._thread.start() - - # Wait for the loop to be ready - timeout = 5.0 - start_time = time.time() - while not self._running and (time.time() - start_time) < timeout: - time.sleep(0.1) - - if not self._running: - raise AsyncOperationError("Failed to start event loop within timeout") - - def is_running(self) -> bool: - """Check if the async handler is running""" - return self._running and self._loop is not None and not self._loop.is_closed() - - def run_async_safely(self, coro: Coroutine, timeout: Optional[float] = None) -> Any: - """ - Run an async coroutine safely with proper error handling and timeout - - Args: - coro: The coroutine to run - timeout: Timeout in seconds (uses default if None) - - Returns: - The result of the coroutine - - Raises: - AsyncOperationError: If the operation fails or times out - """ - if not self.is_running(): - raise AsyncOperationError("AsyncHandler is not running") - - timeout = timeout or self._timeout_default - - try: - # Schedule the coroutine on the event loop - future = asyncio.run_coroutine_threadsafe( - asyncio.wait_for(coro, timeout=timeout), - self._loop - ) - - # Wait for the result with timeout - result = future.result(timeout=timeout + 1.0) # Add buffer to future timeout - logger.debug("Async operation completed successfully") - return result - - except asyncio.TimeoutError: - logger.error(f"Async operation timed out after {timeout} seconds") - raise AsyncOperationError(f"Operation timed out after {timeout} seconds") - except Exception as e: - logger.error(f"Async operation failed: {e}") - raise AsyncOperationError(f"Async operation failed: {e}") - - def schedule_coroutine(self, coro: Coroutine, callback: Optional[Callable] = None) -> None: - """ - Schedule a coroutine to run asynchronously without waiting for result - - Args: - coro: The coroutine to schedule - callback: Optional callback to call with the result - """ - if not self.is_running(): - logger.warning("Cannot schedule coroutine: AsyncHandler is not running") - return - - async def wrapped_coro(): - """Wrapper to handle exceptions and callbacks""" - try: - result = await coro - if callback: - try: - callback(result) - except Exception as e: - logger.error(f"Error in coroutine callback: {e}") - return result - except Exception as e: - logger.error(f"Error in scheduled coroutine: {e}") - if callback: - try: - callback(None) # Call callback with None on error - except Exception as cb_e: - logger.error(f"Error in error callback: {cb_e}") - - try: - asyncio.run_coroutine_threadsafe(wrapped_coro(), self._loop) - logger.debug("Coroutine scheduled successfully") - except Exception as e: - logger.error(f"Failed to schedule coroutine: {e}") - - def create_task_safely(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]: - """ - Create an asyncio task safely with proper error handling - - Args: - coro: The coroutine to create a task for - name: Optional name for the task - - Returns: - The created task or None if failed - """ - if not self.is_running(): - logger.warning("Cannot create task: AsyncHandler is not running") - return None - - async def create_task(): - """Create the task in the event loop""" - try: - task = asyncio.create_task(coro, name=name) - logger.debug(f"Task created: {name or 'unnamed'}") - return task - except Exception as e: - logger.error(f"Failed to create task {name}: {e}") - return None - - try: - future = asyncio.run_coroutine_threadsafe(create_task(), self._loop) - return future.result(timeout=5.0) - except Exception as e: - logger.error(f"Failed to create task {name}: {e}") - return None - - async def handle_orchestrator_connection(self, orchestrator) -> bool: - """ - Handle orchestrator connection with proper async patterns - - Args: - orchestrator: The orchestrator instance to connect to - - Returns: - True if connection successful, False otherwise - """ - try: - logger.info("Connecting to orchestrator...") - - # Add decision callback if orchestrator supports it - if hasattr(orchestrator, 'add_decision_callback'): - await orchestrator.add_decision_callback(self._handle_trading_decision) - logger.info("Decision callback added to orchestrator") - - # Start COB integration if available - if hasattr(orchestrator, 'start_cob_integration'): - await orchestrator.start_cob_integration() - logger.info("COB integration started") - - # Start continuous trading if available - if hasattr(orchestrator, 'start_continuous_trading'): - await orchestrator.start_continuous_trading() - logger.info("Continuous trading started") - - logger.info("Successfully connected to orchestrator") - return True - - except Exception as e: - logger.error(f"Failed to connect to orchestrator: {e}") - return False - - async def handle_cob_integration(self, cob_integration) -> bool: - """ - Handle COB integration startup with proper async patterns - - Args: - cob_integration: The COB integration instance - - Returns: - True if startup successful, False otherwise - """ - try: - logger.info("Starting COB integration...") - - if hasattr(cob_integration, 'start'): - await cob_integration.start() - logger.info("COB integration started successfully") - return True - else: - logger.warning("COB integration does not have start method") - return False - - except Exception as e: - logger.error(f"Failed to start COB integration: {e}") - return False - - async def _handle_trading_decision(self, decision: Dict[str, Any]) -> None: - """ - Handle trading decision with proper async patterns - - Args: - decision: The trading decision dictionary - """ - try: - logger.debug(f"Handling trading decision: {decision.get('action', 'UNKNOWN')}") - - # Process the decision (this would be customized based on needs) - # For now, just log it - symbol = decision.get('symbol', 'UNKNOWN') - action = decision.get('action', 'HOLD') - confidence = decision.get('confidence', 0.0) - - logger.info(f"Trading decision processed: {action} {symbol} (confidence: {confidence:.2f})") - - except Exception as e: - logger.error(f"Error handling trading decision: {e}") - - def run_in_executor(self, func: Callable, *args, **kwargs) -> Any: - """ - Run a blocking function in the thread pool executor - - Args: - func: The function to run - *args: Positional arguments for the function - **kwargs: Keyword arguments for the function - - Returns: - The result of the function - """ - if not self.is_running(): - raise AsyncOperationError("AsyncHandler is not running") - - try: - # Create a partial function with the arguments - partial_func = functools.partial(func, *args, **kwargs) - - # Create a coroutine that runs the function in executor - async def run_in_executor_coro(): - return await self._loop.run_in_executor(self._executor, partial_func) - - # Run the coroutine - future = asyncio.run_coroutine_threadsafe(run_in_executor_coro(), self._loop) - - result = future.result(timeout=self._timeout_default) - logger.debug("Executor function completed successfully") - return result - - except Exception as e: - logger.error(f"Error running function in executor: {e}") - raise AsyncOperationError(f"Executor function failed: {e}") - - def add_periodic_task(self, coro_func: Callable[[], Coroutine], interval: float, name: Optional[str] = None) -> Optional[asyncio.Task]: - """ - Add a periodic task that runs at specified intervals - - Args: - coro_func: Function that returns a coroutine to run periodically - interval: Interval in seconds between runs - name: Optional name for the task - - Returns: - The created task or None if failed - """ - async def periodic_runner(): - """Run the coroutine periodically""" - task_name = name or "periodic_task" - logger.info(f"Starting periodic task: {task_name} (interval: {interval}s)") - - try: - while True: - try: - coro = coro_func() - await coro - logger.debug(f"Periodic task {task_name} completed") - except Exception as e: - logger.error(f"Error in periodic task {task_name}: {e}") - - await asyncio.sleep(interval) - - except asyncio.CancelledError: - logger.info(f"Periodic task {task_name} cancelled") - raise - except Exception as e: - logger.error(f"Fatal error in periodic task {task_name}: {e}") - - return self.create_task_safely(periodic_runner(), name=f"periodic_{name}") - - def stop(self) -> None: - """Stop the async handler and clean up resources""" - try: - logger.info("Stopping AsyncHandler...") - - if self._loop and not self._loop.is_closed(): - # Cancel all tasks - if self._loop.is_running(): - asyncio.run_coroutine_threadsafe(self._cancel_all_tasks(), self._loop) - - # Stop the event loop - self._loop.call_soon_threadsafe(self._loop.stop) - - # Shutdown executor - if self._executor: - self._executor.shutdown(wait=True) - - # Wait for thread to finish - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=5.0) - - self._running = False - logger.info("AsyncHandler stopped successfully") - - except Exception as e: - logger.error(f"Error stopping AsyncHandler: {e}") - - async def _cancel_all_tasks(self) -> None: - """Cancel all running tasks""" - try: - tasks = [task for task in asyncio.all_tasks(self._loop) if not task.done()] - if tasks: - logger.info(f"Cancelling {len(tasks)} running tasks") - for task in tasks: - task.cancel() - - # Wait for tasks to be cancelled - await asyncio.gather(*tasks, return_exceptions=True) - logger.debug("All tasks cancelled") - except Exception as e: - logger.error(f"Error cancelling tasks: {e}") - - def __enter__(self): - """Context manager entry""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Context manager exit""" - self.stop() - - -class AsyncContextManager: - """ - Context manager for async operations that ensures proper cleanup - """ - - def __init__(self, async_handler: AsyncHandler): - self.async_handler = async_handler - self.active_tasks = [] - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - # Cancel any active tasks - for task in self.active_tasks: - if not task.done(): - task.cancel() - - def create_task(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]: - """Create a task and track it for cleanup""" - task = self.async_handler.create_task_safely(coro, name) - if task: - self.active_tasks.append(task) - return task - - -def create_async_handler(loop: Optional[asyncio.AbstractEventLoop] = None) -> AsyncHandler: - """ - Factory function to create an AsyncHandler instance - - Args: - loop: Optional event loop to use - - Returns: - AsyncHandler instance - """ - return AsyncHandler(loop=loop) - - -def run_async_safely(coro: Coroutine, timeout: Optional[float] = None) -> Any: - """ - Convenience function to run a coroutine safely with a temporary AsyncHandler - - Args: - coro: The coroutine to run - timeout: Timeout in seconds - - Returns: - The result of the coroutine - """ - with AsyncHandler() as handler: - return handler.run_async_safely(coro, timeout=timeout) \ No newline at end of file diff --git a/core/bookmap_data_provider.py b/core/bookmap_data_provider.py deleted file mode 100644 index c25bc47..0000000 --- a/core/bookmap_data_provider.py +++ /dev/null @@ -1,952 +0,0 @@ -""" -Bookmap Order Book Data Provider - -This module integrates with Bookmap to gather: -- Current Order Book (COB) data -- Session Volume Profile (SVP) data -- Order book sweeps and momentum trades detection -- Real-time order size heatmap matrix (last 10 minutes) -- Level 2 market depth analysis - -The data is processed and fed to CNN and DQN networks for enhanced trading decisions. -""" - -import asyncio -import json -import logging -import time -import websockets -import numpy as np -import pandas as pd -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Any, Callable -from collections import deque, defaultdict -from dataclasses import dataclass -from threading import Thread, Lock -import requests - -logger = logging.getLogger(__name__) - -@dataclass -class OrderBookLevel: - """Represents a single order book level""" - price: float - size: float - orders: int - side: str # 'bid' or 'ask' - timestamp: datetime - -@dataclass -class OrderBookSnapshot: - """Complete order book snapshot""" - symbol: str - timestamp: datetime - bids: List[OrderBookLevel] - asks: List[OrderBookLevel] - spread: float - mid_price: float - -@dataclass -class VolumeProfileLevel: - """Volume profile level data""" - price: float - volume: float - buy_volume: float - sell_volume: float - trades_count: int - vwap: float - -@dataclass -class OrderFlowSignal: - """Order flow signal detection""" - timestamp: datetime - signal_type: str # 'sweep', 'absorption', 'iceberg', 'momentum' - price: float - volume: float - confidence: float - description: str - -class BookmapDataProvider: - """ - Real-time order book data provider using Bookmap-style analysis - - Features: - - Level 2 order book monitoring - - Order flow detection (sweeps, absorptions) - - Volume profile analysis - - Order size heatmap generation - - Market microstructure analysis - """ - - def __init__(self, symbols: List[str] = None, depth_levels: int = 20): - """ - Initialize Bookmap data provider - - Args: - symbols: List of symbols to monitor - depth_levels: Number of order book levels to track - """ - self.symbols = symbols or ['ETHUSDT', 'BTCUSDT'] - self.depth_levels = depth_levels - self.is_streaming = False - - # Order book data storage - self.order_books: Dict[str, OrderBookSnapshot] = {} - self.order_book_history: Dict[str, deque] = {} - self.volume_profiles: Dict[str, List[VolumeProfileLevel]] = {} - - # Heatmap data (10-minute rolling window) - self.heatmap_window = timedelta(minutes=10) - self.order_heatmaps: Dict[str, deque] = {} - self.price_levels: Dict[str, List[float]] = {} - - # Order flow detection - self.flow_signals: Dict[str, deque] = {} - self.sweep_threshold = 0.8 # Minimum confidence for sweep detection - self.absorption_threshold = 0.7 # Minimum confidence for absorption - - # Market microstructure metrics - self.bid_ask_spreads: Dict[str, deque] = {} - self.order_book_imbalances: Dict[str, deque] = {} - self.liquidity_metrics: Dict[str, Dict] = {} - - # WebSocket connections - self.websocket_tasks: Dict[str, asyncio.Task] = {} - self.data_lock = Lock() - - # Callbacks for CNN/DQN integration - self.cnn_callbacks: List[Callable] = [] - self.dqn_callbacks: List[Callable] = [] - - # Performance tracking - self.update_counts = defaultdict(int) - self.last_update_times = {} - - # Initialize data structures - for symbol in self.symbols: - self.order_book_history[symbol] = deque(maxlen=1000) - self.order_heatmaps[symbol] = deque(maxlen=600) # 10 min at 1s intervals - self.flow_signals[symbol] = deque(maxlen=500) - self.bid_ask_spreads[symbol] = deque(maxlen=1000) - self.order_book_imbalances[symbol] = deque(maxlen=1000) - self.liquidity_metrics[symbol] = { - 'total_bid_size': 0.0, - 'total_ask_size': 0.0, - 'weighted_mid': 0.0, - 'liquidity_ratio': 1.0 - } - - logger.info(f"BookmapDataProvider initialized for {len(self.symbols)} symbols") - logger.info(f"Tracking {depth_levels} order book levels per side") - - def add_cnn_callback(self, callback: Callable[[str, Dict], None]): - """Add callback for CNN model updates""" - self.cnn_callbacks.append(callback) - logger.info(f"Added CNN callback: {len(self.cnn_callbacks)} total") - - def add_dqn_callback(self, callback: Callable[[str, Dict], None]): - """Add callback for DQN model updates""" - self.dqn_callbacks.append(callback) - logger.info(f"Added DQN callback: {len(self.dqn_callbacks)} total") - - async def start_streaming(self): - """Start real-time order book streaming""" - if self.is_streaming: - logger.warning("Bookmap streaming already active") - return - - self.is_streaming = True - logger.info("Starting Bookmap order book streaming") - - # Start order book streams for each symbol - for symbol in self.symbols: - # Order book depth stream - depth_task = asyncio.create_task(self._stream_order_book_depth(symbol)) - self.websocket_tasks[f"{symbol}_depth"] = depth_task - - # Trade stream for order flow analysis - trade_task = asyncio.create_task(self._stream_trades(symbol)) - self.websocket_tasks[f"{symbol}_trades"] = trade_task - - # Start analysis threads - analysis_task = asyncio.create_task(self._continuous_analysis()) - self.websocket_tasks["analysis"] = analysis_task - - logger.info(f"Started streaming for {len(self.symbols)} symbols") - - async def stop_streaming(self): - """Stop order book streaming""" - if not self.is_streaming: - return - - logger.info("Stopping Bookmap streaming") - self.is_streaming = False - - # Cancel all tasks - for name, task in self.websocket_tasks.items(): - if not task.done(): - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - self.websocket_tasks.clear() - logger.info("Bookmap streaming stopped") - - async def _stream_order_book_depth(self, symbol: str): - """Stream order book depth data""" - binance_symbol = symbol.lower() - url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@depth20@100ms" - - while self.is_streaming: - try: - async with websockets.connect(url) as websocket: - logger.info(f"Order book depth WebSocket connected for {symbol}") - - async for message in websocket: - if not self.is_streaming: - break - - try: - data = json.loads(message) - await self._process_depth_update(symbol, data) - except Exception as e: - logger.warning(f"Error processing depth for {symbol}: {e}") - - except Exception as e: - logger.error(f"Depth WebSocket error for {symbol}: {e}") - if self.is_streaming: - await asyncio.sleep(2) - - async def _stream_trades(self, symbol: str): - """Stream trade data for order flow analysis""" - binance_symbol = symbol.lower() - url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade" - - while self.is_streaming: - try: - async with websockets.connect(url) as websocket: - logger.info(f"Trade WebSocket connected for {symbol}") - - async for message in websocket: - if not self.is_streaming: - break - - try: - data = json.loads(message) - await self._process_trade_update(symbol, data) - except Exception as e: - logger.warning(f"Error processing trade for {symbol}: {e}") - - except Exception as e: - logger.error(f"Trade WebSocket error for {symbol}: {e}") - if self.is_streaming: - await asyncio.sleep(2) - - async def _process_depth_update(self, symbol: str, data: Dict): - """Process order book depth update""" - try: - timestamp = datetime.now() - - # Parse bids and asks - bids = [] - asks = [] - - for bid_data in data.get('bids', []): - price = float(bid_data[0]) - size = float(bid_data[1]) - bids.append(OrderBookLevel( - price=price, - size=size, - orders=1, # Binance doesn't provide order count - side='bid', - timestamp=timestamp - )) - - for ask_data in data.get('asks', []): - price = float(ask_data[0]) - size = float(ask_data[1]) - asks.append(OrderBookLevel( - price=price, - size=size, - orders=1, - side='ask', - timestamp=timestamp - )) - - # Sort order book levels - bids.sort(key=lambda x: x.price, reverse=True) - asks.sort(key=lambda x: x.price) - - # Calculate spread and mid price - if bids and asks: - best_bid = bids[0].price - best_ask = asks[0].price - spread = best_ask - best_bid - mid_price = (best_bid + best_ask) / 2 - else: - spread = 0.0 - mid_price = 0.0 - - # Create order book snapshot - snapshot = OrderBookSnapshot( - symbol=symbol, - timestamp=timestamp, - bids=bids, - asks=asks, - spread=spread, - mid_price=mid_price - ) - - with self.data_lock: - self.order_books[symbol] = snapshot - self.order_book_history[symbol].append(snapshot) - - # Update liquidity metrics - self._update_liquidity_metrics(symbol, snapshot) - - # Update order book imbalance - self._calculate_order_book_imbalance(symbol, snapshot) - - # Update heatmap data - self._update_order_heatmap(symbol, snapshot) - - # Update counters - self.update_counts[f"{symbol}_depth"] += 1 - self.last_update_times[f"{symbol}_depth"] = timestamp - - except Exception as e: - logger.error(f"Error processing depth update for {symbol}: {e}") - - async def _process_trade_update(self, symbol: str, data: Dict): - """Process trade data for order flow analysis""" - try: - timestamp = datetime.fromtimestamp(int(data['T']) / 1000) - price = float(data['p']) - quantity = float(data['q']) - is_buyer_maker = data['m'] - - # Analyze for order flow signals - await self._analyze_order_flow(symbol, timestamp, price, quantity, is_buyer_maker) - - # Update volume profile - self._update_volume_profile(symbol, price, quantity, is_buyer_maker) - - self.update_counts[f"{symbol}_trades"] += 1 - - except Exception as e: - logger.error(f"Error processing trade for {symbol}: {e}") - - def _update_liquidity_metrics(self, symbol: str, snapshot: OrderBookSnapshot): - """Update liquidity metrics from order book snapshot""" - try: - total_bid_size = sum(level.size for level in snapshot.bids) - total_ask_size = sum(level.size for level in snapshot.asks) - - # Calculate weighted mid price - if snapshot.bids and snapshot.asks: - bid_weight = total_bid_size / (total_bid_size + total_ask_size) - ask_weight = total_ask_size / (total_bid_size + total_ask_size) - weighted_mid = (snapshot.bids[0].price * ask_weight + - snapshot.asks[0].price * bid_weight) - else: - weighted_mid = snapshot.mid_price - - # Liquidity ratio (bid/ask balance) - if total_ask_size > 0: - liquidity_ratio = total_bid_size / total_ask_size - else: - liquidity_ratio = 1.0 - - self.liquidity_metrics[symbol] = { - 'total_bid_size': total_bid_size, - 'total_ask_size': total_ask_size, - 'weighted_mid': weighted_mid, - 'liquidity_ratio': liquidity_ratio, - 'spread_bps': (snapshot.spread / snapshot.mid_price) * 10000 if snapshot.mid_price > 0 else 0 - } - - except Exception as e: - logger.error(f"Error updating liquidity metrics for {symbol}: {e}") - - def _calculate_order_book_imbalance(self, symbol: str, snapshot: OrderBookSnapshot): - """Calculate order book imbalance ratio""" - try: - if not snapshot.bids or not snapshot.asks: - return - - # Calculate imbalance for top N levels - n_levels = min(5, len(snapshot.bids), len(snapshot.asks)) - - total_bid_size = sum(snapshot.bids[i].size for i in range(n_levels)) - total_ask_size = sum(snapshot.asks[i].size for i in range(n_levels)) - - if total_bid_size + total_ask_size > 0: - imbalance = (total_bid_size - total_ask_size) / (total_bid_size + total_ask_size) - else: - imbalance = 0.0 - - self.order_book_imbalances[symbol].append({ - 'timestamp': snapshot.timestamp, - 'imbalance': imbalance, - 'bid_size': total_bid_size, - 'ask_size': total_ask_size - }) - - except Exception as e: - logger.error(f"Error calculating imbalance for {symbol}: {e}") - - def _update_order_heatmap(self, symbol: str, snapshot: OrderBookSnapshot): - """Update order size heatmap matrix""" - try: - # Create heatmap entry - heatmap_entry = { - 'timestamp': snapshot.timestamp, - 'mid_price': snapshot.mid_price, - 'levels': {} - } - - # Add bid levels - for level in snapshot.bids: - price_offset = level.price - snapshot.mid_price - heatmap_entry['levels'][price_offset] = { - 'side': 'bid', - 'size': level.size, - 'price': level.price - } - - # Add ask levels - for level in snapshot.asks: - price_offset = level.price - snapshot.mid_price - heatmap_entry['levels'][price_offset] = { - 'side': 'ask', - 'size': level.size, - 'price': level.price - } - - self.order_heatmaps[symbol].append(heatmap_entry) - - # Clean old entries (keep 10 minutes) - cutoff_time = snapshot.timestamp - self.heatmap_window - while (self.order_heatmaps[symbol] and - self.order_heatmaps[symbol][0]['timestamp'] < cutoff_time): - self.order_heatmaps[symbol].popleft() - - except Exception as e: - logger.error(f"Error updating heatmap for {symbol}: {e}") - - def _update_volume_profile(self, symbol: str, price: float, quantity: float, is_buyer_maker: bool): - """Update volume profile with new trade""" - try: - # Initialize if not exists - if symbol not in self.volume_profiles: - self.volume_profiles[symbol] = [] - - # Find or create price level - price_level = None - for level in self.volume_profiles[symbol]: - if abs(level.price - price) < 0.01: # Price tolerance - price_level = level - break - - if not price_level: - price_level = VolumeProfileLevel( - price=price, - volume=0.0, - buy_volume=0.0, - sell_volume=0.0, - trades_count=0, - vwap=price - ) - self.volume_profiles[symbol].append(price_level) - - # Update volume profile - volume = price * quantity - old_total = price_level.volume - - price_level.volume += volume - price_level.trades_count += 1 - - if is_buyer_maker: - price_level.sell_volume += volume - else: - price_level.buy_volume += volume - - # Update VWAP - if price_level.volume > 0: - price_level.vwap = ((price_level.vwap * old_total) + (price * volume)) / price_level.volume - - except Exception as e: - logger.error(f"Error updating volume profile for {symbol}: {e}") - - async def _analyze_order_flow(self, symbol: str, timestamp: datetime, price: float, - quantity: float, is_buyer_maker: bool): - """Analyze order flow for sweep and absorption patterns""" - try: - # Get recent order book data - if symbol not in self.order_book_history or not self.order_book_history[symbol]: - return - - recent_snapshots = list(self.order_book_history[symbol])[-10:] # Last 10 snapshots - - # Check for order book sweeps - sweep_signal = self._detect_order_sweep(symbol, recent_snapshots, price, quantity, is_buyer_maker) - if sweep_signal: - self.flow_signals[symbol].append(sweep_signal) - await self._notify_flow_signal(symbol, sweep_signal) - - # Check for absorption patterns - absorption_signal = self._detect_absorption(symbol, recent_snapshots, price, quantity) - if absorption_signal: - self.flow_signals[symbol].append(absorption_signal) - await self._notify_flow_signal(symbol, absorption_signal) - - # Check for momentum trades - momentum_signal = self._detect_momentum_trade(symbol, price, quantity, is_buyer_maker) - if momentum_signal: - self.flow_signals[symbol].append(momentum_signal) - await self._notify_flow_signal(symbol, momentum_signal) - - except Exception as e: - logger.error(f"Error analyzing order flow for {symbol}: {e}") - - def _detect_order_sweep(self, symbol: str, snapshots: List[OrderBookSnapshot], - price: float, quantity: float, is_buyer_maker: bool) -> Optional[OrderFlowSignal]: - """Detect order book sweep patterns""" - try: - if len(snapshots) < 2: - return None - - before_snapshot = snapshots[-2] - after_snapshot = snapshots[-1] - - # Check if multiple levels were consumed - if is_buyer_maker: # Sell order, check ask side - levels_consumed = 0 - total_consumed_size = 0 - - for level in before_snapshot.asks[:5]: # Check top 5 levels - if level.price <= price: - levels_consumed += 1 - total_consumed_size += level.size - - if levels_consumed >= 2 and total_consumed_size > quantity * 1.5: - confidence = min(0.9, levels_consumed / 5.0 + 0.3) - - return OrderFlowSignal( - timestamp=datetime.now(), - signal_type='sweep', - price=price, - volume=quantity * price, - confidence=confidence, - description=f"Sell sweep: {levels_consumed} levels, {total_consumed_size:.2f} size" - ) - else: # Buy order, check bid side - levels_consumed = 0 - total_consumed_size = 0 - - for level in before_snapshot.bids[:5]: - if level.price >= price: - levels_consumed += 1 - total_consumed_size += level.size - - if levels_consumed >= 2 and total_consumed_size > quantity * 1.5: - confidence = min(0.9, levels_consumed / 5.0 + 0.3) - - return OrderFlowSignal( - timestamp=datetime.now(), - signal_type='sweep', - price=price, - volume=quantity * price, - confidence=confidence, - description=f"Buy sweep: {levels_consumed} levels, {total_consumed_size:.2f} size" - ) - - return None - - except Exception as e: - logger.error(f"Error detecting sweep for {symbol}: {e}") - return None - - def _detect_absorption(self, symbol: str, snapshots: List[OrderBookSnapshot], - price: float, quantity: float) -> Optional[OrderFlowSignal]: - """Detect absorption patterns where large orders are absorbed without price movement""" - try: - if len(snapshots) < 3: - return None - - # Check if large order was absorbed with minimal price impact - volume_threshold = 10000 # $10K minimum for absorption - price_impact_threshold = 0.001 # 0.1% max price impact - - trade_value = price * quantity - if trade_value < volume_threshold: - return None - - # Calculate price impact - price_before = snapshots[-3].mid_price - price_after = snapshots[-1].mid_price - price_impact = abs(price_after - price_before) / price_before - - if price_impact < price_impact_threshold: - confidence = min(0.8, (trade_value / 50000) * 0.5 + 0.3) # Scale with size - - return OrderFlowSignal( - timestamp=datetime.now(), - signal_type='absorption', - price=price, - volume=trade_value, - confidence=confidence, - description=f"Absorption: ${trade_value:.0f} with {price_impact*100:.3f}% impact" - ) - - return None - - except Exception as e: - logger.error(f"Error detecting absorption for {symbol}: {e}") - return None - - def _detect_momentum_trade(self, symbol: str, price: float, quantity: float, - is_buyer_maker: bool) -> Optional[OrderFlowSignal]: - """Detect momentum trades based on size and direction""" - try: - trade_value = price * quantity - momentum_threshold = 25000 # $25K minimum for momentum classification - - if trade_value < momentum_threshold: - return None - - # Calculate confidence based on trade size - confidence = min(0.9, trade_value / 100000 * 0.6 + 0.3) - - direction = "sell" if is_buyer_maker else "buy" - - return OrderFlowSignal( - timestamp=datetime.now(), - signal_type='momentum', - price=price, - volume=trade_value, - confidence=confidence, - description=f"Large {direction}: ${trade_value:.0f}" - ) - - except Exception as e: - logger.error(f"Error detecting momentum for {symbol}: {e}") - return None - - async def _notify_flow_signal(self, symbol: str, signal: OrderFlowSignal): - """Notify CNN and DQN models of order flow signals""" - try: - signal_data = { - 'signal_type': signal.signal_type, - 'price': signal.price, - 'volume': signal.volume, - 'confidence': signal.confidence, - 'timestamp': signal.timestamp, - 'description': signal.description - } - - # Notify CNN callbacks - for callback in self.cnn_callbacks: - try: - callback(symbol, signal_data) - except Exception as e: - logger.warning(f"Error in CNN callback: {e}") - - # Notify DQN callbacks - for callback in self.dqn_callbacks: - try: - callback(symbol, signal_data) - except Exception as e: - logger.warning(f"Error in DQN callback: {e}") - - except Exception as e: - logger.error(f"Error notifying flow signal: {e}") - - async def _continuous_analysis(self): - """Continuous analysis of market microstructure""" - while self.is_streaming: - try: - await asyncio.sleep(1) # Analyze every second - - for symbol in self.symbols: - # Generate CNN features - cnn_features = self.get_cnn_features(symbol) - if cnn_features is not None: - for callback in self.cnn_callbacks: - try: - callback(symbol, {'features': cnn_features, 'type': 'orderbook'}) - except Exception as e: - logger.warning(f"Error in CNN feature callback: {e}") - - # Generate DQN state features - dqn_features = self.get_dqn_state_features(symbol) - if dqn_features is not None: - for callback in self.dqn_callbacks: - try: - callback(symbol, {'state': dqn_features, 'type': 'orderbook'}) - except Exception as e: - logger.warning(f"Error in DQN state callback: {e}") - - except Exception as e: - logger.error(f"Error in continuous analysis: {e}") - await asyncio.sleep(5) - - def get_cnn_features(self, symbol: str) -> Optional[np.ndarray]: - """Generate CNN input features from order book data""" - try: - if symbol not in self.order_books: - return None - - snapshot = self.order_books[symbol] - features = [] - - # Order book features (40 features: 20 levels x 2 sides) - for i in range(min(20, len(snapshot.bids))): - bid = snapshot.bids[i] - features.append(bid.size) - features.append(bid.price - snapshot.mid_price) # Price offset - - # Pad if not enough bid levels - while len(features) < 40: - features.extend([0.0, 0.0]) - - for i in range(min(20, len(snapshot.asks))): - ask = snapshot.asks[i] - features.append(ask.size) - features.append(ask.price - snapshot.mid_price) # Price offset - - # Pad if not enough ask levels - while len(features) < 80: - features.extend([0.0, 0.0]) - - # Liquidity metrics (10 features) - metrics = self.liquidity_metrics.get(symbol, {}) - features.extend([ - metrics.get('total_bid_size', 0.0), - metrics.get('total_ask_size', 0.0), - metrics.get('liquidity_ratio', 1.0), - metrics.get('spread_bps', 0.0), - snapshot.spread, - metrics.get('weighted_mid', snapshot.mid_price) - snapshot.mid_price, - len(snapshot.bids), - len(snapshot.asks), - snapshot.mid_price, - time.time() % 86400 # Time of day - ]) - - # Order book imbalance features (5 features) - if self.order_book_imbalances[symbol]: - latest_imbalance = self.order_book_imbalances[symbol][-1] - features.extend([ - latest_imbalance['imbalance'], - latest_imbalance['bid_size'], - latest_imbalance['ask_size'], - latest_imbalance['bid_size'] + latest_imbalance['ask_size'], - abs(latest_imbalance['imbalance']) - ]) - else: - features.extend([0.0, 0.0, 0.0, 0.0, 0.0]) - - # Flow signal features (5 features) - recent_signals = [s for s in self.flow_signals[symbol] - if (datetime.now() - s.timestamp).seconds < 60] - - sweep_count = sum(1 for s in recent_signals if s.signal_type == 'sweep') - absorption_count = sum(1 for s in recent_signals if s.signal_type == 'absorption') - momentum_count = sum(1 for s in recent_signals if s.signal_type == 'momentum') - - max_confidence = max([s.confidence for s in recent_signals], default=0.0) - total_flow_volume = sum(s.volume for s in recent_signals) - - features.extend([ - sweep_count, - absorption_count, - momentum_count, - max_confidence, - total_flow_volume - ]) - - return np.array(features, dtype=np.float32) - - except Exception as e: - logger.error(f"Error generating CNN features for {symbol}: {e}") - return None - - def get_dqn_state_features(self, symbol: str) -> Optional[np.ndarray]: - """Generate DQN state features from order book data""" - try: - if symbol not in self.order_books: - return None - - snapshot = self.order_books[symbol] - state_features = [] - - # Normalized order book state (20 features) - total_bid_size = sum(level.size for level in snapshot.bids[:10]) - total_ask_size = sum(level.size for level in snapshot.asks[:10]) - total_size = total_bid_size + total_ask_size - - if total_size > 0: - for i in range(min(10, len(snapshot.bids))): - state_features.append(snapshot.bids[i].size / total_size) - - # Pad bids - while len(state_features) < 10: - state_features.append(0.0) - - for i in range(min(10, len(snapshot.asks))): - state_features.append(snapshot.asks[i].size / total_size) - - # Pad asks - while len(state_features) < 20: - state_features.append(0.0) - else: - state_features.extend([0.0] * 20) - - # Market state indicators (10 features) - metrics = self.liquidity_metrics.get(symbol, {}) - - # Normalize spread as percentage - spread_pct = (snapshot.spread / snapshot.mid_price) if snapshot.mid_price > 0 else 0 - - # Liquidity imbalance - liquidity_ratio = metrics.get('liquidity_ratio', 1.0) - liquidity_imbalance = (liquidity_ratio - 1) / (liquidity_ratio + 1) - - # Recent flow signals strength - recent_signals = [s for s in self.flow_signals[symbol] - if (datetime.now() - s.timestamp).seconds < 30] - flow_strength = sum(s.confidence for s in recent_signals) / max(len(recent_signals), 1) - - # Price volatility (from recent snapshots) - if len(self.order_book_history[symbol]) >= 10: - recent_prices = [s.mid_price for s in list(self.order_book_history[symbol])[-10:]] - price_volatility = np.std(recent_prices) / np.mean(recent_prices) if recent_prices else 0 - else: - price_volatility = 0 - - state_features.extend([ - spread_pct * 10000, # Spread in basis points - liquidity_imbalance, - flow_strength, - price_volatility * 100, # Volatility as percentage - min(len(snapshot.bids), 20) / 20, # Book depth ratio - min(len(snapshot.asks), 20) / 20, - sweep_count / 10 if 'sweep_count' in locals() else 0, # From CNN features - absorption_count / 5 if 'absorption_count' in locals() else 0, - momentum_count / 5 if 'momentum_count' in locals() else 0, - (datetime.now().hour * 60 + datetime.now().minute) / 1440 # Time of day normalized - ]) - - return np.array(state_features, dtype=np.float32) - - except Exception as e: - logger.error(f"Error generating DQN features for {symbol}: {e}") - return None - - def get_order_heatmap_matrix(self, symbol: str, levels: int = 40) -> Optional[np.ndarray]: - """Generate order size heatmap matrix for dashboard visualization""" - try: - if symbol not in self.order_heatmaps or not self.order_heatmaps[symbol]: - return None - - # Create price levels around current mid price - current_snapshot = self.order_books.get(symbol) - if not current_snapshot: - return None - - mid_price = current_snapshot.mid_price - price_step = mid_price * 0.0001 # 1 basis point steps - - # Create matrix: time x price levels - time_window = min(600, len(self.order_heatmaps[symbol])) # 10 minutes max - heatmap_matrix = np.zeros((time_window, levels)) - - # Fill matrix with order sizes - for t, entry in enumerate(list(self.order_heatmaps[symbol])[-time_window:]): - for price_offset, level_data in entry['levels'].items(): - # Convert price offset to matrix index - level_idx = int((price_offset + (levels/2) * price_step) / price_step) - - if 0 <= level_idx < levels: - size_weight = 1.0 if level_data['side'] == 'bid' else -1.0 - heatmap_matrix[t, level_idx] = level_data['size'] * size_weight - - return heatmap_matrix - - except Exception as e: - logger.error(f"Error generating heatmap matrix for {symbol}: {e}") - return None - - def get_volume_profile_data(self, symbol: str) -> Optional[List[Dict]]: - """Get session volume profile data""" - try: - if symbol not in self.volume_profiles: - return None - - profile_data = [] - for level in sorted(self.volume_profiles[symbol], key=lambda x: x.price): - profile_data.append({ - 'price': level.price, - 'volume': level.volume, - 'buy_volume': level.buy_volume, - 'sell_volume': level.sell_volume, - 'trades_count': level.trades_count, - 'vwap': level.vwap, - 'net_volume': level.buy_volume - level.sell_volume - }) - - return profile_data - - except Exception as e: - logger.error(f"Error getting volume profile for {symbol}: {e}") - return None - - def get_current_order_book(self, symbol: str) -> Optional[Dict]: - """Get current order book snapshot""" - try: - if symbol not in self.order_books: - return None - - snapshot = self.order_books[symbol] - - return { - 'timestamp': snapshot.timestamp.isoformat(), - 'symbol': symbol, - 'mid_price': snapshot.mid_price, - 'spread': snapshot.spread, - 'bids': [{'price': l.price, 'size': l.size} for l in snapshot.bids[:20]], - 'asks': [{'price': l.price, 'size': l.size} for l in snapshot.asks[:20]], - 'liquidity_metrics': self.liquidity_metrics.get(symbol, {}), - 'recent_signals': [ - { - 'type': s.signal_type, - 'price': s.price, - 'volume': s.volume, - 'confidence': s.confidence, - 'timestamp': s.timestamp.isoformat() - } - for s in list(self.flow_signals[symbol])[-5:] # Last 5 signals - ] - } - - except Exception as e: - logger.error(f"Error getting order book for {symbol}: {e}") - return None - - def get_statistics(self) -> Dict[str, Any]: - """Get provider statistics""" - return { - 'symbols': self.symbols, - 'is_streaming': self.is_streaming, - 'update_counts': dict(self.update_counts), - 'last_update_times': {k: v.isoformat() if isinstance(v, datetime) else v - for k, v in self.last_update_times.items()}, - 'order_books_active': len(self.order_books), - 'flow_signals_total': sum(len(signals) for signals in self.flow_signals.values()), - 'cnn_callbacks': len(self.cnn_callbacks), - 'dqn_callbacks': len(self.dqn_callbacks), - 'websocket_tasks': len(self.websocket_tasks) - } \ No newline at end of file diff --git a/core/bookmap_integration.py b/core/bookmap_integration.py deleted file mode 100644 index 904e253..0000000 --- a/core/bookmap_integration.py +++ /dev/null @@ -1,1839 +0,0 @@ -""" -Order Book Analysis Integration (Free Data Sources) - -This module provides Bookmap-style functionality using free order book data: -- Current Order Book (COB) analysis using Binance free depth streams -- Session Volume Profile (SVP) calculated from trade and depth data -- Order flow detection (sweeps, absorptions, momentum) -- Real-time order book heatmap generation -- Level 2 market depth streaming (20 levels via Binance free API) - -Data is fed to CNN and DQN networks for enhanced trading decisions. -Uses only free data sources - no paid APIs required. -""" - -import asyncio -import json -import logging -import time -import websockets -import numpy as np -import pandas as pd -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Any, Callable -from collections import deque, defaultdict -from dataclasses import dataclass -from threading import Thread, Lock -import requests - -logger = logging.getLogger(__name__) - -@dataclass -class OrderBookLevel: - """Single order book level""" - price: float - size: float - orders: int - side: str # 'bid' or 'ask' - timestamp: datetime - -@dataclass -class OrderBookSnapshot: - """Complete order book snapshot""" - symbol: str - timestamp: datetime - bids: List[OrderBookLevel] - asks: List[OrderBookLevel] - spread: float - mid_price: float - -@dataclass -class VolumeProfileLevel: - """Volume profile level data""" - price: float - volume: float - buy_volume: float - sell_volume: float - trades_count: int - vwap: float - -@dataclass -class OrderFlowSignal: - """Order flow signal detection""" - timestamp: datetime - signal_type: str # 'sweep', 'absorption', 'iceberg', 'momentum' - price: float - volume: float - confidence: float - description: str - -class BookmapIntegration: - """ - Order book analysis using free data sources - - Features: - - Real-time order book monitoring (Binance free depth@20 levels) - - Order flow pattern detection - - Enhanced Session Volume Profile (SVP) analysis - - Market microstructure metrics - - CNN/DQN model integration - - High-frequency order book snapshots for pattern detection - """ - - def __init__(self, symbols: List[str] = None): - self.symbols = symbols or ['ETHUSDT', 'BTCUSDT'] - self.is_streaming = False - - # Data storage - self.order_books: Dict[str, OrderBookSnapshot] = {} - self.order_book_history: Dict[str, deque] = {} - self.volume_profiles: Dict[str, List[VolumeProfileLevel]] = {} - self.flow_signals: Dict[str, deque] = {} - - # Enhanced Session Volume Profile tracking - self.session_start_time = {} # Track session start for each symbol - self.session_volume_profiles: Dict[str, List[VolumeProfileLevel]] = {} - self.price_level_cache: Dict[str, Dict[float, VolumeProfileLevel]] = {} - - # Heatmap data (10-minute rolling window) - self.heatmap_window = timedelta(minutes=10) - self.order_heatmaps: Dict[str, deque] = {} - - # Market metrics - self.liquidity_metrics: Dict[str, Dict] = {} - self.order_book_imbalances: Dict[str, deque] = {} - - # Enhanced Order Flow Analysis - self.aggressive_passive_ratios: Dict[str, deque] = {} - self.trade_size_distributions: Dict[str, deque] = {} - self.market_maker_taker_flows: Dict[str, deque] = {} - self.order_flow_intensity: Dict[str, deque] = {} - self.liquidity_consumption_rates: Dict[str, deque] = {} - self.price_impact_measurements: Dict[str, deque] = {} - - # Advanced metrics for institutional vs retail detection - self.large_order_threshold = 50000 # $50K+ considered institutional - self.block_trade_threshold = 100000 # $100K+ considered block trades - self.iceberg_detection_window = 30 # seconds for iceberg detection - self.trade_clustering_window = 5 # seconds for trade clustering analysis - - # Free data source optimization - self.depth_snapshots_per_second = 10 # 100ms updates = 10 per second - self.trade_aggregation_window = 1.0 # 1 second aggregation - self.last_trade_aggregation = {} - - # WebSocket connections - self.websocket_tasks: Dict[str, asyncio.Task] = {} - self.data_lock = Lock() - - # Model callbacks - self.cnn_callbacks: List[Callable] = [] - self.dqn_callbacks: List[Callable] = [] - - # Initialize data structures - for symbol in self.symbols: - self.order_book_history[symbol] = deque(maxlen=1000) - self.order_heatmaps[symbol] = deque(maxlen=600) # 10 min at 1s intervals - self.flow_signals[symbol] = deque(maxlen=500) - self.order_book_imbalances[symbol] = deque(maxlen=1000) - self.session_volume_profiles[symbol] = [] - self.price_level_cache[symbol] = {} - self.session_start_time[symbol] = datetime.now() - self.last_trade_aggregation[symbol] = datetime.now() - - # Enhanced order flow analysis buffers - self.aggressive_passive_ratios[symbol] = deque(maxlen=300) # 5 minutes at 1s intervals - self.trade_size_distributions[symbol] = deque(maxlen=1000) - self.market_maker_taker_flows[symbol] = deque(maxlen=600) - self.order_flow_intensity[symbol] = deque(maxlen=300) - self.liquidity_consumption_rates[symbol] = deque(maxlen=300) - self.price_impact_measurements[symbol] = deque(maxlen=300) - - self.liquidity_metrics[symbol] = { - 'total_bid_size': 0.0, - 'total_ask_size': 0.0, - 'weighted_mid': 0.0, - 'liquidity_ratio': 1.0, - 'avg_spread_bps': 0.0, - 'volume_weighted_spread': 0.0 - } - - logger.info(f"Order Book Integration initialized for symbols: {self.symbols}") - logger.info("Using FREE data sources: Binance WebSocket depth@20 + trades") - - def add_cnn_callback(self, callback: Callable[[str, Dict], None]): - """Add CNN model callback""" - self.cnn_callbacks.append(callback) - logger.info(f"Added CNN callback: {len(self.cnn_callbacks)} total") - - def add_dqn_callback(self, callback: Callable[[str, Dict], None]): - """Add DQN model callback""" - self.dqn_callbacks.append(callback) - logger.info(f"Added DQN callback: {len(self.dqn_callbacks)} total") - - async def start_streaming(self): - """Start order book data streaming""" - if self.is_streaming: - logger.warning("Bookmap streaming already active") - return - - self.is_streaming = True - logger.info("Starting Bookmap order book streaming") - - # Start streams for each symbol - for symbol in self.symbols: - # Order book depth stream (20 levels, 100ms updates) - depth_task = asyncio.create_task(self._stream_order_book_depth(symbol)) - self.websocket_tasks[f"{symbol}_depth"] = depth_task - - # Trade stream for order flow analysis - trade_task = asyncio.create_task(self._stream_trades(symbol)) - self.websocket_tasks[f"{symbol}_trades"] = trade_task - - # Aggregated trade stream (for larger trades and better order flow analysis) - agg_trade_task = asyncio.create_task(self._stream_aggregate_trades(symbol)) - self.websocket_tasks[f"{symbol}_aggTrade"] = agg_trade_task - - # 24hr ticker stream (for volume and statistical analysis) - ticker_task = asyncio.create_task(self._stream_ticker(symbol)) - self.websocket_tasks[f"{symbol}_ticker"] = ticker_task - - # Start continuous analysis - analysis_task = asyncio.create_task(self._continuous_analysis()) - self.websocket_tasks["analysis"] = analysis_task - - logger.info(f"Started streaming for {len(self.symbols)} symbols") - - async def stop_streaming(self): - """Stop streaming""" - if not self.is_streaming: - return - - logger.info("Stopping Bookmap streaming") - self.is_streaming = False - - # Cancel all tasks - for name, task in self.websocket_tasks.items(): - if not task.done(): - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - self.websocket_tasks.clear() - logger.info("Bookmap streaming stopped") - - async def _stream_order_book_depth(self, symbol: str): - """Stream order book depth data""" - binance_symbol = symbol.lower() - url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@depth20@100ms" - - while self.is_streaming: - try: - async with websockets.connect(url) as websocket: - logger.info(f"Order book depth connected for {symbol}") - - async for message in websocket: - if not self.is_streaming: - break - - try: - data = json.loads(message) - await self._process_depth_update(symbol, data) - except Exception as e: - logger.warning(f"Error processing depth for {symbol}: {e}") - - except Exception as e: - logger.error(f"Depth WebSocket error for {symbol}: {e}") - if self.is_streaming: - await asyncio.sleep(2) - - async def _stream_trades(self, symbol: str): - """Stream individual trade data for order flow analysis""" - binance_symbol = symbol.lower() - url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade" - - while self.is_streaming: - try: - async with websockets.connect(url) as websocket: - logger.info(f"Trade stream connected for {symbol}") - - async for message in websocket: - if not self.is_streaming: - break - - try: - data = json.loads(message) - await self._process_trade_update(symbol, data) - except Exception as e: - logger.warning(f"Error processing trade for {symbol}: {e}") - - except Exception as e: - logger.error(f"Trade WebSocket error for {symbol}: {e}") - if self.is_streaming: - await asyncio.sleep(2) - - async def _stream_aggregate_trades(self, symbol: str): - """Stream aggregated trade data for institutional order flow detection""" - binance_symbol = symbol.lower() - url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@aggTrade" - - while self.is_streaming: - try: - async with websockets.connect(url) as websocket: - logger.info(f"Aggregate Trade stream connected for {symbol}") - - async for message in websocket: - if not self.is_streaming: - break - - try: - data = json.loads(message) - await self._process_aggregate_trade_update(symbol, data) - except Exception as e: - logger.warning(f"Error processing aggTrade for {symbol}: {e}") - - except Exception as e: - logger.error(f"Aggregate Trade WebSocket error for {symbol}: {e}") - if self.is_streaming: - await asyncio.sleep(2) - - async def _stream_ticker(self, symbol: str): - """Stream 24hr ticker data for volume analysis""" - binance_symbol = symbol.lower() - url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@ticker" - - while self.is_streaming: - try: - async with websockets.connect(url) as websocket: - logger.info(f"Ticker stream connected for {symbol}") - - async for message in websocket: - if not self.is_streaming: - break - - try: - data = json.loads(message) - await self._process_ticker_update(symbol, data) - except Exception as e: - logger.warning(f"Error processing ticker for {symbol}: {e}") - - except Exception as e: - logger.error(f"Ticker WebSocket error for {symbol}: {e}") - if self.is_streaming: - await asyncio.sleep(2) - - async def _process_depth_update(self, symbol: str, data: Dict): - """Process order book depth update""" - try: - timestamp = datetime.now() - - # Parse bids and asks - bids = [] - asks = [] - - for bid_data in data.get('bids', []): - price = float(bid_data[0]) - size = float(bid_data[1]) - bids.append(OrderBookLevel( - price=price, - size=size, - orders=1, - side='bid', - timestamp=timestamp - )) - - for ask_data in data.get('asks', []): - price = float(ask_data[0]) - size = float(ask_data[1]) - asks.append(OrderBookLevel( - price=price, - size=size, - orders=1, - side='ask', - timestamp=timestamp - )) - - # Sort levels - bids.sort(key=lambda x: x.price, reverse=True) - asks.sort(key=lambda x: x.price) - - # Calculate spread and mid price - if bids and asks: - best_bid = bids[0].price - best_ask = asks[0].price - spread = best_ask - best_bid - mid_price = (best_bid + best_ask) / 2 - else: - spread = 0.0 - mid_price = 0.0 - - # Create snapshot - snapshot = OrderBookSnapshot( - symbol=symbol, - timestamp=timestamp, - bids=bids, - asks=asks, - spread=spread, - mid_price=mid_price - ) - - with self.data_lock: - self.order_books[symbol] = snapshot - self.order_book_history[symbol].append(snapshot) - - # Update metrics - self._update_liquidity_metrics(symbol, snapshot) - self._calculate_order_book_imbalance(symbol, snapshot) - self._update_order_heatmap(symbol, snapshot) - - except Exception as e: - logger.error(f"Error processing depth update for {symbol}: {e}") - - async def _process_trade_update(self, symbol: str, data: Dict): - """Process individual trade data with enhanced order flow analysis""" - try: - timestamp = datetime.fromtimestamp(int(data['T']) / 1000) - price = float(data['p']) - quantity = float(data['q']) - is_buyer_maker = data['m'] - trade_id = data.get('t', '') - - # Calculate trade value - trade_value = price * quantity - - # Enhanced order flow analysis - await self._analyze_enhanced_order_flow(symbol, timestamp, price, quantity, trade_value, is_buyer_maker, 'individual') - - # Traditional order flow analysis - await self._analyze_order_flow(symbol, timestamp, price, quantity, is_buyer_maker) - - # Update volume profile - self._update_volume_profile(symbol, price, quantity, is_buyer_maker) - - except Exception as e: - logger.error(f"Error processing trade for {symbol}: {e}") - - async def _process_aggregate_trade_update(self, symbol: str, data: Dict): - """Process aggregated trade data for institutional flow detection""" - try: - timestamp = datetime.fromtimestamp(int(data['T']) / 1000) - price = float(data['p']) - quantity = float(data['q']) - is_buyer_maker = data['m'] - first_trade_id = data.get('f', '') - last_trade_id = data.get('l', '') - - # Calculate trade value and aggregation size - trade_value = price * quantity - trade_aggregation_size = int(last_trade_id) - int(first_trade_id) + 1 if first_trade_id and last_trade_id else 1 - - # Enhanced analysis for aggregated trades (institutional detection) - await self._analyze_enhanced_order_flow(symbol, timestamp, price, quantity, trade_value, is_buyer_maker, 'aggregated', trade_aggregation_size) - - # Detect large block trades and iceberg orders - await self._detect_institutional_activity(symbol, timestamp, price, quantity, trade_value, trade_aggregation_size, is_buyer_maker) - - except Exception as e: - logger.error(f"Error processing aggregate trade for {symbol}: {e}") - - async def _process_ticker_update(self, symbol: str, data: Dict): - """Process ticker data for volume and statistical analysis""" - try: - # Extract relevant ticker data - volume_24h = float(data.get('v', 0)) # 24hr volume - quote_volume_24h = float(data.get('q', 0)) # 24hr quote volume - price_change_24h = float(data.get('P', 0)) # 24hr price change % - high_24h = float(data.get('h', 0)) - low_24h = float(data.get('l', 0)) - weighted_avg_price = float(data.get('w', 0)) # Weighted average price - - # Update volume statistics for relative analysis - self._update_volume_statistics(symbol, volume_24h, quote_volume_24h, weighted_avg_price) - - except Exception as e: - logger.error(f"Error processing ticker for {symbol}: {e}") - - def _update_liquidity_metrics(self, symbol: str, snapshot: OrderBookSnapshot): - """Update liquidity metrics""" - try: - total_bid_size = sum(level.size for level in snapshot.bids) - total_ask_size = sum(level.size for level in snapshot.asks) - - # Weighted mid price - if snapshot.bids and snapshot.asks: - bid_weight = total_bid_size / (total_bid_size + total_ask_size) - ask_weight = total_ask_size / (total_bid_size + total_ask_size) - weighted_mid = (snapshot.bids[0].price * ask_weight + - snapshot.asks[0].price * bid_weight) - else: - weighted_mid = snapshot.mid_price - - # Liquidity ratio - liquidity_ratio = total_bid_size / total_ask_size if total_ask_size > 0 else 1.0 - - self.liquidity_metrics[symbol] = { - 'total_bid_size': total_bid_size, - 'total_ask_size': total_ask_size, - 'weighted_mid': weighted_mid, - 'liquidity_ratio': liquidity_ratio, - 'spread_bps': (snapshot.spread / snapshot.mid_price) * 10000 if snapshot.mid_price > 0 else 0 - } - - except Exception as e: - logger.error(f"Error updating liquidity metrics for {symbol}: {e}") - - def _calculate_order_book_imbalance(self, symbol: str, snapshot: OrderBookSnapshot): - """Calculate order book imbalance""" - try: - if not snapshot.bids or not snapshot.asks: - return - - # Top 5 levels imbalance - n_levels = min(5, len(snapshot.bids), len(snapshot.asks)) - - total_bid_size = sum(snapshot.bids[i].size for i in range(n_levels)) - total_ask_size = sum(snapshot.asks[i].size for i in range(n_levels)) - - if total_bid_size + total_ask_size > 0: - imbalance = (total_bid_size - total_ask_size) / (total_bid_size + total_ask_size) - else: - imbalance = 0.0 - - self.order_book_imbalances[symbol].append({ - 'timestamp': snapshot.timestamp, - 'imbalance': imbalance, - 'bid_size': total_bid_size, - 'ask_size': total_ask_size - }) - - except Exception as e: - logger.error(f"Error calculating imbalance for {symbol}: {e}") - - def _update_order_heatmap(self, symbol: str, snapshot: OrderBookSnapshot): - """Update order heatmap matrix""" - try: - heatmap_entry = { - 'timestamp': snapshot.timestamp, - 'mid_price': snapshot.mid_price, - 'levels': {} - } - - # Add bid levels - for level in snapshot.bids: - price_offset = level.price - snapshot.mid_price - heatmap_entry['levels'][price_offset] = { - 'side': 'bid', - 'size': level.size, - 'price': level.price - } - - # Add ask levels - for level in snapshot.asks: - price_offset = level.price - snapshot.mid_price - heatmap_entry['levels'][price_offset] = { - 'side': 'ask', - 'size': level.size, - 'price': level.price - } - - self.order_heatmaps[symbol].append(heatmap_entry) - - # Clean old entries - cutoff_time = snapshot.timestamp - self.heatmap_window - while (self.order_heatmaps[symbol] and - self.order_heatmaps[symbol][0]['timestamp'] < cutoff_time): - self.order_heatmaps[symbol].popleft() - - except Exception as e: - logger.error(f"Error updating heatmap for {symbol}: {e}") - - def _update_volume_profile(self, symbol: str, price: float, quantity: float, is_buyer_maker: bool): - """Enhanced Session Volume Profile (SVP) update using free data""" - try: - # Calculate trade volume in USDT - volume = price * quantity - - # Use price level caching for better performance - price_key = round(price, 2) # Round to 2 decimal places for price level grouping - - # Update session volume profile - if price_key not in self.price_level_cache[symbol]: - self.price_level_cache[symbol][price_key] = VolumeProfileLevel( - price=price_key, - volume=0.0, - buy_volume=0.0, - sell_volume=0.0, - trades_count=0, - vwap=price - ) - - level = self.price_level_cache[symbol][price_key] - old_total_volume = level.volume - old_total_quantity = level.trades_count - - # Update volume metrics - level.volume += volume - level.trades_count += 1 - - # Update buy/sell volume breakdown - if is_buyer_maker: - level.sell_volume += volume # Market maker is selling - else: - level.buy_volume += volume # Market maker is buying - - # Calculate Volume Weighted Average Price (VWAP) for this level - if level.volume > 0: - level.vwap = ((level.vwap * old_total_volume) + (price * volume)) / level.volume - - # Also update the rolling volume profile (last 10 minutes) - self._update_rolling_volume_profile(symbol, price_key, volume, is_buyer_maker) - - # Session reset detection (every 24 hours or major price gaps) - current_time = datetime.now() - if self._should_reset_session(symbol, current_time, price): - self._reset_session_volume_profile(symbol, current_time) - - except Exception as e: - logger.error(f"Error updating Session Volume Profile for {symbol}: {e}") - - def _update_rolling_volume_profile(self, symbol: str, price_key: float, volume: float, is_buyer_maker: bool): - """Update rolling 10-minute volume profile for real-time heatmap""" - try: - # Find or create level in regular volume profile - price_level = None - for level in self.volume_profiles.get(symbol, []): - if abs(level.price - price_key) < 0.01: - price_level = level - break - - if not price_level: - if symbol not in self.volume_profiles: - self.volume_profiles[symbol] = [] - - price_level = VolumeProfileLevel( - price=price_key, - volume=0.0, - buy_volume=0.0, - sell_volume=0.0, - trades_count=0, - vwap=price_key - ) - self.volume_profiles[symbol].append(price_level) - - # Update rolling metrics - old_volume = price_level.volume - price_level.volume += volume - price_level.trades_count += 1 - - if is_buyer_maker: - price_level.sell_volume += volume - else: - price_level.buy_volume += volume - - # Update VWAP - if price_level.volume > 0: - price_level.vwap = ((price_level.vwap * old_volume) + (price_key * volume)) / price_level.volume - - except Exception as e: - logger.error(f"Error updating rolling volume profile for {symbol}: {e}") - - def _should_reset_session(self, symbol: str, current_time: datetime, current_price: float) -> bool: - """Determine if session volume profile should be reset""" - try: - session_start = self.session_start_time.get(symbol) - if not session_start: - return False - - # Reset every 24 hours (daily session) - if (current_time - session_start).total_seconds() > 86400: # 24 hours - return True - - # Reset on major price gaps (> 5% from session VWAP) - if self.price_level_cache.get(symbol): - total_volume = sum(level.volume for level in self.price_level_cache[symbol].values()) - if total_volume > 0: - weighted_price = sum(level.vwap * level.volume for level in self.price_level_cache[symbol].values()) / total_volume - price_gap = abs(current_price - weighted_price) / weighted_price - if price_gap > 0.05: # 5% gap - return True - - return False - - except Exception as e: - logger.error(f"Error checking session reset for {symbol}: {e}") - return False - - def _reset_session_volume_profile(self, symbol: str, reset_time: datetime): - """Reset session volume profile""" - try: - logger.info(f"Resetting session volume profile for {symbol}") - self.session_start_time[symbol] = reset_time - self.price_level_cache[symbol] = {} - self.session_volume_profiles[symbol] = [] - - except Exception as e: - logger.error(f"Error resetting session volume profile for {symbol}: {e}") - - async def _analyze_order_flow(self, symbol: str, timestamp: datetime, price: float, - quantity: float, is_buyer_maker: bool): - """Analyze order flow patterns""" - try: - if symbol not in self.order_book_history or not self.order_book_history[symbol]: - return - - recent_snapshots = list(self.order_book_history[symbol])[-10:] - - # Check for sweeps - sweep_signal = self._detect_order_sweep(symbol, recent_snapshots, price, quantity, is_buyer_maker) - if sweep_signal: - self.flow_signals[symbol].append(sweep_signal) - await self._notify_flow_signal(symbol, sweep_signal) - - # Check for absorption - absorption_signal = self._detect_absorption(symbol, recent_snapshots, price, quantity) - if absorption_signal: - self.flow_signals[symbol].append(absorption_signal) - await self._notify_flow_signal(symbol, absorption_signal) - - # Check for momentum - momentum_signal = self._detect_momentum_trade(symbol, price, quantity, is_buyer_maker) - if momentum_signal: - self.flow_signals[symbol].append(momentum_signal) - await self._notify_flow_signal(symbol, momentum_signal) - - except Exception as e: - logger.error(f"Error analyzing order flow for {symbol}: {e}") - - async def _analyze_enhanced_order_flow(self, symbol: str, timestamp: datetime, price: float, - quantity: float, trade_value: float, is_buyer_maker: bool, - trade_type: str, aggregation_size: int = 1): - """Enhanced order flow analysis with aggressive vs passive ratios""" - try: - # Determine if trade is aggressive (taker) or passive (maker) - is_aggressive = not is_buyer_maker # In Binance data, m=false means buyer is taker (aggressive) - - # Calculate aggressive vs passive ratios - self._update_aggressive_passive_ratio(symbol, timestamp, trade_value, is_aggressive) - - # Update trade size distribution - self._update_trade_size_distribution(symbol, timestamp, trade_value, trade_type) - - # Update market maker vs taker flow - self._update_market_maker_taker_flow(symbol, timestamp, trade_value, is_buyer_maker, is_aggressive) - - # Calculate order flow intensity - self._update_order_flow_intensity(symbol, timestamp, trade_value, aggregation_size) - - # Measure liquidity consumption - await self._measure_liquidity_consumption(symbol, timestamp, price, quantity, trade_value, is_aggressive) - - # Measure price impact - await self._measure_price_impact(symbol, timestamp, price, trade_value, is_aggressive) - - except Exception as e: - logger.error(f"Error in enhanced order flow analysis for {symbol}: {e}") - - def _update_aggressive_passive_ratio(self, symbol: str, timestamp: datetime, trade_value: float, is_aggressive: bool): - """Update aggressive vs passive participant ratios""" - try: - current_window = [] - cutoff_time = timestamp - timedelta(seconds=60) # 1-minute window - - # Filter recent trades within window - for entry in self.aggressive_passive_ratios[symbol]: - if entry['timestamp'] > cutoff_time: - current_window.append(entry) - - # Add current trade - current_window.append({ - 'timestamp': timestamp, - 'trade_value': trade_value, - 'is_aggressive': is_aggressive - }) - - # Calculate ratios - aggressive_volume = sum(t['trade_value'] for t in current_window if t['is_aggressive']) - passive_volume = sum(t['trade_value'] for t in current_window if not t['is_aggressive']) - total_volume = aggressive_volume + passive_volume - - if total_volume > 0: - aggressive_ratio = aggressive_volume / total_volume - passive_ratio = passive_volume / total_volume - - ratio_data = { - 'timestamp': timestamp, - 'aggressive_ratio': aggressive_ratio, - 'passive_ratio': passive_ratio, - 'aggressive_volume': aggressive_volume, - 'passive_volume': passive_volume, - 'total_volume': total_volume, - 'trade_count': len(current_window), - 'avg_aggressive_size': aggressive_volume / max(1, sum(1 for t in current_window if t['is_aggressive'])), - 'avg_passive_size': passive_volume / max(1, sum(1 for t in current_window if not t['is_aggressive'])) - } - - # Update buffer - self.aggressive_passive_ratios[symbol].clear() - self.aggressive_passive_ratios[symbol].extend(current_window) - - # Store calculated ratios for use by models - if not hasattr(self, 'current_flow_ratios'): - self.current_flow_ratios = {} - self.current_flow_ratios[symbol] = ratio_data - - except Exception as e: - logger.error(f"Error updating aggressive/passive ratio for {symbol}: {e}") - - def _update_trade_size_distribution(self, symbol: str, timestamp: datetime, trade_value: float, trade_type: str): - """Update trade size distribution for institutional vs retail detection""" - try: - # Classify trade size - if trade_value < 1000: - size_category = 'micro' # < $1K (retail) - elif trade_value < 10000: - size_category = 'small' # $1K-$10K (retail/small institutional) - elif trade_value < 50000: - size_category = 'medium' # $10K-$50K (institutional) - elif trade_value < 100000: - size_category = 'large' # $50K-$100K (large institutional) - else: - size_category = 'block' # > $100K (block trades) - - trade_data = { - 'timestamp': timestamp, - 'trade_value': trade_value, - 'trade_type': trade_type, - 'size_category': size_category, - 'is_institutional': trade_value >= self.large_order_threshold, - 'is_block_trade': trade_value >= self.block_trade_threshold - } - - self.trade_size_distributions[symbol].append(trade_data) - - except Exception as e: - logger.error(f"Error updating trade size distribution for {symbol}: {e}") - - def _update_market_maker_taker_flow(self, symbol: str, timestamp: datetime, trade_value: float, - is_buyer_maker: bool, is_aggressive: bool): - """Update market maker vs taker flow analysis""" - try: - flow_data = { - 'timestamp': timestamp, - 'trade_value': trade_value, - 'is_buyer_maker': is_buyer_maker, - 'is_aggressive': is_aggressive, - 'flow_direction': 'buy_aggressive' if not is_buyer_maker else 'sell_aggressive', - 'market_maker_side': 'sell' if is_buyer_maker else 'buy' - } - - self.market_maker_taker_flows[symbol].append(flow_data) - - except Exception as e: - logger.error(f"Error updating market maker/taker flow for {symbol}: {e}") - - def _update_order_flow_intensity(self, symbol: str, timestamp: datetime, trade_value: float, aggregation_size: int): - """Calculate order flow intensity based on trade frequency and size""" - try: - # Calculate intensity based on trade value and aggregation - base_intensity = trade_value / 10000 # Normalize by $10K - aggregation_intensity = aggregation_size / 10 # Normalize aggregation factor - - # Time-based intensity (trades per second) - recent_trades = [t for t in self.order_flow_intensity[symbol] - if (timestamp - t['timestamp']).total_seconds() < 10] - time_intensity = len(recent_trades) / 10 # Trades per second over 10s window - - intensity_score = base_intensity * (1 + aggregation_intensity) * (1 + time_intensity) - - intensity_data = { - 'timestamp': timestamp, - 'intensity_score': intensity_score, - 'base_intensity': base_intensity, - 'aggregation_intensity': aggregation_intensity, - 'time_intensity': time_intensity, - 'trade_value': trade_value, - 'aggregation_size': aggregation_size - } - - self.order_flow_intensity[symbol].append(intensity_data) - - except Exception as e: - logger.error(f"Error updating order flow intensity for {symbol}: {e}") - - async def _measure_liquidity_consumption(self, symbol: str, timestamp: datetime, price: float, - quantity: float, trade_value: float, is_aggressive: bool): - """Measure liquidity consumption rates""" - try: - if not is_aggressive: - return # Only measure for aggressive trades - - current_snapshot = self.order_books.get(symbol) - if not current_snapshot: - return - - # Calculate how much liquidity was consumed - if price >= current_snapshot.mid_price: # Buy-side consumption - consumed_liquidity = 0 - for ask_level in current_snapshot.asks: - if ask_level.price <= price: - consumed_liquidity += min(ask_level.size, quantity) * ask_level.price - quantity -= ask_level.size - if quantity <= 0: - break - else: # Sell-side consumption - consumed_liquidity = 0 - for bid_level in current_snapshot.bids: - if bid_level.price >= price: - consumed_liquidity += min(bid_level.size, quantity) * bid_level.price - quantity -= bid_level.size - if quantity <= 0: - break - - consumption_rate = consumed_liquidity / trade_value if trade_value > 0 else 0 - - consumption_data = { - 'timestamp': timestamp, - 'price': price, - 'trade_value': trade_value, - 'consumed_liquidity': consumed_liquidity, - 'consumption_rate': consumption_rate, - 'side': 'buy' if price >= current_snapshot.mid_price else 'sell' - } - - self.liquidity_consumption_rates[symbol].append(consumption_data) - - except Exception as e: - logger.error(f"Error measuring liquidity consumption for {symbol}: {e}") - - async def _measure_price_impact(self, symbol: str, timestamp: datetime, price: float, - trade_value: float, is_aggressive: bool): - """Measure price impact of trades""" - try: - if not is_aggressive: - return - - # Get price before and after (approximated by looking at recent snapshots) - recent_snapshots = list(self.order_book_history[symbol])[-5:] - if len(recent_snapshots) < 2: - return - - price_before = recent_snapshots[-2].mid_price - price_after = recent_snapshots[-1].mid_price - - price_impact = abs(price_after - price_before) / price_before if price_before > 0 else 0 - impact_per_dollar = price_impact / (trade_value / 1000000) if trade_value > 0 else 0 # Impact per $1M - - impact_data = { - 'timestamp': timestamp, - 'trade_price': price, - 'trade_value': trade_value, - 'price_before': price_before, - 'price_after': price_after, - 'price_impact': price_impact, - 'impact_per_million': impact_per_dollar, - 'impact_category': self._categorize_impact(price_impact) - } - - self.price_impact_measurements[symbol].append(impact_data) - - except Exception as e: - logger.error(f"Error measuring price impact for {symbol}: {e}") - - def _categorize_impact(self, price_impact: float) -> str: - """Categorize price impact level""" - if price_impact < 0.0001: # < 0.01% - return 'minimal' - elif price_impact < 0.001: # < 0.1% - return 'low' - elif price_impact < 0.005: # < 0.5% - return 'medium' - elif price_impact < 0.01: # < 1% - return 'high' - else: - return 'extreme' - - async def _detect_institutional_activity(self, symbol: str, timestamp: datetime, price: float, - quantity: float, trade_value: float, aggregation_size: int, - is_buyer_maker: bool): - """Detect institutional trading activity patterns""" - try: - # Block trade detection - if trade_value >= self.block_trade_threshold: - signal = OrderFlowSignal( - timestamp=timestamp, - signal_type='block_trade', - price=price, - volume=trade_value, - confidence=min(0.95, trade_value / 500000), # Higher confidence for larger trades - description=f"Block trade: ${trade_value:.0f} ({'Buy' if not is_buyer_maker else 'Sell'})" - ) - self.flow_signals[symbol].append(signal) - await self._notify_flow_signal(symbol, signal) - - # Iceberg order detection (multiple large aggregated trades in sequence) - await self._detect_iceberg_orders(symbol, timestamp, price, trade_value, aggregation_size, is_buyer_maker) - - # High-frequency activity detection - await self._detect_hft_activity(symbol, timestamp, trade_value, aggregation_size) - - except Exception as e: - logger.error(f"Error detecting institutional activity for {symbol}: {e}") - - async def _detect_iceberg_orders(self, symbol: str, timestamp: datetime, price: float, - trade_value: float, aggregation_size: int, is_buyer_maker: bool): - """Detect iceberg order patterns""" - try: - if trade_value < self.large_order_threshold: - return - - # Look for similar-sized trades in recent history - cutoff_time = timestamp - timedelta(seconds=self.iceberg_detection_window) - recent_large_trades = [] - - for trade_data in self.trade_size_distributions[symbol]: - if (trade_data['timestamp'] > cutoff_time and - trade_data['trade_value'] >= self.large_order_threshold): - recent_large_trades.append(trade_data) - - # Iceberg pattern: 3+ large trades with similar sizes - if len(recent_large_trades) >= 3: - avg_size = sum(t['trade_value'] for t in recent_large_trades) / len(recent_large_trades) - size_consistency = all(abs(t['trade_value'] - avg_size) / avg_size < 0.2 for t in recent_large_trades) - - if size_consistency: - total_iceberg_volume = sum(t['trade_value'] for t in recent_large_trades) - confidence = min(0.9, len(recent_large_trades) / 10 + total_iceberg_volume / 1000000) - - signal = OrderFlowSignal( - timestamp=timestamp, - signal_type='iceberg', - price=price, - volume=total_iceberg_volume, - confidence=confidence, - description=f"Iceberg: {len(recent_large_trades)} trades, ${total_iceberg_volume:.0f} total" - ) - self.flow_signals[symbol].append(signal) - await self._notify_flow_signal(symbol, signal) - - except Exception as e: - logger.error(f"Error detecting iceberg orders for {symbol}: {e}") - - async def _detect_hft_activity(self, symbol: str, timestamp: datetime, trade_value: float, aggregation_size: int): - """Detect high-frequency trading activity""" - try: - # Look for high-frequency patterns (many small trades in rapid succession) - cutoff_time = timestamp - timedelta(seconds=5) - recent_trades = [t for t in self.order_flow_intensity[symbol] if t['timestamp'] > cutoff_time] - - if len(recent_trades) >= 20: # 20+ trades in 5 seconds - avg_trade_size = sum(t['trade_value'] for t in recent_trades) / len(recent_trades) - - if avg_trade_size < 5000: # Small average trade size suggests HFT - total_hft_volume = sum(t['trade_value'] for t in recent_trades) - confidence = min(0.8, len(recent_trades) / 50) - - signal = OrderFlowSignal( - timestamp=timestamp, - signal_type='hft_activity', - price=0, # Multiple prices - volume=total_hft_volume, - confidence=confidence, - description=f"HFT: {len(recent_trades)} trades in 5s, avg ${avg_trade_size:.0f}" - ) - self.flow_signals[symbol].append(signal) - await self._notify_flow_signal(symbol, signal) - - except Exception as e: - logger.error(f"Error detecting HFT activity for {symbol}: {e}") - - def _update_volume_statistics(self, symbol: str, volume_24h: float, quote_volume_24h: float, weighted_avg_price: float): - """Update volume statistics for relative analysis""" - try: - # Store 24h volume data for relative comparisons - if not hasattr(self, 'volume_stats'): - self.volume_stats = {} - - self.volume_stats[symbol] = { - 'volume_24h': volume_24h, - 'quote_volume_24h': quote_volume_24h, - 'weighted_avg_price': weighted_avg_price, - 'timestamp': datetime.now() - } - - except Exception as e: - logger.error(f"Error updating volume statistics for {symbol}: {e}") - - def _detect_order_sweep(self, symbol: str, snapshots: List[OrderBookSnapshot], - price: float, quantity: float, is_buyer_maker: bool) -> Optional[OrderFlowSignal]: - """Detect order book sweeps""" - try: - if len(snapshots) < 2: - return None - - before_snapshot = snapshots[-2] - - if is_buyer_maker: # Sell order, check ask side - levels_consumed = 0 - total_consumed_size = 0 - - for level in before_snapshot.asks[:5]: - if level.price <= price: - levels_consumed += 1 - total_consumed_size += level.size - - if levels_consumed >= 2 and total_consumed_size > quantity * 1.5: - confidence = min(0.9, levels_consumed / 5.0 + 0.3) - - return OrderFlowSignal( - timestamp=datetime.now(), - signal_type='sweep', - price=price, - volume=quantity * price, - confidence=confidence, - description=f"Sell sweep: {levels_consumed} levels" - ) - else: # Buy order, check bid side - levels_consumed = 0 - total_consumed_size = 0 - - for level in before_snapshot.bids[:5]: - if level.price >= price: - levels_consumed += 1 - total_consumed_size += level.size - - if levels_consumed >= 2 and total_consumed_size > quantity * 1.5: - confidence = min(0.9, levels_consumed / 5.0 + 0.3) - - return OrderFlowSignal( - timestamp=datetime.now(), - signal_type='sweep', - price=price, - volume=quantity * price, - confidence=confidence, - description=f"Buy sweep: {levels_consumed} levels" - ) - - return None - - except Exception as e: - logger.error(f"Error detecting sweep for {symbol}: {e}") - return None - - def _detect_absorption(self, symbol: str, snapshots: List[OrderBookSnapshot], - price: float, quantity: float) -> Optional[OrderFlowSignal]: - """Detect absorption patterns""" - try: - if len(snapshots) < 3: - return None - - volume_threshold = 10000 # $10K minimum - price_impact_threshold = 0.001 # 0.1% max impact - - trade_value = price * quantity - if trade_value < volume_threshold: - return None - - # Calculate price impact - price_before = snapshots[-3].mid_price - price_after = snapshots[-1].mid_price - price_impact = abs(price_after - price_before) / price_before - - if price_impact < price_impact_threshold: - confidence = min(0.8, (trade_value / 50000) * 0.5 + 0.3) - - return OrderFlowSignal( - timestamp=datetime.now(), - signal_type='absorption', - price=price, - volume=trade_value, - confidence=confidence, - description=f"Absorption: ${trade_value:.0f}" - ) - - return None - - except Exception as e: - logger.error(f"Error detecting absorption for {symbol}: {e}") - return None - - def _detect_momentum_trade(self, symbol: str, price: float, quantity: float, - is_buyer_maker: bool) -> Optional[OrderFlowSignal]: - """Detect momentum trades""" - try: - trade_value = price * quantity - momentum_threshold = 25000 # $25K minimum - - if trade_value < momentum_threshold: - return None - - confidence = min(0.9, trade_value / 100000 * 0.6 + 0.3) - direction = "sell" if is_buyer_maker else "buy" - - return OrderFlowSignal( - timestamp=datetime.now(), - signal_type='momentum', - price=price, - volume=trade_value, - confidence=confidence, - description=f"Large {direction}: ${trade_value:.0f}" - ) - - except Exception as e: - logger.error(f"Error detecting momentum for {symbol}: {e}") - return None - - async def _notify_flow_signal(self, symbol: str, signal: OrderFlowSignal): - """Notify models of flow signals""" - try: - signal_data = { - 'signal_type': signal.signal_type, - 'price': signal.price, - 'volume': signal.volume, - 'confidence': signal.confidence, - 'timestamp': signal.timestamp, - 'description': signal.description - } - - # Notify CNN callbacks - for callback in self.cnn_callbacks: - try: - callback(symbol, signal_data) - except Exception as e: - logger.warning(f"Error in CNN callback: {e}") - - # Notify DQN callbacks - for callback in self.dqn_callbacks: - try: - callback(symbol, signal_data) - except Exception as e: - logger.warning(f"Error in DQN callback: {e}") - - except Exception as e: - logger.error(f"Error notifying flow signal: {e}") - - async def _continuous_analysis(self): - """Continuous analysis and model feeding""" - while self.is_streaming: - try: - await asyncio.sleep(1) # Analyze every second - - for symbol in self.symbols: - # Generate features for models - cnn_features = self.get_cnn_features(symbol) - if cnn_features is not None: - for callback in self.cnn_callbacks: - try: - callback(symbol, {'features': cnn_features, 'type': 'orderbook'}) - except Exception as e: - logger.warning(f"Error in CNN feature callback: {e}") - - dqn_features = self.get_dqn_state_features(symbol) - if dqn_features is not None: - for callback in self.dqn_callbacks: - try: - callback(symbol, {'state': dqn_features, 'type': 'orderbook'}) - except Exception as e: - logger.warning(f"Error in DQN state callback: {e}") - - except Exception as e: - logger.error(f"Error in continuous analysis: {e}") - await asyncio.sleep(5) - - def get_cnn_features(self, symbol: str) -> Optional[np.ndarray]: - """Generate CNN features from order book data""" - try: - if symbol not in self.order_books: - return None - - snapshot = self.order_books[symbol] - features = [] - - # Order book features (80 features: 20 levels x 2 sides x 2 values) - for i in range(min(20, len(snapshot.bids))): - bid = snapshot.bids[i] - features.append(bid.size) - features.append(bid.price - snapshot.mid_price) - - # Pad bids - while len(features) < 40: - features.extend([0.0, 0.0]) - - for i in range(min(20, len(snapshot.asks))): - ask = snapshot.asks[i] - features.append(ask.size) - features.append(ask.price - snapshot.mid_price) - - # Pad asks - while len(features) < 80: - features.extend([0.0, 0.0]) - - # Liquidity metrics (10 features) - metrics = self.liquidity_metrics.get(symbol, {}) - features.extend([ - metrics.get('total_bid_size', 0.0), - metrics.get('total_ask_size', 0.0), - metrics.get('liquidity_ratio', 1.0), - metrics.get('spread_bps', 0.0), - snapshot.spread, - metrics.get('weighted_mid', snapshot.mid_price) - snapshot.mid_price, - len(snapshot.bids), - len(snapshot.asks), - snapshot.mid_price, - time.time() % 86400 # Time of day - ]) - - # Order book imbalance (5 features) - if self.order_book_imbalances[symbol]: - latest_imbalance = self.order_book_imbalances[symbol][-1] - features.extend([ - latest_imbalance['imbalance'], - latest_imbalance['bid_size'], - latest_imbalance['ask_size'], - latest_imbalance['bid_size'] + latest_imbalance['ask_size'], - abs(latest_imbalance['imbalance']) - ]) - else: - features.extend([0.0, 0.0, 0.0, 0.0, 0.0]) - - # Enhanced flow signals (15 features) - recent_signals = [s for s in self.flow_signals[symbol] - if (datetime.now() - s.timestamp).seconds < 60] - - sweep_count = sum(1 for s in recent_signals if s.signal_type == 'sweep') - absorption_count = sum(1 for s in recent_signals if s.signal_type == 'absorption') - momentum_count = sum(1 for s in recent_signals if s.signal_type == 'momentum') - block_count = sum(1 for s in recent_signals if s.signal_type == 'block_trade') - iceberg_count = sum(1 for s in recent_signals if s.signal_type == 'iceberg') - hft_count = sum(1 for s in recent_signals if s.signal_type == 'hft_activity') - max_confidence = max([s.confidence for s in recent_signals], default=0.0) - total_flow_volume = sum(s.volume for s in recent_signals) - - # Enhanced order flow metrics - flow_metrics = self.get_enhanced_order_flow_metrics(symbol) - if flow_metrics: - aggressive_ratio = flow_metrics['aggressive_passive']['aggressive_ratio'] - institutional_ratio = flow_metrics['institutional_retail']['institutional_ratio'] - flow_intensity = flow_metrics['flow_intensity']['current_intensity'] - avg_consumption_rate = flow_metrics['liquidity']['avg_consumption_rate'] - avg_price_impact = flow_metrics['price_impact']['avg_impact'] / 10000 # Normalize from basis points - buy_pressure = flow_metrics['maker_taker_flow']['buy_pressure'] - sell_pressure = flow_metrics['maker_taker_flow']['sell_pressure'] - else: - aggressive_ratio = 0.5 - institutional_ratio = 0.5 - flow_intensity = 0.0 - avg_consumption_rate = 0.0 - avg_price_impact = 0.0 - buy_pressure = 0.5 - sell_pressure = 0.5 - - features.extend([ - sweep_count, - absorption_count, - momentum_count, - block_count, - iceberg_count, - hft_count, - max_confidence, - total_flow_volume, - aggressive_ratio, - institutional_ratio, - flow_intensity, - avg_consumption_rate, - avg_price_impact, - buy_pressure, - sell_pressure - ]) - - return np.array(features, dtype=np.float32) - - except Exception as e: - logger.error(f"Error generating CNN features for {symbol}: {e}") - return None - - def get_dqn_state_features(self, symbol: str) -> Optional[np.ndarray]: - """Generate DQN state features""" - try: - if symbol not in self.order_books: - return None - - snapshot = self.order_books[symbol] - state_features = [] - - # Normalized order book state (20 features) - total_bid_size = sum(level.size for level in snapshot.bids[:10]) - total_ask_size = sum(level.size for level in snapshot.asks[:10]) - total_size = total_bid_size + total_ask_size - - if total_size > 0: - for i in range(min(10, len(snapshot.bids))): - state_features.append(snapshot.bids[i].size / total_size) - - while len(state_features) < 10: - state_features.append(0.0) - - for i in range(min(10, len(snapshot.asks))): - state_features.append(snapshot.asks[i].size / total_size) - - while len(state_features) < 20: - state_features.append(0.0) - else: - state_features.extend([0.0] * 20) - - # Enhanced market state indicators (20 features) - metrics = self.liquidity_metrics.get(symbol, {}) - - spread_pct = (snapshot.spread / snapshot.mid_price) if snapshot.mid_price > 0 else 0 - liquidity_ratio = metrics.get('liquidity_ratio', 1.0) - liquidity_imbalance = (liquidity_ratio - 1) / (liquidity_ratio + 1) - - # Flow strength - recent_signals = [s for s in self.flow_signals[symbol] - if (datetime.now() - s.timestamp).seconds < 30] - flow_strength = sum(s.confidence for s in recent_signals) / max(len(recent_signals), 1) - - # Price volatility - if len(self.order_book_history[symbol]) >= 10: - recent_prices = [s.mid_price for s in list(self.order_book_history[symbol])[-10:]] - price_volatility = np.std(recent_prices) / np.mean(recent_prices) if recent_prices else 0 - else: - price_volatility = 0 - - # Enhanced order flow metrics for DQN - flow_metrics = self.get_enhanced_order_flow_metrics(symbol) - if flow_metrics: - aggressive_ratio = flow_metrics['aggressive_passive']['aggressive_ratio'] - institutional_ratio = flow_metrics['institutional_retail']['institutional_ratio'] - flow_intensity = min(flow_metrics['flow_intensity']['current_intensity'] / 10, 1.0) # Normalize - consumption_rate = flow_metrics['liquidity']['avg_consumption_rate'] - price_impact = min(flow_metrics['price_impact']['avg_impact'] / 100, 1.0) # Normalize basis points - buy_pressure = flow_metrics['maker_taker_flow']['buy_pressure'] - sell_pressure = flow_metrics['maker_taker_flow']['sell_pressure'] - - # Trade size distribution ratios - size_dist = flow_metrics['size_distribution'] - total_trades = sum(size_dist.values()) or 1 - retail_ratio = (size_dist.get('micro', 0) + size_dist.get('small', 0)) / total_trades - institutional_trade_ratio = (size_dist.get('large', 0) + size_dist.get('block', 0)) / total_trades - - # Recent activity indicators - block_activity = min(size_dist.get('block', 0) / 10, 1.0) # Normalize - else: - aggressive_ratio = 0.5 - institutional_ratio = 0.5 - flow_intensity = 0.0 - consumption_rate = 0.0 - price_impact = 0.0 - buy_pressure = 0.5 - sell_pressure = 0.5 - retail_ratio = 0.5 - institutional_trade_ratio = 0.5 - block_activity = 0.0 - - state_features.extend([ - spread_pct * 10000, # Spread in basis points - liquidity_imbalance, - flow_strength, - price_volatility * 100, - min(len(snapshot.bids), 20) / 20, - min(len(snapshot.asks), 20) / 20, - len([s for s in recent_signals if s.signal_type == 'sweep']) / 10, - len([s for s in recent_signals if s.signal_type == 'absorption']) / 5, - len([s for s in recent_signals if s.signal_type == 'momentum']) / 5, - (datetime.now().hour * 60 + datetime.now().minute) / 1440, - # Enhanced order flow state features - aggressive_ratio, - institutional_ratio, - flow_intensity, - consumption_rate, - price_impact, - buy_pressure, - sell_pressure, - retail_ratio, - institutional_trade_ratio, - block_activity - ]) - - return np.array(state_features, dtype=np.float32) - - except Exception as e: - logger.error(f"Error generating DQN features for {symbol}: {e}") - return None - - def get_order_heatmap_matrix(self, symbol: str, levels: int = 40) -> Optional[np.ndarray]: - """Generate heatmap matrix for visualization""" - try: - if symbol not in self.order_heatmaps or not self.order_heatmaps[symbol]: - return None - - current_snapshot = self.order_books.get(symbol) - if not current_snapshot: - return None - - mid_price = current_snapshot.mid_price - price_step = mid_price * 0.0001 # 1 basis point steps - - # Matrix: time x price levels - time_window = min(600, len(self.order_heatmaps[symbol])) - heatmap_matrix = np.zeros((time_window, levels)) - - # Fill matrix - for t, entry in enumerate(list(self.order_heatmaps[symbol])[-time_window:]): - for price_offset, level_data in entry['levels'].items(): - level_idx = int((price_offset + (levels/2) * price_step) / price_step) - - if 0 <= level_idx < levels: - size_weight = 1.0 if level_data['side'] == 'bid' else -1.0 - heatmap_matrix[t, level_idx] = level_data['size'] * size_weight - - return heatmap_matrix - - except Exception as e: - logger.error(f"Error generating heatmap matrix for {symbol}: {e}") - return None - - def get_dashboard_data(self, symbol: str) -> Optional[Dict]: - """Get data for dashboard visualization""" - try: - if symbol not in self.order_books: - return None - - snapshot = self.order_books[symbol] - - return { - 'timestamp': snapshot.timestamp.isoformat(), - 'symbol': symbol, - 'mid_price': snapshot.mid_price, - 'spread': snapshot.spread, - 'bids': [{'price': l.price, 'size': l.size} for l in snapshot.bids[:20]], - 'asks': [{'price': l.price, 'size': l.size} for l in snapshot.asks[:20]], - 'liquidity_metrics': self.liquidity_metrics.get(symbol, {}), - 'volume_profile': self.get_volume_profile_data(symbol), - 'heatmap_matrix': self.get_order_heatmap_matrix(symbol).tolist() if self.get_order_heatmap_matrix(symbol) is not None else None, - 'enhanced_order_flow': self.get_enhanced_order_flow_metrics(symbol), - 'recent_signals': [ - { - 'type': s.signal_type, - 'price': s.price, - 'volume': s.volume, - 'confidence': s.confidence, - 'timestamp': s.timestamp.isoformat(), - 'description': s.description - } - for s in list(self.flow_signals[symbol])[-10:] - ] - } - - except Exception as e: - logger.error(f"Error getting dashboard data for {symbol}: {e}") - return None - - def get_volume_profile_data(self, symbol: str) -> Optional[List[Dict]]: - """Get rolling volume profile data (10-minute window)""" - try: - if symbol not in self.volume_profiles: - return None - - profile_data = [] - for level in sorted(self.volume_profiles[symbol], key=lambda x: x.price): - profile_data.append({ - 'price': level.price, - 'volume': level.volume, - 'buy_volume': level.buy_volume, - 'sell_volume': level.sell_volume, - 'trades_count': level.trades_count, - 'vwap': level.vwap, - 'net_volume': level.buy_volume - level.sell_volume - }) - - return profile_data - - except Exception as e: - logger.error(f"Error getting volume profile for {symbol}: {e}") - return None - - def get_session_volume_profile_data(self, symbol: str) -> Optional[List[Dict]]: - """Get Session Volume Profile (SVP) data - full session data""" - try: - if symbol not in self.price_level_cache: - return None - - session_data = [] - total_volume = sum(level.volume for level in self.price_level_cache[symbol].values()) - - for price_key, level in sorted(self.price_level_cache[symbol].items()): - volume_percentage = (level.volume / total_volume * 100) if total_volume > 0 else 0 - - session_data.append({ - 'price': level.price, - 'volume': level.volume, - 'buy_volume': level.buy_volume, - 'sell_volume': level.sell_volume, - 'trades_count': level.trades_count, - 'vwap': level.vwap, - 'net_volume': level.buy_volume - level.sell_volume, - 'volume_percentage': volume_percentage, - 'is_high_volume_node': volume_percentage > 2.0, # Mark significant price levels - 'buy_sell_ratio': level.buy_volume / level.sell_volume if level.sell_volume > 0 else float('inf') - }) - - return session_data - - except Exception as e: - logger.error(f"Error getting Session Volume Profile for {symbol}: {e}") - return None - - def get_session_statistics(self, symbol: str) -> Optional[Dict]: - """Get session trading statistics""" - try: - if symbol not in self.price_level_cache: - return None - - levels = list(self.price_level_cache[symbol].values()) - if not levels: - return None - - total_volume = sum(level.volume for level in levels) - total_buy_volume = sum(level.buy_volume for level in levels) - total_sell_volume = sum(level.sell_volume for level in levels) - total_trades = sum(level.trades_count for level in levels) - - # Calculate session VWAP - session_vwap = sum(level.vwap * level.volume for level in levels) / total_volume if total_volume > 0 else 0 - - # Find price extremes - prices = [level.price for level in levels] - session_high = max(prices) if prices else 0 - session_low = min(prices) if prices else 0 - - # Find Point of Control (POC) - price level with highest volume - poc_level = max(levels, key=lambda x: x.volume) if levels else None - poc_price = poc_level.price if poc_level else 0 - poc_volume = poc_level.volume if poc_level else 0 - - # Calculate Value Area (70% of volume around POC) - sorted_levels = sorted(levels, key=lambda x: x.volume, reverse=True) - value_area_volume = total_volume * 0.7 - value_area_levels = [] - current_volume = 0 - - for level in sorted_levels: - value_area_levels.append(level) - current_volume += level.volume - if current_volume >= value_area_volume: - break - - value_area_high = max(level.price for level in value_area_levels) if value_area_levels else 0 - value_area_low = min(level.price for level in value_area_levels) if value_area_levels else 0 - - session_start = self.session_start_time.get(symbol, datetime.now()) - session_duration = (datetime.now() - session_start).total_seconds() / 3600 # Hours - - return { - 'symbol': symbol, - 'session_start': session_start.isoformat(), - 'session_duration_hours': session_duration, - 'total_volume': total_volume, - 'total_buy_volume': total_buy_volume, - 'total_sell_volume': total_sell_volume, - 'total_trades': total_trades, - 'session_vwap': session_vwap, - 'session_high': session_high, - 'session_low': session_low, - 'poc_price': poc_price, - 'poc_volume': poc_volume, - 'value_area_high': value_area_high, - 'value_area_low': value_area_low, - 'value_area_range': value_area_high - value_area_low, - 'buy_sell_ratio': total_buy_volume / total_sell_volume if total_sell_volume > 0 else float('inf'), - 'price_levels_traded': len(levels), - 'avg_trade_size': total_volume / total_trades if total_trades > 0 else 0 - } - - except Exception as e: - logger.error(f"Error getting session statistics for {symbol}: {e}") - return None - - def get_market_profile_analysis(self, symbol: str) -> Optional[Dict]: - """Get detailed market profile analysis""" - try: - current_snapshot = self.order_books.get(symbol) - session_stats = self.get_session_statistics(symbol) - svp_data = self.get_session_volume_profile_data(symbol) - - if not all([current_snapshot, session_stats, svp_data]): - return None - - current_price = current_snapshot.mid_price - session_vwap = session_stats['session_vwap'] - poc_price = session_stats['poc_price'] - value_area_high = session_stats['value_area_high'] - value_area_low = session_stats['value_area_low'] - - # Market structure analysis - price_vs_vwap = "above" if current_price > session_vwap else "below" - price_vs_poc = "above" if current_price > poc_price else "below" - - in_value_area = value_area_low <= current_price <= value_area_high - - # Find support and resistance levels from high volume nodes - high_volume_nodes = [item for item in svp_data if item['is_high_volume_node']] - resistance_levels = [node['price'] for node in high_volume_nodes if node['price'] > current_price] - support_levels = [node['price'] for node in high_volume_nodes if node['price'] < current_price] - - # Sort to get nearest levels - resistance_levels.sort() - support_levels.sort(reverse=True) - - return { - 'symbol': symbol, - 'current_price': current_price, - 'market_structure': { - 'price_vs_vwap': price_vs_vwap, - 'price_vs_poc': price_vs_poc, - 'in_value_area': in_value_area, - 'distance_from_vwap_bps': int(abs(current_price - session_vwap) / session_vwap * 10000), - 'distance_from_poc_bps': int(abs(current_price - poc_price) / poc_price * 10000) - }, - 'key_levels': { - 'session_vwap': session_vwap, - 'poc_price': poc_price, - 'value_area_high': value_area_high, - 'value_area_low': value_area_low, - 'nearest_resistance': resistance_levels[0] if resistance_levels else None, - 'nearest_support': support_levels[0] if support_levels else None - }, - 'volume_analysis': { - 'total_high_volume_nodes': len(high_volume_nodes), - 'resistance_levels': resistance_levels[:3], # Top 3 resistance - 'support_levels': support_levels[:3], # Top 3 support - 'poc_strength': session_stats['poc_volume'] / session_stats['total_volume'] * 100 - }, - 'session_statistics': session_stats - } - - except Exception as e: - logger.error(f"Error getting market profile analysis for {symbol}: {e}") - return None - - def get_enhanced_order_flow_metrics(self, symbol: str) -> Optional[Dict]: - """Get enhanced order flow metrics including aggressive vs passive ratios""" - try: - if symbol not in self.current_flow_ratios: - return None - - current_ratios = self.current_flow_ratios.get(symbol, {}) - - # Get recent trade size distribution - recent_trades = list(self.trade_size_distributions[symbol])[-100:] # Last 100 trades - if not recent_trades: - return None - - # Calculate institutional vs retail breakdown - institutional_trades = [t for t in recent_trades if t['is_institutional']] - retail_trades = [t for t in recent_trades if not t['is_institutional']] - block_trades = [t for t in recent_trades if t['is_block_trade']] - - institutional_volume = sum(t['trade_value'] for t in institutional_trades) - retail_volume = sum(t['trade_value'] for t in retail_trades) - total_volume = institutional_volume + retail_volume - - # Size category breakdown - size_breakdown = { - 'micro': len([t for t in recent_trades if t['size_category'] == 'micro']), - 'small': len([t for t in recent_trades if t['size_category'] == 'small']), - 'medium': len([t for t in recent_trades if t['size_category'] == 'medium']), - 'large': len([t for t in recent_trades if t['size_category'] == 'large']), - 'block': len([t for t in recent_trades if t['size_category'] == 'block']) - } - - # Get recent order flow intensity - recent_intensity = list(self.order_flow_intensity[symbol])[-10:] - avg_intensity = sum(i['intensity_score'] for i in recent_intensity) / max(1, len(recent_intensity)) - - # Get recent liquidity consumption - recent_consumption = list(self.liquidity_consumption_rates[symbol])[-20:] - avg_consumption_rate = sum(c['consumption_rate'] for c in recent_consumption) / max(1, len(recent_consumption)) - - # Get recent price impact - recent_impacts = list(self.price_impact_measurements[symbol])[-20:] - avg_price_impact = sum(i['price_impact'] for i in recent_impacts) / max(1, len(recent_impacts)) - - # Impact distribution - impact_distribution = {} - for impact in recent_impacts: - category = impact['impact_category'] - impact_distribution[category] = impact_distribution.get(category, 0) + 1 - - # Market maker vs taker flow analysis - recent_flows = list(self.market_maker_taker_flows[symbol])[-50:] - buy_aggressive_volume = sum(f['trade_value'] for f in recent_flows if f['flow_direction'] == 'buy_aggressive') - sell_aggressive_volume = sum(f['trade_value'] for f in recent_flows if f['flow_direction'] == 'sell_aggressive') - - return { - 'symbol': symbol, - 'timestamp': datetime.now().isoformat(), - - # Aggressive vs Passive Analysis - 'aggressive_passive': { - 'aggressive_ratio': current_ratios.get('aggressive_ratio', 0), - 'passive_ratio': current_ratios.get('passive_ratio', 0), - 'aggressive_volume': current_ratios.get('aggressive_volume', 0), - 'passive_volume': current_ratios.get('passive_volume', 0), - 'avg_aggressive_size': current_ratios.get('avg_aggressive_size', 0), - 'avg_passive_size': current_ratios.get('avg_passive_size', 0), - 'trade_count': current_ratios.get('trade_count', 0) - }, - - # Institutional vs Retail Analysis - 'institutional_retail': { - 'institutional_ratio': institutional_volume / total_volume if total_volume > 0 else 0, - 'retail_ratio': retail_volume / total_volume if total_volume > 0 else 0, - 'institutional_volume': institutional_volume, - 'retail_volume': retail_volume, - 'institutional_trade_count': len(institutional_trades), - 'retail_trade_count': len(retail_trades), - 'block_trade_count': len(block_trades), - 'avg_institutional_size': institutional_volume / max(1, len(institutional_trades)), - 'avg_retail_size': retail_volume / max(1, len(retail_trades)) - }, - - # Trade Size Distribution - 'size_distribution': size_breakdown, - - # Order Flow Intensity - 'flow_intensity': { - 'current_intensity': avg_intensity, - 'intensity_category': 'high' if avg_intensity > 5 else 'medium' if avg_intensity > 2 else 'low' - }, - - # Liquidity Analysis - 'liquidity': { - 'avg_consumption_rate': avg_consumption_rate, - 'consumption_category': 'high' if avg_consumption_rate > 0.8 else 'medium' if avg_consumption_rate > 0.5 else 'low' - }, - - # Price Impact Analysis - 'price_impact': { - 'avg_impact': avg_price_impact * 10000, # in basis points - 'impact_distribution': impact_distribution, - 'impact_category': 'high' if avg_price_impact > 0.005 else 'medium' if avg_price_impact > 0.001 else 'low' - }, - - # Market Maker vs Taker Flow - 'maker_taker_flow': { - 'buy_aggressive_volume': buy_aggressive_volume, - 'sell_aggressive_volume': sell_aggressive_volume, - 'buy_pressure': buy_aggressive_volume / (buy_aggressive_volume + sell_aggressive_volume) if (buy_aggressive_volume + sell_aggressive_volume) > 0 else 0.5, - 'sell_pressure': sell_aggressive_volume / (buy_aggressive_volume + sell_aggressive_volume) if (buy_aggressive_volume + sell_aggressive_volume) > 0 else 0.5 - }, - - # 24h Volume Statistics (if available) - 'volume_stats': self.volume_stats.get(symbol, {}) - } - - except Exception as e: - logger.error(f"Error getting enhanced order flow metrics for {symbol}: {e}") - return None \ No newline at end of file diff --git a/core/cnn_training_pipeline.py b/core/cnn_training_pipeline.py deleted file mode 100644 index 15685df..0000000 --- a/core/cnn_training_pipeline.py +++ /dev/null @@ -1,785 +0,0 @@ -""" -CNN Training Pipeline with Comprehensive Data Storage and Replay - -This module implements a robust CNN training pipeline that: -1. Integrates with the comprehensive training data collection system -2. Stores all backpropagation data for gradient replay -3. Enables retraining on most profitable setups -4. Maintains training episode profitability tracking -5. Supports both real-time and batch training modes - -Key Features: -- Integration with TrainingDataCollector for data validation -- Gradient and loss storage for each training step -- Profitable episode prioritization and replay -- Comprehensive training metrics and validation -- Real-time pivot point prediction with outcome tracking -""" - -import asyncio -import logging -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import Dataset, DataLoader -from datetime import datetime, timedelta -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Any, Callable -from dataclasses import dataclass, field -import json -import pickle -from collections import deque, defaultdict -import threading -from concurrent.futures import ThreadPoolExecutor - -from .training_data_collector import ( - TrainingDataCollector, - TrainingEpisode, - ModelInputPackage, - get_training_data_collector -) - -logger = logging.getLogger(__name__) - -@dataclass -class CNNTrainingStep: - """Single CNN training step with complete backpropagation data""" - step_id: str - timestamp: datetime - episode_id: str - - # Input data - input_features: torch.Tensor - target_labels: torch.Tensor - - # Forward pass results - model_outputs: Dict[str, torch.Tensor] - predictions: Dict[str, Any] - confidence_scores: torch.Tensor - - # Loss components - total_loss: float - pivot_prediction_loss: float - confidence_loss: float - regularization_loss: float - - # Backpropagation data - gradients: Dict[str, torch.Tensor] # Gradients for each parameter - gradient_norms: Dict[str, float] # Gradient norms for monitoring - - # Model state - model_state_dict: Optional[Dict[str, torch.Tensor]] = None - optimizer_state: Optional[Dict[str, Any]] = None - - # Training metadata - learning_rate: float = 0.001 - batch_size: int = 32 - epoch: int = 0 - - # Profitability tracking - actual_profitability: Optional[float] = None - prediction_accuracy: Optional[float] = None - training_value: float = 0.0 # Value of this training step for replay - -@dataclass -class CNNTrainingSession: - """Complete CNN training session with multiple steps""" - session_id: str - start_timestamp: datetime - end_timestamp: Optional[datetime] = None - - # Session configuration - training_mode: str = 'real_time' # 'real_time', 'batch', 'replay' - symbol: str = '' - - # Training steps - training_steps: List[CNNTrainingStep] = field(default_factory=list) - - # Session metrics - total_steps: int = 0 - average_loss: float = 0.0 - best_loss: float = float('inf') - convergence_achieved: bool = False - - # Profitability metrics - profitable_predictions: int = 0 - total_predictions: int = 0 - profitability_rate: float = 0.0 - - # Session value for replay prioritization - session_value: float = 0.0 - -class CNNPivotPredictor(nn.Module): - """CNN model for pivot point prediction with comprehensive output""" - - def __init__(self, - input_channels: int = 10, # Multiple timeframes - sequence_length: int = 300, # 300 bars - hidden_dim: int = 256, - num_pivot_classes: int = 3, # high, low, none - dropout_rate: float = 0.2): - - super(CNNPivotPredictor, self).__init__() - - self.input_channels = input_channels - self.sequence_length = sequence_length - self.hidden_dim = hidden_dim - - # Convolutional layers for pattern extraction - self.conv_layers = nn.Sequential( - # First conv block - nn.Conv1d(input_channels, 64, kernel_size=7, padding=3), - nn.BatchNorm1d(64), - nn.ReLU(), - nn.Dropout(dropout_rate), - - # Second conv block - nn.Conv1d(64, 128, kernel_size=5, padding=2), - nn.BatchNorm1d(128), - nn.ReLU(), - nn.Dropout(dropout_rate), - - # Third conv block - nn.Conv1d(128, 256, kernel_size=3, padding=1), - nn.BatchNorm1d(256), - nn.ReLU(), - nn.Dropout(dropout_rate), - ) - - # LSTM for temporal dependencies - self.lstm = nn.LSTM( - input_size=256, - hidden_size=hidden_dim, - num_layers=2, - batch_first=True, - dropout=dropout_rate, - bidirectional=True - ) - - # Attention mechanism - self.attention = nn.MultiheadAttention( - embed_dim=hidden_dim * 2, # Bidirectional LSTM - num_heads=8, - dropout=dropout_rate, - batch_first=True - ) - - # Output heads - self.pivot_classifier = nn.Sequential( - nn.Linear(hidden_dim * 2, hidden_dim), - nn.ReLU(), - nn.Dropout(dropout_rate), - nn.Linear(hidden_dim, num_pivot_classes) - ) - - self.pivot_price_regressor = nn.Sequential( - nn.Linear(hidden_dim * 2, hidden_dim), - nn.ReLU(), - nn.Dropout(dropout_rate), - nn.Linear(hidden_dim, 1) - ) - - self.confidence_head = nn.Sequential( - nn.Linear(hidden_dim * 2, hidden_dim // 2), - nn.ReLU(), - nn.Linear(hidden_dim // 2, 1), - nn.Sigmoid() - ) - - # Initialize weights - self.apply(self._init_weights) - - def _init_weights(self, module): - """Initialize weights with proper scaling""" - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Conv1d): - torch.nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') - - def forward(self, x): - """ - Forward pass through CNN pivot predictor - - Args: - x: Input tensor [batch_size, input_channels, sequence_length] - - Returns: - Dict containing predictions and hidden states - """ - batch_size = x.size(0) - - # Convolutional feature extraction - conv_features = self.conv_layers(x) # [batch, 256, sequence_length] - - # Prepare for LSTM (transpose to [batch, sequence, features]) - lstm_input = conv_features.transpose(1, 2) # [batch, sequence_length, 256] - - # LSTM processing - lstm_output, (hidden, cell) = self.lstm(lstm_input) # [batch, sequence_length, hidden_dim*2] - - # Attention mechanism - attended_output, attention_weights = self.attention( - lstm_output, lstm_output, lstm_output - ) - - # Use the last timestep for predictions - final_features = attended_output[:, -1, :] # [batch, hidden_dim*2] - - # Generate predictions - pivot_logits = self.pivot_classifier(final_features) - pivot_price = self.pivot_price_regressor(final_features) - confidence = self.confidence_head(final_features) - - return { - 'pivot_logits': pivot_logits, - 'pivot_price': pivot_price, - 'confidence': confidence, - 'hidden_states': final_features, - 'attention_weights': attention_weights, - 'conv_features': conv_features, - 'lstm_output': lstm_output - } - -class CNNTrainingDataset(Dataset): - """Dataset for CNN training with training episodes""" - - def __init__(self, training_episodes: List[TrainingEpisode]): - self.episodes = training_episodes - self.valid_episodes = self._validate_episodes() - - def _validate_episodes(self) -> List[TrainingEpisode]: - """Validate and filter episodes for training""" - valid = [] - for episode in self.episodes: - try: - # Check if episode has required data - if (episode.input_package.cnn_features is not None and - episode.actual_outcome.outcome_validated): - valid.append(episode) - except Exception as e: - logger.warning(f"Invalid episode {episode.episode_id}: {e}") - - logger.info(f"Validated {len(valid)}/{len(self.episodes)} episodes for training") - return valid - - def __len__(self): - return len(self.valid_episodes) - - def __getitem__(self, idx): - episode = self.valid_episodes[idx] - - # Extract features - features = torch.from_numpy(episode.input_package.cnn_features).float() - - # Create labels from actual outcomes - pivot_class = self._determine_pivot_class(episode.actual_outcome) - pivot_price = episode.actual_outcome.optimal_exit_price - confidence_target = episode.actual_outcome.profitability_score - - return { - 'features': features, - 'pivot_class': torch.tensor(pivot_class, dtype=torch.long), - 'pivot_price': torch.tensor(pivot_price, dtype=torch.float), - 'confidence_target': torch.tensor(confidence_target, dtype=torch.float), - 'episode_id': episode.episode_id, - 'profitability': episode.actual_outcome.profitability_score - } - - def _determine_pivot_class(self, outcome) -> int: - """Determine pivot class from outcome""" - if outcome.price_change_15m > 0.5: # Significant upward movement - return 0 # High pivot - elif outcome.price_change_15m < -0.5: # Significant downward movement - return 1 # Low pivot - else: - return 2 # No significant pivot - -class CNNTrainer: - """CNN trainer with comprehensive data storage and replay capabilities""" - - def __init__(self, - model: CNNPivotPredictor, - device: str = 'cuda', - learning_rate: float = 0.001, - storage_dir: str = "cnn_training_storage"): - - self.model = model.to(device) - self.device = device - self.learning_rate = learning_rate - - # Storage - self.storage_dir = Path(storage_dir) - self.storage_dir.mkdir(parents=True, exist_ok=True) - - # Optimizer - self.optimizer = torch.optim.AdamW( - self.model.parameters(), - lr=learning_rate, - weight_decay=1e-5 - ) - - # Learning rate scheduler - self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer, mode='min', patience=10, factor=0.5 - ) - - # Training data collector - self.data_collector = get_training_data_collector() - - # Training sessions storage - self.training_sessions: List[CNNTrainingSession] = [] - self.current_session: Optional[CNNTrainingSession] = None - - # Training statistics - self.training_stats = { - 'total_sessions': 0, - 'total_steps': 0, - 'best_validation_loss': float('inf'), - 'profitable_predictions': 0, - 'total_predictions': 0, - 'replay_sessions': 0 - } - - # Background training - self.is_training = False - self.training_thread = None - - logger.info(f"CNN Trainer initialized") - logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}") - logger.info(f"Storage directory: {self.storage_dir}") - - def start_real_time_training(self, symbol: str): - """Start real-time training for a symbol""" - if self.is_training: - logger.warning("CNN training already running") - return - - self.is_training = True - self.training_thread = threading.Thread( - target=self._real_time_training_worker, - args=(symbol,), - daemon=True - ) - self.training_thread.start() - - logger.info(f"Started real-time CNN training for {symbol}") - - def stop_training(self): - """Stop training""" - self.is_training = False - if self.training_thread: - self.training_thread.join(timeout=10) - - if self.current_session: - self._finalize_training_session() - - logger.info("CNN training stopped") - - def _real_time_training_worker(self, symbol: str): - """Real-time training worker""" - logger.info(f"Real-time CNN training worker started for {symbol}") - - while self.is_training: - try: - # Get high-priority episodes for training - episodes = self.data_collector.get_high_priority_episodes( - symbol=symbol, - limit=100, - min_priority=0.3 - ) - - if len(episodes) >= 32: # Minimum batch size - self._train_on_episodes(episodes, training_mode='real_time') - - # Wait before next training cycle - threading.Event().wait(300) # Train every 5 minutes - - except Exception as e: - logger.error(f"Error in real-time training worker: {e}") - threading.Event().wait(60) # Wait before retrying - - logger.info(f"Real-time CNN training worker stopped for {symbol}") - - def train_on_profitable_episodes(self, - symbol: str, - min_profitability: float = 0.7, - max_episodes: int = 500) -> Dict[str, Any]: - """Train specifically on most profitable episodes""" - try: - # Get all episodes for symbol - all_episodes = self.data_collector.training_episodes.get(symbol, []) - - # Filter for profitable episodes - profitable_episodes = [ - ep for ep in all_episodes - if (ep.actual_outcome.is_profitable and - ep.actual_outcome.profitability_score >= min_profitability) - ] - - # Sort by profitability and limit - profitable_episodes.sort( - key=lambda x: x.actual_outcome.profitability_score, - reverse=True - ) - profitable_episodes = profitable_episodes[:max_episodes] - - if len(profitable_episodes) < 10: - logger.warning(f"Insufficient profitable episodes for {symbol}: {len(profitable_episodes)}") - return {'status': 'insufficient_data', 'episodes_found': len(profitable_episodes)} - - # Train on profitable episodes - results = self._train_on_episodes( - profitable_episodes, - training_mode='profitable_replay' - ) - - logger.info(f"Trained on {len(profitable_episodes)} profitable episodes for {symbol}") - return results - - except Exception as e: - logger.error(f"Error training on profitable episodes: {e}") - return {'status': 'error', 'error': str(e)} - - def _train_on_episodes(self, - episodes: List[TrainingEpisode], - training_mode: str = 'batch') -> Dict[str, Any]: - """Train on a batch of episodes with comprehensive data storage""" - try: - # Start new training session - session = CNNTrainingSession( - session_id=f"{training_mode}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", - start_timestamp=datetime.now(), - training_mode=training_mode, - symbol=episodes[0].input_package.symbol if episodes else 'unknown' - ) - self.current_session = session - - # Create dataset and dataloader - dataset = CNNTrainingDataset(episodes) - dataloader = DataLoader( - dataset, - batch_size=32, - shuffle=True, - num_workers=2 - ) - - # Training loop - self.model.train() - total_loss = 0.0 - num_batches = 0 - - for batch_idx, batch in enumerate(dataloader): - # Move to device - features = batch['features'].to(self.device) - pivot_class = batch['pivot_class'].to(self.device) - pivot_price = batch['pivot_price'].to(self.device) - confidence_target = batch['confidence_target'].to(self.device) - - # Forward pass - self.optimizer.zero_grad() - outputs = self.model(features) - - # Calculate losses - classification_loss = F.cross_entropy(outputs['pivot_logits'], pivot_class) - regression_loss = F.mse_loss(outputs['pivot_price'].squeeze(), pivot_price) - confidence_loss = F.binary_cross_entropy( - outputs['confidence'].squeeze(), - confidence_target - ) - - # Combined loss - total_batch_loss = classification_loss + 0.5 * regression_loss + 0.3 * confidence_loss - - # Backward pass - total_batch_loss.backward() - - # Gradient clipping - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - - # Store gradients before optimizer step - gradients = {} - gradient_norms = {} - for name, param in self.model.named_parameters(): - if param.grad is not None: - gradients[name] = param.grad.clone().detach() - gradient_norms[name] = param.grad.norm().item() - - # Optimizer step - self.optimizer.step() - - # Create training step record - step = CNNTrainingStep( - step_id=f"{session.session_id}_step_{batch_idx}", - timestamp=datetime.now(), - episode_id=f"batch_{batch_idx}", - input_features=features.detach().cpu(), - target_labels=pivot_class.detach().cpu(), - model_outputs={k: v.detach().cpu() for k, v in outputs.items()}, - predictions=self._extract_predictions(outputs), - confidence_scores=outputs['confidence'].detach().cpu(), - total_loss=total_batch_loss.item(), - pivot_prediction_loss=classification_loss.item(), - confidence_loss=confidence_loss.item(), - regularization_loss=0.0, - gradients=gradients, - gradient_norms=gradient_norms, - learning_rate=self.optimizer.param_groups[0]['lr'], - batch_size=features.size(0) - ) - - # Calculate training value for this step - step.training_value = self._calculate_step_training_value(step, batch) - - # Add to session - session.training_steps.append(step) - - total_loss += total_batch_loss.item() - num_batches += 1 - - # Log progress - if batch_idx % 10 == 0: - logger.debug(f"Batch {batch_idx}: Loss = {total_batch_loss.item():.4f}") - - # Finalize session - session.end_timestamp = datetime.now() - session.total_steps = num_batches - session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0 - session.best_loss = min(step.total_loss for step in session.training_steps) - - # Calculate session value - session.session_value = self._calculate_session_value(session) - - # Update scheduler - self.scheduler.step(session.average_loss) - - # Save session - self._save_training_session(session) - - # Update statistics - self.training_stats['total_sessions'] += 1 - self.training_stats['total_steps'] += session.total_steps - if training_mode == 'profitable_replay': - self.training_stats['replay_sessions'] += 1 - - logger.info(f"Training session completed: {session.session_id}") - logger.info(f"Average loss: {session.average_loss:.4f}") - logger.info(f"Session value: {session.session_value:.3f}") - - return { - 'status': 'success', - 'session_id': session.session_id, - 'average_loss': session.average_loss, - 'total_steps': session.total_steps, - 'session_value': session.session_value - } - - except Exception as e: - logger.error(f"Error in training session: {e}") - return {'status': 'error', 'error': str(e)} - finally: - self.current_session = None - - def _extract_predictions(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: - """Extract human-readable predictions from model outputs""" - try: - pivot_probs = F.softmax(outputs['pivot_logits'], dim=1) - predicted_class = torch.argmax(pivot_probs, dim=1) - - return { - 'pivot_class': predicted_class.cpu().numpy().tolist(), - 'pivot_probabilities': pivot_probs.cpu().numpy().tolist(), - 'pivot_price': outputs['pivot_price'].cpu().numpy().tolist(), - 'confidence': outputs['confidence'].cpu().numpy().tolist() - } - except Exception as e: - logger.warning(f"Error extracting predictions: {e}") - return {} - - def _calculate_step_training_value(self, - step: CNNTrainingStep, - batch: Dict[str, Any]) -> float: - """Calculate the training value of a step for replay prioritization""" - try: - value = 0.0 - - # Base value from loss (lower loss = higher value) - if step.total_loss > 0: - value += 1.0 / (1.0 + step.total_loss) - - # Bonus for high profitability episodes in batch - avg_profitability = torch.mean(batch['profitability']).item() - value += avg_profitability * 0.3 - - # Bonus for gradient magnitude (indicates learning) - avg_grad_norm = np.mean(list(step.gradient_norms.values())) - value += min(avg_grad_norm / 10.0, 0.2) # Cap at 0.2 - - return min(value, 1.0) - - except Exception as e: - logger.warning(f"Error calculating step training value: {e}") - return 0.0 - - def _calculate_session_value(self, session: CNNTrainingSession) -> float: - """Calculate overall session value for replay prioritization""" - try: - if not session.training_steps: - return 0.0 - - # Average step values - avg_step_value = np.mean([step.training_value for step in session.training_steps]) - - # Bonus for convergence - convergence_bonus = 0.0 - if len(session.training_steps) > 10: - early_loss = np.mean([s.total_loss for s in session.training_steps[:5]]) - late_loss = np.mean([s.total_loss for s in session.training_steps[-5:]]) - if early_loss > late_loss: - convergence_bonus = min((early_loss - late_loss) / early_loss, 0.3) - - # Bonus for profitable replay sessions - mode_bonus = 0.2 if session.training_mode == 'profitable_replay' else 0.0 - - return min(avg_step_value + convergence_bonus + mode_bonus, 1.0) - - except Exception as e: - logger.warning(f"Error calculating session value: {e}") - return 0.0 - - def _save_training_session(self, session: CNNTrainingSession): - """Save training session to disk""" - try: - session_dir = self.storage_dir / session.symbol / 'sessions' - session_dir.mkdir(parents=True, exist_ok=True) - - # Save full session data - session_file = session_dir / f"{session.session_id}.pkl" - with open(session_file, 'wb') as f: - pickle.dump(session, f) - - # Save session metadata - metadata = { - 'session_id': session.session_id, - 'start_timestamp': session.start_timestamp.isoformat(), - 'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None, - 'training_mode': session.training_mode, - 'symbol': session.symbol, - 'total_steps': session.total_steps, - 'average_loss': session.average_loss, - 'best_loss': session.best_loss, - 'session_value': session.session_value - } - - metadata_file = session_dir / f"{session.session_id}_metadata.json" - with open(metadata_file, 'w') as f: - json.dump(metadata, f, indent=2) - - logger.debug(f"Saved training session: {session.session_id}") - - except Exception as e: - logger.error(f"Error saving training session: {e}") - - def _finalize_training_session(self): - """Finalize current training session""" - if self.current_session: - self.current_session.end_timestamp = datetime.now() - self._save_training_session(self.current_session) - self.training_sessions.append(self.current_session) - self.current_session = None - - def get_training_statistics(self) -> Dict[str, Any]: - """Get comprehensive training statistics""" - stats = self.training_stats.copy() - - # Add recent session information - if self.training_sessions: - recent_sessions = sorted( - self.training_sessions, - key=lambda x: x.start_timestamp, - reverse=True - )[:10] - - stats['recent_sessions'] = [ - { - 'session_id': s.session_id, - 'timestamp': s.start_timestamp.isoformat(), - 'mode': s.training_mode, - 'average_loss': s.average_loss, - 'session_value': s.session_value - } - for s in recent_sessions - ] - - # Calculate profitability rate - if stats['total_predictions'] > 0: - stats['profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions'] - else: - stats['profitability_rate'] = 0.0 - - return stats - - def replay_high_value_sessions(self, - symbol: str, - min_session_value: float = 0.7, - max_sessions: int = 10) -> Dict[str, Any]: - """Replay high-value training sessions""" - try: - # Find high-value sessions - high_value_sessions = [ - s for s in self.training_sessions - if (s.symbol == symbol and - s.session_value >= min_session_value) - ] - - # Sort by value and limit - high_value_sessions.sort(key=lambda x: x.session_value, reverse=True) - high_value_sessions = high_value_sessions[:max_sessions] - - if not high_value_sessions: - return {'status': 'no_high_value_sessions', 'sessions_found': 0} - - # Replay sessions - total_replayed = 0 - for session in high_value_sessions: - # Extract episodes from session steps - episode_ids = list(set(step.episode_id for step in session.training_steps)) - - # Get corresponding episodes - episodes = [] - for episode_id in episode_ids: - # Find episode in data collector - for ep in self.data_collector.training_episodes.get(symbol, []): - if ep.episode_id == episode_id: - episodes.append(ep) - break - - if episodes: - self._train_on_episodes(episodes, training_mode='high_value_replay') - total_replayed += 1 - - logger.info(f"Replayed {total_replayed} high-value sessions for {symbol}") - return { - 'status': 'success', - 'sessions_replayed': total_replayed, - 'sessions_found': len(high_value_sessions) - } - - except Exception as e: - logger.error(f"Error replaying high-value sessions: {e}") - return {'status': 'error', 'error': str(e)} - -# Global instance -cnn_trainer = None - -def get_cnn_trainer(model: CNNPivotPredictor = None) -> CNNTrainer: - """Get global CNN trainer instance""" - global cnn_trainer - if cnn_trainer is None: - if model is None: - model = CNNPivotPredictor() - cnn_trainer = CNNTrainer(model) - return cnn_trainer \ No newline at end of file diff --git a/core/enhanced_cnn_adapter.py b/core/enhanced_cnn_adapter.py deleted file mode 100644 index 7e9f42c..0000000 --- a/core/enhanced_cnn_adapter.py +++ /dev/null @@ -1,864 +0,0 @@ -# """ -# Enhanced CNN Adapter for Standardized Input Format - -# This module provides an adapter for the EnhancedCNN model to work with the standardized -# BaseDataInput format, enabling seamless integration with the multi-modal trading system. -# """ - -# import torch -# import numpy as np -# import logging -# import os -# import random -# from datetime import datetime, timedelta -# from typing import Dict, List, Optional, Tuple, Any, Union -# from threading import Lock - -# from .data_models import BaseDataInput, ModelOutput, create_model_output -# from NN.models.enhanced_cnn import EnhancedCNN -# from utils.inference_logger import log_model_inference - -# logger = logging.getLogger(__name__) - -# class EnhancedCNNAdapter: -# """ -# Adapter for EnhancedCNN model to work with standardized BaseDataInput format - -# This adapter: -# 1. Converts BaseDataInput to the format expected by EnhancedCNN -# 2. Processes model outputs to create standardized ModelOutput -# 3. Manages model training with collected data -# 4. Handles checkpoint management -# """ - -# def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"): -# """ -# Initialize the EnhancedCNN adapter - -# Args: -# model_path: Path to load model from, if None a new model is created -# checkpoint_dir: Directory to save checkpoints to -# """ -# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -# self.model = None -# self.model_path = model_path -# self.checkpoint_dir = checkpoint_dir -# self.training_lock = Lock() -# self.training_data = [] -# self.max_training_samples = 10000 -# self.batch_size = 32 -# self.learning_rate = 0.0001 -# self.model_name = "enhanced_cnn" - -# # Enhanced metrics tracking -# self.last_inference_time = None -# self.last_inference_duration = 0.0 -# self.last_prediction_output = None -# self.last_training_time = None -# self.last_training_duration = 0.0 -# self.last_training_loss = 0.0 -# self.inference_count = 0 -# self.training_count = 0 - -# # Create checkpoint directory if it doesn't exist -# os.makedirs(checkpoint_dir, exist_ok=True) - -# # Initialize the model -# self._initialize_model() - -# # Load checkpoint if available -# if model_path and os.path.exists(model_path): -# self._load_checkpoint(model_path) -# else: -# self._load_best_checkpoint() - -# # Final device check and move -# self._ensure_model_on_device() - -# logger.info(f"EnhancedCNNAdapter initialized on {self.device}") - -# def _create_realistic_synthetic_features(self, symbol: str) -> torch.Tensor: -# """Create realistic synthetic features instead of random data""" -# try: -# # Create realistic market-like features -# features = torch.zeros(7850, dtype=torch.float32, device=self.device) - -# # OHLCV features (6000 features: 300 frames x 4 timeframes x 5 features) -# ohlcv_start = 0 -# for timeframe_idx in range(4): # 1s, 1m, 1h, 1d -# base_price = 3500.0 + timeframe_idx * 10 # Slight variation per timeframe -# for frame_idx in range(300): -# # Create realistic price movement -# price_change = torch.sin(torch.tensor(frame_idx * 0.1)) * 0.01 # Cyclical movement -# current_price = base_price * (1 + price_change) - -# # Realistic OHLCV values -# open_price = current_price -# high_price = current_price * torch.uniform(1.0, 1.005) -# low_price = current_price * torch.uniform(0.995, 1.0) -# close_price = current_price * torch.uniform(0.998, 1.002) -# volume = torch.uniform(500.0, 2000.0) - -# # Set features -# feature_idx = ohlcv_start + frame_idx * 5 + timeframe_idx * 1500 -# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume]) - -# # BTC OHLCV features (1500 features: 300 frames x 5 features) -# btc_start = 6000 -# btc_base_price = 50000.0 -# for frame_idx in range(300): -# price_change = torch.sin(torch.tensor(frame_idx * 0.05)) * 0.02 -# current_price = btc_base_price * (1 + price_change) - -# open_price = current_price -# high_price = current_price * torch.uniform(1.0, 1.01) -# low_price = current_price * torch.uniform(0.99, 1.0) -# close_price = current_price * torch.uniform(0.995, 1.005) -# volume = torch.uniform(100.0, 500.0) - -# feature_idx = btc_start + frame_idx * 5 -# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume]) - -# # COB features (200 features) - realistic order book data -# cob_start = 7500 -# for i in range(200): -# features[cob_start + i] = torch.uniform(0.0, 1000.0) # Realistic COB values - -# # Technical indicators (100 features) -# indicator_start = 7700 -# for i in range(100): -# features[indicator_start + i] = torch.uniform(-1.0, 1.0) # Normalized indicators - -# # Last predictions (50 features) -# prediction_start = 7800 -# for i in range(50): -# features[prediction_start + i] = torch.uniform(0.0, 1.0) # Probability values - -# return features - -# except Exception as e: -# logger.error(f"Error creating realistic synthetic features: {e}") -# # Fallback to small random variation -# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5 -# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1 -# return base_features + noise - -# def _create_realistic_features(self, symbol: str) -> torch.Tensor: -# """Create features from real market data if available""" -# try: -# # This would need to be implemented to use actual market data -# # For now, fall back to synthetic features -# return self._create_realistic_synthetic_features(symbol) -# except Exception as e: -# logger.error(f"Error creating realistic features: {e}") -# return self._create_realistic_synthetic_features(symbol) - -# def _initialize_model(self): -# """Initialize the EnhancedCNN model""" -# try: -# # Calculate input shape based on BaseDataInput structure -# # OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features -# # BTC OHLCV: 300 frames x 5 features = 1500 features -# # COB: ±20 buckets x 4 metrics = 160 features -# # MA: 4 timeframes x 10 buckets = 40 features -# # Technical indicators: 100 features -# # Last predictions: 50 features -# # Total: 7850 features -# input_shape = 7850 -# n_actions = 3 # BUY, SELL, HOLD - -# # Create model -# self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions) -# # Ensure model is moved to the correct device -# self.model.to(self.device) - -# logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}") - -# except Exception as e: -# logger.error(f"Error initializing EnhancedCNN model: {e}") -# raise - -# def _load_checkpoint(self, checkpoint_path: str) -> bool: -# """Load model from checkpoint path""" -# try: -# if self.model and os.path.exists(checkpoint_path): -# success = self.model.load(checkpoint_path) -# if success: -# # Ensure model is moved to the correct device after loading -# self.model.to(self.device) -# logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}") -# return True -# else: -# logger.warning(f"Failed to load model from {checkpoint_path}") -# return False -# else: -# logger.warning(f"Checkpoint path does not exist: {checkpoint_path}") -# return False -# except Exception as e: -# logger.error(f"Error loading checkpoint: {e}") -# return False - -# def _load_best_checkpoint(self) -> bool: -# """Load the best available checkpoint""" -# try: -# return self.load_best_checkpoint() -# except Exception as e: -# logger.error(f"Error loading best checkpoint: {e}") -# return False - -# def load_best_checkpoint(self) -> bool: -# """Load the best checkpoint based on accuracy""" -# try: -# # Import checkpoint manager -# from utils.checkpoint_manager import CheckpointManager - -# # Create checkpoint manager -# checkpoint_manager = CheckpointManager( -# checkpoint_dir=self.checkpoint_dir, -# max_checkpoints=10, -# metric_name="accuracy" -# ) - -# # Load best checkpoint -# best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name) - -# if not best_checkpoint_path: -# logger.info(f"No checkpoints found for {self.model_name} - starting in COLD START mode") -# return False - -# # Load model -# success = self.model.load(best_checkpoint_path) - -# if success: -# # Ensure model is moved to the correct device after loading -# self.model.to(self.device) -# logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}") - -# # Log metrics -# metrics = best_checkpoint_metadata.get('metrics', {}) -# logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}") - -# return True -# else: -# logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}") -# return False - -# except Exception as e: -# logger.error(f"Error loading best checkpoint: {e}") -# return False - -# def _ensure_model_on_device(self): -# """Ensure model and all its components are on the correct device""" -# try: -# if self.model: -# self.model.to(self.device) -# # Also ensure the model's internal device is set correctly -# if hasattr(self.model, 'device'): -# self.model.device = self.device -# logger.debug(f"Model ensured on device {self.device}") -# except Exception as e: -# logger.error(f"Error ensuring model on device: {e}") - -# def _create_default_output(self, symbol: str) -> ModelOutput: -# """Create default output when prediction fails""" -# return create_model_output( -# model_type='cnn', -# model_name=self.model_name, -# symbol=symbol, -# action='HOLD', -# confidence=0.0, -# metadata={'error': 'Prediction failed, using default output'} -# ) - -# def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]: -# """Process hidden states for cross-model feeding""" -# processed_states = {} - -# for key, value in hidden_states.items(): -# if isinstance(value, torch.Tensor): -# # Convert tensor to numpy array -# processed_states[key] = value.cpu().numpy().tolist() -# else: -# processed_states[key] = value - -# return processed_states - - - -# def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor: -# """ -# Convert BaseDataInput to feature vector for EnhancedCNN - -# Args: -# base_data: Standardized input data - -# Returns: -# torch.Tensor: Feature vector for EnhancedCNN -# """ -# try: -# # Use the get_feature_vector method from BaseDataInput -# features = base_data.get_feature_vector() - -# # Validate feature quality before using -# self._validate_feature_quality(features) - -# # Convert to torch tensor -# features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device) - -# return features_tensor - -# except Exception as e: -# logger.error(f"Error converting BaseDataInput to features: {e}") -# # Return empty tensor with correct shape -# return torch.zeros(7850, dtype=torch.float32, device=self.device) - -# def _validate_feature_quality(self, features: np.ndarray): -# """Validate that features are realistic and not synthetic/placeholder data""" -# try: -# if len(features) != 7850: -# logger.warning(f"Feature vector has wrong size: {len(features)} != 7850") -# return - -# # Check for all-zero or all-identical features (indicates placeholder data) -# if np.all(features == 0): -# logger.warning("Feature vector contains all zeros - likely placeholder data") -# return - -# # Check for repetitive patterns in OHLCV data (first 6000 features) -# ohlcv_features = features[:6000] -# if len(ohlcv_features) >= 20: -# # Check if first 20 values are identical (indicates padding with same bar) -# if np.allclose(ohlcv_features[:20], ohlcv_features[0], atol=1e-6): -# logger.warning("OHLCV features show repetitive pattern - possible synthetic data") - -# # Check for unrealistic values -# if np.any(features > 1e6) or np.any(features < -1e6): -# logger.warning("Feature vector contains unrealistic values") - -# # Check for NaN or infinite values -# if np.any(np.isnan(features)) or np.any(np.isinf(features)): -# logger.warning("Feature vector contains NaN or infinite values") - -# except Exception as e: -# logger.error(f"Error validating feature quality: {e}") - -# def predict(self, base_data: BaseDataInput) -> ModelOutput: -# """ -# Make a prediction using the EnhancedCNN model - -# Args: -# base_data: Standardized input data - -# Returns: -# ModelOutput: Standardized model output -# """ -# try: -# # Track inference timing -# start_time = datetime.now() -# inference_start = start_time.timestamp() - -# # Convert BaseDataInput to features -# features = self._convert_base_data_to_features(base_data) - -# # Ensure features has batch dimension -# if features.dim() == 1: -# features = features.unsqueeze(0) - -# # Ensure model is on correct device before prediction -# self._ensure_model_on_device() - -# # Set model to evaluation mode -# self.model.eval() - -# # Make prediction -# with torch.no_grad(): -# q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features) - -# # Get action and confidence -# action_probs = torch.softmax(q_values, dim=1) -# action_idx = torch.argmax(action_probs, dim=1).item() -# raw_confidence = float(action_probs[0, action_idx].item()) - -# # Validate confidence - prevent 100% confidence which indicates overfitting -# if raw_confidence >= 0.99: -# logger.warning(f"CNN produced suspiciously high confidence: {raw_confidence:.4f} - possible overfitting") -# # Cap confidence at 0.95 to prevent unrealistic predictions -# confidence = min(raw_confidence, 0.95) -# logger.info(f"Capped confidence from {raw_confidence:.4f} to {confidence:.4f}") -# else: -# confidence = raw_confidence - -# # Map action index to action string -# actions = ['BUY', 'SELL', 'HOLD'] -# action = actions[action_idx] - -# # Extract pivot price prediction (simplified - take first value from price_pred) -# pivot_price = None -# if price_pred is not None and len(price_pred.squeeze()) > 0: -# # Get current price from base_data for context -# current_price = 0.0 -# if base_data.ohlcv_1s and len(base_data.ohlcv_1s) > 0: -# current_price = base_data.ohlcv_1s[-1].close - -# # Calculate pivot price as current price + predicted change -# price_change_pct = float(price_pred.squeeze()[0].item()) # First prediction value -# pivot_price = current_price * (1 + price_change_pct * 0.01) # Convert percentage to price - -# # Create predictions dictionary -# predictions = { -# 'action': action, -# 'buy_probability': float(action_probs[0, 0].item()), -# 'sell_probability': float(action_probs[0, 1].item()), -# 'hold_probability': float(action_probs[0, 2].item()), -# 'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(), -# 'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist(), -# 'pivot_price': pivot_price -# } - -# # Create hidden states dictionary -# hidden_states = { -# 'features': features_refined.squeeze(0).cpu().numpy().tolist() -# } - -# # Calculate inference duration -# end_time = datetime.now() -# inference_duration = (end_time.timestamp() - inference_start) * 1000 # Convert to milliseconds - -# # Update metrics -# self.last_inference_time = start_time -# self.last_inference_duration = inference_duration -# self.inference_count += 1 - -# # Store last prediction output for dashboard -# self.last_prediction_output = { -# 'action': action, -# 'confidence': confidence, -# 'pivot_price': pivot_price, -# 'timestamp': start_time, -# 'symbol': base_data.symbol -# } - -# # Create metadata dictionary -# metadata = { -# 'model_version': '1.0', -# 'timestamp': start_time.isoformat(), -# 'input_shape': features.shape, -# 'inference_duration_ms': inference_duration, -# 'inference_count': self.inference_count -# } - -# # Create ModelOutput -# model_output = ModelOutput( -# model_type='cnn', -# model_name=self.model_name, -# symbol=base_data.symbol, -# timestamp=start_time, -# confidence=confidence, -# predictions=predictions, -# hidden_states=hidden_states, -# metadata=metadata -# ) - -# # Log inference with full input data for training feedback -# log_model_inference( -# model_name=self.model_name, -# symbol=base_data.symbol, -# action=action, -# confidence=confidence, -# probabilities={ -# 'BUY': predictions['buy_probability'], -# 'SELL': predictions['sell_probability'], -# 'HOLD': predictions['hold_probability'] -# }, -# input_features=features.cpu().numpy(), # Store full feature vector -# processing_time_ms=inference_duration, -# checkpoint_id=None, # Could be enhanced to track checkpoint -# metadata={ -# 'base_data_input': { -# 'symbol': base_data.symbol, -# 'timestamp': base_data.timestamp.isoformat(), -# 'ohlcv_1s_count': len(base_data.ohlcv_1s), -# 'ohlcv_1m_count': len(base_data.ohlcv_1m), -# 'ohlcv_1h_count': len(base_data.ohlcv_1h), -# 'ohlcv_1d_count': len(base_data.ohlcv_1d), -# 'btc_ohlcv_1s_count': len(base_data.btc_ohlcv_1s), -# 'has_cob_data': base_data.cob_data is not None, -# 'technical_indicators_count': len(base_data.technical_indicators), -# 'pivot_points_count': len(base_data.pivot_points), -# 'last_predictions_count': len(base_data.last_predictions) -# }, -# 'model_predictions': { -# 'pivot_price': pivot_price, -# 'extrema_prediction': predictions['extrema'], -# 'price_prediction': predictions['price_prediction'] -# } -# } -# ) - -# return model_output - -# except Exception as e: -# logger.error(f"Error making prediction with EnhancedCNN: {e}") -# # Return default ModelOutput -# return create_model_output( -# model_type='cnn', -# model_name=self.model_name, -# symbol=base_data.symbol, -# action='HOLD', -# confidence=0.0 -# ) - -# def add_training_sample(self, symbol_or_base_data, actual_action: str, reward: float): -# """ -# Add a training sample to the training data - -# Args: -# symbol_or_base_data: Either a symbol string or BaseDataInput object -# actual_action: Actual action taken ('BUY', 'SELL', 'HOLD') -# reward: Reward received for the action -# """ -# try: -# # Handle both symbol string and BaseDataInput object -# if isinstance(symbol_or_base_data, str): -# # For cold start mode - create a simple training sample with current features -# # This is a simplified approach for rapid training -# symbol = symbol_or_base_data - -# # Create a realistic feature vector instead of random data -# # Use actual market data if available, otherwise create realistic synthetic data -# try: -# # Try to get real market data first -# if hasattr(self, 'data_provider') and self.data_provider: -# # This would need to be implemented in the adapter -# features = self._create_realistic_features(symbol) -# else: -# # Create realistic synthetic features (not random) -# features = self._create_realistic_synthetic_features(symbol) -# except Exception as e: -# logger.warning(f"Could not create realistic features for {symbol}: {e}") -# # Fallback to small random variation instead of pure random -# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5 -# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1 -# features = base_features + noise - -# logger.debug(f"Added realistic training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}") - -# else: -# # Full BaseDataInput object -# base_data = symbol_or_base_data -# features = self._convert_base_data_to_features(base_data) -# symbol = base_data.symbol - -# logger.debug(f"Added full training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}") - -# # Convert action to index -# actions = ['BUY', 'SELL', 'HOLD'] -# action_idx = actions.index(actual_action) - -# # Add to training data -# with self.training_lock: -# self.training_data.append((features, action_idx, reward)) - -# # Limit training data size -# if len(self.training_data) > self.max_training_samples: -# # Sort by reward (highest first) and keep top samples -# self.training_data.sort(key=lambda x: x[2], reverse=True) -# self.training_data = self.training_data[:self.max_training_samples] - -# except Exception as e: -# logger.error(f"Error adding training sample: {e}") - -# def train(self, epochs: int = 1) -> Dict[str, float]: -# """ -# Train the model with collected data and inference history - -# Args: -# epochs: Number of epochs to train for - -# Returns: -# Dict[str, float]: Training metrics -# """ -# try: -# # Track training timing -# training_start_time = datetime.now() -# training_start = training_start_time.timestamp() - -# with self.training_lock: -# # Get additional training data from inference history -# self._load_training_data_from_inference_history() - -# # Check if we have enough data -# if len(self.training_data) < self.batch_size: -# logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}") -# return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)} - -# # Ensure model is on correct device before training -# self._ensure_model_on_device() - -# # Set model to training mode -# self.model.train() - -# # Create optimizer -# optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) - -# # Training metrics -# total_loss = 0.0 -# correct_predictions = 0 -# total_predictions = 0 - -# # Train for specified number of epochs -# for epoch in range(epochs): -# # Shuffle training data -# np.random.shuffle(self.training_data) - -# # Process in batches -# for i in range(0, len(self.training_data), self.batch_size): -# batch = self.training_data[i:i+self.batch_size] - -# # Skip if batch is too small -# if len(batch) < 2: -# continue - -# # Prepare batch - ensure all tensors are on the correct device -# features = torch.stack([sample[0].to(self.device) for sample in batch]) -# actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device) -# rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device) - -# # Zero gradients -# optimizer.zero_grad() - -# # Forward pass -# q_values, _, _, _, _ = self.model(features) - -# # Calculate loss (CrossEntropyLoss with reward weighting) -# # First, apply softmax to get probabilities -# probs = torch.softmax(q_values, dim=1) - -# # Get probability of chosen action -# chosen_probs = probs[torch.arange(len(actions)), actions] - -# # Calculate negative log likelihood loss -# nll_loss = -torch.log(chosen_probs + 1e-10) - -# # Weight by reward (higher reward = higher weight) -# # Normalize rewards to [0, 1] range -# min_reward = rewards.min() -# max_reward = rewards.max() -# if max_reward > min_reward: -# normalized_rewards = (rewards - min_reward) / (max_reward - min_reward) -# else: -# normalized_rewards = torch.ones_like(rewards) - -# # Apply reward weighting (higher reward = higher weight) -# weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights - -# # Mean loss -# loss = weighted_loss.mean() - -# # Backward pass -# loss.backward() - -# # Update weights -# optimizer.step() - -# # Update metrics -# total_loss += loss.item() - -# # Calculate accuracy -# predicted_actions = torch.argmax(q_values, dim=1) -# correct_predictions += (predicted_actions == actions).sum().item() -# total_predictions += len(actions) - -# # Validate training - detect overfitting -# if total_predictions > 0: -# current_accuracy = correct_predictions / total_predictions -# if current_accuracy >= 0.99: -# logger.warning(f"CNN training shows suspiciously high accuracy: {current_accuracy:.4f} - possible overfitting") -# # Add regularization to prevent overfitting -# l2_reg = 0.01 * sum(p.pow(2.0).sum() for p in self.model.parameters()) -# loss = loss + l2_reg -# logger.info("Added L2 regularization to prevent overfitting") - -# # Calculate final metrics -# avg_loss = total_loss / (len(self.training_data) / self.batch_size) -# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 - -# # Calculate training duration -# training_end_time = datetime.now() -# training_duration = (training_end_time.timestamp() - training_start) * 1000 # Convert to milliseconds - -# # Update training metrics -# self.last_training_time = training_start_time -# self.last_training_duration = training_duration -# self.last_training_loss = avg_loss -# self.training_count += 1 - -# # Save checkpoint -# self._save_checkpoint(avg_loss, accuracy) - -# logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}, duration={training_duration:.1f}ms") - -# return { -# 'loss': avg_loss, -# 'accuracy': accuracy, -# 'samples': len(self.training_data), -# 'duration_ms': training_duration, -# 'training_count': self.training_count -# } - -# except Exception as e: -# logger.error(f"Error training model: {e}") -# return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)} - -# def _save_checkpoint(self, loss: float, accuracy: float): -# """ -# Save model checkpoint - -# Args: -# loss: Training loss -# accuracy: Training accuracy -# """ -# try: -# # Import checkpoint manager -# from utils.checkpoint_manager import CheckpointManager - -# # Create checkpoint manager -# checkpoint_manager = CheckpointManager( -# checkpoint_dir=self.checkpoint_dir, -# max_checkpoints=10, -# metric_name="accuracy" -# ) - -# # Create temporary model file -# temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp") -# self.model.save(temp_path) - -# # Create metrics -# metrics = { -# 'loss': loss, -# 'accuracy': accuracy, -# 'samples': len(self.training_data) -# } - -# # Create metadata -# metadata = { -# 'timestamp': datetime.now().isoformat(), -# 'model_name': self.model_name, -# 'input_shape': self.model.input_shape, -# 'n_actions': self.model.n_actions -# } - -# # Save checkpoint -# checkpoint_path = checkpoint_manager.save_checkpoint( -# model_name=self.model_name, -# model_path=f"{temp_path}.pt", -# metrics=metrics, -# metadata=metadata -# ) - -# # Delete temporary model file -# if os.path.exists(f"{temp_path}.pt"): -# os.remove(f"{temp_path}.pt") - -# logger.info(f"Model checkpoint saved to {checkpoint_path}") - -# except Exception as e: -# logger.error(f"Error saving checkpoint: {e}") - -# def _load_training_data_from_inference_history(self): -# """Load training data from inference history for continuous learning""" -# try: -# from utils.database_manager import get_database_manager - -# db_manager = get_database_manager() - -# # Get recent inference records with input features -# inference_records = db_manager.get_inference_records_for_training( -# model_name=self.model_name, -# hours_back=24, # Last 24 hours -# limit=1000 -# ) - -# if not inference_records: -# logger.debug("No inference records found for training") -# return - -# # Convert inference records to training samples -# # For now, use a simple approach: treat high-confidence predictions as ground truth -# for record in inference_records: -# if record.input_features is not None and record.confidence > 0.7: -# # Convert action to index -# actions = ['BUY', 'SELL', 'HOLD'] -# if record.action in actions: -# action_idx = actions.index(record.action) - -# # Use confidence as a proxy for reward (high confidence = good prediction) -# reward = record.confidence * 2 - 1 # Scale to [-1, 1] - -# # Convert features to tensor -# features_tensor = torch.tensor(record.input_features, dtype=torch.float32, device=self.device) - -# # Add to training data if not already present (avoid duplicates) -# sample_exists = any( -# torch.equal(features_tensor, existing[0]) -# for existing in self.training_data -# ) - -# if not sample_exists: -# self.training_data.append((features_tensor, action_idx, reward)) - -# logger.info(f"Loaded {len(inference_records)} inference records for training, total training samples: {len(self.training_data)}") - -# except Exception as e: -# logger.error(f"Error loading training data from inference history: {e}") - -# def evaluate_predictions_against_outcomes(self, hours_back: int = 1) -> Dict[str, float]: -# """ -# Evaluate past predictions against actual market outcomes - -# Args: -# hours_back: How many hours back to evaluate - -# Returns: -# Dict with evaluation metrics -# """ -# try: -# from utils.database_manager import get_database_manager - -# db_manager = get_database_manager() - -# # Get inference records from the specified time period -# inference_records = db_manager.get_inference_records_for_training( -# model_name=self.model_name, -# hours_back=hours_back, -# limit=100 -# ) - -# if not inference_records: -# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0} - -# # For now, use a simple evaluation based on confidence -# # In a real implementation, this would compare against actual price movements -# correct_predictions = 0 -# total_predictions = len(inference_records) - -# # Simple heuristic: high confidence predictions are more likely to be correct -# for record in inference_records: -# if record.confidence > 0.8: # High confidence threshold -# correct_predictions += 1 -# elif record.confidence > 0.6: # Medium confidence -# correct_predictions += 0.5 - -# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 - -# logger.info(f"Prediction evaluation: {correct_predictions:.1f}/{total_predictions} = {accuracy:.3f} accuracy") - -# return { -# 'accuracy': accuracy, -# 'total_predictions': total_predictions, -# 'correct_predictions': correct_predictions -# } - -# except Exception as e: -# logger.error(f"Error evaluating predictions: {e}") -# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0} diff --git a/core/enhanced_orchestrator.py b/core/enhanced_orchestrator.py deleted file mode 100644 index 7193585..0000000 --- a/core/enhanced_orchestrator.py +++ /dev/null @@ -1,464 +0,0 @@ -""" -Enhanced Trading Orchestrator - -Central coordination hub for the multi-modal trading system that manages: -- Data subscription and management -- Model inference coordination -- Cross-model data feeding -- Training pipeline orchestration -- Decision making using Mixture of Experts -""" - -import asyncio -import logging -import numpy as np -from datetime import datetime -from typing import Dict, List, Optional, Any -from dataclasses import dataclass, field - -from core.data_provider import DataProvider -from core.trading_action import TradingAction -from utils.tensorboard_logger import TensorBoardLogger - -logger = logging.getLogger(__name__) - -@dataclass -class ModelOutput: - """Extensible model output format supporting all model types""" - model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator' - model_name: str # Specific model identifier - symbol: str - timestamp: datetime - confidence: float - predictions: Dict[str, Any] # Model-specific predictions - hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding - metadata: Dict[str, Any] = field(default_factory=dict) # Additional info - -@dataclass -class BaseDataInput: - """Unified base data input for all models""" - symbol: str - timestamp: datetime - ohlcv_data: Dict[str, Any] = field(default_factory=dict) # Multi-timeframe OHLCV - cob_data: Optional[Dict[str, Any]] = None # COB buckets for 1s timeframe - technical_indicators: Dict[str, float] = field(default_factory=dict) - pivot_points: List[Any] = field(default_factory=list) - last_predictions: Dict[str, ModelOutput] = field(default_factory=dict) # From all models - market_microstructure: Dict[str, Any] = field(default_factory=dict) # Order flow, etc. - -@dataclass -class COBData: - """Cumulative Order Book data for price buckets""" - symbol: str - timestamp: datetime - current_price: float - bucket_size: float # $1 for ETH, $10 for BTC - price_buckets: Dict[float, Dict[str, float]] = field(default_factory=dict) # price -> {bid_volume, ask_volume, etc.} - bid_ask_imbalance: Dict[float, float] = field(default_factory=dict) # price -> imbalance ratio - volume_weighted_prices: Dict[float, float] = field(default_factory=dict) # price -> VWAP within bucket - order_flow_metrics: Dict[str, float] = field(default_factory=dict) # Various order flow indicators - -class EnhancedTradingOrchestrator: - """ - Enhanced Trading Orchestrator implementing the design specification - - Coordinates data flow, model inference, and decision making for the multi-modal trading system. - """ - - def __init__(self, data_provider: DataProvider, symbols: List[str], enhanced_rl_training: bool = False, model_registry: Dict = None): - """Initialize the enhanced orchestrator""" - self.data_provider = data_provider - self.symbols = symbols - self.enhanced_rl_training = enhanced_rl_training - self.model_registry = model_registry or {} - - # Data management - self.data_buffers = {symbol: {} for symbol in symbols} - self.last_update_times = {symbol: {} for symbol in symbols} - - # Model output storage - self.model_outputs = {symbol: {} for symbol in symbols} - self.model_output_history = {symbol: {} for symbol in symbols} - - # Training pipeline - self.training_data = {symbol: [] for symbol in symbols} - self.tensorboard_logger = TensorBoardLogger("runs", f"orchestrator_{datetime.now().strftime('%Y%m%d_%H%M%S')}") - - # COB integration - self.cob_data = {symbol: None for symbol in symbols} - - # Performance tracking - self.performance_metrics = { - 'inference_count': 0, - 'successful_states': 0, - 'total_episodes': 0 - } - - logger.info("Enhanced Trading Orchestrator initialized") - - async def start_cob_integration(self): - """Start COB data integration for real-time market microstructure""" - try: - # Subscribe to COB data updates - self.data_provider.subscribe_to_cob_data(self._on_cob_data_update) - logger.info("COB integration started") - except Exception as e: - logger.error(f"Error starting COB integration: {e}") - - async def start_realtime_processing(self): - """Start real-time data processing""" - try: - # Subscribe to tick data for real-time processing - for symbol in self.symbols: - self.data_provider.subscribe_to_ticks( - callback=self._on_tick_data, - symbols=[symbol], - subscriber_name=f"orchestrator_{symbol}" - ) - - logger.info("Real-time processing started") - except Exception as e: - logger.error(f"Error starting real-time processing: {e}") - - def _on_cob_data_update(self, symbol: str, cob_data: dict): - """Handle COB data updates""" - try: - # Process and store COB data - self.cob_data[symbol] = self._process_cob_data(symbol, cob_data) - logger.debug(f"COB data updated for {symbol}") - except Exception as e: - logger.error(f"Error processing COB data for {symbol}: {e}") - - def _process_cob_data(self, symbol: str, cob_data: dict) -> COBData: - """Process raw COB data into structured format""" - try: - # Determine bucket size based on symbol - bucket_size = 1.0 if 'ETH' in symbol else 10.0 - - # Extract current price - stats = cob_data.get('stats', {}) - current_price = stats.get('mid_price', 0) - - # Create COB data structure - cob = COBData( - symbol=symbol, - timestamp=datetime.now(), - current_price=current_price, - bucket_size=bucket_size - ) - - # Process order book data into price buckets - bids = cob_data.get('bids', []) - asks = cob_data.get('asks', []) - - # Create price buckets around current price - bucket_count = 20 # ±20 buckets - for i in range(-bucket_count, bucket_count + 1): - bucket_price = current_price + (i * bucket_size) - cob.price_buckets[bucket_price] = { - 'bid_volume': 0.0, - 'ask_volume': 0.0 - } - - # Aggregate bid volumes into buckets - for price, volume in bids: - bucket_price = round(price / bucket_size) * bucket_size - if bucket_price in cob.price_buckets: - cob.price_buckets[bucket_price]['bid_volume'] += volume - - # Aggregate ask volumes into buckets - for price, volume in asks: - bucket_price = round(price / bucket_size) * bucket_size - if bucket_price in cob.price_buckets: - cob.price_buckets[bucket_price]['ask_volume'] += volume - - # Calculate bid/ask imbalances - for price, volumes in cob.price_buckets.items(): - bid_vol = volumes['bid_volume'] - ask_vol = volumes['ask_volume'] - total_vol = bid_vol + ask_vol - if total_vol > 0: - cob.bid_ask_imbalance[price] = (bid_vol - ask_vol) / total_vol - else: - cob.bid_ask_imbalance[price] = 0.0 - - # Calculate volume-weighted prices - for price, volumes in cob.price_buckets.items(): - bid_vol = volumes['bid_volume'] - ask_vol = volumes['ask_volume'] - total_vol = bid_vol + ask_vol - if total_vol > 0: - cob.volume_weighted_prices[price] = ( - (price * bid_vol) + (price * ask_vol) - ) / total_vol - else: - cob.volume_weighted_prices[price] = price - - # Calculate order flow metrics - cob.order_flow_metrics = { - 'total_bid_volume': sum(v['bid_volume'] for v in cob.price_buckets.values()), - 'total_ask_volume': sum(v['ask_volume'] for v in cob.price_buckets.values()), - 'bid_ask_ratio': 0.0 if cob.order_flow_metrics['total_ask_volume'] == 0 else - cob.order_flow_metrics['total_bid_volume'] / cob.order_flow_metrics['total_ask_volume'] - } - - return cob - except Exception as e: - logger.error(f"Error processing COB data for {symbol}: {e}") - return COBData(symbol=symbol, timestamp=datetime.now(), current_price=0, bucket_size=bucket_size) - - def _on_tick_data(self, tick): - """Handle incoming tick data""" - try: - # Update data buffers - symbol = tick.symbol - if symbol not in self.data_buffers: - self.data_buffers[symbol] = {} - - # Store tick data - if 'ticks' not in self.data_buffers[symbol]: - self.data_buffers[symbol]['ticks'] = [] - self.data_buffers[symbol]['ticks'].append(tick) - - # Keep only last 1000 ticks - if len(self.data_buffers[symbol]['ticks']) > 1000: - self.data_buffers[symbol]['ticks'] = self.data_buffers[symbol]['ticks'][-1000:] - - # Update last update time - self.last_update_times[symbol]['tick'] = datetime.now() - - logger.debug(f"Tick data updated for {symbol}") - except Exception as e: - logger.error(f"Error processing tick data: {e}") - - def build_comprehensive_rl_state(self, symbol: str) -> Optional[np.ndarray]: - """ - Build comprehensive RL state with 13,400 features as specified - - Returns: - np.ndarray: State vector with 13,400 features - """ - try: - # Initialize state vector - state_size = 13400 - state = np.zeros(state_size, dtype=np.float32) - - # Get latest data - ohlcv_data = self.data_provider.get_latest_candles(symbol, '1s', limit=100) - cob_data = self.cob_data.get(symbol) - - # Feature index tracking - idx = 0 - - # 1. OHLCV features (4000 features) - if ohlcv_data is not None and not ohlcv_data.empty: - # Use last 100 1s candles (40 features each: O,H,L,C,V + 36 indicators) - for i in range(min(100, len(ohlcv_data))): - if idx + 40 <= state_size: - row = ohlcv_data.iloc[-(i+1)] - state[idx] = row.get('open', 0) / 100000 # Normalized - state[idx+1] = row.get('high', 0) / 100000 - state[idx+2] = row.get('low', 0) / 100000 - state[idx+3] = row.get('close', 0) / 100000 - state[idx+4] = row.get('volume', 0) / 1000000 - - # Add technical indicators if available - indicator_idx = 5 - for col in ['sma_10', 'sma_20', 'ema_12', 'ema_26', 'rsi_14', - 'macd', 'bb_upper', 'bb_lower', 'atr', 'adx']: - if col in row and idx + indicator_idx < state_size: - state[idx + indicator_idx] = row[col] / 100000 - indicator_idx += 1 - - idx += 40 - - # 2. COB features (8000 features) - if cob_data and idx + 8000 <= state_size: - # Use 200 price buckets (40 features each) - bucket_prices = sorted(cob_data.price_buckets.keys()) - for i, price in enumerate(bucket_prices[:200]): - if idx + 40 <= state_size: - bucket = cob_data.price_buckets[price] - state[idx] = bucket.get('bid_volume', 0) / 1000000 # Normalized - state[idx+1] = bucket.get('ask_volume', 0) / 1000000 - state[idx+2] = cob_data.bid_ask_imbalance.get(price, 0) - state[idx+3] = cob_data.volume_weighted_prices.get(price, price) / 100000 - - # Additional COB metrics - state[idx+4] = cob_data.order_flow_metrics.get('total_bid_volume', 0) / 10000000 - state[idx+5] = cob_data.order_flow_metrics.get('total_ask_volume', 0) / 10000000 - state[idx+6] = cob_data.order_flow_metrics.get('bid_ask_ratio', 0) - - idx += 40 - - # 3. Technical indicator features (1000 features) - # Already included in OHLCV section above - - # 4. Market microstructure features (400 features) - if cob_data and idx + 400 <= state_size: - # Add order flow metrics - metrics = list(cob_data.order_flow_metrics.values()) - for i, metric in enumerate(metrics[:400]): - if idx + i < state_size: - state[idx + i] = metric - - # Log state building success - self.performance_metrics['successful_states'] += 1 - logger.debug(f"Comprehensive RL state built for {symbol}: {len(state)} features") - - # Log to TensorBoard - self.tensorboard_logger.log_state_metrics( - symbol=symbol, - state_info={ - 'size': len(state), - 'quality': 1.0, - 'feature_counts': { - 'total': len(state), - 'non_zero': np.count_nonzero(state) - } - }, - step=self.performance_metrics['successful_states'] - ) - - return state - - except Exception as e: - logger.error(f"Error building comprehensive RL state for {symbol}: {e}") - return None - - def calculate_enhanced_pivot_reward(self, trade_decision: Dict, market_data: Dict, trade_outcome: Dict) -> float: - """ - Calculate enhanced pivot-based reward - - Args: - trade_decision: Trading decision with action and confidence - market_data: Market context data - trade_outcome: Actual trade results - - Returns: - float: Enhanced reward value - """ - try: - # Base reward from PnL - pnl_reward = trade_outcome.get('net_pnl', 0) / 100 # Normalize - - # Confidence weighting - confidence = trade_decision.get('confidence', 0.5) - confidence_reward = confidence * 0.2 - - # Volatility adjustment - volatility = market_data.get('volatility', 0.01) - volatility_reward = (1.0 - volatility * 10) * 0.1 # Prefer low volatility - - # Order flow alignment - order_flow = market_data.get('order_flow_strength', 0) - order_flow_reward = order_flow * 0.2 - - # Pivot alignment bonus (if near pivot in favorable direction) - pivot_bonus = 0.0 - if market_data.get('near_pivot', False): - action = trade_decision.get('action', '').upper() - pivot_type = market_data.get('pivot_type', '').upper() - - # Bonus for buying near support or selling near resistance - if (action == 'BUY' and pivot_type == 'LOW') or \ - (action == 'SELL' and pivot_type == 'HIGH'): - pivot_bonus = 0.5 - - # Calculate final reward - enhanced_reward = pnl_reward + confidence_reward + volatility_reward + order_flow_reward + pivot_bonus - - # Log to TensorBoard - self.tensorboard_logger.log_scalars('Rewards/Components', { - 'pnl_component': pnl_reward, - 'confidence': confidence_reward, - 'volatility': volatility_reward, - 'order_flow': order_flow_reward, - 'pivot_bonus': pivot_bonus - }, self.performance_metrics['total_episodes']) - - self.tensorboard_logger.log_scalar('Rewards/Enhanced', enhanced_reward, self.performance_metrics['total_episodes']) - - logger.debug(f"Enhanced reward calculated: {enhanced_reward}") - return enhanced_reward - - except Exception as e: - logger.error(f"Error calculating enhanced pivot reward: {e}") - return 0.0 - - async def make_coordinated_decisions(self) -> Dict[str, TradingAction]: - """ - Make coordinated trading decisions using all available models - - Returns: - Dict[str, TradingAction]: Trading actions for each symbol - """ - try: - decisions = {} - - # For each symbol, coordinate model inference - for symbol in self.symbols: - # Build comprehensive state for RL model - rl_state = self.build_comprehensive_rl_state(symbol) - - if rl_state is not None: - # Store state for training - self.performance_metrics['total_episodes'] += 1 - - # Create mock RL decision (in a real implementation, this would call the RL model) - action = 'BUY' if np.mean(rl_state[:100]) > 0.5 else 'SELL' - confidence = min(1.0, max(0.0, np.std(rl_state) * 10)) - - # Create trading action - decisions[symbol] = TradingAction( - symbol=symbol, - timestamp=datetime.now(), - action=action, - confidence=confidence, - source='rl_orchestrator' - ) - - logger.info(f"Coordinated decision for {symbol}: {action} (confidence: {confidence:.3f})") - else: - logger.warning(f"Failed to build state for {symbol}, skipping decision") - - self.performance_metrics['inference_count'] += 1 - return decisions - - except Exception as e: - logger.error(f"Error making coordinated decisions: {e}") - return {} - - def _get_symbol_correlation(self, symbol1: str, symbol2: str) -> float: - """ - Calculate correlation between two symbols - - Args: - symbol1: First symbol - symbol2: Second symbol - - Returns: - float: Correlation coefficient (-1 to 1) - """ - try: - # Get recent price data for both symbols - data1 = self.data_provider.get_latest_candles(symbol1, '1m', limit=50) - data2 = self.data_provider.get_latest_candles(symbol2, '1m', limit=50) - - if data1 is None or data2 is None or data1.empty or data2.empty: - return 0.0 - - # Align data by timestamp - merged = data1[['close']].join(data2[['close']], lsuffix='_1', rsuffix='_2', how='inner') - - if len(merged) < 10: - return 0.0 - - # Calculate correlation - correlation = merged['close_1'].corr(merged['close_2']) - return correlation if not np.isnan(correlation) else 0.0 - - except Exception as e: - logger.error(f"Error calculating symbol correlation: {e}") - return 0.0 -``` \ No newline at end of file diff --git a/core/enhanced_training_integration.py b/core/enhanced_training_integration.py deleted file mode 100644 index 6cdd674..0000000 --- a/core/enhanced_training_integration.py +++ /dev/null @@ -1,775 +0,0 @@ -""" -Enhanced Training Integration Module - -This module provides comprehensive integration between the training data collection system, -CNN training pipeline, RL training pipeline, and your existing infrastructure. - -Key Features: -- Real-time integration with existing DataProvider -- Coordinated training across CNN and RL models -- Automatic outcome validation and profitability tracking -- Integration with existing COB RL model -- Performance monitoring and optimization -- Seamless connection to existing orchestrator and trading executor -""" - -import asyncio -import logging -import numpy as np -import pandas as pd -import torch -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Any, Callable -from dataclasses import dataclass -import threading -import time -from pathlib import Path - -# Import existing components -from .data_provider import DataProvider -from .orchestrator import Orchestrator -from .trading_executor import TradingExecutor - -# Import our training system components -from .training_data_collector import ( - TrainingDataCollector, - get_training_data_collector -) -from .cnn_training_pipeline import ( - CNNPivotPredictor, - CNNTrainer, - get_cnn_trainer -) -from .rl_training_pipeline import ( - RLTradingAgent, - RLTrainer, - get_rl_trainer -) -from .training_integration import TrainingIntegration - -# Import existing RL model -try: - from NN.models.cob_rl_model import COBRLModelInterface -except ImportError: - logger.warning("Could not import COBRLModelInterface - using fallback") - COBRLModelInterface = None - -logger = logging.getLogger(__name__) - -@dataclass -class EnhancedTrainingConfig: - """Enhanced configuration for comprehensive training integration""" - # Data collection - collection_interval: float = 1.0 - min_data_completeness: float = 0.8 - - # Training triggers - min_episodes_for_cnn_training: int = 100 - min_experiences_for_rl_training: int = 200 - training_frequency_minutes: int = 30 - - # Profitability thresholds - min_profitability_for_replay: float = 0.1 - high_profitability_threshold: float = 0.5 - - # Model integration - use_existing_cob_rl_model: bool = True - enable_cross_model_learning: bool = True - - # Performance optimization - max_concurrent_training_sessions: int = 2 - enable_background_validation: bool = True - -class EnhancedTrainingIntegration: - """Enhanced training integration with existing infrastructure""" - - def __init__(self, - data_provider: DataProvider, - orchestrator: Orchestrator = None, - trading_executor: TradingExecutor = None, - config: EnhancedTrainingConfig = None): - - self.data_provider = data_provider - self.orchestrator = orchestrator - self.trading_executor = trading_executor - self.config = config or EnhancedTrainingConfig() - - # Initialize training components - self.data_collector = get_training_data_collector() - - # Initialize CNN components - self.cnn_model = CNNPivotPredictor() - self.cnn_trainer = get_cnn_trainer(self.cnn_model) - - # Initialize RL components - if self.config.use_existing_cob_rl_model and COBRLModelInterface: - self.existing_rl_model = COBRLModelInterface() - logger.info("Using existing COB RL model") - else: - self.existing_rl_model = None - - self.rl_agent = RLTradingAgent() - self.rl_trainer = get_rl_trainer(self.rl_agent) - - # Integration state - self.is_running = False - self.training_threads = {} - self.validation_thread = None - - # Performance tracking - self.integration_stats = { - 'total_data_packages': 0, - 'cnn_training_sessions': 0, - 'rl_training_sessions': 0, - 'profitable_predictions': 0, - 'total_predictions': 0, - 'cross_model_improvements': 0, - 'last_update': datetime.now() - } - - # Model prediction tracking - self.recent_predictions = {} - self.prediction_outcomes = {} - - # Cross-model learning - self.model_performance_history = { - 'cnn': [], - 'rl': [], - 'orchestrator': [] - } - - logger.info("Enhanced Training Integration initialized") - logger.info(f"CNN model parameters: {sum(p.numel() for p in self.cnn_model.parameters()):,}") - logger.info(f"RL agent parameters: {sum(p.numel() for p in self.rl_agent.parameters()):,}") - logger.info(f"Using existing COB RL model: {self.existing_rl_model is not None}") - - def start_enhanced_integration(self): - """Start the enhanced training integration system""" - if self.is_running: - logger.warning("Enhanced training integration already running") - return - - self.is_running = True - - # Start data collection - self.data_collector.start_collection() - - # Start CNN training - if self.config.min_episodes_for_cnn_training > 0: - for symbol in self.data_provider.symbols: - self.cnn_trainer.start_real_time_training(symbol) - - # Start coordinated training thread - self.training_threads['coordinator'] = threading.Thread( - target=self._training_coordinator_worker, - daemon=True - ) - self.training_threads['coordinator'].start() - - # Start data collection and validation - self.training_threads['data_collector'] = threading.Thread( - target=self._enhanced_data_collection_worker, - daemon=True - ) - self.training_threads['data_collector'].start() - - # Start outcome validation if enabled - if self.config.enable_background_validation: - self.validation_thread = threading.Thread( - target=self._outcome_validation_worker, - daemon=True - ) - self.validation_thread.start() - - logger.info("Enhanced training integration started") - - def stop_enhanced_integration(self): - """Stop the enhanced training integration system""" - self.is_running = False - - # Stop data collection - self.data_collector.stop_collection() - - # Stop CNN training - self.cnn_trainer.stop_training() - - # Wait for threads to finish - for thread_name, thread in self.training_threads.items(): - thread.join(timeout=10) - logger.info(f"Stopped {thread_name} thread") - - if self.validation_thread: - self.validation_thread.join(timeout=5) - - logger.info("Enhanced training integration stopped") - - def _enhanced_data_collection_worker(self): - """Enhanced data collection with real-time model integration""" - logger.info("Enhanced data collection worker started") - - while self.is_running: - try: - for symbol in self.data_provider.symbols: - self._collect_enhanced_training_data(symbol) - - time.sleep(self.config.collection_interval) - - except Exception as e: - logger.error(f"Error in enhanced data collection: {e}") - time.sleep(5) - - logger.info("Enhanced data collection worker stopped") - - def _collect_enhanced_training_data(self, symbol: str): - """Collect enhanced training data with model predictions""" - try: - # Get comprehensive market data - market_data = self._get_comprehensive_market_data(symbol) - - if not market_data or not self._validate_market_data(market_data): - return - - # Get current model predictions - model_predictions = self._get_all_model_predictions(symbol, market_data) - - # Create enhanced features - cnn_features = self._create_enhanced_cnn_features(symbol, market_data) - rl_state = self._create_enhanced_rl_state(symbol, market_data, model_predictions) - - # Collect training data with predictions - episode_id = self.data_collector.collect_training_data( - symbol=symbol, - ohlcv_data=market_data['ohlcv'], - tick_data=market_data['ticks'], - cob_data=market_data['cob'], - technical_indicators=market_data['indicators'], - pivot_points=market_data['pivots'], - cnn_features=cnn_features, - rl_state=rl_state, - orchestrator_context=market_data['context'], - model_predictions=model_predictions - ) - - if episode_id: - # Store predictions for outcome validation - self.recent_predictions[episode_id] = { - 'timestamp': datetime.now(), - 'symbol': symbol, - 'predictions': model_predictions, - 'market_data': market_data - } - - # Add RL experience if we have action - if 'rl_action' in model_predictions: - self._add_rl_experience(symbol, market_data, model_predictions, episode_id) - - self.integration_stats['total_data_packages'] += 1 - - except Exception as e: - logger.error(f"Error collecting enhanced training data for {symbol}: {e}") - - def _get_comprehensive_market_data(self, symbol: str) -> Dict[str, Any]: - """Get comprehensive market data from all sources""" - try: - market_data = {} - - # OHLCV data - ohlcv_data = {} - for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']: - df = self.data_provider.get_historical_data(symbol, timeframe, limit=300, refresh=True) - if df is not None and not df.empty: - ohlcv_data[timeframe] = df - market_data['ohlcv'] = ohlcv_data - - # Tick data - market_data['ticks'] = self._get_recent_tick_data(symbol) - - # COB data - market_data['cob'] = self._get_cob_data(symbol) - - # Technical indicators - market_data['indicators'] = self._get_technical_indicators(symbol) - - # Pivot points - market_data['pivots'] = self._get_pivot_points(symbol) - - # Market context - market_data['context'] = self._get_market_context(symbol) - - return market_data - - except Exception as e: - logger.error(f"Error getting comprehensive market data: {e}") - return {} - - def _get_all_model_predictions(self, symbol: str, market_data: Dict[str, Any]) -> Dict[str, Any]: - """Get predictions from all available models""" - predictions = {} - - try: - # CNN predictions - if self.cnn_model and market_data.get('ohlcv'): - cnn_features = self._create_enhanced_cnn_features(symbol, market_data) - if cnn_features is not None: - cnn_input = torch.from_numpy(cnn_features).float().unsqueeze(0) - - # Reshape for CNN (add channel dimension) - cnn_input = cnn_input.view(1, 10, -1) # Assuming 10 channels - - with torch.no_grad(): - cnn_outputs = self.cnn_model(cnn_input) - predictions['cnn'] = { - 'pivot_logits': cnn_outputs['pivot_logits'].cpu().numpy(), - 'pivot_price': cnn_outputs['pivot_price'].cpu().numpy(), - 'confidence': cnn_outputs['confidence'].cpu().numpy(), - 'timestamp': datetime.now() - } - - # RL predictions - if self.rl_agent and market_data.get('cob'): - rl_state = self._create_enhanced_rl_state(symbol, market_data, predictions) - if rl_state is not None: - action, confidence = self.rl_agent.select_action(rl_state, epsilon=0.1) - predictions['rl'] = { - 'action': action, - 'confidence': confidence, - 'timestamp': datetime.now() - } - predictions['rl_action'] = action - - # Existing COB RL model predictions - if self.existing_rl_model and market_data.get('cob'): - cob_features = market_data['cob'].get('cob_features', []) - if cob_features and len(cob_features) >= 2000: - cob_array = np.array(cob_features[:2000], dtype=np.float32) - cob_prediction = self.existing_rl_model.predict(cob_array) - predictions['cob_rl'] = { - 'predicted_direction': cob_prediction.get('predicted_direction', 1), - 'confidence': cob_prediction.get('confidence', 0.5), - 'value': cob_prediction.get('value', 0.0), - 'timestamp': datetime.now() - } - - # Orchestrator predictions (if available) - if self.orchestrator: - try: - # This would integrate with your orchestrator's prediction method - orchestrator_prediction = self._get_orchestrator_prediction(symbol, market_data, predictions) - if orchestrator_prediction: - predictions['orchestrator'] = orchestrator_prediction - except Exception as e: - logger.debug(f"Could not get orchestrator prediction: {e}") - - return predictions - - except Exception as e: - logger.error(f"Error getting model predictions: {e}") - return {} - - def _add_rl_experience(self, symbol: str, market_data: Dict[str, Any], - predictions: Dict[str, Any], episode_id: str): - """Add RL experience to the training buffer""" - try: - # Create RL state - state = self._create_enhanced_rl_state(symbol, market_data, predictions) - if state is None: - return - - # Get action from predictions - action = predictions.get('rl_action', 1) # Default to HOLD - - # Calculate immediate reward (placeholder - would be updated with actual outcome) - reward = 0.0 - - # Create next state (same as current for now - would be updated) - next_state = state.copy() - - # Market context - market_context = { - 'symbol': symbol, - 'episode_id': episode_id, - 'timestamp': datetime.now(), - 'market_session': market_data['context'].get('market_session', 'unknown'), - 'volatility_regime': market_data['context'].get('volatility_regime', 'unknown') - } - - # Add experience - experience_id = self.rl_trainer.add_experience( - state=state, - action=action, - reward=reward, - next_state=next_state, - done=False, - market_context=market_context, - cnn_predictions=predictions.get('cnn'), - confidence_score=predictions.get('rl', {}).get('confidence', 0.0) - ) - - if experience_id: - logger.debug(f"Added RL experience: {experience_id}") - - except Exception as e: - logger.error(f"Error adding RL experience: {e}") - - def _training_coordinator_worker(self): - """Coordinate training across all models""" - logger.info("Training coordinator worker started") - - while self.is_running: - try: - # Check if we should trigger training - for symbol in self.data_provider.symbols: - self._check_and_trigger_training(symbol) - - # Wait before next check - time.sleep(self.config.training_frequency_minutes * 60) - - except Exception as e: - logger.error(f"Error in training coordinator: {e}") - time.sleep(60) - - logger.info("Training coordinator worker stopped") - - def _check_and_trigger_training(self, symbol: str): - """Check conditions and trigger training if needed""" - try: - # Get training episodes and experiences - episodes = self.data_collector.get_high_priority_episodes(symbol, limit=1000) - - # Check CNN training conditions - if len(episodes) >= self.config.min_episodes_for_cnn_training: - profitable_episodes = [ep for ep in episodes if ep.actual_outcome.is_profitable] - - if len(profitable_episodes) >= 20: # Minimum profitable episodes - logger.info(f"Triggering CNN training for {symbol} with {len(profitable_episodes)} profitable episodes") - - results = self.cnn_trainer.train_on_profitable_episodes( - symbol=symbol, - min_profitability=self.config.min_profitability_for_replay, - max_episodes=len(profitable_episodes) - ) - - if results.get('status') == 'success': - self.integration_stats['cnn_training_sessions'] += 1 - logger.info(f"CNN training completed for {symbol}") - - # Check RL training conditions - buffer_stats = self.rl_trainer.experience_buffer.get_buffer_statistics() - total_experiences = buffer_stats.get('total_experiences', 0) - - if total_experiences >= self.config.min_experiences_for_rl_training: - profitable_experiences = buffer_stats.get('profitable_experiences', 0) - - if profitable_experiences >= 50: # Minimum profitable experiences - logger.info(f"Triggering RL training with {profitable_experiences} profitable experiences") - - results = self.rl_trainer.train_on_profitable_experiences( - min_profitability=self.config.min_profitability_for_replay, - max_experiences=min(profitable_experiences, 500), - batch_size=32 - ) - - if results.get('status') == 'success': - self.integration_stats['rl_training_sessions'] += 1 - logger.info("RL training completed") - - except Exception as e: - logger.error(f"Error checking training conditions for {symbol}: {e}") - - def _outcome_validation_worker(self): - """Background worker for validating prediction outcomes""" - logger.info("Outcome validation worker started") - - while self.is_running: - try: - self._validate_recent_predictions() - time.sleep(300) # Check every 5 minutes - - except Exception as e: - logger.error(f"Error in outcome validation: {e}") - time.sleep(60) - - logger.info("Outcome validation worker stopped") - - def _validate_recent_predictions(self): - """Validate recent predictions against actual outcomes""" - try: - current_time = datetime.now() - validation_delay = timedelta(hours=1) # Wait 1 hour to validate - - validated_predictions = [] - - for episode_id, prediction_data in self.recent_predictions.items(): - prediction_time = prediction_data['timestamp'] - - if current_time - prediction_time >= validation_delay: - # Validate this prediction - outcome = self._calculate_prediction_outcome(prediction_data) - - if outcome: - self.prediction_outcomes[episode_id] = outcome - - # Update RL experience if exists - if 'rl_action' in prediction_data['predictions']: - self._update_rl_experience_outcome(episode_id, outcome) - - # Update statistics - if outcome['is_profitable']: - self.integration_stats['profitable_predictions'] += 1 - self.integration_stats['total_predictions'] += 1 - - validated_predictions.append(episode_id) - - # Remove validated predictions - for episode_id in validated_predictions: - del self.recent_predictions[episode_id] - - if validated_predictions: - logger.info(f"Validated {len(validated_predictions)} predictions") - - except Exception as e: - logger.error(f"Error validating predictions: {e}") - - def _calculate_prediction_outcome(self, prediction_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """Calculate actual outcome for a prediction""" - try: - symbol = prediction_data['symbol'] - prediction_time = prediction_data['timestamp'] - - # Get price data after prediction - current_df = self.data_provider.get_historical_data(symbol, '1m', limit=100, refresh=True) - - if current_df is None or current_df.empty: - return None - - # Find price at prediction time and current price - prediction_price = prediction_data['market_data']['ohlcv'].get('1m', pd.DataFrame()) - if prediction_price.empty: - return None - - base_price = float(prediction_price['close'].iloc[-1]) - current_price = float(current_df['close'].iloc[-1]) - - # Calculate outcome - price_change = (current_price - base_price) / base_price - is_profitable = abs(price_change) > 0.005 # 0.5% threshold - - return { - 'episode_id': prediction_data.get('episode_id'), - 'base_price': base_price, - 'current_price': current_price, - 'price_change': price_change, - 'is_profitable': is_profitable, - 'profitability_score': abs(price_change) * 10, # Scale to 0-1 range - 'validation_time': datetime.now() - } - - except Exception as e: - logger.error(f"Error calculating prediction outcome: {e}") - return None - - def _update_rl_experience_outcome(self, episode_id: str, outcome: Dict[str, Any]): - """Update RL experience with actual outcome""" - try: - # Find the experience ID associated with this episode - # This is a simplified approach - in practice you'd maintain better mapping - actual_profit = outcome['price_change'] - - # Determine optimal action based on outcome - if outcome['price_change'] > 0.01: - optimal_action = 2 # BUY - elif outcome['price_change'] < -0.01: - optimal_action = 0 # SELL - else: - optimal_action = 1 # HOLD - - # Update experience (this would need proper experience ID mapping) - # For now, we'll update the most recent experience - # In practice, you'd maintain a mapping between episodes and experiences - - except Exception as e: - logger.error(f"Error updating RL experience outcome: {e}") - - def get_integration_statistics(self) -> Dict[str, Any]: - """Get comprehensive integration statistics""" - stats = self.integration_stats.copy() - - # Add component statistics - stats['data_collector'] = self.data_collector.get_collection_statistics() - stats['cnn_trainer'] = self.cnn_trainer.get_training_statistics() - stats['rl_trainer'] = self.rl_trainer.get_training_statistics() - - # Add performance metrics - stats['is_running'] = self.is_running - stats['active_symbols'] = len(self.data_provider.symbols) - stats['recent_predictions_count'] = len(self.recent_predictions) - stats['validated_outcomes_count'] = len(self.prediction_outcomes) - - # Calculate profitability rate - if stats['total_predictions'] > 0: - stats['overall_profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions'] - else: - stats['overall_profitability_rate'] = 0.0 - - return stats - - def trigger_manual_training(self, training_type: str = 'all', symbol: str = None) -> Dict[str, Any]: - """Manually trigger training""" - results = {} - - try: - if training_type in ['all', 'cnn']: - symbols = [symbol] if symbol else self.data_provider.symbols - for sym in symbols: - cnn_results = self.cnn_trainer.train_on_profitable_episodes( - symbol=sym, - min_profitability=0.1, - max_episodes=200 - ) - results[f'cnn_{sym}'] = cnn_results - - if training_type in ['all', 'rl']: - rl_results = self.rl_trainer.train_on_profitable_experiences( - min_profitability=0.1, - max_experiences=500, - batch_size=32 - ) - results['rl'] = rl_results - - return {'status': 'success', 'results': results} - - except Exception as e: - logger.error(f"Error in manual training trigger: {e}") - return {'status': 'error', 'error': str(e)} - - # Helper methods (simplified implementations) - def _get_recent_tick_data(self, symbol: str) -> List[Dict[str, Any]]: - """Get recent tick data""" - # Implementation would get tick data from data provider - return [] - - def _get_cob_data(self, symbol: str) -> Dict[str, Any]: - """Get COB data""" - # Implementation would get COB data from data provider - return {} - - def _get_technical_indicators(self, symbol: str) -> Dict[str, float]: - """Get technical indicators""" - # Implementation would get indicators from data provider - return {} - - def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]: - """Get pivot points""" - # Implementation would get pivot points from data provider - return [] - - def _get_market_context(self, symbol: str) -> Dict[str, Any]: - """Get market context""" - return { - 'symbol': symbol, - 'timestamp': datetime.now(), - 'market_session': 'unknown', - 'volatility_regime': 'unknown' - } - - def _validate_market_data(self, market_data: Dict[str, Any]) -> bool: - """Validate market data completeness""" - required_fields = ['ohlcv', 'indicators'] - return all(field in market_data for field in required_fields) - - def _create_enhanced_cnn_features(self, symbol: str, market_data: Dict[str, Any]) -> Optional[np.ndarray]: - """Create enhanced CNN features""" - try: - # Simplified feature creation - features = [] - - # Add OHLCV features - for timeframe in ['1m', '5m', '15m', '1h']: - if timeframe in market_data.get('ohlcv', {}): - df = market_data['ohlcv'][timeframe] - if not df.empty: - ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values - if len(ohlcv_values) > 0: - recent_values = ohlcv_values[-60:].flatten() - features.extend(recent_values) - - # Pad to target size - target_size = 3000 # 10 channels * 300 sequence length - if len(features) < target_size: - features.extend([0.0] * (target_size - len(features))) - else: - features = features[:target_size] - - return np.array(features, dtype=np.float32) - - except Exception as e: - logger.warning(f"Error creating CNN features: {e}") - return None - - def _create_enhanced_rl_state(self, symbol: str, market_data: Dict[str, Any], - predictions: Dict[str, Any] = None) -> Optional[np.ndarray]: - """Create enhanced RL state""" - try: - state_features = [] - - # Add market features - if '1m' in market_data.get('ohlcv', {}): - df = market_data['ohlcv']['1m'] - if not df.empty: - latest = df.iloc[-1] - state_features.extend([ - latest['open'], latest['high'], - latest['low'], latest['close'], latest['volume'] - ]) - - # Add technical indicators - indicators = market_data.get('indicators', {}) - for value in indicators.values(): - state_features.append(value) - - # Add model predictions as features - if predictions: - if 'cnn' in predictions: - cnn_pred = predictions['cnn'] - state_features.extend(cnn_pred.get('pivot_logits', [0, 0, 0])) - state_features.append(cnn_pred.get('confidence', [0.0])[0]) - - if 'cob_rl' in predictions: - cob_pred = predictions['cob_rl'] - state_features.append(cob_pred.get('predicted_direction', 1)) - state_features.append(cob_pred.get('confidence', 0.5)) - - # Pad to target size - target_size = 2000 - if len(state_features) < target_size: - state_features.extend([0.0] * (target_size - len(state_features))) - else: - state_features = state_features[:target_size] - - return np.array(state_features, dtype=np.float32) - - except Exception as e: - logger.warning(f"Error creating RL state: {e}") - return None - - def _get_orchestrator_prediction(self, symbol: str, market_data: Dict[str, Any], - predictions: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """Get orchestrator prediction""" - # This would integrate with your orchestrator - return None - -# Global instance -enhanced_training_integration = None - -def get_enhanced_training_integration(data_provider: DataProvider = None, - orchestrator: Orchestrator = None, - trading_executor: TradingExecutor = None) -> EnhancedTrainingIntegration: - """Get global enhanced training integration instance""" - global enhanced_training_integration - if enhanced_training_integration is None: - if data_provider is None: - raise ValueError("DataProvider required for first initialization") - enhanced_training_integration = EnhancedTrainingIntegration( - data_provider, orchestrator, trading_executor - ) - return enhanced_training_integration \ No newline at end of file diff --git a/core/mexc_webclient/__init__.py b/core/mexc_webclient/__init__.py deleted file mode 100644 index 449bcac..0000000 --- a/core/mexc_webclient/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# MEXC Web Client Module -# -# This module provides web-based trading capabilities for MEXC futures trading -# which is not supported by their official API. - -from .mexc_futures_client import MEXCFuturesWebClient - -__all__ = ['MEXCFuturesWebClient'] \ No newline at end of file diff --git a/core/mexc_webclient/auto_browser.py b/core/mexc_webclient/auto_browser.py deleted file mode 100644 index d8b5c31..0000000 --- a/core/mexc_webclient/auto_browser.py +++ /dev/null @@ -1,555 +0,0 @@ -#!/usr/bin/env python3 -""" -MEXC Auto Browser with Request Interception - -This script automatically spawns a ChromeDriver instance and captures -all MEXC futures trading requests in real-time, including full request -and response data needed for reverse engineering. -""" - -import logging -import time -import json -import sys -import os -from typing import Dict, List, Optional, Any -from datetime import datetime -import threading -import queue - -# Selenium imports -try: - from selenium import webdriver - from selenium.webdriver.chrome.options import Options - from selenium.webdriver.chrome.service import Service - from selenium.webdriver.common.by import By - from selenium.webdriver.support.ui import WebDriverWait - from selenium.webdriver.support import expected_conditions as EC - from selenium.common.exceptions import TimeoutException, WebDriverException - from webdriver_manager.chrome import ChromeDriverManager -except ImportError: - print("Please install selenium and webdriver-manager:") - print("pip install selenium webdriver-manager") - sys.exit(1) - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -class MEXCRequestInterceptor: - """ - Automatically spawns ChromeDriver and intercepts all MEXC API requests - """ - - def __init__(self, headless: bool = False, save_to_file: bool = True): - """ - Initialize the request interceptor - - Args: - headless: Run browser in headless mode - save_to_file: Save captured requests to JSON file - """ - self.driver = None - self.headless = headless - self.save_to_file = save_to_file - self.captured_requests = [] - self.captured_responses = [] - self.session_cookies = {} - self.monitoring = False - self.request_queue = queue.Queue() - - # File paths for saving data - self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - self.requests_file = f"mexc_requests_{self.timestamp}.json" - self.cookies_file = f"mexc_cookies_{self.timestamp}.json" - - def setup_browser(self): - """Setup Chrome browser with necessary options""" - chrome_options = webdriver.ChromeOptions() - # Enable headless mode if needed - if self.headless: - chrome_options.add_argument('--headless') - chrome_options.add_argument('--disable-gpu') - chrome_options.add_argument('--window-size=1920,1080') - chrome_options.add_argument('--disable-extensions') - - # Set up Chrome options with a user data directory to persist session - user_data_base_dir = os.path.join(os.getcwd(), 'chrome_user_data') - os.makedirs(user_data_base_dir, exist_ok=True) - - # Check for existing session directories - session_dirs = [d for d in os.listdir(user_data_base_dir) if d.startswith('session_')] - session_dirs.sort(reverse=True) # Sort descending to get the most recent first - - user_data_dir = None - if session_dirs: - use_existing = input(f"Found {len(session_dirs)} existing sessions. Use an existing session? (y/n): ").lower().strip() == 'y' - if use_existing: - print("Available sessions:") - for i, session in enumerate(session_dirs[:5], 1): # Show up to 5 most recent - print(f"{i}. {session}") - choice = input("Enter session number (default 1) or any other key for most recent: ") - if choice.isdigit() and 1 <= int(choice) <= len(session_dirs): - selected_session = session_dirs[int(choice) - 1] - else: - selected_session = session_dirs[0] - user_data_dir = os.path.join(user_data_base_dir, selected_session) - print(f"Using session: {selected_session}") - - if user_data_dir is None: - user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}') - os.makedirs(user_data_dir, exist_ok=True) - print(f"Creating new session: session_{self.timestamp}") - - chrome_options.add_argument(f'--user-data-dir={user_data_dir}') - - # Enable logging to capture JS console output and network activity - chrome_options.set_capability('goog:loggingPrefs', { - 'browser': 'ALL', - 'performance': 'ALL' - }) - - try: - self.driver = webdriver.Chrome(options=chrome_options) - except Exception as e: - print(f"Failed to start browser with session: {e}") - print("Falling back to a new session...") - user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}_fallback') - os.makedirs(user_data_dir, exist_ok=True) - print(f"Creating fallback session: session_{self.timestamp}_fallback") - chrome_options = webdriver.ChromeOptions() - if self.headless: - chrome_options.add_argument('--headless') - chrome_options.add_argument('--disable-gpu') - chrome_options.add_argument('--window-size=1920,1080') - chrome_options.add_argument('--disable-extensions') - chrome_options.add_argument(f'--user-data-dir={user_data_dir}') - chrome_options.set_capability('goog:loggingPrefs', { - 'browser': 'ALL', - 'performance': 'ALL' - }) - self.driver = webdriver.Chrome(options=chrome_options) - - return self.driver - - def start_monitoring(self): - """Start the browser and begin monitoring""" - logger.info("Starting MEXC Request Interceptor...") - - try: - # Setup ChromeDriver - self.driver = self.setup_browser() - - # Navigate to MEXC futures - mexc_url = "https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap" - logger.info(f"Navigating to: {mexc_url}") - self.driver.get(mexc_url) - - # Wait for page load - WebDriverWait(self.driver, 10).until( - EC.presence_of_element_located((By.TAG_NAME, "body")) - ) - - logger.info("✅ MEXC page loaded successfully!") - logger.info("📝 Please log in manually in the browser window") - logger.info("🔍 Request monitoring is now active...") - - # Start monitoring in background thread - self.monitoring = True - monitor_thread = threading.Thread(target=self._monitor_requests, daemon=True) - monitor_thread.start() - - # Wait for manual login - self._wait_for_login() - - return True - - except Exception as e: - logger.error(f"Failed to start monitoring: {e}") - return False - - def _wait_for_login(self): - """Wait for user to log in and show interactive menu""" - logger.info("\n" + "="*60) - logger.info("MEXC REQUEST INTERCEPTOR - INTERACTIVE MODE") - logger.info("="*60) - - while True: - print("\nOptions:") - print("1. Check login status") - print("2. Extract current cookies") - print("3. Show captured requests summary") - print("4. Save captured data to files") - print("5. Perform test trade (manual)") - print("6. Monitor for 60 seconds") - print("0. Stop and exit") - - choice = input("\nEnter choice (0-6): ").strip() - - if choice == "1": - self._check_login_status() - elif choice == "2": - self._extract_cookies() - elif choice == "3": - self._show_requests_summary() - elif choice == "4": - self._save_all_data() - elif choice == "5": - self._guide_test_trade() - elif choice == "6": - self._monitor_for_duration(60) - elif choice == "0": - break - else: - print("Invalid choice. Please try again.") - - self.stop_monitoring() - - def _check_login_status(self): - """Check if user is logged into MEXC""" - try: - cookies = self.driver.get_cookies() - auth_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint'] - found_auth = [] - - for cookie in cookies: - if cookie['name'] in auth_cookies and cookie['value']: - found_auth.append(cookie['name']) - - if len(found_auth) >= 2: - print("✅ LOGIN DETECTED - You appear to be logged in!") - print(f" Found auth cookies: {', '.join(found_auth)}") - return True - else: - print("❌ NOT LOGGED IN - Please log in to MEXC in the browser") - print(" Missing required authentication cookies") - return False - - except Exception as e: - print(f"❌ Error checking login: {e}") - return False - - def _extract_cookies(self): - """Extract and display current session cookies""" - try: - cookies = self.driver.get_cookies() - cookie_dict = {} - - for cookie in cookies: - cookie_dict[cookie['name']] = cookie['value'] - - self.session_cookies = cookie_dict - - print(f"\n📊 Extracted {len(cookie_dict)} cookies:") - - # Show important cookies - important = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId'] - for name in important: - if name in cookie_dict: - value = cookie_dict[name] - display_value = value[:20] + "..." if len(value) > 20 else value - print(f" ✅ {name}: {display_value}") - else: - print(f" ❌ {name}: Missing") - - # Save cookies to file - if self.save_to_file: - with open(self.cookies_file, 'w') as f: - json.dump(cookie_dict, f, indent=2) - print(f"\n💾 Cookies saved to: {self.cookies_file}") - - except Exception as e: - print(f"❌ Error extracting cookies: {e}") - - def _monitor_requests(self): - """Background thread to monitor network requests""" - last_log_count = 0 - - while self.monitoring: - try: - # Get performance logs - logs = self.driver.get_log('performance') - - for log in logs: - try: - message = json.loads(log['message']) - method = message.get('message', {}).get('method', '') - - # Capture network requests - if method == 'Network.requestWillBeSent': - self._process_request(message['message']['params']) - elif method == 'Network.responseReceived': - self._process_response(message['message']['params']) - - except (json.JSONDecodeError, KeyError) as e: - continue - - # Show progress every 10 new requests - if len(self.captured_requests) >= last_log_count + 10: - last_log_count = len(self.captured_requests) - logger.info(f"📈 Captured {len(self.captured_requests)} requests, {len(self.captured_responses)} responses") - - except Exception as e: - if self.monitoring: # Only log if we're still supposed to be monitoring - logger.debug(f"Monitor error: {e}") - - time.sleep(0.5) # Check every 500ms - - def _process_request(self, request_data): - """Process a captured network request""" - try: - url = request_data.get('request', {}).get('url', '') - - # Filter for MEXC API requests - if self._is_mexc_request(url): - request_info = { - 'type': 'request', - 'timestamp': datetime.now().isoformat(), - 'url': url, - 'method': request_data.get('request', {}).get('method', ''), - 'headers': request_data.get('request', {}).get('headers', {}), - 'postData': request_data.get('request', {}).get('postData', ''), - 'requestId': request_data.get('requestId', '') - } - - self.captured_requests.append(request_info) - - # Show important requests immediately - if ('futures.mexc.com' in url or 'captcha' in url): - print(f"\n🚀 CAPTURED REQUEST: {request_info['method']} {url}") - if request_info['postData']: - print(f" 📄 POST Data: {request_info['postData'][:100]}...") - - # Enhanced captcha detection and detailed logging - if 'captcha' in url.lower() or 'robot' in url.lower(): - logger.info(f"CAPTCHA REQUEST DETECTED: {request_data.get('request', {}).get('method', 'UNKNOWN')} {url}") - logger.info(f" Headers: {request_data.get('request', {}).get('headers', {})}") - if request_data.get('request', {}).get('postData', ''): - logger.info(f" Data: {request_data.get('request', {}).get('postData', '')}") - # Attempt to capture related JavaScript or DOM elements (if possible) - if self.driver is not None: - try: - js_snippet = self.driver.execute_script("return document.querySelector('script[src*=\"captcha\"]') ? document.querySelector('script[src*=\"captcha\"]').outerHTML : 'No captcha script found';") - logger.info(f" Related JS Snippet: {js_snippet}") - except Exception as e: - logger.warning(f" Could not capture JS snippet: {e}") - try: - dom_element = self.driver.execute_script("return document.querySelector('div[id*=\"captcha\"]') ? document.querySelector('div[id*=\"captcha\"]').outerHTML : 'No captcha element found';") - logger.info(f" Related DOM Element: {dom_element}") - except Exception as e: - logger.warning(f" Could not capture DOM element: {e}") - else: - logger.warning(" Driver not initialized, cannot capture JS or DOM elements") - - except Exception as e: - logger.debug(f"Error processing request: {e}") - - def _process_response(self, response_data): - """Process a captured network response""" - try: - url = response_data.get('response', {}).get('url', '') - - # Filter for MEXC API responses - if self._is_mexc_request(url): - response_info = { - 'type': 'response', - 'timestamp': datetime.now().isoformat(), - 'url': url, - 'status': response_data.get('response', {}).get('status', 0), - 'headers': response_data.get('response', {}).get('headers', {}), - 'requestId': response_data.get('requestId', '') - } - - self.captured_responses.append(response_info) - - # Show important responses immediately - if ('futures.mexc.com' in url or 'captcha' in url): - status = response_info['status'] - status_emoji = "✅" if status == 200 else "❌" - print(f" {status_emoji} RESPONSE: {status} for {url}") - - except Exception as e: - logger.debug(f"Error processing response: {e}") - - def _is_mexc_request(self, url: str) -> bool: - """Check if URL is a relevant MEXC API request""" - mexc_indicators = [ - 'futures.mexc.com', - 'ucgateway/captcha_api', - 'api/v1/private', - 'api/v3/order', - 'mexc.com/api' - ] - - return any(indicator in url for indicator in mexc_indicators) - - def _show_requests_summary(self): - """Show summary of captured requests""" - print(f"\n📊 CAPTURE SUMMARY:") - print(f" Total Requests: {len(self.captured_requests)}") - print(f" Total Responses: {len(self.captured_responses)}") - - # Group by URL pattern - url_counts = {} - for req in self.captured_requests: - base_url = req['url'].split('?')[0] # Remove query params - url_counts[base_url] = url_counts.get(base_url, 0) + 1 - - print("\n🔗 Top URLs:") - for url, count in sorted(url_counts.items(), key=lambda x: x[1], reverse=True)[:5]: - print(f" {count}x {url}") - - # Show recent futures API calls - futures_requests = [r for r in self.captured_requests if 'futures.mexc.com' in r['url']] - if futures_requests: - print(f"\n🚀 Futures API Calls: {len(futures_requests)}") - for req in futures_requests[-3:]: # Show last 3 - print(f" {req['method']} {req['url']}") - - def _save_all_data(self): - """Save all captured data to files""" - if not self.save_to_file: - print("File saving is disabled") - return - - try: - # Save requests - with open(self.requests_file, 'w') as f: - json.dump({ - 'requests': self.captured_requests, - 'responses': self.captured_responses, - 'summary': { - 'total_requests': len(self.captured_requests), - 'total_responses': len(self.captured_responses), - 'capture_session': self.timestamp - } - }, f, indent=2) - - # Save cookies if we have them - if self.session_cookies: - with open(self.cookies_file, 'w') as f: - json.dump(self.session_cookies, f, indent=2) - - print(f"\n💾 Data saved to:") - print(f" 📋 Requests: {self.requests_file}") - if self.session_cookies: - print(f" 🍪 Cookies: {self.cookies_file}") - - # Extract and save CAPTCHA tokens from captured requests - captcha_tokens = self.extract_captcha_tokens() - if captcha_tokens: - captcha_file = f"mexc_captcha_tokens_{self.timestamp}.json" - with open(captcha_file, 'w') as f: - json.dump(captcha_tokens, f, indent=2) - logger.info(f"Saved CAPTCHA tokens to {captcha_file}") - else: - logger.warning("No CAPTCHA tokens found in captured requests") - - except Exception as e: - print(f"❌ Error saving data: {e}") - - def _guide_test_trade(self): - """Guide user through performing a test trade""" - print("\n🧪 TEST TRADE GUIDE:") - print("1. Make sure you're logged into MEXC") - print("2. Go to the trading interface") - print("3. Try to place a SMALL test trade (it may fail, but we'll capture the requests)") - print("4. Watch the console for captured API calls") - print("\n⚠️ IMPORTANT: Use very small amounts for testing!") - input("\nPress Enter when you're ready to start monitoring...") - - self._monitor_for_duration(120) # Monitor for 2 minutes - - def _monitor_for_duration(self, seconds: int): - """Monitor requests for a specific duration""" - print(f"\n🔍 Monitoring requests for {seconds} seconds...") - print("Perform your trading actions now!") - - start_time = time.time() - initial_count = len(self.captured_requests) - - while time.time() - start_time < seconds: - current_count = len(self.captured_requests) - new_requests = current_count - initial_count - - remaining = seconds - int(time.time() - start_time) - print(f"\r⏱️ Time remaining: {remaining}s | New requests: {new_requests}", end="", flush=True) - - time.sleep(1) - - final_count = len(self.captured_requests) - new_total = final_count - initial_count - print(f"\n✅ Monitoring complete! Captured {new_total} new requests") - - def stop_monitoring(self): - """Stop monitoring and close browser""" - logger.info("Stopping request monitoring...") - self.monitoring = False - - if self.driver: - self.driver.quit() - logger.info("Browser closed") - - # Final save - if self.save_to_file and (self.captured_requests or self.captured_responses): - self._save_all_data() - logger.info("Final data save complete") - - def extract_captcha_tokens(self): - """Extract CAPTCHA tokens from captured requests""" - captcha_tokens = [] - for request in self.captured_requests: - if 'captcha-token' in request.get('headers', {}): - token = request['headers']['captcha-token'] - captcha_tokens.append({ - 'token': token, - 'url': request.get('url', ''), - 'timestamp': request.get('timestamp', '') - }) - elif 'captcha' in request.get('url', '').lower(): - response = request.get('response', {}) - if response and 'captcha-token' in response.get('headers', {}): - token = response['headers']['captcha-token'] - captcha_tokens.append({ - 'token': token, - 'url': request.get('url', ''), - 'timestamp': request.get('timestamp', '') - }) - return captcha_tokens - -def main(): - """Main function to run the interceptor""" - print("🚀 MEXC Request Interceptor with ChromeDriver") - print("=" * 50) - print("This will automatically:") - print("✅ Download/setup ChromeDriver") - print("✅ Open MEXC futures page") - print("✅ Capture all API requests/responses") - print("✅ Extract session cookies") - print("✅ Save data to JSON files") - print("\nPress Ctrl+C to stop at any time") - - # Ask for preferences - headless = input("\nRun in headless mode? (y/n): ").lower().strip() == 'y' - - interceptor = MEXCRequestInterceptor(headless=headless, save_to_file=True) - - try: - success = interceptor.start_monitoring() - if not success: - print("❌ Failed to start monitoring") - return - - except KeyboardInterrupt: - print("\n\n⏹️ Stopping interceptor...") - except Exception as e: - print(f"\n❌ Error: {e}") - finally: - interceptor.stop_monitoring() - print("\n👋 Goodbye!") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/core/mexc_webclient/browser_automation.py b/core/mexc_webclient/browser_automation.py deleted file mode 100644 index f2f6e43..0000000 --- a/core/mexc_webclient/browser_automation.py +++ /dev/null @@ -1,358 +0,0 @@ -""" -MEXC Browser Automation for Cookie Extraction and Request Monitoring - -This module uses Selenium to automate browser interactions and extract -session cookies and request data for MEXC futures trading. -""" - -import logging -import time -import json -from typing import Dict, List, Optional, Any -from selenium import webdriver -from selenium.webdriver.chrome.options import Options -from selenium.webdriver.common.by import By -from selenium.webdriver.support.ui import WebDriverWait -from selenium.webdriver.support import expected_conditions as EC -from selenium.common.exceptions import TimeoutException, WebDriverException -from selenium.webdriver.chrome.service import Service -from webdriver_manager.chrome import ChromeDriverManager - -logger = logging.getLogger(__name__) - -class MEXCBrowserAutomation: - """ - Browser automation for MEXC futures trading session management - """ - - def __init__(self, headless: bool = False, proxy: Optional[str] = None): - """ - Initialize browser automation - - Args: - headless: Run browser in headless mode - proxy: HTTP proxy to use (format: host:port) - """ - self.driver = None - self.headless = headless - self.proxy = proxy - self.logged_in = False - - def setup_chrome_driver(self) -> webdriver.Chrome: - """Setup Chrome driver with appropriate options""" - chrome_options = Options() - - if self.headless: - chrome_options.add_argument("--headless") - - # Basic Chrome options for automation - chrome_options.add_argument("--no-sandbox") - chrome_options.add_argument("--disable-dev-shm-usage") - chrome_options.add_argument("--disable-blink-features=AutomationControlled") - chrome_options.add_experimental_option("excludeSwitches", ["enable-automation"]) - chrome_options.add_experimental_option('useAutomationExtension', False) - - # Set user agent to avoid detection - chrome_options.add_argument("--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36") - - # Proxy setup if provided - if self.proxy: - chrome_options.add_argument(f"--proxy-server=http://{self.proxy}") - - # Enable network logging - chrome_options.add_argument("--enable-logging") - chrome_options.add_argument("--log-level=0") - chrome_options.set_capability("goog:loggingPrefs", {"performance": "ALL"}) - - # Automatically download and setup ChromeDriver - service = Service(ChromeDriverManager().install()) - - try: - driver = webdriver.Chrome(service=service, options=chrome_options) - - # Execute script to avoid detection - driver.execute_script("Object.defineProperty(navigator, 'webdriver', {get: () => undefined})") - - return driver - except WebDriverException as e: - logger.error(f"Failed to setup Chrome driver: {e}") - raise - - def start_browser(self): - """Start the browser session""" - if self.driver is None: - logger.info("Starting Chrome browser for MEXC automation") - self.driver = self.setup_chrome_driver() - logger.info("Browser started successfully") - - def stop_browser(self): - """Stop the browser session""" - if self.driver: - logger.info("Stopping browser") - self.driver.quit() - self.driver = None - - def navigate_to_mexc_futures(self, symbol: str = "ETH_USDT"): - """ - Navigate to MEXC futures trading page - - Args: - symbol: Trading symbol to navigate to - """ - if not self.driver: - self.start_browser() - - url = f"https://www.mexc.com/en-GB/futures/{symbol}?type=linear_swap" - logger.info(f"Navigating to MEXC futures: {url}") - - self.driver.get(url) - - # Wait for page to load - try: - WebDriverWait(self.driver, 10).until( - EC.presence_of_element_located((By.TAG_NAME, "body")) - ) - logger.info("MEXC futures page loaded") - except TimeoutException: - logger.error("Timeout waiting for MEXC page to load") - - def wait_for_login(self, timeout: int = 300) -> bool: - """ - Wait for user to manually log in to MEXC - - Args: - timeout: Maximum time to wait for login (seconds) - - Returns: - bool: True if login detected, False if timeout - """ - logger.info("Please log in to MEXC manually in the browser window") - logger.info("Waiting for login completion...") - - start_time = time.time() - - while time.time() - start_time < timeout: - # Check if we can find elements that indicate logged in state - try: - # Look for user-specific elements that appear after login - cookies = self.driver.get_cookies() - - # Check for authentication cookies - auth_cookies = ['uc_token', 'u_id'] - logged_in_indicators = 0 - - for cookie in cookies: - if cookie['name'] in auth_cookies and cookie['value']: - logged_in_indicators += 1 - - if logged_in_indicators >= 2: - logger.info("Login detected!") - self.logged_in = True - return True - - except Exception as e: - logger.debug(f"Error checking login status: {e}") - - time.sleep(2) # Check every 2 seconds - - logger.error(f"Login timeout after {timeout} seconds") - return False - - def extract_session_cookies(self) -> Dict[str, str]: - """ - Extract all cookies from current browser session - - Returns: - Dictionary of cookie name-value pairs - """ - if not self.driver: - logger.error("Browser not started") - return {} - - cookies = {} - - try: - browser_cookies = self.driver.get_cookies() - - for cookie in browser_cookies: - cookies[cookie['name']] = cookie['value'] - - logger.info(f"Extracted {len(cookies)} cookies from browser session") - - # Log important cookies (without values for security) - important_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId'] - for cookie_name in important_cookies: - if cookie_name in cookies: - logger.info(f"Found important cookie: {cookie_name}") - else: - logger.warning(f"Missing important cookie: {cookie_name}") - - return cookies - - except Exception as e: - logger.error(f"Failed to extract cookies: {e}") - return {} - - def monitor_network_requests(self, duration: int = 60) -> List[Dict[str, Any]]: - """ - Monitor network requests for the specified duration - - Args: - duration: How long to monitor requests (seconds) - - Returns: - List of captured network requests - """ - if not self.driver: - logger.error("Browser not started") - return [] - - logger.info(f"Starting network monitoring for {duration} seconds") - logger.info("Please perform trading actions in the browser (open/close positions)") - - start_time = time.time() - captured_requests = [] - - while time.time() - start_time < duration: - try: - # Get performance logs (network requests) - logs = self.driver.get_log('performance') - - for log in logs: - message = json.loads(log['message']) - - # Filter for relevant MEXC API requests - if (message.get('message', {}).get('method') == 'Network.responseReceived'): - response = message['message']['params']['response'] - url = response.get('url', '') - - # Look for futures API calls - if ('futures.mexc.com' in url or - 'ucgateway/captcha_api' in url or - 'api/v1/private' in url): - - request_data = { - 'url': url, - 'method': response.get('mimeType', ''), - 'status': response.get('status'), - 'headers': response.get('headers', {}), - 'timestamp': log['timestamp'] - } - - captured_requests.append(request_data) - logger.info(f"Captured request: {url}") - - except Exception as e: - logger.debug(f"Error in network monitoring: {e}") - - time.sleep(1) - - logger.info(f"Network monitoring complete. Captured {len(captured_requests)} requests") - return captured_requests - - def perform_test_trade(self, symbol: str = "ETH_USDT", volume: float = 1.0, leverage: int = 200): - """ - Attempt to perform a test trade to capture the complete request flow - - Args: - symbol: Trading symbol - volume: Position size - leverage: Leverage multiplier - """ - if not self.logged_in: - logger.error("Not logged in - cannot perform test trade") - return - - logger.info(f"Attempting test trade: {symbol}, Volume: {volume}, Leverage: {leverage}x") - logger.info("This will attempt to click trading interface elements") - - try: - # This would need to be implemented based on MEXC's specific UI elements - # For now, just wait and let user perform manual actions - logger.info("Please manually place a small test trade while monitoring is active") - time.sleep(30) - - except Exception as e: - logger.error(f"Error during test trade: {e}") - - def full_session_capture(self, symbol: str = "ETH_USDT") -> Dict[str, Any]: - """ - Complete session capture workflow - - Args: - symbol: Trading symbol to use - - Returns: - Dictionary containing cookies and captured requests - """ - logger.info("Starting full MEXC session capture") - - try: - # Start browser and navigate to MEXC - self.navigate_to_mexc_futures(symbol) - - # Wait for manual login - if not self.wait_for_login(): - return {'success': False, 'error': 'Login timeout'} - - # Extract session cookies - cookies = self.extract_session_cookies() - - if not cookies: - return {'success': False, 'error': 'Failed to extract cookies'} - - # Monitor network requests while user performs actions - logger.info("Starting network monitoring - please perform trading actions now") - requests = self.monitor_network_requests(duration=120) # 2 minutes - - return { - 'success': True, - 'cookies': cookies, - 'network_requests': requests, - 'timestamp': int(time.time()) - } - - except Exception as e: - logger.error(f"Error in session capture: {e}") - return {'success': False, 'error': str(e)} - - finally: - self.stop_browser() - -def main(): - """Main function for standalone execution""" - logging.basicConfig(level=logging.INFO) - - print("MEXC Browser Automation - Session Capture") - print("This will open a browser window for you to log into MEXC") - print("Make sure you have Chrome browser installed") - - automation = MEXCBrowserAutomation(headless=False) - - try: - result = automation.full_session_capture() - - if result['success']: - print(f"\nSession capture successful!") - print(f"Extracted {len(result['cookies'])} cookies") - print(f"Captured {len(result['network_requests'])} network requests") - - # Save results to file - output_file = f"mexc_session_capture_{int(time.time())}.json" - with open(output_file, 'w') as f: - json.dump(result, f, indent=2) - - print(f"Results saved to: {output_file}") - - else: - print(f"Session capture failed: {result['error']}") - - except KeyboardInterrupt: - print("\nSession capture interrupted by user") - except Exception as e: - print(f"Error: {e}") - finally: - automation.stop_browser() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/core/mexc_webclient/mexc_futures_client.py b/core/mexc_webclient/mexc_futures_client.py deleted file mode 100644 index f8c83fe..0000000 --- a/core/mexc_webclient/mexc_futures_client.py +++ /dev/null @@ -1,525 +0,0 @@ -""" -MEXC Futures Web Client - -This module implements a web-based client for MEXC futures trading -since their official API doesn't support futures (leverage) trading. - -It mimics browser behavior by replicating the exact HTTP requests -that the web interface makes. -""" - -import logging -import requests -import time -import json -import hmac -import hashlib -import base64 -from typing import Dict, List, Optional, Any -from datetime import datetime -import uuid -from urllib.parse import urlencode -import glob -import os - -logger = logging.getLogger(__name__) - -class MEXCSessionManager: - def __init__(self): - self.captcha_token = None - - def get_captcha_token(self) -> str: - return self.captcha_token if self.captcha_token else "" - - def save_captcha_token(self, token: str): - self.captcha_token = token - logger.info("MEXC: Captcha token saved in session manager") - -class MEXCFuturesWebClient: - """ - MEXC Futures Web Client that mimics browser behavior for futures trading. - - Since MEXC's official API doesn't support futures, this client replicates - the exact HTTP requests made by their web interface. - """ - - def __init__(self, api_key: str, api_secret: str, user_id: str, base_url: str = 'https://www.mexc.com', headless: bool = True): - """ - Initialize the MEXC Futures Web Client - - Args: - api_key: API key for authentication - api_secret: API secret for authentication - user_id: User ID for authentication - base_url: Base URL for the MEXC website - headless: Whether to run the browser in headless mode - """ - self.api_key = api_key - self.api_secret = api_secret - self.user_id = user_id - self.base_url = base_url - self.is_authenticated = False - self.headless = headless - self.session = requests.Session() - self.session_manager = MEXCSessionManager() # Adding session_manager attribute - self.captcha_url = f'{base_url}/ucgateway/captcha_api' - self.futures_api_url = "https://futures.mexc.com/api/v1" - - # Setup default headers that mimic a real browser - self.setup_browser_headers() - - def setup_browser_headers(self): - """Setup default headers that mimic Chrome browser""" - self.session.headers.update({ - 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36', - 'Accept': '*/*', - 'Accept-Language': 'en-GB,en-US;q=0.9,en;q=0.8', - 'Accept-Encoding': 'gzip, deflate, br', - 'sec-ch-ua': '"Chromium";v="136", "Google Chrome";v="136", "Not.A/Brand";v="99"', - 'sec-ch-ua-mobile': '?0', - 'sec-ch-ua-platform': '"Windows"', - 'sec-fetch-dest': 'empty', - 'sec-fetch-mode': 'cors', - 'sec-fetch-site': 'same-origin', - 'Cache-Control': 'no-cache', - 'Pragma': 'no-cache', - 'Referer': f'{self.base_url}/en-GB/futures/ETH_USDT?type=linear_swap', - 'Language': 'English', - 'X-Language': 'en-GB', - 'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}", - 'trochilus-uid': str(self.user_id) if self.user_id is not None else '' - }) - - def load_session_cookies(self, cookies: Dict[str, str]): - """ - Load session cookies from browser - - Args: - cookies: Dictionary of cookie name-value pairs - """ - for name, value in cookies.items(): - self.session.cookies.set(name, value) - - # Extract important session info from cookies - self.auth_token = cookies.get('uc_token') - self.user_id = cookies.get('u_id') - self.fingerprint = cookies.get('x-mxc-fingerprint') - self.visitor_id = cookies.get('mexc_fingerprint_visitorId') - - if self.auth_token and self.user_id: - self.is_authenticated = True - logger.info("MEXC: Loaded authenticated session") - else: - logger.warning("MEXC: Session cookies incomplete - authentication may fail") - - def extract_cookies_from_browser(self, cookie_string: str) -> Dict[str, str]: - """ - Extract cookies from a browser cookie string - - Args: - cookie_string: Raw cookie string from browser (copy from Network tab) - - Returns: - Dictionary of parsed cookies - """ - cookies = {} - cookie_pairs = cookie_string.split(';') - - for pair in cookie_pairs: - if '=' in pair: - name, value = pair.strip().split('=', 1) - cookies[name] = value - - return cookies - - def verify_captcha(self, symbol: str, side: str, leverage: str) -> bool: - """ - Verify captcha for robot trading protection - - Args: - symbol: Trading symbol (e.g., 'ETH_USDT') - side: 'openlong', 'closelong', 'openshort', 'closeshort' - leverage: Leverage string (e.g., '200X') - - Returns: - bool: True if captcha verification successful - """ - if not self.is_authenticated: - logger.error("MEXC: Cannot verify captcha - not authenticated") - return False - - # Build captcha endpoint URL - endpoint = f"robot.future.{side}.{symbol}.{leverage}" - url = f"{self.captcha_url}/{endpoint}" - - # Attempt to get captcha token from session manager - captcha_token = self.session_manager.get_captcha_token() - if not captcha_token: - logger.warning("MEXC: No captcha token available, attempting to fetch from browser") - captcha_token = self._extract_captcha_token_from_browser() - if captcha_token: - self.session_manager.save_captcha_token(captcha_token) - else: - logger.error("MEXC: Failed to extract captcha token from browser") - return False - - headers = { - 'Content-Type': 'application/json', - 'Language': 'en-GB', - 'Referer': f'{self.base_url}/en-GB/futures/{symbol}?type=linear_swap', - 'trochilus-uid': self.user_id if self.user_id else '', - 'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}", - 'captcha-token': captcha_token - } - - logger.info(f"MEXC: Verifying captcha for {endpoint}") - try: - response = self.session.get(url, headers=headers, timeout=10) - if response.status_code == 200: - data = response.json() - if data.get('success'): - logger.info(f"MEXC: Captcha verified successfully for {endpoint}") - return True - else: - logger.error(f"MEXC: Captcha verification failed for {endpoint}: {data}") - return False - else: - logger.error(f"MEXC: Captcha verification request failed with status {response.status_code}: {response.text}") - return False - except Exception as e: - logger.error(f"MEXC: Captcha verification error for {endpoint}: {str(e)}") - return False - - def _extract_captcha_token_from_browser(self) -> str: - """ - Extract captcha token from browser session using stored cookies or requests. - This method looks for the most recent mexc_captcha_tokens JSON file to retrieve a token. - """ - try: - # Look for the most recent mexc_captcha_tokens file - captcha_files = glob.glob("mexc_captcha_tokens_*.json") - if not captcha_files: - logger.error("MEXC: No CAPTCHA token files found") - return "" - - # Sort files by timestamp (most recent first) - latest_file = max(captcha_files, key=os.path.getctime) - logger.info(f"MEXC: Using CAPTCHA token file {latest_file}") - - with open(latest_file, 'r') as f: - captcha_data = json.load(f) - - if captcha_data and isinstance(captcha_data, list) and len(captcha_data) > 0: - # Return the most recent token - return captcha_data[0].get('token', '') - else: - logger.error("MEXC: No valid CAPTCHA tokens found in file") - return "" - except Exception as e: - logger.error(f"MEXC: Error extracting captcha token from browser data: {str(e)}") - return "" - - def generate_signature(self, method: str, path: str, params: Dict[str, Any], - timestamp: int, nonce: int) -> str: - """ - Generate signature for MEXC futures API requests - - This is reverse-engineered from the browser requests - """ - # This is a placeholder - the actual signature generation would need - # to be reverse-engineered from the browser's JavaScript - # For now, return empty string and rely on cookie authentication - return "" - - def open_long_position(self, symbol: str, volume: float, leverage: int = 200, - price: Optional[float] = None) -> Dict[str, Any]: - """ - Open a long futures position - - Args: - symbol: Trading symbol (e.g., 'ETH_USDT') - volume: Position size (contracts) - leverage: Leverage multiplier (default 200) - price: Limit price (None for market order) - - Returns: - dict: Order response with order ID - """ - if not self.is_authenticated: - logger.error("MEXC: Cannot open position - not authenticated") - return {'success': False, 'error': 'Not authenticated'} - - # First verify captcha - if not self.verify_captcha(symbol, 'openlong', f'{leverage}X'): - logger.error("MEXC: Captcha verification failed for opening long position") - return {'success': False, 'error': 'Captcha verification failed'} - - # Prepare order parameters based on the request dump - timestamp = int(time.time() * 1000) - nonce = timestamp - - order_data = { - 'symbol': symbol, - 'side': 1, # 1 = long, 2 = short - 'openType': 2, # Open position - 'type': '5', # Market order (might be '1' for limit) - 'vol': volume, - 'leverage': leverage, - 'marketCeiling': False, - 'priceProtect': '0', - 'ts': timestamp, - 'mhash': self._generate_mhash(), # This needs to be implemented - 'mtoken': self.visitor_id - } - - # Add price for limit orders - if price is not None: - order_data['price'] = price - order_data['type'] = '1' # Limit order - - # Add encrypted parameters (these would need proper implementation) - order_data['p0'] = self._encrypt_p0(order_data) # Placeholder - order_data['k0'] = self._encrypt_k0(order_data) # Placeholder - order_data['chash'] = self._generate_chash(order_data) # Placeholder - - # Setup headers for the order request - headers = { - 'Authorization': self.auth_token, - 'Content-Type': 'application/json', - 'Language': 'English', - 'x-language': 'en-GB', - 'x-mxc-nonce': str(nonce), - 'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce), - 'trochilus-uid': self.user_id, - 'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}", - 'Referer': 'https://www.mexc.com/' - } - - # Make the order request - url = f"{self.futures_api_url}/private/order/create" - - try: - # First make OPTIONS request (preflight) - options_response = self.session.options(url, headers=headers, timeout=10) - - if options_response.status_code == 200: - # Now make the actual POST request - response = self.session.post(url, json=order_data, headers=headers, timeout=15) - - if response.status_code == 200: - data = response.json() - if data.get('success') and data.get('code') == 0: - order_id = data.get('data', {}).get('orderId') - logger.info(f"MEXC: Long position opened successfully - Order ID: {order_id}") - return { - 'success': True, - 'order_id': order_id, - 'timestamp': data.get('data', {}).get('ts'), - 'symbol': symbol, - 'side': 'long', - 'volume': volume, - 'leverage': leverage - } - else: - logger.error(f"MEXC: Order failed: {data}") - return {'success': False, 'error': data.get('msg', 'Unknown error')} - else: - logger.error(f"MEXC: Order request failed with status {response.status_code}") - return {'success': False, 'error': f'HTTP {response.status_code}'} - else: - logger.error(f"MEXC: OPTIONS preflight failed with status {options_response.status_code}") - return {'success': False, 'error': f'Preflight failed: HTTP {options_response.status_code}'} - - except Exception as e: - logger.error(f"MEXC: Order execution error: {e}") - return {'success': False, 'error': str(e)} - - def close_long_position(self, symbol: str, volume: float, leverage: int = 200, - price: Optional[float] = None) -> Dict[str, Any]: - """ - Close a long futures position - - Args: - symbol: Trading symbol (e.g., 'ETH_USDT') - volume: Position size to close (contracts) - leverage: Leverage multiplier - price: Limit price (None for market order) - - Returns: - dict: Order response - """ - if not self.is_authenticated: - logger.error("MEXC: Cannot close position - not authenticated") - return {'success': False, 'error': 'Not authenticated'} - - # First verify captcha - if not self.verify_captcha(symbol, 'closelong', f'{leverage}X'): - logger.error("MEXC: Captcha verification failed for closing long position") - return {'success': False, 'error': 'Captcha verification failed'} - - # Similar to open_long_position but with closeType instead of openType - timestamp = int(time.time() * 1000) - nonce = timestamp - - order_data = { - 'symbol': symbol, - 'side': 2, # Close side is opposite - 'closeType': 1, # Close position - 'type': '5', # Market order - 'vol': volume, - 'leverage': leverage, - 'marketCeiling': False, - 'priceProtect': '0', - 'ts': timestamp, - 'mhash': self._generate_mhash(), - 'mtoken': self.visitor_id - } - - if price is not None: - order_data['price'] = price - order_data['type'] = '1' - - order_data['p0'] = self._encrypt_p0(order_data) - order_data['k0'] = self._encrypt_k0(order_data) - order_data['chash'] = self._generate_chash(order_data) - - return self._execute_order(order_data, 'close_long') - - def open_short_position(self, symbol: str, volume: float, leverage: int = 200, - price: Optional[float] = None) -> Dict[str, Any]: - """Open a short futures position""" - if not self.verify_captcha(symbol, 'openshort', f'{leverage}X'): - return {'success': False, 'error': 'Captcha verification failed'} - - order_data = { - 'symbol': symbol, - 'side': 2, # 2 = short - 'openType': 2, - 'type': '5', - 'vol': volume, - 'leverage': leverage, - 'marketCeiling': False, - 'priceProtect': '0', - 'ts': int(time.time() * 1000), - 'mhash': self._generate_mhash(), - 'mtoken': self.visitor_id - } - - if price is not None: - order_data['price'] = price - order_data['type'] = '1' - - order_data['p0'] = self._encrypt_p0(order_data) - order_data['k0'] = self._encrypt_k0(order_data) - order_data['chash'] = self._generate_chash(order_data) - - return self._execute_order(order_data, 'open_short') - - def close_short_position(self, symbol: str, volume: float, leverage: int = 200, - price: Optional[float] = None) -> Dict[str, Any]: - """Close a short futures position""" - if not self.verify_captcha(symbol, 'closeshort', f'{leverage}X'): - return {'success': False, 'error': 'Captcha verification failed'} - - order_data = { - 'symbol': symbol, - 'side': 1, # Close side is opposite - 'closeType': 1, - 'type': '5', - 'vol': volume, - 'leverage': leverage, - 'marketCeiling': False, - 'priceProtect': '0', - 'ts': int(time.time() * 1000), - 'mhash': self._generate_mhash(), - 'mtoken': self.visitor_id - } - - if price is not None: - order_data['price'] = price - order_data['type'] = '1' - - order_data['p0'] = self._encrypt_p0(order_data) - order_data['k0'] = self._encrypt_k0(order_data) - order_data['chash'] = self._generate_chash(order_data) - - return self._execute_order(order_data, 'close_short') - - def _execute_order(self, order_data: Dict[str, Any], action: str) -> Dict[str, Any]: - """Common order execution logic""" - timestamp = order_data['ts'] - nonce = timestamp - - headers = { - 'Authorization': self.auth_token, - 'Content-Type': 'application/json', - 'Language': 'English', - 'x-language': 'en-GB', - 'x-mxc-nonce': str(nonce), - 'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce), - 'trochilus-uid': self.user_id, - 'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}", - 'Referer': 'https://www.mexc.com/' - } - - url = f"{self.futures_api_url}/private/order/create" - - try: - response = self.session.post(url, json=order_data, headers=headers, timeout=15) - - if response.status_code == 200: - data = response.json() - if data.get('success') and data.get('code') == 0: - order_id = data.get('data', {}).get('orderId') - logger.info(f"MEXC: {action} executed successfully - Order ID: {order_id}") - return { - 'success': True, - 'order_id': order_id, - 'timestamp': data.get('data', {}).get('ts'), - 'action': action - } - else: - logger.error(f"MEXC: {action} failed: {data}") - return {'success': False, 'error': data.get('msg', 'Unknown error')} - else: - logger.error(f"MEXC: {action} request failed with status {response.status_code}") - return {'success': False, 'error': f'HTTP {response.status_code}'} - - except Exception as e: - logger.error(f"MEXC: {action} execution error: {e}") - return {'success': False, 'error': str(e)} - - # Placeholder methods for encryption/hashing - these need proper implementation - def _generate_mhash(self) -> str: - """Generate mhash parameter (needs reverse engineering)""" - return "a0015441fd4c3b6ba427b894b76cb7dd" # Placeholder from request dump - - def _encrypt_p0(self, order_data: Dict[str, Any]) -> str: - """Encrypt p0 parameter (needs reverse engineering)""" - return "placeholder_p0_encryption" # This needs proper implementation - - def _encrypt_k0(self, order_data: Dict[str, Any]) -> str: - """Encrypt k0 parameter (needs reverse engineering)""" - return "placeholder_k0_encryption" # This needs proper implementation - - def _generate_chash(self, order_data: Dict[str, Any]) -> str: - """Generate chash parameter (needs reverse engineering)""" - return "d6c64d28e362f314071b3f9d78ff7494d9cd7177ae0465e772d1840e9f7905d8" # Placeholder - - def get_account_info(self) -> Dict[str, Any]: - """Get account information including positions and balances""" - if not self.is_authenticated: - return {'success': False, 'error': 'Not authenticated'} - - # This would need to be implemented by reverse engineering the account info endpoints - logger.info("MEXC: Account info endpoint not yet implemented") - return {'success': False, 'error': 'Not implemented'} - - def get_open_positions(self) -> List[Dict[str, Any]]: - """Get list of open futures positions""" - if not self.is_authenticated: - return [] - - # This would need to be implemented by reverse engineering the positions endpoint - logger.info("MEXC: Open positions endpoint not yet implemented") - return [] \ No newline at end of file diff --git a/core/mexc_webclient/session_manager.py b/core/mexc_webclient/session_manager.py deleted file mode 100644 index 9b3c412..0000000 --- a/core/mexc_webclient/session_manager.py +++ /dev/null @@ -1,259 +0,0 @@ -""" -MEXC Session Manager - -Helper utilities for managing MEXC web sessions and extracting cookies from browser. -""" - -import logging -import json -import re -from typing import Dict, Optional, Any -from pathlib import Path - -logger = logging.getLogger(__name__) - -class MEXCSessionManager: - """ - Helper class for managing MEXC web sessions and extracting browser cookies - """ - - def __init__(self): - self.session_file = Path("mexc_session.json") - - def extract_cookies_from_network_tab(self, cookie_header: str) -> Dict[str, str]: - """ - Extract cookies from browser Network tab cookie header - - Args: - cookie_header: Raw cookie string from browser (copy from Request Headers) - - Returns: - Dictionary of parsed cookies - """ - cookies = {} - - # Remove 'Cookie: ' prefix if present - if cookie_header.startswith('Cookie: '): - cookie_header = cookie_header[8:] - elif cookie_header.startswith('cookie: '): - cookie_header = cookie_header[8:] - - # Split by semicolon and parse each cookie - cookie_pairs = cookie_header.split(';') - - for pair in cookie_pairs: - pair = pair.strip() - if '=' in pair: - name, value = pair.split('=', 1) - cookies[name.strip()] = value.strip() - - logger.info(f"Extracted {len(cookies)} cookies from browser") - return cookies - - def validate_session_cookies(self, cookies: Dict[str, str]) -> bool: - """ - Validate that essential cookies are present for authentication - - Args: - cookies: Dictionary of cookie name-value pairs - - Returns: - bool: True if cookies appear valid for authentication - """ - required_cookies = [ - 'uc_token', # User authentication token - 'u_id', # User ID - 'x-mxc-fingerprint', # Browser fingerprint - 'mexc_fingerprint_visitorId' # Visitor ID - ] - - missing_cookies = [] - for cookie_name in required_cookies: - if cookie_name not in cookies or not cookies[cookie_name]: - missing_cookies.append(cookie_name) - - if missing_cookies: - logger.warning(f"Missing required cookies: {missing_cookies}") - return False - - logger.info("All required cookies are present") - return True - - def save_session(self, cookies: Dict[str, str], metadata: Optional[Dict[str, Any]] = None): - """ - Save session cookies to file for reuse - - Args: - cookies: Dictionary of cookies to save - metadata: Optional metadata about the session - """ - session_data = { - 'cookies': cookies, - 'metadata': metadata or {}, - 'timestamp': int(time.time()) - } - - try: - with open(self.session_file, 'w') as f: - json.dump(session_data, f, indent=2) - logger.info(f"Session saved to {self.session_file}") - except Exception as e: - logger.error(f"Failed to save session: {e}") - - def load_session(self) -> Optional[Dict[str, str]]: - """ - Load session cookies from file - - Returns: - Dictionary of cookies if successful, None otherwise - """ - if not self.session_file.exists(): - logger.info("No saved session found") - return None - - try: - with open(self.session_file, 'r') as f: - session_data = json.load(f) - - cookies = session_data.get('cookies', {}) - timestamp = session_data.get('timestamp', 0) - - # Check if session is too old (24 hours) - import time - if time.time() - timestamp > 24 * 3600: - logger.warning("Saved session is too old (>24h), may be expired") - - if self.validate_session_cookies(cookies): - logger.info("Loaded valid session from file") - return cookies - else: - logger.warning("Loaded session has invalid cookies") - return None - - except Exception as e: - logger.error(f"Failed to load session: {e}") - return None - - def extract_from_curl_command(self, curl_command: str) -> Dict[str, str]: - """ - Extract cookies from a curl command copied from browser - - Args: - curl_command: Complete curl command from browser "Copy as cURL" - - Returns: - Dictionary of extracted cookies - """ - cookies = {} - - # Find cookie header in curl command - cookie_match = re.search(r'-H [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE) - if not cookie_match: - cookie_match = re.search(r'--header [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE) - - if cookie_match: - cookie_header = cookie_match.group(1) - cookies = self.extract_cookies_from_network_tab(cookie_header) - logger.info(f"Extracted {len(cookies)} cookies from curl command") - else: - logger.warning("No cookie header found in curl command") - - return cookies - - def print_cookie_extraction_guide(self): - """Print instructions for extracting cookies from browser""" - print("\n" + "="*80) - print("MEXC COOKIE EXTRACTION GUIDE") - print("="*80) - print(""" -To extract cookies from your browser for MEXC futures trading: - -METHOD 1: Browser Network Tab -1. Open MEXC futures page and log in: https://www.mexc.com/en-GB/futures/ETH_USDT -2. Open browser Developer Tools (F12) -3. Go to Network tab -4. Try to place a small futures trade (it will fail, but we need the request) -5. Find the request to 'futures.mexc.com' in the Network tab -6. Right-click on the request -> Copy -> Copy request headers -7. Find the 'Cookie:' line and copy everything after 'Cookie: ' - -METHOD 2: Copy as cURL -1. Follow steps 1-5 above -2. Right-click on the futures API request -> Copy -> Copy as cURL -3. Paste the entire cURL command - -METHOD 3: Manual Cookie Extraction -1. While logged into MEXC, press F12 -> Application/Storage tab -2. On the left, expand 'Cookies' -> click on 'https://www.mexc.com' -3. Copy the values for these important cookies: - - uc_token - - u_id - - x-mxc-fingerprint - - mexc_fingerprint_visitorId - -IMPORTANT NOTES: -- Cookies expire after some time (usually 24 hours) -- You must be logged into MEXC futures (not just spot trading) -- Keep your cookies secure - they provide access to your account -- Test with small amounts first - -Example usage: - session_manager = MEXCSessionManager() - - # Method 1: From cookie header - cookie_header = "uc_token=ABC123; u_id=DEF456; ..." - cookies = session_manager.extract_cookies_from_network_tab(cookie_header) - - # Method 2: From cURL command - curl_cmd = "curl 'https://futures.mexc.com/...' -H 'cookie: uc_token=ABC123...'" - cookies = session_manager.extract_from_curl_command(curl_cmd) - - # Save session for reuse - session_manager.save_session(cookies) - """) - print("="*80) - -if __name__ == "__main__": - # When run directly, show the extraction guide - import time - - manager = MEXCSessionManager() - manager.print_cookie_extraction_guide() - - print("\nWould you like to:") - print("1. Load saved session") - print("2. Extract cookies from clipboard") - print("3. Exit") - - choice = input("\nEnter choice (1-3): ").strip() - - if choice == "1": - cookies = manager.load_session() - if cookies: - print(f"\nLoaded {len(cookies)} cookies from saved session") - if manager.validate_session_cookies(cookies): - print("Session appears valid for trading") - else: - print("Warning: Session may be incomplete or expired") - else: - print("No valid saved session found") - - elif choice == "2": - print("\nPaste your cookie header or cURL command:") - user_input = input().strip() - - if user_input.startswith('curl'): - cookies = manager.extract_from_curl_command(user_input) - else: - cookies = manager.extract_cookies_from_network_tab(user_input) - - if cookies and manager.validate_session_cookies(cookies): - print(f"\nSuccessfully extracted {len(cookies)} valid cookies") - save = input("Save session for reuse? (y/n): ").strip().lower() - if save == 'y': - manager.save_session(cookies) - else: - print("Failed to extract valid cookies") - - else: - print("Goodbye!") \ No newline at end of file diff --git a/core/mexc_webclient/test_mexc_futures_webclient.py b/core/mexc_webclient/test_mexc_futures_webclient.py deleted file mode 100644 index f41b13b..0000000 --- a/core/mexc_webclient/test_mexc_futures_webclient.py +++ /dev/null @@ -1,346 +0,0 @@ -#!/usr/bin/env python3 -""" -Test MEXC Futures Web Client - -This script demonstrates how to use the MEXC Futures Web Client -for futures trading that isn't supported by their official API. - -IMPORTANT: This requires extracting cookies from your browser session. -""" - -import logging -import sys -import os -import time -import json -import uuid - -# Add the project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from mexc_futures_client import MEXCFuturesWebClient -from session_manager import MEXCSessionManager - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -# Constants -SYMBOL = "ETH_USDT" -LEVERAGE = 300 -CREDENTIALS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json') - -# Read credentials from mexc_credentials.json in JSON format -def load_credentials(): - credentials_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json') - cookies = {} - captcha_token_open = '' - captcha_token_close = '' - try: - with open(credentials_file, 'r') as f: - data = json.load(f) - cookies = data.get('credentials', {}).get('cookies', {}) - captcha_token_open = data.get('credentials', {}).get('captcha_token_open', '') - captcha_token_close = data.get('credentials', {}).get('captcha_token_close', '') - logger.info(f"Loaded credentials from {credentials_file}") - except Exception as e: - logger.error(f"Error loading credentials: {e}") - return cookies, captcha_token_open, captcha_token_close - -def test_basic_connection(): - """Test basic connection and authentication""" - logger.info("Testing MEXC Futures Web Client") - - # Initialize session manager - session_manager = MEXCSessionManager() - - # Try to load saved session first - cookies = session_manager.load_session() - - if not cookies: - # Explicitly load the cookies from the file we have - cookies_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'mexc_cookies_20250703_003625.json') - if os.path.exists(cookies_file): - try: - with open(cookies_file, 'r') as f: - cookies = json.load(f) - logger.info(f"Loaded cookies from {cookies_file}") - except Exception as e: - logger.error(f"Failed to load cookies from {cookies_file}: {e}") - cookies = None - else: - logger.error(f"Cookies file not found at {cookies_file}") - cookies = None - - if not cookies: - print("\nNo saved session found. You need to extract cookies from your browser.") - session_manager.print_cookie_extraction_guide() - - print("\nPaste your cookie header or cURL command (or press Enter to exit):") - user_input = input().strip() - - if not user_input: - print("No input provided. Exiting.") - return False - - # Extract cookies from user input - if user_input.startswith('curl'): - cookies = session_manager.extract_from_curl_command(user_input) - else: - cookies = session_manager.extract_cookies_from_network_tab(user_input) - - if not cookies: - logger.error("Failed to extract cookies from input") - return False - - # Validate and save session - if session_manager.validate_session_cookies(cookies): - session_manager.save_session(cookies) - logger.info("Session saved for future use") - else: - logger.warning("Extracted cookies may be incomplete") - - # Initialize the web client - client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='', base_url='https://www.mexc.com', headless=True) - # Load cookies into the client's session - for name, value in cookies.items(): - client.session.cookies.set(name, value) - - # Update headers to include additional parameters from captured requests - client.session.headers.update({ - 'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}", - 'trochilus-uid': cookies.get('u_id', ''), - 'Referer': 'https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap', - 'Language': 'English', - 'X-Language': 'en-GB' - }) - - if not client.is_authenticated: - logger.error("Failed to authenticate with extracted cookies") - return False - - logger.info("Successfully authenticated with MEXC") - logger.info(f"User ID: {client.user_id}") - logger.info(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "No auth token") - - return True - -def test_captcha_verification(client: MEXCFuturesWebClient): - """Test captcha verification system""" - logger.info("Testing captcha verification...") - - # Test captcha for ETH_USDT long position with 200x leverage - success = client.verify_captcha('ETH_USDT', 'openlong', '200X') - - if success: - logger.info("Captcha verification successful") - else: - logger.warning("Captcha verification failed - this may be normal if no position is being opened") - - return success - -def test_position_opening(client: MEXCFuturesWebClient, dry_run: bool = True): - """Test opening a position (dry run by default)""" - if dry_run: - logger.info("DRY RUN: Testing position opening (no actual trade)") - else: - logger.warning("LIVE TRADING: Opening actual position!") - - symbol = 'ETH_USDT' - volume = 1 # Small test position - leverage = 200 - - logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x") - - if not dry_run: - result = client.open_long_position(symbol, volume, leverage) - - if result['success']: - logger.info(f"Position opened successfully!") - logger.info(f"Order ID: {result['order_id']}") - logger.info(f"Timestamp: {result['timestamp']}") - return True - else: - logger.error(f"Failed to open position: {result['error']}") - return False - else: - logger.info("DRY RUN: Would attempt to open position here") - # Test just the captcha verification part - return client.verify_captcha(symbol, 'openlong', f'{leverage}X') - -def test_position_opening_live(client): - symbol = "ETH_USDT" - volume = 1 # Small volume for testing - leverage = 200 - - logger.info(f"LIVE TRADING: Opening actual position!") - logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x") - - result = client.open_long_position(symbol, volume, leverage) - if result.get('success'): - logger.info(f"Successfully opened position: {result}") - else: - logger.error(f"Failed to open position: {result.get('error', 'Unknown error')}") - -def interactive_menu(client: MEXCFuturesWebClient): - """Interactive menu for testing different functions""" - while True: - print("\n" + "="*50) - print("MEXC Futures Web Client Test Menu") - print("="*50) - print("1. Test captcha verification") - print("2. Test position opening (DRY RUN)") - print("3. Test position opening (LIVE - BE CAREFUL!)") - print("4. Test position closing (DRY RUN)") - print("5. Show session info") - print("6. Refresh session") - print("0. Exit") - - choice = input("\nEnter choice (0-6): ").strip() - - if choice == "1": - test_captcha_verification(client) - - elif choice == "2": - test_position_opening(client, dry_run=True) - - elif choice == "3": - test_position_opening_live(client) - - elif choice == "4": - logger.info("DRY RUN: Position closing test") - success = client.verify_captcha('ETH_USDT', 'closelong', '200X') - if success: - logger.info("DRY RUN: Would close position here") - else: - logger.warning("Captcha verification failed for position closing") - - elif choice == "5": - print(f"\nSession Information:") - print(f"Authenticated: {client.is_authenticated}") - print(f"User ID: {client.user_id}") - print(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "None") - print(f"Fingerprint: {client.fingerprint}") - print(f"Visitor ID: {client.visitor_id}") - - elif choice == "6": - session_manager = MEXCSessionManager() - session_manager.print_cookie_extraction_guide() - - elif choice == "0": - print("Goodbye!") - break - - else: - print("Invalid choice. Please try again.") - -def main(): - """Main test function""" - print("MEXC Futures Web Client Test") - print("WARNING: This is experimental software for futures trading") - print("Use at your own risk and test with small amounts first!") - - # Load cookies and tokens - cookies, captcha_token_open, captcha_token_close = load_credentials() - if not cookies: - logger.error("Failed to load cookies from credentials file") - sys.exit(1) - - # Initialize client with loaded cookies and tokens - client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='') - # Load cookies into the client's session - for name, value in cookies.items(): - client.session.cookies.set(name, value) - # Set captcha tokens - client.captcha_token_open = captcha_token_open - client.captcha_token_close = captcha_token_close - - # Try to load credentials from the new JSON file - try: - with open(CREDENTIALS_FILE, 'r') as f: - credentials_data = json.load(f) - cookies = credentials_data['credentials']['cookies'] - captcha_token_open = credentials_data['credentials']['captcha_token_open'] - captcha_token_close = credentials_data['credentials']['captcha_token_close'] - client.load_session_cookies(cookies) - client.session_manager.save_captcha_token(captcha_token_open) # Assuming this is for opening - - except FileNotFoundError: - logger.error(f"Credentials file not found at {CREDENTIALS_FILE}") - return False - except json.JSONDecodeError as e: - logger.error(f"Error loading credentials: {e}") - return False - except KeyError as e: - logger.error(f"Missing key in credentials file: {e}") - return False - - if not client.is_authenticated: - logger.error("Client not authenticated. Please ensure valid cookies and tokens are in mexc_credentials.json") - return False - - # Test connection and authentication - logger.info("Successfully authenticated with MEXC") - - # Set leverage - leverage_response = client.update_leverage(symbol=SYMBOL, leverage=LEVERAGE) - if leverage_response and leverage_response.get('code') == 200: - logger.info(f"Leverage set to {LEVERAGE}x for {SYMBOL}") - else: - logger.error(f"Failed to set leverage: {leverage_response}") - sys.exit(1) - - # Get current price - ticker = client.get_ticker_data(symbol=SYMBOL) - if ticker and ticker.get('code') == 200: - current_price = float(ticker['data']['last']) - logger.info(f"Current {SYMBOL} price: {current_price}") - else: - logger.error(f"Failed to get ticker data: {ticker}") - sys.exit(1) - - # Calculate order size for a small test trade (e.g., $10 worth) - trade_usdt = 10.0 - order_qty = round((trade_usdt / current_price) * LEVERAGE, 3) - logger.info(f"Calculated order quantity: {order_qty} {SYMBOL} for ~${trade_usdt} at {LEVERAGE}x") - - # Test 1: Open LONG position - logger.info(f"Opening LONG position for {SYMBOL} at {current_price} with qty {order_qty}") - open_long_order = client.create_order( - symbol=SYMBOL, - side=1, # 1 for BUY - position_side=1, # 1 for LONG - order_type=1, # 1 for LIMIT - price=current_price, - vol=order_qty - ) - if open_long_order and open_long_order.get('code') == 200: - logger.info(f"✅ Successfully opened LONG position: {open_long_order['data']}") - else: - logger.error(f"❌ Failed to open LONG position: {open_long_order}") - sys.exit(1) - - # Test 2: Close LONG position - logger.info(f"Closing LONG position for {SYMBOL}") - close_long_order = client.create_order( - symbol=SYMBOL, - side=2, # 2 for SELL - position_side=1, # 1 for LONG - order_type=1, # 1 for LIMIT - price=current_price, - vol=order_qty, - reduce_only=True - ) - if close_long_order and close_long_order.get('code') == 200: - logger.info(f"✅ Successfully closed LONG position: {close_long_order['data']}") - else: - logger.error(f"❌ Failed to close LONG position: {close_long_order}") - sys.exit(1) - - logger.info("All tests completed successfully!") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/core/negative_case_trainer.py b/core/negative_case_trainer.py deleted file mode 100644 index 089ef0f..0000000 --- a/core/negative_case_trainer.py +++ /dev/null @@ -1,595 +0,0 @@ -""" -Negative Case Trainer - Intensive Training on Losing Trades - -This module focuses on learning from losses to prevent future mistakes. -Stores negative cases in testcases/negative folder for reuse and retraining. -Supports simultaneous inference and training. -""" - -import os -import json -import logging -import pickle -import threading -import time -from datetime import datetime, timedelta -from typing import Dict, List, Any, Optional, Tuple -from dataclasses import dataclass, asdict -from collections import deque -import numpy as np -import pandas as pd - -# Import checkpoint management -import torch -from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint -from utils.training_integration import get_training_integration - -logger = logging.getLogger(__name__) - -@dataclass -class NegativeCase: - """Represents a losing trade case for intensive training""" - case_id: str - timestamp: datetime - symbol: str - action: str # 'BUY' or 'SELL' - entry_price: float - exit_price: float - loss_amount: float - loss_percentage: float - confidence_used: float - market_state_before: Dict[str, Any] - market_state_after: Dict[str, Any] - tick_data: List[Dict[str, Any]] # 15 minutes of tick data around the trade - technical_indicators: Dict[str, float] - what_should_have_been_done: str # 'HOLD', 'OPPOSITE', 'WAIT' - lesson_learned: str - training_priority: int # 1-5, 5 being highest priority - retraining_count: int = 0 - last_retrained: Optional[datetime] = None - -@dataclass -class TrainingSession: - """Represents an intensive training session on negative cases""" - session_id: str - start_time: datetime - cases_trained: List[str] # case_ids - epochs_completed: int - loss_improvement: float - accuracy_improvement: float - inference_paused: bool = False - training_active: bool = True - -class NegativeCaseTrainer: - """ - Intensive trainer focused on learning from losing trades with checkpoint management - - Features: - - Stores all losing trades as negative cases - - Intensive retraining on losses - - Simultaneous inference and training - - Persistent storage in testcases/negative - - Priority-based training (bigger losses = higher priority) - - Checkpoint management for training progress - """ - - def __init__(self, storage_dir: str = "testcases/negative", - model_name: str = "negative_case_trainer", enable_checkpoints: bool = True): - self.storage_dir = storage_dir - self.stored_cases: List[NegativeCase] = [] - self.training_queue = deque(maxlen=1000) - self.training_lock = threading.Lock() - self.inference_lock = threading.Lock() - - # Checkpoint management - self.model_name = model_name - self.enable_checkpoints = enable_checkpoints - self.training_integration = get_training_integration() if enable_checkpoints else None - self.training_session_count = 0 - self.best_loss_reduction = 0.0 - self.checkpoint_frequency = 25 # Save checkpoint every 25 training sessions - - # Training configuration - self.max_concurrent_training = 3 # Max parallel training sessions - self.intensive_training_epochs = 50 # Epochs per negative case - self.priority_multiplier = 2.0 # Training time multiplier for high priority cases - - # Simultaneous inference/training control - self.inference_active = True - self.training_active = False - self.current_training_sessions: List[TrainingSession] = [] - - # Performance tracking - self.total_cases_processed = 0 - self.total_training_time = 0.0 - self.accuracy_improvements = [] - - # Initialize storage - self._initialize_storage() - self._load_existing_cases() - - # Load best checkpoint if available - if self.enable_checkpoints: - self.load_best_checkpoint() - - # Start background training thread - self.training_thread = threading.Thread(target=self._background_training_loop, daemon=True) - self.training_thread.start() - - logger.info(f"NegativeCaseTrainer initialized with {len(self.stored_cases)} existing cases") - logger.info(f"Storage directory: {self.storage_dir}") - logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}") - logger.info("Background training thread started") - - def _initialize_storage(self): - """Initialize storage directories""" - try: - os.makedirs(self.storage_dir, exist_ok=True) - os.makedirs(f"{self.storage_dir}/cases", exist_ok=True) - os.makedirs(f"{self.storage_dir}/sessions", exist_ok=True) - os.makedirs(f"{self.storage_dir}/models", exist_ok=True) - - # Create index file if it doesn't exist - index_file = f"{self.storage_dir}/case_index.json" - if not os.path.exists(index_file): - with open(index_file, 'w') as f: - json.dump({"cases": [], "last_updated": datetime.now().isoformat()}, f) - - logger.info(f"Storage initialized at {self.storage_dir}") - - except Exception as e: - logger.error(f"Error initializing storage: {e}") - - def _load_existing_cases(self): - """Load existing negative cases from storage""" - try: - index_file = f"{self.storage_dir}/case_index.json" - if os.path.exists(index_file): - with open(index_file, 'r') as f: - index_data = json.load(f) - - for case_info in index_data.get("cases", []): - case_file = f"{self.storage_dir}/cases/{case_info['case_id']}.pkl" - if os.path.exists(case_file): - try: - with open(case_file, 'rb') as f: - case = pickle.load(f) - self.stored_cases.append(case) - except Exception as e: - logger.warning(f"Error loading case {case_info['case_id']}: {e}") - - logger.info(f"Loaded {len(self.stored_cases)} existing negative cases") - - except Exception as e: - logger.error(f"Error loading existing cases: {e}") - - def add_losing_trade(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str: - """ - Add a losing trade as a negative case for intensive training - - Args: - trade_info: Trade information including P&L - market_data: Market state and tick data around the trade - - Returns: - case_id: Unique identifier for the negative case - """ - try: - # Generate unique case ID - case_id = f"loss_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{trade_info['symbol'].replace('/', '')}" - - # Calculate loss metrics - loss_amount = abs(trade_info.get('pnl', 0)) - loss_percentage = (loss_amount / trade_info.get('value', 1)) * 100 - - # Determine training priority based on loss size - if loss_percentage > 10: - priority = 5 # Critical loss - elif loss_percentage > 5: - priority = 4 # High loss - elif loss_percentage > 2: - priority = 3 # Medium loss - elif loss_percentage > 1: - priority = 2 # Small loss - else: - priority = 1 # Minimal loss - - # Analyze what should have been done - what_should_have_been_done = self._analyze_optimal_action(trade_info, market_data) - lesson_learned = self._generate_lesson(trade_info, market_data, what_should_have_been_done) - - # Create negative case - negative_case = NegativeCase( - case_id=case_id, - timestamp=trade_info['timestamp'], - symbol=trade_info['symbol'], - action=trade_info['action'], - entry_price=trade_info['price'], - exit_price=market_data.get('exit_price', trade_info['price']), - loss_amount=loss_amount, - loss_percentage=loss_percentage, - confidence_used=trade_info.get('confidence', 0.5), - market_state_before=market_data.get('state_before', {}), - market_state_after=market_data.get('state_after', {}), - tick_data=market_data.get('tick_data', []), - technical_indicators=market_data.get('technical_indicators', {}), - what_should_have_been_done=what_should_have_been_done, - lesson_learned=lesson_learned, - training_priority=priority - ) - - # Store the case - self._store_case(negative_case) - - # Add to training queue with priority - with self.training_lock: - self.training_queue.append(negative_case) - self.stored_cases.append(negative_case) - - logger.error(f"NEGATIVE CASE ADDED: {case_id} | Loss: ${loss_amount:.2f} ({loss_percentage:.1f}%) | Priority: {priority}") - logger.error(f"Lesson: {lesson_learned}") - - return case_id - - except Exception as e: - logger.error(f"Error adding losing trade: {e}") - return "" - - def _analyze_optimal_action(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str: - """Analyze what the optimal action should have been""" - try: - # Simple analysis based on price movement - entry_price = trade_info['price'] - exit_price = market_data.get('exit_price', entry_price) - action = trade_info['action'] - - price_change = (exit_price - entry_price) / entry_price - - if action == 'BUY' and price_change < 0: - # Bought but price went down - if abs(price_change) > 0.005: # >0.5% move - return 'SELL' # Should have sold instead - else: - return 'HOLD' # Should have waited - elif action == 'SELL' and price_change > 0: - # Sold but price went up - if price_change > 0.005: # >0.5% move - return 'BUY' # Should have bought instead - else: - return 'HOLD' # Should have waited - else: - return 'HOLD' # Should have done nothing - - except Exception as e: - logger.error(f"Error analyzing optimal action: {e}") - return 'HOLD' - - def _generate_lesson(self, trade_info: Dict[str, Any], market_data: Dict[str, Any], optimal_action: str) -> str: - """Generate a lesson learned from the losing trade""" - try: - action = trade_info['action'] - symbol = trade_info['symbol'] - loss_pct = (abs(trade_info.get('pnl', 0)) / trade_info.get('value', 1)) * 100 - confidence = trade_info.get('confidence', 0.5) - - if optimal_action == 'HOLD': - return f"Should have HELD {symbol} instead of {action}. Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss." - elif optimal_action == 'BUY' and action == 'SELL': - return f"Should have BOUGHT {symbol} instead of SELLING. Market moved opposite to prediction." - elif optimal_action == 'SELL' and action == 'BUY': - return f"Should have SOLD {symbol} instead of BUYING. Market moved opposite to prediction." - else: - return f"Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss on {action} {symbol}." - - except Exception as e: - logger.error(f"Error generating lesson: {e}") - return "Learn from this loss to improve future decisions." - - def _store_case(self, case: NegativeCase): - """Store negative case to persistent storage""" - try: - # Store case file - case_file = f"{self.storage_dir}/cases/{case.case_id}.pkl" - with open(case_file, 'wb') as f: - pickle.dump(case, f) - - # Update index - index_file = f"{self.storage_dir}/case_index.json" - with open(index_file, 'r') as f: - index_data = json.load(f) - - # Add case to index - case_info = { - 'case_id': case.case_id, - 'timestamp': case.timestamp.isoformat(), - 'symbol': case.symbol, - 'loss_amount': case.loss_amount, - 'loss_percentage': case.loss_percentage, - 'training_priority': case.training_priority, - 'retraining_count': case.retraining_count - } - - index_data['cases'].append(case_info) - index_data['last_updated'] = datetime.now().isoformat() - - with open(index_file, 'w') as f: - json.dump(index_data, f, indent=2) - - logger.info(f"Stored negative case: {case.case_id}") - - except Exception as e: - logger.error(f"Error storing case: {e}") - - def _background_training_loop(self): - """Background loop for intensive training on negative cases""" - logger.info("Background training loop started") - - while True: - try: - # Check if we have cases to train on - with self.training_lock: - if not self.training_queue: - time.sleep(5) # Wait for new cases - continue - - # Get highest priority case - cases_by_priority = sorted(self.training_queue, key=lambda x: x.training_priority, reverse=True) - case_to_train = cases_by_priority[0] - self.training_queue.remove(case_to_train) - - # Start intensive training session - self._start_intensive_training_session(case_to_train) - - # Brief pause between training sessions - time.sleep(2) - - except Exception as e: - logger.error(f"Error in background training loop: {e}") - time.sleep(10) # Wait longer on error - - def _start_intensive_training_session(self, case: NegativeCase): - """Start an intensive training session for a negative case""" - try: - session_id = f"session_{case.case_id}_{int(time.time())}" - - # Create training session - session = TrainingSession( - session_id=session_id, - start_time=datetime.now(), - cases_trained=[case.case_id], - epochs_completed=0, - loss_improvement=0.0, - accuracy_improvement=0.0 - ) - - self.current_training_sessions.append(session) - self.training_active = True - - logger.warning(f"INTENSIVE TRAINING STARTED: {session_id}") - logger.warning(f"Training on loss case: {case.case_id} (Priority: {case.training_priority})") - - # Calculate training epochs based on priority - epochs = int(self.intensive_training_epochs * case.training_priority * self.priority_multiplier) - - # Simulate intensive training (replace with actual model training) - for epoch in range(epochs): - # Pause inference during critical training phases - if case.training_priority >= 4 and epoch % 10 == 0: - with self.inference_lock: - session.inference_paused = True - time.sleep(0.1) # Brief pause for critical training - session.inference_paused = False - - # Simulate training step - session.epochs_completed = epoch + 1 - - # Log progress for high priority cases - if case.training_priority >= 4 and epoch % 10 == 0: - logger.warning(f"Intensive training progress: {epoch}/{epochs} epochs ({case.case_id})") - - time.sleep(0.05) # Simulate training time - - # Update case retraining info - case.retraining_count += 1 - case.last_retrained = datetime.now() - - # Calculate improvements (simulated) - session.loss_improvement = np.random.uniform(0.1, 0.5) # 10-50% improvement - session.accuracy_improvement = np.random.uniform(0.05, 0.2) # 5-20% improvement - - # Store training session results - self._store_training_session(session) - - # Update statistics - self.total_cases_processed += 1 - self.total_training_time += (datetime.now() - session.start_time).total_seconds() - self.accuracy_improvements.append(session.accuracy_improvement) - - # Remove from active sessions - self.current_training_sessions.remove(session) - if not self.current_training_sessions: - self.training_active = False - - logger.warning(f"INTENSIVE TRAINING COMPLETED: {session_id}") - logger.warning(f"Epochs: {session.epochs_completed} | Loss improvement: {session.loss_improvement:.1%} | Accuracy improvement: {session.accuracy_improvement:.1%}") - - except Exception as e: - logger.error(f"Error in intensive training session: {e}") - - def _store_training_session(self, session: TrainingSession): - """Store training session results""" - try: - session_file = f"{self.storage_dir}/sessions/{session.session_id}.json" - session_data = { - 'session_id': session.session_id, - 'start_time': session.start_time.isoformat(), - 'end_time': datetime.now().isoformat(), - 'cases_trained': session.cases_trained, - 'epochs_completed': session.epochs_completed, - 'loss_improvement': session.loss_improvement, - 'accuracy_improvement': session.accuracy_improvement - } - - with open(session_file, 'w') as f: - json.dump(session_data, f, indent=2) - - except Exception as e: - logger.error(f"Error storing training session: {e}") - - def can_inference_proceed(self) -> bool: - """Check if inference can proceed (not blocked by critical training)""" - with self.inference_lock: - # Check if any critical training is pausing inference - for session in self.current_training_sessions: - if session.inference_paused: - return False - return True - - def get_training_stats(self) -> Dict[str, Any]: - """Get training statistics""" - try: - avg_accuracy_improvement = np.mean(self.accuracy_improvements) if self.accuracy_improvements else 0.0 - - return { - 'total_negative_cases': len(self.stored_cases), - 'cases_in_queue': len(self.training_queue), - 'total_cases_processed': self.total_cases_processed, - 'total_training_time': self.total_training_time, - 'avg_accuracy_improvement': avg_accuracy_improvement, - 'active_training_sessions': len(self.current_training_sessions), - 'training_active': self.training_active, - 'high_priority_cases': len([c for c in self.stored_cases if c.training_priority >= 4]), - 'storage_directory': self.storage_dir - } - - except Exception as e: - logger.error(f"Error getting training stats: {e}") - return {} - - def get_recent_lessons(self, count: int = 5) -> List[str]: - """Get recent lessons learned from negative cases""" - try: - recent_cases = sorted(self.stored_cases, key=lambda x: x.timestamp, reverse=True)[:count] - return [case.lesson_learned for case in recent_cases] - except Exception as e: - logger.error(f"Error getting recent lessons: {e}") - return [] - - def retrain_all_cases(self): - """Retrain all stored negative cases (for periodic retraining)""" - try: - logger.warning("RETRAINING ALL NEGATIVE CASES - This may take a while...") - - with self.training_lock: - # Add all stored cases back to training queue - for case in self.stored_cases: - if case not in self.training_queue: - self.training_queue.append(case) - - logger.warning(f"Added {len(self.stored_cases)} cases to retraining queue") - - except Exception as e: - logger.error(f"Error retraining all cases: {e}") - - def load_best_checkpoint(self): - """Load the best checkpoint for this negative case trainer""" - try: - if not self.enable_checkpoints: - return - - result = load_best_checkpoint(self.model_name) - if result: - file_path, metadata = result - checkpoint = torch.load(file_path, map_location='cpu') - - # Load training state - if 'training_session_count' in checkpoint: - self.training_session_count = checkpoint['training_session_count'] - if 'best_loss_reduction' in checkpoint: - self.best_loss_reduction = checkpoint['best_loss_reduction'] - if 'total_cases_processed' in checkpoint: - self.total_cases_processed = checkpoint['total_cases_processed'] - if 'total_training_time' in checkpoint: - self.total_training_time = checkpoint['total_training_time'] - if 'accuracy_improvements' in checkpoint: - self.accuracy_improvements = checkpoint['accuracy_improvements'] - - logger.info(f"Loaded NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}") - logger.info(f"Session: {self.training_session_count}, Best loss reduction: {self.best_loss_reduction:.4f}") - - except Exception as e: - logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}") - - def save_checkpoint(self, loss_improvement: float = 0.0, force_save: bool = False): - """Save checkpoint if performance improved or forced""" - try: - if not self.enable_checkpoints: - return False - - self.training_session_count += 1 - - # Update best loss reduction - improved = False - if loss_improvement > self.best_loss_reduction: - self.best_loss_reduction = loss_improvement - improved = True - - # Save checkpoint if improved, forced, or at regular intervals - should_save = ( - force_save or - improved or - self.training_session_count % self.checkpoint_frequency == 0 - ) - - if should_save: - # Prepare checkpoint data - checkpoint_data = { - 'training_session_count': self.training_session_count, - 'best_loss_reduction': self.best_loss_reduction, - 'total_cases_processed': self.total_cases_processed, - 'total_training_time': self.total_training_time, - 'accuracy_improvements': self.accuracy_improvements, - 'storage_dir': self.storage_dir, - 'max_concurrent_training': self.max_concurrent_training, - 'intensive_training_epochs': self.intensive_training_epochs - } - - # Create performance metrics for checkpoint manager - avg_accuracy_improvement = ( - sum(self.accuracy_improvements) / len(self.accuracy_improvements) - if self.accuracy_improvements else 0.0 - ) - - performance_metrics = { - 'loss_reduction': self.best_loss_reduction, - 'avg_accuracy_improvement': avg_accuracy_improvement, - 'total_cases_processed': self.total_cases_processed, - 'training_efficiency': ( - self.total_cases_processed / self.total_training_time - if self.total_training_time > 0 else 0.0 - ) - } - - # Save using checkpoint manager - metadata = save_checkpoint( - model=checkpoint_data, # We're saving data dict instead of model - model_name=self.model_name, - model_type="negative_case_trainer", - performance_metrics=performance_metrics, - training_metadata={ - 'session': self.training_session_count, - 'cases_processed': self.total_cases_processed, - 'training_time_hours': self.total_training_time / 3600 - }, - force_save=force_save - ) - - if metadata: - logger.info(f"Saved NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}") - return True - - return False - - except Exception as e: - logger.error(f"Error saving NegativeCaseTrainer checkpoint: {e}") - return False \ No newline at end of file diff --git a/core/nn_decision_fusion.py b/core/nn_decision_fusion.py deleted file mode 100644 index c6d9a4e..0000000 --- a/core/nn_decision_fusion.py +++ /dev/null @@ -1,277 +0,0 @@ -#!/usr/bin/env python3 -""" -Neural Network Decision Fusion System -Central NN that merges all model outputs + market data for final trading decisions -""" - -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from typing import Dict, List, Optional, Any -from dataclasses import dataclass -from datetime import datetime -import logging - -logger = logging.getLogger(__name__) - -@dataclass -class ModelPrediction: - """Standardized prediction from any model""" - model_name: str - prediction_type: str # 'price', 'direction', 'action' - value: float # -1 to 1 for direction, actual price for price predictions - confidence: float # 0 to 1 - timestamp: datetime - metadata: Optional[Dict[str, Any]] = None - -@dataclass -class MarketContext: - """Current market context for decision fusion""" - symbol: str - current_price: float - price_change_1m: float - price_change_5m: float - volume_ratio: float - volatility: float - timestamp: datetime - -@dataclass -class FusionDecision: - """Final trading decision from fusion NN""" - action: str # 'BUY', 'SELL', 'HOLD' - confidence: float # 0 to 1 - expected_return: float # Expected return percentage - risk_score: float # 0 to 1, higher = riskier - position_size: float # Recommended position size - reasoning: str # Human-readable explanation - model_contributions: Dict[str, float] # How much each model contributed - timestamp: datetime - -class DecisionFusionNetwork(nn.Module): - """Small NN that fuses model predictions with market context""" - - def __init__(self, input_dim: int = 32, hidden_dim: int = 64): - super().__init__() - - self.fusion_layers = nn.Sequential( - nn.Linear(input_dim, hidden_dim), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(hidden_dim, hidden_dim // 2), - nn.ReLU(), - nn.Linear(hidden_dim // 2, 16) - ) - - # Output heads - self.action_head = nn.Linear(16, 3) # BUY, SELL, HOLD - self.confidence_head = nn.Linear(16, 1) - self.return_head = nn.Linear(16, 1) - - def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]: - """Forward pass through fusion network""" - fusion_output = self.fusion_layers(features) - - action_logits = self.action_head(fusion_output) - action_probs = F.softmax(action_logits, dim=1) - - confidence = torch.sigmoid(self.confidence_head(fusion_output)) - expected_return = torch.tanh(self.return_head(fusion_output)) - - return { - 'action_probs': action_probs, - 'confidence': confidence.squeeze(), - 'expected_return': expected_return.squeeze() - } - -class NeuralDecisionFusion: - """Main NN-based decision fusion system""" - - def __init__(self, training_mode: bool = True): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.network = DecisionFusionNetwork().to(self.device) - self.training_mode = training_mode - self.registered_models = {} - self.last_predictions = {} - - logger.info(f"Neural Decision Fusion initialized on {self.device}") - - def register_model(self, model_name: str, model_type: str, prediction_format: str): - """Register a model that will provide predictions""" - self.registered_models[model_name] = { - 'type': model_type, - 'format': prediction_format, - 'prediction_count': 0 - } - logger.info(f"Registered NN model: {model_name} ({model_type})") - - def add_prediction(self, prediction: ModelPrediction): - """Add a prediction from a registered model""" - self.last_predictions[prediction.model_name] = prediction - if prediction.model_name in self.registered_models: - self.registered_models[prediction.model_name]['prediction_count'] += 1 - - logger.debug(f"🔮 {prediction.model_name}: {prediction.value:.3f} " - f"(confidence: {prediction.confidence:.3f})") - - def make_decision(self, symbol: str, market_context: MarketContext, - min_confidence: float = 0.25) -> Optional[FusionDecision]: - """Make NN-driven trading decision""" - try: - if len(self.last_predictions) < 1: - logger.debug("No NN predictions available") - return None - - # Prepare features - features = self._prepare_features(market_context) - if features is None: - return None - - # Run NN inference - with torch.no_grad(): - self.network.eval() - features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.device) - outputs = self.network(features_tensor) - - action_probs = outputs['action_probs'][0].cpu().numpy() - confidence = outputs['confidence'].cpu().item() - expected_return = outputs['expected_return'].cpu().item() - - # Determine action - action_idx = np.argmax(action_probs) - actions = ['BUY', 'SELL', 'HOLD'] - action = actions[action_idx] - - # Check confidence threshold - if confidence < min_confidence: - action = 'HOLD' - logger.debug(f"Low NN confidence ({confidence:.3f}), defaulting to HOLD") - - # Calculate position size - position_size = self._calculate_position_size(confidence, expected_return) - - # Generate reasoning - reasoning = self._generate_reasoning(action, confidence, expected_return, action_probs) - - # Calculate risk score and model contributions - risk_score = min(1.0, abs(expected_return) * 5 + (1 - confidence) * 0.5) - model_contributions = self._calculate_model_contributions() - - decision = FusionDecision( - action=action, - confidence=confidence, - expected_return=expected_return, - risk_score=risk_score, - position_size=position_size, - reasoning=reasoning, - model_contributions=model_contributions, - timestamp=datetime.now() - ) - - logger.info(f"🧠 NN DECISION: {action} (conf: {confidence:.3f}, " - f"return: {expected_return:.3f}, size: {position_size:.4f})") - - return decision - - except Exception as e: - logger.error(f"Error in NN decision making: {e}") - return None - - def _prepare_features(self, context: MarketContext) -> Optional[np.ndarray]: - """Prepare feature vector for NN""" - try: - features = np.zeros(32) - - # Model predictions (slots 0-15) - idx = 0 - for model_name, prediction in self.last_predictions.items(): - if idx < 14: # Leave room for other features - features[idx] = prediction.value - features[idx + 1] = prediction.confidence - idx += 2 - - # Market context (slots 16-31) - features[16] = np.tanh(context.price_change_1m * 100) # 1m change - features[17] = np.tanh(context.price_change_5m * 100) # 5m change - features[18] = np.tanh(context.volume_ratio - 1) # Volume ratio - features[19] = np.tanh(context.volatility * 100) # Volatility - features[20] = context.current_price / 10000.0 # Normalized price - - # Time features - now = context.timestamp - features[21] = now.hour / 24.0 - features[22] = now.weekday() / 7.0 - - # Model agreement features - if len(self.last_predictions) >= 2: - values = [p.value for p in self.last_predictions.values()] - features[23] = np.mean(values) # Average prediction - features[24] = np.std(values) # Prediction variance - features[25] = len(self.last_predictions) # Model count - - return features - - except Exception as e: - logger.error(f"Error preparing NN features: {e}") - return None - - def _calculate_position_size(self, confidence: float, expected_return: float) -> float: - """Calculate position size based on NN outputs""" - base_size = 0.01 # 0.01 ETH base - - # Scale by confidence - confidence_multiplier = max(0.1, min(2.0, confidence * 1.5)) - - # Scale by expected return - return_multiplier = 1.0 + abs(expected_return) * 0.5 - - final_size = base_size * confidence_multiplier * return_multiplier - return max(0.001, min(0.05, final_size)) - - def _generate_reasoning(self, action: str, confidence: float, - expected_return: float, action_probs: np.ndarray) -> str: - """Generate human-readable reasoning""" - reasons = [] - - if action == 'BUY': - reasons.append(f"NN suggests BUY ({action_probs[0]:.1%})") - elif action == 'SELL': - reasons.append(f"NN suggests SELL ({action_probs[1]:.1%})") - else: - reasons.append(f"NN suggests HOLD") - - if confidence > 0.7: - reasons.append("High confidence") - elif confidence > 0.5: - reasons.append("Moderate confidence") - else: - reasons.append("Low confidence") - - if abs(expected_return) > 0.01: - direction = "positive" if expected_return > 0 else "negative" - reasons.append(f"Expected {direction} return: {expected_return:.2%}") - - reasons.append(f"Based on {len(self.last_predictions)} NN models") - - return " | ".join(reasons) - - def _calculate_model_contributions(self) -> Dict[str, float]: - """Calculate how much each model contributed to the decision""" - contributions = {} - total_confidence = sum(p.confidence for p in self.last_predictions.values()) if self.last_predictions else 1.0 - - if total_confidence > 0: - for model_name, prediction in self.last_predictions.items(): - contributions[model_name] = prediction.confidence / total_confidence - - return contributions - - def get_status(self) -> Dict[str, Any]: - """Get NN fusion system status""" - return { - 'device': str(self.device), - 'training_mode': self.training_mode, - 'registered_models': len(self.registered_models), - 'recent_predictions': len(self.last_predictions), - 'model_parameters': sum(p.numel() for p in self.network.parameters()) - } \ No newline at end of file diff --git a/core/prediction_tracker.py b/core/prediction_tracker.py deleted file mode 100644 index 6de008e..0000000 Binary files a/core/prediction_tracker.py and /dev/null differ diff --git a/core/realtime_tick_processor.py b/core/realtime_tick_processor.py deleted file mode 100644 index 904a5fc..0000000 --- a/core/realtime_tick_processor.py +++ /dev/null @@ -1,649 +0,0 @@ -""" -Real-Time Tick Processing Neural Network Module - -This module acts as a Neural Network DPS (Data Processing System) alternative, -processing raw tick data with ultra-low latency and feeding processed features -to trading models in real-time. - -Features: -- Real-time tick ingestion with volume processing -- Neural network feature extraction from tick streams -- Ultra-low latency processing (sub-millisecond) -- Volume-weighted price analysis -- Microstructure pattern detection -- Real-time feature streaming to models -""" - -import asyncio -import logging -import time -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Any, Deque -from collections import deque -from threading import Thread, Lock -import websockets -import json -from dataclasses import dataclass - -logger = logging.getLogger(__name__) - -@dataclass -class TickData: - """Raw tick data structure""" - timestamp: datetime - price: float - volume: float - side: str # 'buy' or 'sell' - trade_id: Optional[str] = None - -@dataclass -class ProcessedTickFeatures: - """Processed tick features for model consumption""" - timestamp: datetime - price_features: np.ndarray # Price-based features - volume_features: np.ndarray # Volume-based features - microstructure_features: np.ndarray # Market microstructure features - neural_features: np.ndarray # Neural network extracted features - confidence: float # Feature quality confidence - -class TickProcessingNN(nn.Module): - """ - Neural Network for real-time tick processing - Extracts high-level features from raw tick data - """ - - def __init__(self, input_size: int = 9, hidden_size: int = 128, output_size: int = 64): - super(TickProcessingNN, self).__init__() - - # Tick sequence processing layers - self.tick_encoder = nn.Sequential( - nn.Linear(input_size, hidden_size), - nn.ReLU(), - nn.Dropout(0.1), - nn.Linear(hidden_size, hidden_size), - nn.ReLU(), - nn.Dropout(0.1) - ) - - # LSTM for temporal patterns - self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, num_layers=2) - - # Attention mechanism for important tick selection - self.attention = nn.MultiheadAttention(hidden_size, num_heads=8, batch_first=True) - - # Feature extraction heads - self.price_head = nn.Linear(hidden_size, 16) # Price pattern features - self.volume_head = nn.Linear(hidden_size, 16) # Volume pattern features - self.microstructure_head = nn.Linear(hidden_size, 16) # Microstructure features - - # Final feature fusion - self.feature_fusion = nn.Sequential( - nn.Linear(48, output_size), # 16+16+16 = 48 - nn.ReLU(), - nn.Linear(output_size, output_size) - ) - - # Confidence estimation - self.confidence_head = nn.Sequential( - nn.Linear(output_size, 32), - nn.ReLU(), - nn.Linear(32, 1), - nn.Sigmoid() - ) - - def forward(self, tick_sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Process tick sequence and extract features - - Args: - tick_sequence: [batch, sequence_length, features] - - Returns: - features: [batch, output_size] - extracted features - confidence: [batch, 1] - feature confidence - """ - batch_size, seq_len, _ = tick_sequence.shape - - # Encode each tick - encoded = self.tick_encoder(tick_sequence) # [batch, seq_len, hidden_size] - - # LSTM processing for temporal patterns - lstm_out, _ = self.lstm(encoded) # [batch, seq_len, hidden_size] - - # Attention to focus on important ticks - attended, _ = self.attention(lstm_out, lstm_out, lstm_out) # [batch, seq_len, hidden_size] - - # Use the last attended output - final_features = attended[:, -1, :] # [batch, hidden_size] - - # Extract specialized features - price_features = self.price_head(final_features) - volume_features = self.volume_head(final_features) - microstructure_features = self.microstructure_head(final_features) - - # Fuse all features - combined_features = torch.cat([price_features, volume_features, microstructure_features], dim=1) - final_features = self.feature_fusion(combined_features) - - # Estimate confidence - confidence = self.confidence_head(final_features) - - return final_features, confidence - -class RealTimeTickProcessor: - """ - Real-time tick processing system with neural network feature extraction - Acts as a DPS alternative for ultra-low latency tick processing - """ - - def __init__(self, symbols: List[str] = None, tick_buffer_size: int = 1000): - """Initialize the real-time tick processor""" - self.symbols = symbols or ['ETH/USDT', 'BTC/USDT'] - self.tick_buffer_size = tick_buffer_size - - # Tick storage buffers - self.tick_buffers: Dict[str, Deque[TickData]] = {} - self.processed_features: Dict[str, Deque[ProcessedTickFeatures]] = {} - - # Initialize buffers for each symbol - for symbol in self.symbols: - self.tick_buffers[symbol] = deque(maxlen=tick_buffer_size) - self.processed_features[symbol] = deque(maxlen=100) # Keep last 100 processed features - - # Neural network for feature extraction - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - self.tick_nn = TickProcessingNN(input_size=9).to(self.device) - self.tick_nn.eval() # Start in evaluation mode - - # Processing parameters - self.processing_window = 50 # Number of ticks to process at once - self.min_ticks_for_processing = 10 # Minimum ticks before processing - - # Real-time streaming - self.streaming = False - self.websocket_tasks = {} - self.processing_threads = {} - - # Performance tracking - self.processing_times = deque(maxlen=1000) - self.tick_counts = {symbol: 0 for symbol in self.symbols} - - # Thread safety - self.data_lock = Lock() - - # Feature subscribers (models that want real-time features) - self.feature_subscribers = [] - - logger.info(f"RealTimeTickProcessor initialized for symbols: {self.symbols}") - logger.info(f"Neural network device: {self.device}") - logger.info(f"Tick buffer size: {tick_buffer_size}") - - def add_feature_subscriber(self, callback): - """Add a callback function to receive processed features""" - self.feature_subscribers.append(callback) - logger.info(f"Added feature subscriber: {callback.__name__}") - - def remove_feature_subscriber(self, callback): - """Remove a feature subscriber""" - if callback in self.feature_subscribers: - self.feature_subscribers.remove(callback) - logger.info(f"Removed feature subscriber: {callback.__name__}") - - async def start_processing(self): - """Start real-time tick processing""" - logger.info("Starting real-time tick processing...") - self.streaming = True - - # Start WebSocket streams for each symbol - for symbol in self.symbols: - task = asyncio.create_task(self._websocket_stream(symbol)) - self.websocket_tasks[symbol] = task - - # Start processing thread for each symbol - thread = Thread(target=self._processing_loop, args=(symbol,), daemon=True) - thread.start() - self.processing_threads[symbol] = thread - - logger.info("Real-time tick processing started") - - async def stop_processing(self): - """Stop real-time tick processing""" - logger.info("Stopping real-time tick processing...") - self.streaming = False - - # Cancel WebSocket tasks - for symbol, task in self.websocket_tasks.items(): - if not task.done(): - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - self.websocket_tasks.clear() - logger.info("Real-time tick processing stopped") - - async def _websocket_stream(self, symbol: str): - """WebSocket stream for real-time tick data""" - binance_symbol = symbol.replace('/', '').lower() - url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade" - - while self.streaming: - try: - async with websockets.connect(url) as websocket: - logger.info(f"Tick WebSocket connected for {symbol}") - - async for message in websocket: - if not self.streaming: - break - - try: - data = json.loads(message) - await self._process_raw_tick(symbol, data) - except Exception as e: - logger.warning(f"Error processing tick for {symbol}: {e}") - - except Exception as e: - logger.error(f"WebSocket error for {symbol}: {e}") - if self.streaming: - logger.info(f"Reconnecting tick WebSocket for {symbol} in 2 seconds...") - await asyncio.sleep(2) - - async def _process_raw_tick(self, symbol: str, raw_data: Dict): - """Process raw tick data from WebSocket""" - try: - # Extract tick information - tick = TickData( - timestamp=datetime.fromtimestamp(int(raw_data['T']) / 1000), - price=float(raw_data['p']), - volume=float(raw_data['q']), - side='buy' if raw_data['m'] == False else 'sell', # m=true means buyer is market maker (sell) - trade_id=raw_data.get('t') - ) - - # Add to buffer - with self.data_lock: - self.tick_buffers[symbol].append(tick) - self.tick_counts[symbol] += 1 - - except Exception as e: - logger.error(f"Error processing raw tick for {symbol}: {e}") - - def _processing_loop(self, symbol: str): - """Main processing loop for a symbol""" - logger.info(f"Starting processing loop for {symbol}") - - while self.streaming: - try: - # Check if we have enough ticks to process - with self.data_lock: - tick_count = len(self.tick_buffers[symbol]) - - if tick_count >= self.min_ticks_for_processing: - start_time = time.time() - - # Process ticks - features = self._extract_neural_features(symbol) - - if features is not None: - # Store processed features - with self.data_lock: - self.processed_features[symbol].append(features) - - # Notify subscribers - self._notify_feature_subscribers(symbol, features) - - # Track processing time - processing_time = (time.time() - start_time) * 1000 # Convert to ms - self.processing_times.append(processing_time) - - if len(self.processing_times) % 100 == 0: - avg_time = np.mean(list(self.processing_times)) - logger.debug(f"RTP: Average processing time: {avg_time:.2f}ms") - - # Small sleep to prevent CPU overload - time.sleep(0.001) # 1ms sleep for ultra-low latency - - except Exception as e: - logger.error(f"Error in processing loop for {symbol}: {e}") - time.sleep(0.01) # Longer sleep on error - - def _extract_neural_features(self, symbol: str) -> Optional[ProcessedTickFeatures]: - """Extract neural network features from recent ticks""" - try: - with self.data_lock: - # Get recent ticks - recent_ticks = list(self.tick_buffers[symbol])[-self.processing_window:] - - if len(recent_ticks) < self.min_ticks_for_processing: - return None - - # Convert ticks to neural network input - tick_features = self._ticks_to_features(recent_ticks) - - # Process with neural network - with torch.no_grad(): - tick_tensor = torch.FloatTensor(tick_features).unsqueeze(0).to(self.device) - neural_features, confidence = self.tick_nn(tick_tensor) - - neural_features = neural_features.cpu().numpy().flatten() - confidence = confidence.cpu().numpy().item() - - # Extract traditional features - price_features = self._extract_price_features(recent_ticks) - volume_features = self._extract_volume_features(recent_ticks) - microstructure_features = self._extract_microstructure_features(recent_ticks) - - # Create processed features object - processed = ProcessedTickFeatures( - timestamp=recent_ticks[-1].timestamp, - price_features=price_features, - volume_features=volume_features, - microstructure_features=microstructure_features, - neural_features=neural_features, - confidence=confidence - ) - - return processed - - except Exception as e: - logger.error(f"Error extracting neural features for {symbol}: {e}") - return None - - def _ticks_to_features(self, ticks: List[TickData]) -> np.ndarray: - """Convert tick data to neural network input features""" - features = [] - - for i, tick in enumerate(ticks): - tick_features = [ - tick.price, - tick.volume, - 1.0 if tick.side == 'buy' else 0.0, # Buy/sell indicator - tick.timestamp.timestamp(), # Timestamp - ] - - # Add relative features if we have previous ticks - if i > 0: - prev_tick = ticks[i-1] - price_change = (tick.price - prev_tick.price) / prev_tick.price - volume_ratio = tick.volume / (prev_tick.volume + 1e-8) - time_delta = (tick.timestamp - prev_tick.timestamp).total_seconds() - - tick_features.extend([ - price_change, - volume_ratio, - time_delta - ]) - else: - tick_features.extend([0.0, 1.0, 0.0]) # Default values for first tick - - # Add moving averages if we have enough data - if i >= 5: - recent_prices = [t.price for t in ticks[max(0, i-4):i+1]] - recent_volumes = [t.volume for t in ticks[max(0, i-4):i+1]] - - price_ma = np.mean(recent_prices) - volume_ma = np.mean(recent_volumes) - - tick_features.extend([ - (tick.price - price_ma) / price_ma, # Price deviation from MA - (tick.volume - volume_ma) / (volume_ma + 1e-8) # Volume deviation from MA - ]) - else: - tick_features.extend([0.0, 0.0]) - - features.append(tick_features) - - # Pad or truncate to fixed size - target_length = self.processing_window - if len(features) < target_length: - # Pad with zeros - padding = [[0.0] * len(features[0])] * (target_length - len(features)) - features = padding + features - elif len(features) > target_length: - # Take the most recent ticks - features = features[-target_length:] - - return np.array(features, dtype=np.float32) - - def _extract_price_features(self, ticks: List[TickData]) -> np.ndarray: - """Extract price-based features""" - prices = np.array([tick.price for tick in ticks]) - - features = [ - prices[-1], # Current price - np.mean(prices), # Average price - np.std(prices), # Price volatility - np.max(prices), # High - np.min(prices), # Low - (prices[-1] - prices[0]) / prices[0] if prices[0] != 0 else 0, # Total return - ] - - # Price momentum features - if len(prices) >= 10: - short_ma = np.mean(prices[-5:]) - long_ma = np.mean(prices[-10:]) - momentum = (short_ma - long_ma) / long_ma if long_ma != 0 else 0 - features.append(momentum) - else: - features.append(0.0) - - return np.array(features, dtype=np.float32) - - def _extract_volume_features(self, ticks: List[TickData]) -> np.ndarray: - """Extract volume-based features""" - volumes = np.array([tick.volume for tick in ticks]) - buy_volumes = np.array([tick.volume for tick in ticks if tick.side == 'buy']) - sell_volumes = np.array([tick.volume for tick in ticks if tick.side == 'sell']) - - features = [ - np.sum(volumes), # Total volume - np.mean(volumes), # Average volume - np.std(volumes), # Volume volatility - np.sum(buy_volumes) if len(buy_volumes) > 0 else 0, # Buy volume - np.sum(sell_volumes) if len(sell_volumes) > 0 else 0, # Sell volume - ] - - # Volume imbalance - total_buy = np.sum(buy_volumes) if len(buy_volumes) > 0 else 0 - total_sell = np.sum(sell_volumes) if len(sell_volumes) > 0 else 0 - total_volume = total_buy + total_sell - - if total_volume > 0: - buy_ratio = total_buy / total_volume - volume_imbalance = buy_ratio - 0.5 # -0.5 to 0.5 range - else: - volume_imbalance = 0.0 - - features.append(volume_imbalance) - - # VWAP (Volume Weighted Average Price) - if np.sum(volumes) > 0: - prices = np.array([tick.price for tick in ticks]) - vwap = np.sum(prices * volumes) / np.sum(volumes) - current_price = ticks[-1].price - vwap_deviation = (current_price - vwap) / vwap if vwap != 0 else 0 - else: - vwap_deviation = 0.0 - - features.append(vwap_deviation) - - return np.array(features, dtype=np.float32) - - def _extract_microstructure_features(self, ticks: List[TickData]) -> np.ndarray: - """Extract market microstructure features""" - features = [] - - # Trade frequency - if len(ticks) >= 2: - time_deltas = [(ticks[i].timestamp - ticks[i-1].timestamp).total_seconds() - for i in range(1, len(ticks))] - avg_time_delta = np.mean(time_deltas) - trade_frequency = 1.0 / avg_time_delta if avg_time_delta > 0 else 0 - else: - trade_frequency = 0.0 - - features.append(trade_frequency) - - # Price impact features - prices = [tick.price for tick in ticks] - volumes = [tick.volume for tick in ticks] - - if len(prices) >= 3: - # Calculate price changes and corresponding volumes - price_changes = [(prices[i] - prices[i-1]) / prices[i-1] - for i in range(1, len(prices)) if prices[i-1] != 0] - corresponding_volumes = volumes[1:len(price_changes)+1] - - if len(price_changes) > 0 and len(corresponding_volumes) > 0: - # Simple price impact measure - price_impact = np.corrcoef(np.abs(price_changes), corresponding_volumes)[0, 1] - if np.isnan(price_impact): - price_impact = 0.0 - else: - price_impact = 0.0 - else: - price_impact = 0.0 - - features.append(price_impact) - - # Bid-ask spread proxy (using price volatility) - if len(prices) >= 5: - recent_prices = prices[-5:] - spread_proxy = (np.max(recent_prices) - np.min(recent_prices)) / np.mean(recent_prices) - else: - spread_proxy = 0.0 - - features.append(spread_proxy) - - # Order flow imbalance (already calculated in volume features, but different perspective) - buy_count = sum(1 for tick in ticks if tick.side == 'buy') - sell_count = len(ticks) - buy_count - total_trades = len(ticks) - - if total_trades > 0: - order_flow_imbalance = (buy_count - sell_count) / total_trades - else: - order_flow_imbalance = 0.0 - - features.append(order_flow_imbalance) - - return np.array(features, dtype=np.float32) - - def _notify_feature_subscribers(self, symbol: str, features: ProcessedTickFeatures): - """Notify all feature subscribers of new processed features""" - for callback in self.feature_subscribers: - try: - callback(symbol, features) - except Exception as e: - logger.error(f"Error notifying feature subscriber {callback.__name__}: {e}") - - def get_latest_features(self, symbol: str) -> Optional[ProcessedTickFeatures]: - """Get the latest processed features for a symbol""" - with self.data_lock: - if symbol in self.processed_features and self.processed_features[symbol]: - return self.processed_features[symbol][-1] - return None - - def get_processing_stats(self) -> Dict[str, Any]: - """Get processing performance statistics""" - stats = { - 'symbols': self.symbols, - 'streaming': self.streaming, - 'tick_counts': dict(self.tick_counts), - 'buffer_sizes': {symbol: len(self.tick_buffers[symbol]) for symbol in self.symbols}, - 'feature_counts': {symbol: len(self.processed_features[symbol]) for symbol in self.symbols}, - 'subscribers': len(self.feature_subscribers) - } - - if self.processing_times: - stats['processing_performance'] = { - 'avg_time_ms': np.mean(list(self.processing_times)), - 'min_time_ms': np.min(list(self.processing_times)), - 'max_time_ms': np.max(list(self.processing_times)), - 'std_time_ms': np.std(list(self.processing_times)) - } - - return stats - - def train_neural_network(self, training_data: List[Tuple[np.ndarray, np.ndarray]], epochs: int = 100): - """Train the tick processing neural network""" - logger.info("Training tick processing neural network...") - - self.tick_nn.train() - optimizer = torch.optim.Adam(self.tick_nn.parameters(), lr=0.001) - criterion = nn.MSELoss() - - for epoch in range(epochs): - total_loss = 0.0 - - for batch_features, batch_targets in training_data: - optimizer.zero_grad() - - # Convert to tensors - features_tensor = torch.FloatTensor(batch_features).to(self.device) - targets_tensor = torch.FloatTensor(batch_targets).to(self.device) - - # Forward pass - outputs, confidence = self.tick_nn(features_tensor) - - # Calculate loss - loss = criterion(outputs, targets_tensor) - - # Backward pass - loss.backward() - optimizer.step() - - total_loss += loss.item() - - if epoch % 10 == 0: - avg_loss = total_loss / len(training_data) - logger.info(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.6f}") - - self.tick_nn.eval() - logger.info("Neural network training completed") - -# Integration with existing orchestrator -def integrate_with_orchestrator(orchestrator, tick_processor: RealTimeTickProcessor): - """Integrate tick processor with enhanced orchestrator""" - - def feature_callback(symbol: str, features: ProcessedTickFeatures): - """Callback to feed processed features to orchestrator""" - try: - # Convert processed features to format expected by orchestrator - feature_dict = { - 'symbol': symbol, - 'timestamp': features.timestamp, - 'neural_features': features.neural_features, - 'price_features': features.price_features, - 'volume_features': features.volume_features, - 'microstructure_features': features.microstructure_features, - 'confidence': features.confidence - } - - # Feed to orchestrator's real-time feature processing - if hasattr(orchestrator, 'process_realtime_features'): - orchestrator.process_realtime_features(feature_dict) - - except Exception as e: - logger.error(f"Error integrating features with orchestrator: {e}") - - # Add the callback to tick processor - tick_processor.add_feature_subscriber(feature_callback) - logger.info("Tick processor integrated with orchestrator") - -# Factory function for easy creation -def create_realtime_tick_processor(symbols: List[str] = None) -> RealTimeTickProcessor: - """Create and configure a real-time tick processor""" - if symbols is None: - symbols = ['ETH/USDT', 'BTC/USDT'] - - processor = RealTimeTickProcessor(symbols=symbols) - logger.info(f"Created RealTimeTickProcessor for symbols: {symbols}") - - return processor \ No newline at end of file diff --git a/core/retrospective_trainer.py b/core/retrospective_trainer.py deleted file mode 100644 index f7a0017..0000000 --- a/core/retrospective_trainer.py +++ /dev/null @@ -1,453 +0,0 @@ -""" -Retrospective Training System - -This module implements a retrospective training system that: -1. Triggers training when trades close with known P&L outcomes -2. Uses captured model inputs from trade entry to train models -3. Optimizes for profit by learning from profitable vs unprofitable patterns -4. Supports simultaneous inference and training without weight reloading -5. Implements reinforcement learning with immediate reward feedback -""" - -import logging -import threading -import time -import queue -from datetime import datetime -from typing import Dict, List, Optional, Any, Callable -from dataclasses import dataclass -import numpy as np -from collections import deque - -logger = logging.getLogger(__name__) - -@dataclass -class TrainingCase: - """Represents a completed trade case for retrospective training""" - case_id: str - symbol: str - action: str # 'BUY' or 'SELL' - entry_price: float - exit_price: float - entry_time: datetime - exit_time: datetime - pnl: float - fees: float - confidence: float - model_inputs: Dict[str, Any] - market_state: Dict[str, Any] - outcome_label: int # 1 for profit, 0 for loss, 2 for breakeven - reward_signal: float # Scaled reward for RL training - leverage: float = 1.0 - -class RetrospectiveTrainer: - """Retrospective training system for real-time model optimization""" - - def __init__(self, orchestrator=None, config: Optional[Dict[str, Any]] = None): - """Initialize the retrospective trainer""" - self.orchestrator = orchestrator - self.config = config or {} - - # Training configuration - self.batch_size = self.config.get('batch_size', 32) - self.min_cases_for_training = self.config.get('min_cases_for_training', 5) - self.profit_threshold = self.config.get('profit_threshold', 0.0) - self.training_frequency = self.config.get('training_frequency_seconds', 120) # 2 minutes - self.max_training_cases = self.config.get('max_training_cases', 1000) - - # Training state - self.training_queue = queue.Queue() - self.completed_cases = deque(maxlen=self.max_training_cases) - self.training_stats = { - 'total_cases': 0, - 'profitable_cases': 0, - 'loss_cases': 0, - 'breakeven_cases': 0, - 'avg_profit': 0.0, - 'last_training_time': datetime.now(), - 'training_sessions': 0, - 'model_updates': 0 - } - - # Threading - self.training_thread = None - self.is_training_active = False - self.training_lock = threading.Lock() - - logger.info("RetrospectiveTrainer initialized") - logger.info(f"Configuration: batch_size={self.batch_size}, " - f"min_cases={self.min_cases_for_training}, " - f"training_freq={self.training_frequency}s") - - def add_completed_trade(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> bool: - """Add a completed trade for retrospective training""" - try: - # Create training case from trade record - case = self._create_training_case(trade_record, model_inputs) - if case is None: - return False - - # Add to completed cases - self.completed_cases.append(case) - self.training_queue.put(case) - - # Update statistics - self.training_stats['total_cases'] += 1 - if case.outcome_label == 1: # Profit - self.training_stats['profitable_cases'] += 1 - elif case.outcome_label == 0: # Loss - self.training_stats['loss_cases'] += 1 - else: # Breakeven - self.training_stats['breakeven_cases'] += 1 - - # Calculate running average profit - total_pnl = sum(c.pnl for c in self.completed_cases) - self.training_stats['avg_profit'] = total_pnl / len(self.completed_cases) - - logger.info(f"RETROSPECTIVE: Added training case {case.case_id} " - f"(P&L: ${case.pnl:.3f}, Label: {case.outcome_label})") - - # Trigger training if we have enough cases - self._maybe_trigger_training() - - return True - - except Exception as e: - logger.error(f"Error adding completed trade for retrospective training: {e}") - return False - - def _create_training_case(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> Optional[TrainingCase]: - """Create a training case from trade record and model inputs""" - try: - # Extract trade information - symbol = trade_record.get('symbol', 'UNKNOWN') - side = trade_record.get('side', 'UNKNOWN') - pnl = trade_record.get('pnl', 0.0) - fees = trade_record.get('fees', 0.0) - confidence = trade_record.get('confidence', 0.0) - - # Calculate net P&L after fees - net_pnl = pnl - fees - - # Determine outcome label and reward signal - if net_pnl > self.profit_threshold: - outcome_label = 1 # Profitable - # Scale reward by profit magnitude and confidence - reward_signal = min(10.0, net_pnl * confidence * 10) # Amplify for training - elif net_pnl < -self.profit_threshold: - outcome_label = 0 # Loss - # Negative reward scaled by loss magnitude - reward_signal = max(-10.0, net_pnl * confidence * 10) # Negative reward - else: - outcome_label = 2 # Breakeven - reward_signal = 0.0 - - # Create case ID - timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S') - case_id = f"retro_{timestamp_str}_{symbol.replace('/', '')}_{side}_pnl_{abs(net_pnl):.3f}".replace('.', 'p') - - # Create training case - case = TrainingCase( - case_id=case_id, - symbol=symbol, - action=side, - entry_price=trade_record.get('entry_price', 0.0), - exit_price=trade_record.get('exit_price', 0.0), - entry_time=trade_record.get('entry_time', datetime.now()), - exit_time=trade_record.get('exit_time', datetime.now()), - pnl=net_pnl, - fees=fees, - confidence=confidence, - model_inputs=model_inputs, - market_state=model_inputs.get('market_state', {}), - outcome_label=outcome_label, - reward_signal=reward_signal, - leverage=trade_record.get('leverage', 1.0) - ) - - return case - - except Exception as e: - logger.error(f"Error creating training case: {e}") - return None - - def _maybe_trigger_training(self): - """Check if we should trigger a training session""" - try: - # Check if we have enough cases - if len(self.completed_cases) < self.min_cases_for_training: - return - - # Check if enough time has passed since last training - time_since_last = (datetime.now() - self.training_stats['last_training_time']).total_seconds() - if time_since_last < self.training_frequency: - return - - # Check if training thread is not already running - if self.is_training_active: - logger.debug("Training already in progress, skipping trigger") - return - - # Start training in background thread - self._start_training_session() - - except Exception as e: - logger.error(f"Error checking training trigger: {e}") - - def _start_training_session(self): - """Start a training session in background thread""" - try: - if self.training_thread and self.training_thread.is_alive(): - logger.debug("Training thread already running") - return - - self.training_thread = threading.Thread( - target=self._run_training_session, - daemon=True, - name="RetrospectiveTrainer" - ) - self.training_thread.start() - logger.info("RETROSPECTIVE: Started training session") - - except Exception as e: - logger.error(f"Error starting training session: {e}") - - def _run_training_session(self): - """Run a complete training session""" - try: - with self.training_lock: - self.is_training_active = True - start_time = time.time() - - logger.info(f"RETROSPECTIVE: Training with {len(self.completed_cases)} cases") - - # Train models if orchestrator available - training_results = {} - if self.orchestrator: - training_results = self._train_models() - - # Update statistics - self.training_stats['last_training_time'] = datetime.now() - self.training_stats['training_sessions'] += 1 - self.training_stats['model_updates'] += len(training_results) - - elapsed_time = time.time() - start_time - logger.info(f"RETROSPECTIVE: Training completed in {elapsed_time:.2f}s - {training_results}") - - except Exception as e: - logger.error(f"Error in retrospective training session: {e}") - import traceback - logger.error(traceback.format_exc()) - finally: - self.is_training_active = False - - def _train_models(self) -> Dict[str, Any]: - """Train available models using retrospective data""" - results = {} - - try: - # Prepare training data - profitable_cases = [c for c in self.completed_cases if c.outcome_label == 1] - loss_cases = [c for c in self.completed_cases if c.outcome_label == 0] - - if len(profitable_cases) == 0 and len(loss_cases) == 0: - return {'error': 'No labeled cases for training'} - - logger.info(f"RETROSPECTIVE: Training data - Profitable: {len(profitable_cases)}, Loss: {len(loss_cases)}") - - # Train DQN agent if available - if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent: - try: - dqn_result = self._train_dqn_retrospective() - results['dqn'] = dqn_result - logger.info(f"RETROSPECTIVE: DQN training result: {dqn_result}") - except Exception as e: - logger.warning(f"DQN retrospective training failed: {e}") - results['dqn'] = {'error': str(e)} - - # Train other models - if self.orchestrator and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer: - try: - # Update extrema trainer with retrospective feedback - extrema_feedback = self._create_extrema_feedback() - if extrema_feedback: - results['extrema'] = {'feedback_cases': len(extrema_feedback)} - logger.info(f"RETROSPECTIVE: Extrema feedback provided for {len(extrema_feedback)} cases") - except Exception as e: - logger.warning(f"Extrema retrospective training failed: {e}") - - return results - - except Exception as e: - logger.error(f"Error training models retrospectively: {e}") - return {'error': str(e)} - - def _train_dqn_retrospective(self) -> Dict[str, Any]: - """Train DQN agent using retrospective experience replay""" - try: - if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent: - return {'error': 'DQN agent not available'} - - dqn_agent = self.orchestrator.rl_agent - experiences_added = 0 - - # Add retrospective experiences to DQN replay buffer - for case in self.completed_cases: - try: - # Extract state from model inputs - state = self._extract_state_vector(case.model_inputs) - if state is None: - continue - - # Action mapping: BUY=0, SELL=1 - action = 0 if case.action == 'BUY' else 1 - - # Use reward signal as immediate reward - reward = case.reward_signal - - # For retrospective training, next_state is None (terminal) - next_state = np.zeros_like(state) # Terminal state - done = True - - # Add experience to DQN replay buffer - if hasattr(dqn_agent, 'add_experience'): - dqn_agent.add_experience(state, action, reward, next_state, done) - experiences_added += 1 - - except Exception as e: - logger.debug(f"Error adding DQN experience: {e}") - continue - - # Train DQN if we have enough experiences - if experiences_added > 0 and hasattr(dqn_agent, 'train'): - try: - # Perform multiple training steps on retrospective data - training_steps = min(10, experiences_added // 4) # Conservative training - for _ in range(training_steps): - loss = dqn_agent.train() - if loss is None: - break - - return { - 'experiences_added': experiences_added, - 'training_steps': training_steps, - 'method': 'retrospective_experience_replay' - } - except Exception as e: - logger.warning(f"DQN training step failed: {e}") - return {'experiences_added': experiences_added, 'training_error': str(e)} - - return {'experiences_added': experiences_added, 'training_steps': 0} - - except Exception as e: - logger.error(f"Error in DQN retrospective training: {e}") - return {'error': str(e)} - - def _extract_state_vector(self, model_inputs: Dict[str, Any]) -> Optional[np.ndarray]: - """Extract state vector for DQN training from model inputs""" - try: - # Try to get pre-built RL state - if 'dqn_state' in model_inputs: - state = model_inputs['dqn_state'] - if isinstance(state, dict) and 'state_vector' in state: - return np.array(state['state_vector']) - - # Build state from market features - market_state = model_inputs.get('market_state', {}) - features = [] - - # Price features - for key in ['current_price', 'price_sma_5', 'price_sma_20', 'price_std_20', 'price_rsi']: - features.append(market_state.get(key, 0.0)) - - # Volume features - for key in ['volume_current', 'volume_sma_20', 'volume_ratio']: - features.append(market_state.get(key, 0.0)) - - # Technical indicators - indicators = model_inputs.get('technical_indicators', {}) - for key in ['sma_10', 'sma_20', 'bb_upper', 'bb_lower', 'bb_position', 'macd', 'volatility']: - features.append(indicators.get(key, 0.0)) - - if len(features) < 5: # Minimum required features - return None - - return np.array(features, dtype=np.float32) - - except Exception as e: - logger.debug(f"Error extracting state vector: {e}") - return None - - def _create_extrema_feedback(self) -> List[Dict[str, Any]]: - """Create feedback data for extrema trainer""" - feedback = [] - - try: - for case in self.completed_cases: - if case.outcome_label in [0, 1]: # Only profit/loss cases - feedback_item = { - 'symbol': case.symbol, - 'action': case.action, - 'entry_price': case.entry_price, - 'exit_price': case.exit_price, - 'was_profitable': case.outcome_label == 1, - 'reward_signal': case.reward_signal, - 'market_state': case.market_state - } - feedback.append(feedback_item) - - return feedback - - except Exception as e: - logger.error(f"Error creating extrema feedback: {e}") - return [] - - def get_training_stats(self) -> Dict[str, Any]: - """Get current training statistics""" - stats = self.training_stats.copy() - stats['total_cases_in_memory'] = len(self.completed_cases) - stats['training_queue_size'] = self.training_queue.qsize() - stats['is_training_active'] = self.is_training_active - - # Calculate profit metrics - if len(self.completed_cases) > 0: - profitable_count = sum(1 for c in self.completed_cases if c.pnl > 0) - stats['profit_rate'] = profitable_count / len(self.completed_cases) - stats['total_pnl'] = sum(c.pnl for c in self.completed_cases) - stats['avg_reward'] = sum(c.reward_signal for c in self.completed_cases) / len(self.completed_cases) - - return stats - - def force_training_session(self) -> bool: - """Force a training session regardless of timing constraints""" - try: - if self.is_training_active: - logger.warning("Training already in progress") - return False - - if len(self.completed_cases) < 1: - logger.warning("No completed cases available for training") - return False - - logger.info("RETROSPECTIVE: Forcing training session") - self._start_training_session() - return True - - except Exception as e: - logger.error(f"Error forcing training session: {e}") - return False - - def stop(self): - """Stop the retrospective trainer""" - try: - self.is_training_active = False - if self.training_thread and self.training_thread.is_alive(): - self.training_thread.join(timeout=10) - logger.info("RetrospectiveTrainer stopped") - except Exception as e: - logger.error(f"Error stopping RetrospectiveTrainer: {e}") - - -def create_retrospective_trainer(orchestrator=None, config: Optional[Dict[str, Any]] = None) -> RetrospectiveTrainer: - """Factory function to create a RetrospectiveTrainer instance""" - return RetrospectiveTrainer(orchestrator=orchestrator, config=config) diff --git a/core/rl_training_pipeline.py b/core/rl_training_pipeline.py deleted file mode 100644 index 5f2fa7a..0000000 --- a/core/rl_training_pipeline.py +++ /dev/null @@ -1,529 +0,0 @@ -""" -RL Training Pipeline with Comprehensive Experience Storage and Replay - -This module implements a robust RL training pipeline that: -1. Stores all training experiences with profitability metrics -2. Implements profit-weighted experience replay -3. Tracks gradient information for each training step -4. Enables retraining on most profitable trading sequences -5. Maintains comprehensive trading episode analysis -""" - -import logging -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from datetime import datetime, timedelta -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Any -from dataclasses import dataclass, field -import json -import pickle -from collections import deque -import threading -import random - -from .training_data_collector import get_training_data_collector - -logger = logging.getLogger(__name__) - -@dataclass -class RLExperience: - """Single RL experience with complete state-action-reward information""" - experience_id: str - timestamp: datetime - episode_id: str - - # Core RL components - state: np.ndarray - action: int # 0=SELL, 1=HOLD, 2=BUY - reward: float - next_state: np.ndarray - done: bool - - # Extended state information - market_context: Dict[str, Any] - cnn_predictions: Optional[Dict[str, Any]] = None - confidence_score: float = 0.0 - - # Actual trading outcome - actual_profit: Optional[float] = None - actual_holding_time: Optional[timedelta] = None - optimal_action: Optional[int] = None - - # Experience value for replay - experience_value: float = 0.0 - profitability_score: float = 0.0 - learning_priority: float = 0.0 - - # Training metadata - times_trained: int = 0 - last_trained: Optional[datetime] = None - -class ProfitWeightedExperienceBuffer: - """Experience buffer with profit-weighted sampling for replay""" - - def __init__(self, max_size: int = 100000): - self.max_size = max_size - self.experiences: Dict[str, RLExperience] = {} - self.experience_order: deque = deque(maxlen=max_size) - self.profitable_experiences: List[str] = [] - self.total_experiences = 0 - self.total_profitable = 0 - - def add_experience(self, experience: RLExperience): - """Add experience to buffer""" - try: - self.experiences[experience.experience_id] = experience - self.experience_order.append(experience.experience_id) - - if experience.actual_profit is not None and experience.actual_profit > 0: - self.profitable_experiences.append(experience.experience_id) - self.total_profitable += 1 - - # Remove oldest if buffer is full - if len(self.experiences) > self.max_size: - oldest_id = self.experience_order[0] - if oldest_id in self.experiences: - del self.experiences[oldest_id] - if oldest_id in self.profitable_experiences: - self.profitable_experiences.remove(oldest_id) - - self.total_experiences += 1 - - except Exception as e: - logger.error(f"Error adding experience to buffer: {e}") - - def sample_batch(self, batch_size: int, prioritize_profitable: bool = True) -> List[RLExperience]: - """Sample batch with profit-weighted prioritization""" - try: - if len(self.experiences) < batch_size: - return list(self.experiences.values()) - - if prioritize_profitable and len(self.profitable_experiences) > batch_size // 2: - # Sample mix of profitable and all experiences - profitable_sample_size = min(batch_size // 2, len(self.profitable_experiences)) - remaining_sample_size = batch_size - profitable_sample_size - - profitable_ids = random.sample(self.profitable_experiences, profitable_sample_size) - all_ids = list(self.experiences.keys()) - remaining_ids = random.sample(all_ids, remaining_sample_size) - - sampled_ids = profitable_ids + remaining_ids - else: - # Random sampling from all experiences - all_ids = list(self.experiences.keys()) - sampled_ids = random.sample(all_ids, batch_size) - - sampled_experiences = [self.experiences[exp_id] for exp_id in sampled_ids] - - # Update training counts - for experience in sampled_experiences: - experience.times_trained += 1 - experience.last_trained = datetime.now() - - return sampled_experiences - - except Exception as e: - logger.error(f"Error sampling batch: {e}") - return list(self.experiences.values())[:batch_size] - - def get_most_profitable_experiences(self, limit: int = 100) -> List[RLExperience]: - """Get most profitable experiences for targeted training""" - try: - profitable_experiences = [ - self.experiences[exp_id] for exp_id in self.profitable_experiences - if exp_id in self.experiences - ] - - profitable_experiences.sort( - key=lambda x: x.actual_profit if x.actual_profit else 0, - reverse=True - ) - - return profitable_experiences[:limit] - - except Exception as e: - logger.error(f"Error getting profitable experiences: {e}") - return [] - -class RLTradingAgent(nn.Module): - """RL Trading Agent with comprehensive state processing""" - - def __init__(self, state_dim: int = 2000, action_dim: int = 3, hidden_dim: int = 512): - super(RLTradingAgent, self).__init__() - - self.state_dim = state_dim - self.action_dim = action_dim - self.hidden_dim = hidden_dim - - # State processing network - self.state_processor = nn.Sequential( - nn.Linear(state_dim, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(hidden_dim, hidden_dim // 2), - nn.LayerNorm(hidden_dim // 2), - nn.ReLU() - ) - - # Q-value network - self.q_network = nn.Sequential( - nn.Linear(hidden_dim // 2, hidden_dim // 4), - nn.ReLU(), - nn.Dropout(0.1), - nn.Linear(hidden_dim // 4, action_dim) - ) - - # Policy network - self.policy_network = nn.Sequential( - nn.Linear(hidden_dim // 2, hidden_dim // 4), - nn.ReLU(), - nn.Dropout(0.1), - nn.Linear(hidden_dim // 4, action_dim), - nn.Softmax(dim=-1) - ) - - # Value network - self.value_network = nn.Sequential( - nn.Linear(hidden_dim // 2, hidden_dim // 4), - nn.ReLU(), - nn.Dropout(0.1), - nn.Linear(hidden_dim // 4, 1) - ) - - def forward(self, state): - """Forward pass through the agent""" - processed_state = self.state_processor(state) - - q_values = self.q_network(processed_state) - policy_probs = self.policy_network(processed_state) - state_value = self.value_network(processed_state) - - return { - 'q_values': q_values, - 'policy_probs': policy_probs, - 'state_value': state_value, - 'processed_state': processed_state - } - - def select_action(self, state, epsilon: float = 0.1) -> Tuple[int, float]: - """Select action using epsilon-greedy policy""" - self.eval() - with torch.no_grad(): - if isinstance(state, np.ndarray): - state = torch.from_numpy(state).float().unsqueeze(0) - - outputs = self.forward(state) - - if random.random() < epsilon: - action = random.randint(0, self.action_dim - 1) - confidence = 0.33 - else: - q_values = outputs['q_values'] - action = torch.argmax(q_values, dim=1).item() - q_softmax = F.softmax(q_values, dim=1) - confidence = torch.max(q_softmax).item() - - return action, confidence - -@dataclass -class RLTrainingStep: - """Single RL training step with backpropagation data""" - step_id: str - timestamp: datetime - batch_experiences: List[str] - - # Training data - total_loss: float - q_loss: float - policy_loss: float - - # Gradients - gradients: Dict[str, torch.Tensor] - gradient_norms: Dict[str, float] - - # Metadata - learning_rate: float = 0.001 - batch_size: int = 32 - - # Performance - batch_profitability: float = 0.0 - correct_actions: int = 0 - total_actions: int = 0 - step_value: float = 0.0 - -@dataclass -class RLTrainingSession: - """Complete RL training session""" - session_id: str - start_timestamp: datetime - end_timestamp: Optional[datetime] = None - - training_mode: str = 'experience_replay' - symbol: str = '' - - training_steps: List[RLTrainingStep] = field(default_factory=list) - - total_steps: int = 0 - average_loss: float = 0.0 - best_loss: float = float('inf') - - profitable_actions: int = 0 - total_actions: int = 0 - profitability_rate: float = 0.0 - session_value: float = 0.0 - -class RLTrainer: - """RL trainer with comprehensive experience storage and replay""" - - def __init__(self, agent: RLTradingAgent, device: str = 'cuda', storage_dir: str = "rl_training_storage"): - self.agent = agent.to(device) - self.device = device - self.storage_dir = Path(storage_dir) - self.storage_dir.mkdir(parents=True, exist_ok=True) - - self.optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001) - self.experience_buffer = ProfitWeightedExperienceBuffer() - self.data_collector = get_training_data_collector() - - self.training_sessions: List[RLTrainingSession] = [] - self.current_session: Optional[RLTrainingSession] = None - - self.gamma = 0.99 - - self.training_stats = { - 'total_sessions': 0, - 'total_steps': 0, - 'total_experiences': 0, - 'profitable_actions': 0, - 'total_actions': 0, - 'average_reward': 0.0 - } - - logger.info(f"RL Trainer initialized with {sum(p.numel() for p in agent.parameters()):,} parameters") - - def add_experience(self, state: np.ndarray, action: int, reward: float, - next_state: np.ndarray, done: bool, market_context: Dict[str, Any], - cnn_predictions: Dict[str, Any] = None, confidence_score: float = 0.0) -> str: - """Add experience to the buffer""" - try: - experience_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}" - - experience = RLExperience( - experience_id=experience_id, - timestamp=datetime.now(), - episode_id=market_context.get('episode_id', 'unknown'), - state=state, - action=action, - reward=reward, - next_state=next_state, - done=done, - market_context=market_context, - cnn_predictions=cnn_predictions, - confidence_score=confidence_score - ) - - self.experience_buffer.add_experience(experience) - self.training_stats['total_experiences'] += 1 - - return experience_id - - except Exception as e: - logger.error(f"Error adding experience: {e}") - return None - - def train_on_experiences(self, batch_size: int = 32, num_batches: int = 10) -> Dict[str, Any]: - """Train on experiences with comprehensive data storage""" - try: - session = RLTrainingSession( - session_id=f"rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}", - start_timestamp=datetime.now(), - training_mode='experience_replay' - ) - self.current_session = session - - self.agent.train() - total_loss = 0.0 - - for batch_idx in range(num_batches): - experiences = self.experience_buffer.sample_batch(batch_size, True) - - if len(experiences) < batch_size: - continue - - # Prepare batch tensors - states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device) - actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device) - rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device) - next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device) - dones = torch.BoolTensor([exp.done for exp in experiences]).to(self.device) - - # Forward pass - self.optimizer.zero_grad() - - current_outputs = self.agent(states) - current_q_values = current_outputs['q_values'] - - # Calculate target Q-values - with torch.no_grad(): - next_outputs = self.agent(next_states) - next_q_values = next_outputs['q_values'] - max_next_q_values = torch.max(next_q_values, dim=1)[0] - target_q_values = rewards + (self.gamma * max_next_q_values * ~dones) - - # Calculate loss - current_q_values_for_actions = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) - q_loss = F.mse_loss(current_q_values_for_actions, target_q_values) - - # Backward pass - q_loss.backward() - - # Store gradients - gradients = {} - gradient_norms = {} - for name, param in self.agent.named_parameters(): - if param.grad is not None: - gradients[name] = param.grad.clone().detach() - gradient_norms[name] = param.grad.norm().item() - - torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=1.0) - self.optimizer.step() - - # Create training step record - step = RLTrainingStep( - step_id=f"{session.session_id}_step_{batch_idx}", - timestamp=datetime.now(), - batch_experiences=[exp.experience_id for exp in experiences], - total_loss=q_loss.item(), - q_loss=q_loss.item(), - policy_loss=0.0, - gradients=gradients, - gradient_norms=gradient_norms, - batch_size=len(experiences) - ) - - session.training_steps.append(step) - total_loss += q_loss.item() - - # Finalize session - session.end_timestamp = datetime.now() - session.total_steps = num_batches - session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0 - - self._save_training_session(session) - - self.training_stats['total_sessions'] += 1 - self.training_stats['total_steps'] += session.total_steps - - logger.info(f"RL training session completed: {session.session_id}") - logger.info(f"Average loss: {session.average_loss:.4f}") - - return { - 'status': 'success', - 'session_id': session.session_id, - 'average_loss': session.average_loss, - 'total_steps': session.total_steps - } - - except Exception as e: - logger.error(f"Error in RL training session: {e}") - return {'status': 'error', 'error': str(e)} - finally: - self.current_session = None - - def train_on_profitable_experiences(self, min_profitability: float = 0.1, - max_experiences: int = 1000, batch_size: int = 32) -> Dict[str, Any]: - """Train specifically on most profitable experiences""" - try: - profitable_experiences = self.experience_buffer.get_most_profitable_experiences(max_experiences) - - filtered_experiences = [ - exp for exp in profitable_experiences - if exp.actual_profit is not None and exp.actual_profit >= min_profitability - ] - - if len(filtered_experiences) < batch_size: - return {'status': 'insufficient_data', 'experiences_found': len(filtered_experiences)} - - logger.info(f"Training on {len(filtered_experiences)} profitable experiences") - - num_batches = len(filtered_experiences) // batch_size - - # Temporarily replace buffer sampling - original_sample_method = self.experience_buffer.sample_batch - - def profitable_sample_batch(batch_size, prioritize_profitable=True): - return random.sample(filtered_experiences, min(batch_size, len(filtered_experiences))) - - self.experience_buffer.sample_batch = profitable_sample_batch - - try: - results = self.train_on_experiences(batch_size=batch_size, num_batches=num_batches) - results['training_mode'] = 'profitable_replay' - results['experiences_used'] = len(filtered_experiences) - return results - finally: - self.experience_buffer.sample_batch = original_sample_method - - except Exception as e: - logger.error(f"Error training on profitable experiences: {e}") - return {'status': 'error', 'error': str(e)} - - def _save_training_session(self, session: RLTrainingSession): - """Save training session to disk""" - try: - session_dir = self.storage_dir / 'sessions' - session_dir.mkdir(parents=True, exist_ok=True) - - session_file = session_dir / f"{session.session_id}.pkl" - with open(session_file, 'wb') as f: - pickle.dump(session, f) - - metadata = { - 'session_id': session.session_id, - 'start_timestamp': session.start_timestamp.isoformat(), - 'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None, - 'training_mode': session.training_mode, - 'total_steps': session.total_steps, - 'average_loss': session.average_loss - } - - metadata_file = session_dir / f"{session.session_id}_metadata.json" - with open(metadata_file, 'w') as f: - json.dump(metadata, f, indent=2) - - except Exception as e: - logger.error(f"Error saving training session: {e}") - - def get_training_statistics(self) -> Dict[str, Any]: - """Get comprehensive training statistics""" - stats = self.training_stats.copy() - - if self.training_sessions: - recent_sessions = sorted(self.training_sessions, key=lambda x: x.start_timestamp, reverse=True)[:10] - stats['recent_sessions'] = [ - { - 'session_id': s.session_id, - 'timestamp': s.start_timestamp.isoformat(), - 'mode': s.training_mode, - 'average_loss': s.average_loss - } - for s in recent_sessions - ] - - return stats - -# Global instance -rl_trainer = None - -def get_rl_trainer(agent: RLTradingAgent = None) -> RLTrainer: - """Get global RL trainer instance""" - global rl_trainer - if rl_trainer is None: - if agent is None: - agent = RLTradingAgent() - rl_trainer = RLTrainer(agent) - return rl_trainer \ No newline at end of file diff --git a/core/robust_cob_provider.py b/core/robust_cob_provider.py deleted file mode 100644 index 443aabf..0000000 --- a/core/robust_cob_provider.py +++ /dev/null @@ -1,460 +0,0 @@ -""" -Robust COB (Consolidated Order Book) Provider - -This module provides a robust COB data provider that handles: -- HTTP 418 errors from Binance (rate limiting) -- Thread safety issues -- API rate limiting and backoff -- Fallback data sources -- Error recovery strategies - -Features: -- Automatic rate limiting and backoff -- Multiple exchange support with fallbacks -- Thread-safe operations -- Comprehensive error handling -- Data validation and integrity checking -""" - -import asyncio -import logging -import time -import threading -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Any, Callable -from dataclasses import dataclass, field -from collections import deque -import json -import numpy as np -from concurrent.futures import ThreadPoolExecutor, as_completed -import requests - -from .api_rate_limiter import get_rate_limiter, RateLimitConfig - -logger = logging.getLogger(__name__) - -@dataclass -class COBData: - """Consolidated Order Book data structure""" - symbol: str - timestamp: datetime - bids: List[Tuple[float, float]] # [(price, quantity), ...] - asks: List[Tuple[float, float]] # [(price, quantity), ...] - - # Derived metrics - spread: float = 0.0 - mid_price: float = 0.0 - total_bid_volume: float = 0.0 - total_ask_volume: float = 0.0 - - # Data quality - data_source: str = 'unknown' - quality_score: float = 1.0 - - def __post_init__(self): - """Calculate derived metrics""" - if self.bids and self.asks: - self.spread = self.asks[0][0] - self.bids[0][0] - self.mid_price = (self.asks[0][0] + self.bids[0][0]) / 2 - self.total_bid_volume = sum(qty for _, qty in self.bids) - self.total_ask_volume = sum(qty for _, qty in self.asks) - - # Calculate quality score based on data completeness - self.quality_score = min( - len(self.bids) / 20, # Expect at least 20 bid levels - len(self.asks) / 20, # Expect at least 20 ask levels - 1.0 - ) - -class RobustCOBProvider: - """Robust COB provider with error handling and rate limiting""" - - def __init__(self, symbols: List[str] = None): - self.symbols = symbols or ['ETHUSDT', 'BTCUSDT'] - - # Rate limiter - self.rate_limiter = get_rate_limiter() - - # Thread safety - self.lock = threading.RLock() - - # Data cache - self.cob_cache: Dict[str, COBData] = {} - self.cache_timestamps: Dict[str, datetime] = {} - self.cache_ttl = timedelta(seconds=5) # 5 second cache TTL - - # Error tracking - self.error_counts: Dict[str, int] = {} - self.last_successful_fetch: Dict[str, datetime] = {} - - # Background fetching - self.is_running = False - self.fetch_threads: Dict[str, threading.Thread] = {} - self.executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="COB-Fetcher") - - # Fallback data - self.fallback_data: Dict[str, COBData] = {} - - # Performance tracking - self.fetch_stats = { - 'total_requests': 0, - 'successful_requests': 0, - 'failed_requests': 0, - 'rate_limited_requests': 0, - 'cache_hits': 0, - 'fallback_uses': 0 - } - - logger.info(f"Robust COB Provider initialized for symbols: {self.symbols}") - - def start_background_fetching(self): - """Start background COB data fetching""" - if self.is_running: - logger.warning("Background fetching already running") - return - - self.is_running = True - - # Start fetching thread for each symbol - for symbol in self.symbols: - thread = threading.Thread( - target=self._background_fetch_worker, - args=(symbol,), - name=f"COB-{symbol}", - daemon=True - ) - self.fetch_threads[symbol] = thread - thread.start() - - logger.info(f"Started background COB fetching for {len(self.symbols)} symbols") - - def stop_background_fetching(self): - """Stop background COB data fetching""" - self.is_running = False - - # Wait for threads to finish - for symbol, thread in self.fetch_threads.items(): - thread.join(timeout=5) - logger.debug(f"Stopped COB fetching for {symbol}") - - # Shutdown executor - self.executor.shutdown(wait=True, timeout=10) - - logger.info("Stopped background COB fetching") - - def _background_fetch_worker(self, symbol: str): - """Background worker for fetching COB data""" - logger.info(f"Started COB fetching worker for {symbol}") - - while self.is_running: - try: - # Fetch COB data - cob_data = self._fetch_cob_data_safe(symbol) - - if cob_data: - with self.lock: - self.cob_cache[symbol] = cob_data - self.cache_timestamps[symbol] = datetime.now() - self.last_successful_fetch[symbol] = datetime.now() - self.error_counts[symbol] = 0 # Reset error count on success - - logger.debug(f"Updated COB cache for {symbol}") - else: - with self.lock: - self.error_counts[symbol] = self.error_counts.get(symbol, 0) + 1 - - logger.debug(f"Failed to fetch COB for {symbol}, error count: {self.error_counts.get(symbol, 0)}") - - # Wait before next fetch (adaptive based on errors) - error_count = self.error_counts.get(symbol, 0) - base_interval = 2.0 # Base 2 second interval - backoff_interval = min(base_interval * (2 ** min(error_count, 5)), 60.0) # Max 60s - - time.sleep(backoff_interval) - - except Exception as e: - logger.error(f"Error in COB fetching worker for {symbol}: {e}") - time.sleep(10) # Wait 10s on unexpected errors - - logger.info(f"Stopped COB fetching worker for {symbol}") - - def _fetch_cob_data_safe(self, symbol: str) -> Optional[COBData]: - """Safely fetch COB data with error handling""" - try: - self.fetch_stats['total_requests'] += 1 - - # Try Binance first - cob_data = self._fetch_binance_cob(symbol) - if cob_data: - self.fetch_stats['successful_requests'] += 1 - return cob_data - - # Try MEXC as fallback - cob_data = self._fetch_mexc_cob(symbol) - if cob_data: - self.fetch_stats['successful_requests'] += 1 - cob_data.data_source = 'mexc_fallback' - return cob_data - - # Use cached fallback data if available - if symbol in self.fallback_data: - self.fetch_stats['fallback_uses'] += 1 - fallback = self.fallback_data[symbol] - fallback.timestamp = datetime.now() - fallback.data_source = 'fallback_cache' - fallback.quality_score *= 0.5 # Reduce quality score for old data - return fallback - - self.fetch_stats['failed_requests'] += 1 - return None - - except Exception as e: - logger.error(f"Error fetching COB data for {symbol}: {e}") - self.fetch_stats['failed_requests'] += 1 - return None - - def _fetch_binance_cob(self, symbol: str) -> Optional[COBData]: - """Fetch COB data from Binance with rate limiting""" - try: - url = f"https://api.binance.com/api/v3/depth" - params = { - 'symbol': symbol, - 'limit': 100 # Get 100 levels - } - - # Use rate limiter - response = self.rate_limiter.make_request( - 'binance_api', - url, - method='GET', - params=params - ) - - if not response: - self.fetch_stats['rate_limited_requests'] += 1 - return None - - if response.status_code != 200: - logger.warning(f"Binance COB API returned {response.status_code} for {symbol}") - return None - - data = response.json() - - # Parse order book data - bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])] - asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])] - - if not bids or not asks: - logger.warning(f"Empty order book data from Binance for {symbol}") - return None - - cob_data = COBData( - symbol=symbol, - timestamp=datetime.now(), - bids=bids, - asks=asks, - data_source='binance' - ) - - # Store as fallback for future use - self.fallback_data[symbol] = cob_data - - return cob_data - - except Exception as e: - logger.error(f"Error fetching Binance COB for {symbol}: {e}") - return None - - def _fetch_mexc_cob(self, symbol: str) -> Optional[COBData]: - """Fetch COB data from MEXC as fallback""" - try: - url = f"https://api.mexc.com/api/v3/depth" - params = { - 'symbol': symbol, - 'limit': 100 - } - - response = self.rate_limiter.make_request( - 'mexc_api', - url, - method='GET', - params=params - ) - - if not response or response.status_code != 200: - return None - - data = response.json() - - # Parse order book data - bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])] - asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])] - - if not bids or not asks: - return None - - return COBData( - symbol=symbol, - timestamp=datetime.now(), - bids=bids, - asks=asks, - data_source='mexc' - ) - - except Exception as e: - logger.debug(f"Error fetching MEXC COB for {symbol}: {e}") - return None - - def get_cob_data(self, symbol: str) -> Optional[COBData]: - """Get COB data for a symbol (from cache or fresh fetch)""" - with self.lock: - # Check cache first - if symbol in self.cob_cache: - cached_data = self.cob_cache[symbol] - cache_time = self.cache_timestamps.get(symbol, datetime.min) - - # Return cached data if still fresh - if datetime.now() - cache_time < self.cache_ttl: - self.fetch_stats['cache_hits'] += 1 - return cached_data - - # If background fetching is running, return cached data even if stale - if self.is_running and symbol in self.cob_cache: - return self.cob_cache[symbol] - - # Fetch fresh data if not running background fetching - if not self.is_running: - return self._fetch_cob_data_safe(symbol) - - return None - - def get_cob_features(self, symbol: str, feature_count: int = 120) -> Optional[np.ndarray]: - """ - Get COB features for ML models - - Args: - symbol: Trading symbol - feature_count: Number of features to return - - Returns: - Numpy array of COB features or None if no data - """ - cob_data = self.get_cob_data(symbol) - if not cob_data: - return None - - try: - features = [] - - # Basic market metrics - features.extend([ - cob_data.mid_price, - cob_data.spread, - cob_data.total_bid_volume, - cob_data.total_ask_volume, - cob_data.quality_score - ]) - - # Bid levels (price and volume) - max_levels = min(len(cob_data.bids), 20) - for i in range(max_levels): - price, volume = cob_data.bids[i] - features.extend([price, volume]) - - # Pad bid levels if needed - for i in range(max_levels, 20): - features.extend([0.0, 0.0]) - - # Ask levels (price and volume) - max_levels = min(len(cob_data.asks), 20) - for i in range(max_levels): - price, volume = cob_data.asks[i] - features.extend([price, volume]) - - # Pad ask levels if needed - for i in range(max_levels, 20): - features.extend([0.0, 0.0]) - - # Calculate additional features - if len(cob_data.bids) > 0 and len(cob_data.asks) > 0: - # Volume imbalance - bid_volume_5 = sum(vol for _, vol in cob_data.bids[:5]) - ask_volume_5 = sum(vol for _, vol in cob_data.asks[:5]) - volume_imbalance = (bid_volume_5 - ask_volume_5) / (bid_volume_5 + ask_volume_5) if (bid_volume_5 + ask_volume_5) > 0 else 0 - features.append(volume_imbalance) - - # Price levels - bid_price_levels = [price for price, _ in cob_data.bids[:10]] - ask_price_levels = [price for price, _ in cob_data.asks[:10]] - features.extend(bid_price_levels + ask_price_levels) - - # Pad or truncate to desired feature count - if len(features) < feature_count: - features.extend([0.0] * (feature_count - len(features))) - else: - features = features[:feature_count] - - return np.array(features, dtype=np.float32) - - except Exception as e: - logger.error(f"Error creating COB features for {symbol}: {e}") - return None - - def get_provider_status(self) -> Dict[str, Any]: - """Get provider status and statistics""" - with self.lock: - status = { - 'is_running': self.is_running, - 'symbols': self.symbols, - 'cache_status': {}, - 'error_counts': self.error_counts.copy(), - 'last_successful_fetch': { - symbol: timestamp.isoformat() - for symbol, timestamp in self.last_successful_fetch.items() - }, - 'fetch_stats': self.fetch_stats.copy(), - 'rate_limiter_status': self.rate_limiter.get_all_endpoint_status() - } - - # Cache status for each symbol - for symbol in self.symbols: - cache_time = self.cache_timestamps.get(symbol) - status['cache_status'][symbol] = { - 'has_data': symbol in self.cob_cache, - 'cache_time': cache_time.isoformat() if cache_time else None, - 'cache_age_seconds': (datetime.now() - cache_time).total_seconds() if cache_time else None, - 'data_quality': self.cob_cache[symbol].quality_score if symbol in self.cob_cache else 0.0 - } - - return status - - def reset_errors(self): - """Reset error counts and rate limiter""" - with self.lock: - self.error_counts.clear() - self.rate_limiter.reset_all_endpoints() - logger.info("Reset all error counts and rate limiter") - - def force_refresh(self, symbol: str = None): - """Force refresh COB data for symbol(s)""" - symbols_to_refresh = [symbol] if symbol else self.symbols - - for sym in symbols_to_refresh: - # Clear cache to force refresh - with self.lock: - if sym in self.cob_cache: - del self.cob_cache[sym] - if sym in self.cache_timestamps: - del self.cache_timestamps[sym] - - logger.info(f"Forced refresh for {sym}") - -# Global COB provider instance -_global_cob_provider = None - -def get_cob_provider(symbols: List[str] = None) -> RobustCOBProvider: - """Get global COB provider instance""" - global _global_cob_provider - if _global_cob_provider is None: - _global_cob_provider = RobustCOBProvider(symbols) - return _global_cob_provider \ No newline at end of file diff --git a/core/shared_cob_service.py b/core/shared_cob_service.py deleted file mode 100644 index c11aae5..0000000 --- a/core/shared_cob_service.py +++ /dev/null @@ -1,350 +0,0 @@ -#!/usr/bin/env python3 -""" -Shared COB Service - Eliminates Redundant COB Implementations - -This service provides a singleton COB integration that can be shared across: -- Dashboard components -- RL trading systems -- Enhanced orchestrators -- Training pipelines - -Instead of each component creating its own COBIntegration instance, -they all share this single service, eliminating redundant connections. -""" - -import asyncio -import logging -import weakref -from typing import Dict, List, Optional, Any, Callable, Set -from datetime import datetime -from threading import Lock -from dataclasses import dataclass - -from .cob_integration import COBIntegration -from .multi_exchange_cob_provider import COBSnapshot -from .data_provider import DataProvider - -logger = logging.getLogger(__name__) - -@dataclass -class COBSubscription: - """Represents a subscription to COB updates""" - subscriber_id: str - callback: Callable - symbol_filter: Optional[List[str]] = None - callback_type: str = "general" # general, cnn, dqn, dashboard - -class SharedCOBService: - """ - Shared COB Service - Singleton pattern for unified COB data access - - This service eliminates redundant COB integrations by providing a single - shared instance that all components can subscribe to. - """ - - _instance: Optional['SharedCOBService'] = None - _lock = Lock() - - def __new__(cls, *args, **kwargs): - """Singleton pattern implementation""" - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = super(SharedCOBService, cls).__new__(cls) - return cls._instance - - def __init__(self, symbols: Optional[List[str]] = None, data_provider: Optional[DataProvider] = None): - """Initialize shared COB service (only called once due to singleton)""" - if hasattr(self, '_initialized'): - return - - self.symbols = symbols or ['BTC/USDT', 'ETH/USDT'] - self.data_provider = data_provider - - # Single COB integration instance - self.cob_integration: Optional[COBIntegration] = None - self.is_running = False - - # Subscriber management - self.subscribers: Dict[str, COBSubscription] = {} - self.subscriber_counter = 0 - self.subscription_lock = Lock() - - # Cached data for immediate access - self.latest_snapshots: Dict[str, COBSnapshot] = {} - self.latest_cnn_features: Dict[str, Any] = {} - self.latest_dqn_states: Dict[str, Any] = {} - - # Performance tracking - self.total_subscribers = 0 - self.update_count = 0 - self.start_time = None - - self._initialized = True - logger.info(f"SharedCOBService initialized for symbols: {self.symbols}") - - async def start(self) -> None: - """Start the shared COB service""" - if self.is_running: - logger.warning("SharedCOBService already running") - return - - logger.info("Starting SharedCOBService...") - - try: - # Initialize COB integration if not already done - if self.cob_integration is None: - self.cob_integration = COBIntegration( - data_provider=self.data_provider, - symbols=self.symbols - ) - - # Register internal callbacks - self.cob_integration.add_cnn_callback(self._on_cob_cnn_update) - self.cob_integration.add_dqn_callback(self._on_cob_dqn_update) - self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_update) - - # Start COB integration - await self.cob_integration.start() - - self.is_running = True - self.start_time = datetime.now() - - logger.info("SharedCOBService started successfully") - logger.info(f"Active subscribers: {len(self.subscribers)}") - - except Exception as e: - logger.error(f"Error starting SharedCOBService: {e}") - raise - - async def stop(self) -> None: - """Stop the shared COB service""" - if not self.is_running: - return - - logger.info("Stopping SharedCOBService...") - - try: - if self.cob_integration: - await self.cob_integration.stop() - - self.is_running = False - - # Notify all subscribers of shutdown - for subscription in self.subscribers.values(): - try: - if hasattr(subscription.callback, '__call__'): - subscription.callback("SHUTDOWN", None) - except Exception as e: - logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}") - - logger.info("SharedCOBService stopped") - - except Exception as e: - logger.error(f"Error stopping SharedCOBService: {e}") - - def subscribe(self, - callback: Callable, - callback_type: str = "general", - symbol_filter: Optional[List[str]] = None, - subscriber_name: str = None) -> str: - """ - Subscribe to COB updates - - Args: - callback: Function to call on updates - callback_type: Type of callback ('general', 'cnn', 'dqn', 'dashboard') - symbol_filter: Only receive updates for these symbols (None = all) - subscriber_name: Optional name for the subscriber - - Returns: - Subscription ID for unsubscribing - """ - with self.subscription_lock: - self.subscriber_counter += 1 - subscriber_id = f"{callback_type}_{self.subscriber_counter}" - if subscriber_name: - subscriber_id = f"{subscriber_name}_{subscriber_id}" - - subscription = COBSubscription( - subscriber_id=subscriber_id, - callback=callback, - symbol_filter=symbol_filter, - callback_type=callback_type - ) - - self.subscribers[subscriber_id] = subscription - self.total_subscribers += 1 - - logger.info(f"New subscriber: {subscriber_id} ({callback_type})") - logger.info(f"Total active subscribers: {len(self.subscribers)}") - - return subscriber_id - - def unsubscribe(self, subscriber_id: str) -> bool: - """ - Unsubscribe from COB updates - - Args: - subscriber_id: ID returned from subscribe() - - Returns: - True if successfully unsubscribed - """ - with self.subscription_lock: - if subscriber_id in self.subscribers: - del self.subscribers[subscriber_id] - logger.info(f"Unsubscribed: {subscriber_id}") - logger.info(f"Remaining subscribers: {len(self.subscribers)}") - return True - else: - logger.warning(f"Subscriber not found: {subscriber_id}") - return False - - # Internal callback handlers - - async def _on_cob_cnn_update(self, symbol: str, data: Dict): - """Handle CNN feature updates from COB integration""" - try: - self.latest_cnn_features[symbol] = data - await self._notify_subscribers("cnn", symbol, data) - except Exception as e: - logger.error(f"Error in CNN update handler: {e}") - - async def _on_cob_dqn_update(self, symbol: str, data: Dict): - """Handle DQN state updates from COB integration""" - try: - self.latest_dqn_states[symbol] = data - await self._notify_subscribers("dqn", symbol, data) - except Exception as e: - logger.error(f"Error in DQN update handler: {e}") - - async def _on_cob_dashboard_update(self, symbol: str, data: Dict): - """Handle dashboard updates from COB integration""" - try: - # Store snapshot if it's a COBSnapshot - if hasattr(data, 'volume_weighted_mid'): # Duck typing for COBSnapshot - self.latest_snapshots[symbol] = data - - await self._notify_subscribers("dashboard", symbol, data) - await self._notify_subscribers("general", symbol, data) - - self.update_count += 1 - - except Exception as e: - logger.error(f"Error in dashboard update handler: {e}") - - async def _notify_subscribers(self, callback_type: str, symbol: str, data: Any): - """Notify all relevant subscribers of an update""" - try: - relevant_subscribers = [ - sub for sub in self.subscribers.values() - if (sub.callback_type == callback_type or sub.callback_type == "general") and - (sub.symbol_filter is None or symbol in sub.symbol_filter) - ] - - for subscription in relevant_subscribers: - try: - if asyncio.iscoroutinefunction(subscription.callback): - asyncio.create_task(subscription.callback(symbol, data)) - else: - subscription.callback(symbol, data) - except Exception as e: - logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}") - - except Exception as e: - logger.error(f"Error notifying subscribers: {e}") - - # Public data access methods - - def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]: - """Get latest COB snapshot for a symbol""" - if self.cob_integration: - return self.cob_integration.get_cob_snapshot(symbol) - return self.latest_snapshots.get(symbol) - - def get_cnn_features(self, symbol: str) -> Optional[Any]: - """Get latest CNN features for a symbol""" - return self.latest_cnn_features.get(symbol) - - def get_dqn_state(self, symbol: str) -> Optional[Any]: - """Get latest DQN state for a symbol""" - return self.latest_dqn_states.get(symbol) - - def get_market_depth_analysis(self, symbol: str) -> Optional[Dict]: - """Get detailed market depth analysis""" - if self.cob_integration: - return self.cob_integration.get_market_depth_analysis(symbol) - return None - - def get_exchange_breakdown(self, symbol: str) -> Optional[Dict]: - """Get liquidity breakdown by exchange""" - if self.cob_integration: - return self.cob_integration.get_exchange_breakdown(symbol) - return None - - def get_price_buckets(self, symbol: str) -> Optional[Dict]: - """Get fine-grain price buckets""" - if self.cob_integration: - return self.cob_integration.get_price_buckets(symbol) - return None - - def get_session_volume_profile(self, symbol: str) -> Optional[Dict]: - """Get session volume profile""" - if self.cob_integration and hasattr(self.cob_integration.cob_provider, 'get_session_volume_profile'): - return self.cob_integration.cob_provider.get_session_volume_profile(symbol) - return None - - def get_realtime_stats_for_nn(self, symbol: str) -> Dict: - """Get real-time statistics formatted for NN models""" - if self.cob_integration: - return self.cob_integration.get_realtime_stats_for_nn(symbol) - return {} - - def get_service_statistics(self) -> Dict[str, Any]: - """Get service statistics""" - uptime = None - if self.start_time: - uptime = (datetime.now() - self.start_time).total_seconds() - - base_stats = { - 'service_name': 'SharedCOBService', - 'is_running': self.is_running, - 'symbols': self.symbols, - 'total_subscribers': len(self.subscribers), - 'lifetime_subscribers': self.total_subscribers, - 'update_count': self.update_count, - 'uptime_seconds': uptime, - 'subscribers_by_type': {} - } - - # Count subscribers by type - for subscription in self.subscribers.values(): - callback_type = subscription.callback_type - if callback_type not in base_stats['subscribers_by_type']: - base_stats['subscribers_by_type'][callback_type] = 0 - base_stats['subscribers_by_type'][callback_type] += 1 - - # Get COB integration stats if available - if self.cob_integration: - cob_stats = self.cob_integration.get_statistics() - base_stats.update(cob_stats) - - return base_stats - -# Global service instance access functions - -def get_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService: - """Get the shared COB service instance""" - return SharedCOBService(symbols=symbols, data_provider=data_provider) - -async def start_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService: - """Start the shared COB service""" - service = get_shared_cob_service(symbols=symbols, data_provider=data_provider) - await service.start() - return service - -async def stop_shared_cob_service(): - """Stop the shared COB service""" - service = get_shared_cob_service() - await service.stop() \ No newline at end of file diff --git a/core/shared_data_manager.py b/core/shared_data_manager.py deleted file mode 100644 index 87a21e7..0000000 --- a/core/shared_data_manager.py +++ /dev/null @@ -1,425 +0,0 @@ -""" -Shared Data Manager for UI Stability Fix - -Manages data sharing between processes through files with proper locking -and atomic operations to prevent corruption and conflicts. -""" - -import json -import os -import time -import tempfile -import platform -from datetime import datetime -from dataclasses import dataclass, asdict -from typing import Dict, Any, Optional, Union -from pathlib import Path -import logging - -# Windows-compatible file locking -if platform.system() == "Windows": - import msvcrt -else: - import fcntl - -logger = logging.getLogger(__name__) - -@dataclass -class ProcessStatus: - """Model for process status information""" - name: str - pid: int - status: str # 'running', 'stopped', 'error' - start_time: datetime - last_heartbeat: datetime - memory_usage: float - cpu_usage: float - error_message: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary with datetime serialization""" - data = asdict(self) - data['start_time'] = self.start_time.isoformat() - data['last_heartbeat'] = self.last_heartbeat.isoformat() - return data - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'ProcessStatus': - """Create from dictionary with datetime deserialization""" - data['start_time'] = datetime.fromisoformat(data['start_time']) - data['last_heartbeat'] = datetime.fromisoformat(data['last_heartbeat']) - return cls(**data) - -@dataclass -class TrainingStatus: - """Model for training status information""" - is_running: bool - current_epoch: int - total_epochs: int - loss: float - accuracy: float - last_update: datetime - model_path: str - error_message: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary with datetime serialization""" - data = asdict(self) - data['last_update'] = self.last_update.isoformat() - return data - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'TrainingStatus': - """Create from dictionary with datetime deserialization""" - data['last_update'] = datetime.fromisoformat(data['last_update']) - return cls(**data) - -@dataclass -class DashboardState: - """Model for dashboard state information""" - is_connected: bool - last_data_update: datetime - active_connections: int - error_count: int - performance_metrics: Dict[str, float] - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary with datetime serialization""" - data = asdict(self) - data['last_data_update'] = self.last_data_update.isoformat() - return data - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'DashboardState': - """Create from dictionary with datetime deserialization""" - data['last_data_update'] = datetime.fromisoformat(data['last_data_update']) - return cls(**data) - - -class SharedDataManager: - """ - Manages data sharing between processes through files with proper locking - and atomic operations to prevent corruption and conflicts. - """ - - def __init__(self, data_dir: str = "shared_data"): - """ - Initialize the shared data manager - - Args: - data_dir: Directory to store shared data files - """ - self.data_dir = Path(data_dir) - self.data_dir.mkdir(exist_ok=True) - - # Define file paths for different data types - self.training_status_file = self.data_dir / "training_status.json" - self.dashboard_state_file = self.data_dir / "dashboard_state.json" - self.process_status_file = self.data_dir / "process_status.json" - self.market_data_file = self.data_dir / "market_data.json" - self.model_metrics_file = self.data_dir / "model_metrics.json" - - logger.info(f"SharedDataManager initialized with data directory: {self.data_dir}") - - def _lock_file(self, file_handle, exclusive=True): - """Cross-platform file locking""" - if platform.system() == "Windows": - # Windows file locking - try: - if exclusive: - msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1) - else: - msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1) - except IOError: - pass # File locking may not be available in all scenarios - else: - # Unix file locking - lock_type = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH - fcntl.flock(file_handle.fileno(), lock_type) - - def _unlock_file(self, file_handle): - """Cross-platform file unlocking""" - if platform.system() == "Windows": - try: - msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1) - except IOError: - pass - else: - fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN) - - def _write_json_atomic(self, file_path: Path, data: Dict[str, Any]) -> None: - """ - Write JSON data atomically with file locking - - Args: - file_path: Path to the file to write - data: Data to write as JSON - """ - temp_path = None - try: - # Create temporary file in the same directory - temp_fd, temp_path = tempfile.mkstemp( - dir=file_path.parent, - prefix=f".{file_path.name}.", - suffix=".tmp" - ) - - with os.fdopen(temp_fd, 'w') as temp_file: - # Lock the temporary file - self._lock_file(temp_file, exclusive=True) - - # Write data with proper formatting - json.dump(data, temp_file, indent=2, default=str) - temp_file.flush() - os.fsync(temp_file.fileno()) - - # Unlock before closing - self._unlock_file(temp_file) - - # Atomically replace the original file - os.replace(temp_path, file_path) - logger.debug(f"Successfully wrote data to {file_path}") - - except Exception as e: - # Clean up temporary file if it exists - if temp_path: - try: - os.unlink(temp_path) - except: - pass - logger.error(f"Failed to write data to {file_path}: {e}") - raise - - def _read_json_safe(self, file_path: Path) -> Dict[str, Any]: - """ - Read JSON data safely with file locking - - Args: - file_path: Path to the file to read - - Returns: - Dictionary containing the JSON data - """ - if not file_path.exists(): - logger.debug(f"File {file_path} does not exist, returning empty dict") - return {} - - try: - with open(file_path, 'r') as file: - # Lock the file for reading - self._lock_file(file, exclusive=False) - data = json.load(file) - self._unlock_file(file) - logger.debug(f"Successfully read data from {file_path}") - return data - - except json.JSONDecodeError as e: - logger.error(f"Invalid JSON in {file_path}: {e}") - return {} - except Exception as e: - logger.error(f"Failed to read data from {file_path}: {e}") - return {} - - def write_training_status(self, status: TrainingStatus) -> None: - """ - Write training status to shared file - - Args: - status: TrainingStatus object to write - """ - try: - data = status.to_dict() - self._write_json_atomic(self.training_status_file, data) - logger.debug("Training status written successfully") - except Exception as e: - logger.error(f"Failed to write training status: {e}") - raise - - def read_training_status(self) -> Optional[TrainingStatus]: - """ - Read training status from shared file - - Returns: - TrainingStatus object or None if not available - """ - try: - data = self._read_json_safe(self.training_status_file) - if not data: - return None - return TrainingStatus.from_dict(data) - except Exception as e: - logger.error(f"Failed to read training status: {e}") - return None - - def write_dashboard_state(self, state: DashboardState) -> None: - """ - Write dashboard state to shared file - - Args: - state: DashboardState object to write - """ - try: - data = state.to_dict() - self._write_json_atomic(self.dashboard_state_file, data) - logger.debug("Dashboard state written successfully") - except Exception as e: - logger.error(f"Failed to write dashboard state: {e}") - raise - - def read_dashboard_state(self) -> Optional[DashboardState]: - """ - Read dashboard state from shared file - - Returns: - DashboardState object or None if not available - """ - try: - data = self._read_json_safe(self.dashboard_state_file) - if not data: - return None - return DashboardState.from_dict(data) - except Exception as e: - logger.error(f"Failed to read dashboard state: {e}") - return None - - def write_process_status(self, status: ProcessStatus) -> None: - """ - Write process status to shared file - - Args: - status: ProcessStatus object to write - """ - try: - data = status.to_dict() - self._write_json_atomic(self.process_status_file, data) - logger.debug("Process status written successfully") - except Exception as e: - logger.error(f"Failed to write process status: {e}") - raise - - def read_process_status(self) -> Optional[ProcessStatus]: - """ - Read process status from shared file - - Returns: - ProcessStatus object or None if not available - """ - try: - data = self._read_json_safe(self.process_status_file) - if not data: - return None - return ProcessStatus.from_dict(data) - except Exception as e: - logger.error(f"Failed to read process status: {e}") - return None - - def write_market_data(self, data: Dict[str, Any]) -> None: - """ - Write market data to shared file - - Args: - data: Market data dictionary to write - """ - try: - # Add timestamp to market data - data['timestamp'] = datetime.now().isoformat() - self._write_json_atomic(self.market_data_file, data) - logger.debug("Market data written successfully") - except Exception as e: - logger.error(f"Failed to write market data: {e}") - raise - - def read_market_data(self) -> Dict[str, Any]: - """ - Read market data from shared file - - Returns: - Dictionary containing market data - """ - try: - return self._read_json_safe(self.market_data_file) - except Exception as e: - logger.error(f"Failed to read market data: {e}") - return {} - - def write_model_metrics(self, metrics: Dict[str, Any]) -> None: - """ - Write model metrics to shared file - - Args: - metrics: Model metrics dictionary to write - """ - try: - # Add timestamp to metrics - metrics['timestamp'] = datetime.now().isoformat() - self._write_json_atomic(self.model_metrics_file, metrics) - logger.debug("Model metrics written successfully") - except Exception as e: - logger.error(f"Failed to write model metrics: {e}") - raise - - def read_model_metrics(self) -> Dict[str, Any]: - """ - Read model metrics from shared file - - Returns: - Dictionary containing model metrics - """ - try: - return self._read_json_safe(self.model_metrics_file) - except Exception as e: - logger.error(f"Failed to read model metrics: {e}") - return {} - - def cleanup(self) -> None: - """ - Clean up shared data files - """ - try: - for file_path in [ - self.training_status_file, - self.dashboard_state_file, - self.process_status_file, - self.market_data_file, - self.model_metrics_file - ]: - if file_path.exists(): - file_path.unlink() - logger.debug(f"Removed {file_path}") - - # Remove directory if empty - if self.data_dir.exists() and not any(self.data_dir.iterdir()): - self.data_dir.rmdir() - logger.debug(f"Removed empty directory {self.data_dir}") - - except Exception as e: - logger.error(f"Failed to cleanup shared data: {e}") - - def get_data_age(self, data_type: str) -> Optional[float]: - """ - Get the age of data in seconds - - Args: - data_type: Type of data ('training', 'dashboard', 'process', 'market', 'metrics') - - Returns: - Age in seconds or None if file doesn't exist - """ - file_map = { - 'training': self.training_status_file, - 'dashboard': self.dashboard_state_file, - 'process': self.process_status_file, - 'market': self.market_data_file, - 'metrics': self.model_metrics_file - } - - file_path = file_map.get(data_type) - if not file_path or not file_path.exists(): - return None - - try: - mtime = file_path.stat().st_mtime - return time.time() - mtime - except Exception as e: - logger.error(f"Failed to get data age for {data_type}: {e}") - return None \ No newline at end of file diff --git a/core/trading_action.py b/core/trading_action.py deleted file mode 100644 index 1d90d37..0000000 --- a/core/trading_action.py +++ /dev/null @@ -1,59 +0,0 @@ -""" -Trading Action Module - -Defines the TradingAction class used throughout the trading system. -""" - -from dataclasses import dataclass -from datetime import datetime -from typing import Dict, Any, List - -@dataclass -class TradingAction: - """Represents a trading action with full context""" - symbol: str - action: str # 'BUY', 'SELL', 'HOLD' - quantity: float - confidence: float - price: float - timestamp: datetime - reasoning: Dict[str, Any] - - def __post_init__(self): - """Validate the trading action after initialization""" - if self.action not in ['BUY', 'SELL', 'HOLD']: - raise ValueError(f"Invalid action: {self.action}. Must be 'BUY', 'SELL', or 'HOLD'") - - if self.confidence < 0.0 or self.confidence > 1.0: - raise ValueError(f"Invalid confidence: {self.confidence}. Must be between 0.0 and 1.0") - - if self.quantity < 0: - raise ValueError(f"Invalid quantity: {self.quantity}. Must be non-negative") - - if self.price <= 0: - raise ValueError(f"Invalid price: {self.price}. Must be positive") - - def to_dict(self) -> Dict[str, Any]: - """Convert trading action to dictionary""" - return { - 'symbol': self.symbol, - 'action': self.action, - 'quantity': self.quantity, - 'confidence': self.confidence, - 'price': self.price, - 'timestamp': self.timestamp.isoformat(), - 'reasoning': self.reasoning - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'TradingAction': - """Create trading action from dictionary""" - return cls( - symbol=data['symbol'], - action=data['action'], - quantity=data['quantity'], - confidence=data['confidence'], - price=data['price'], - timestamp=datetime.fromisoformat(data['timestamp']), - reasoning=data['reasoning'] - ) \ No newline at end of file diff --git a/core/trading_executor_fix.py b/core/trading_executor_fix.py deleted file mode 100644 index 7366f1d..0000000 --- a/core/trading_executor_fix.py +++ /dev/null @@ -1,401 +0,0 @@ -""" -Trading Executor Fix - Addresses issues with entry/exit prices and P&L calculations - -This module provides fixes for: -1. Identical entry prices issue -2. Price caching problems -3. Position tracking reset logic -4. Trade cooldown implementation -5. P&L calculation verification - -Apply these fixes to the TradingExecutor class to improve trade execution reliability. -""" - -import logging -import time -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Any, Union - -logger = logging.getLogger(__name__) - -class TradingExecutorFix: - """ - Fixes for the TradingExecutor class to address entry/exit price issues - and improve P&L calculation accuracy. - """ - - def __init__(self, trading_executor): - """ - Initialize the fix with a reference to the trading executor - - Args: - trading_executor: The TradingExecutor instance to fix - """ - self.trading_executor = trading_executor - - # Add cooldown tracking - self.last_trade_time = {} # {symbol: timestamp} - self.min_trade_cooldown = 30 # 30 seconds minimum between trades - - # Add price history for validation - self.recent_entry_prices = {} # {symbol: [recent_prices]} - self.max_price_history = 10 # Keep last 10 entry prices - - # Add position reset tracking - self.position_reset_flags = {} # {symbol: bool} - - # Add price update tracking - self.last_price_update = {} # {symbol: timestamp} - self.price_update_threshold = 5 # 5 seconds max since last price update - - # Add P&L verification - self.trade_history = {} # {symbol: [trade_records]} - - logger.info("TradingExecutorFix initialized - addressing entry/exit price issues") - - def apply_fixes(self): - """Apply all fixes to the trading executor""" - self._patch_execute_action() - self._patch_close_position() - self._patch_calculate_pnl() - self._patch_update_prices() - - logger.info("All trading executor fixes applied successfully") - - def _patch_execute_action(self): - """Patch the execute_action method to add price validation and cooldown""" - original_execute_action = self.trading_executor.execute_action - - def execute_action_with_fixes(decision): - """Enhanced execute_action with price validation and cooldown""" - try: - symbol = decision.symbol - action = decision.action - current_time = datetime.now() - - # 1. Check cooldown period - if symbol in self.last_trade_time: - time_since_last_trade = (current_time - self.last_trade_time[symbol]).total_seconds() - if time_since_last_trade < self.min_trade_cooldown: - logger.warning(f"Trade rejected: Cooldown period ({time_since_last_trade:.1f}s < {self.min_trade_cooldown}s) for {symbol}") - return False - - # 2. Validate price freshness - if symbol in self.last_price_update: - time_since_update = (current_time - self.last_price_update[symbol]).total_seconds() - if time_since_update > self.price_update_threshold: - logger.warning(f"Trade rejected: Price data stale ({time_since_update:.1f}s > {self.price_update_threshold}s) for {symbol}") - # Force price refresh - self._refresh_price(symbol) - return False - - # 3. Validate entry price against recent history - current_price = self._get_current_price(symbol) - if symbol in self.recent_entry_prices and len(self.recent_entry_prices[symbol]) > 0: - # Check if price is identical to any recent entry - if current_price in self.recent_entry_prices[symbol]: - logger.warning(f"Trade rejected: Duplicate entry price ${current_price} for {symbol}") - return False - - # 4. Ensure position is properly reset before new entry - if not self._ensure_position_reset(symbol): - logger.warning(f"Trade rejected: Position not properly reset for {symbol}") - return False - - # Execute the original action - result = original_execute_action(decision) - - # If successful, update tracking - if result: - # Update cooldown timestamp - self.last_trade_time[symbol] = current_time - - # Update price history - if symbol not in self.recent_entry_prices: - self.recent_entry_prices[symbol] = [] - - self.recent_entry_prices[symbol].append(current_price) - # Keep only the most recent prices - if len(self.recent_entry_prices[symbol]) > self.max_price_history: - self.recent_entry_prices[symbol] = self.recent_entry_prices[symbol][-self.max_price_history:] - - # Mark position as active - self.position_reset_flags[symbol] = False - - logger.info(f"Trade executed: {action} {symbol} at ${current_price} with validation") - - return result - - except Exception as e: - logger.error(f"Error in execute_action_with_fixes: {e}") - return original_execute_action(decision) - - # Replace the original method - self.trading_executor.execute_action = execute_action_with_fixes - logger.info("Patched execute_action with price validation and cooldown") - - def _patch_close_position(self): - """Patch the close_position method to ensure proper position reset""" - original_close_position = self.trading_executor.close_position - - def close_position_with_fixes(symbol, **kwargs): - """Enhanced close_position with proper reset logic""" - try: - # Get current price for P&L verification - exit_price = self._get_current_price(symbol) - - # Call original close position - result = original_close_position(symbol, **kwargs) - - if result: - # Mark position as reset - self.position_reset_flags[symbol] = True - - # Record trade for verification - if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions: - position = self.trading_executor.positions[symbol] - - # Create trade record - trade_record = { - 'symbol': symbol, - 'entry_time': getattr(position, 'entry_time', datetime.now()), - 'exit_time': datetime.now(), - 'entry_price': getattr(position, 'entry_price', 0), - 'exit_price': exit_price, - 'size': getattr(position, 'size', 0), - 'side': getattr(position, 'side', 'UNKNOWN'), - 'pnl': self._calculate_verified_pnl(position, exit_price), - 'fees': getattr(position, 'fees', 0), - 'hold_time_seconds': (datetime.now() - getattr(position, 'entry_time', datetime.now())).total_seconds() - } - - # Store trade record - if symbol not in self.trade_history: - self.trade_history[symbol] = [] - self.trade_history[symbol].append(trade_record) - - logger.info(f"Position closed: {symbol} at ${exit_price} with verified P&L: ${trade_record['pnl']:.2f}") - - return result - - except Exception as e: - logger.error(f"Error in close_position_with_fixes: {e}") - return original_close_position(symbol, **kwargs) - - # Replace the original method - self.trading_executor.close_position = close_position_with_fixes - logger.info("Patched close_position with proper reset logic") - - def _patch_calculate_pnl(self): - """Patch the calculate_pnl method to ensure accurate P&L calculation""" - original_calculate_pnl = getattr(self.trading_executor, 'calculate_pnl', None) - - def calculate_pnl_with_fixes(position, current_price=None): - """Enhanced calculate_pnl with verification""" - try: - # If no original method, implement our own - if original_calculate_pnl is None: - return self._calculate_verified_pnl(position, current_price) - - # Call original method - original_pnl = original_calculate_pnl(position, current_price) - - # Calculate our verified P&L - verified_pnl = self._calculate_verified_pnl(position, current_price) - - # If there's a significant difference, log it - if abs(original_pnl - verified_pnl) > 0.01: - logger.warning(f"P&L calculation discrepancy: original=${original_pnl:.2f}, verified=${verified_pnl:.2f}") - # Use the verified P&L - return verified_pnl - - return original_pnl - - except Exception as e: - logger.error(f"Error in calculate_pnl_with_fixes: {e}") - if original_calculate_pnl: - return original_calculate_pnl(position, current_price) - return 0.0 - - # Replace the original method if it exists - if original_calculate_pnl: - self.trading_executor.calculate_pnl = calculate_pnl_with_fixes - logger.info("Patched calculate_pnl with verification") - else: - # Add the method if it doesn't exist - self.trading_executor.calculate_pnl = calculate_pnl_with_fixes - logger.info("Added calculate_pnl method with verification") - - def _patch_update_prices(self): - """Patch the update_prices method to track price updates""" - original_update_prices = getattr(self.trading_executor, 'update_prices', None) - - def update_prices_with_tracking(prices): - """Enhanced update_prices with timestamp tracking""" - try: - # Call original method if it exists - if original_update_prices: - result = original_update_prices(prices) - else: - # If no original method, update prices directly - if hasattr(self.trading_executor, 'current_prices'): - self.trading_executor.current_prices.update(prices) - result = True - - # Track update timestamps - current_time = datetime.now() - for symbol in prices: - self.last_price_update[symbol] = current_time - - return result - - except Exception as e: - logger.error(f"Error in update_prices_with_tracking: {e}") - if original_update_prices: - return original_update_prices(prices) - return False - - # Replace the original method if it exists - if original_update_prices: - self.trading_executor.update_prices = update_prices_with_tracking - logger.info("Patched update_prices with timestamp tracking") - else: - # Add the method if it doesn't exist - self.trading_executor.update_prices = update_prices_with_tracking - logger.info("Added update_prices method with timestamp tracking") - - def _calculate_verified_pnl(self, position, current_price=None): - """Calculate verified P&L for a position""" - try: - # Get position details - entry_price = getattr(position, 'entry_price', 0) - size = getattr(position, 'size', 0) - side = getattr(position, 'side', 'UNKNOWN') - leverage = getattr(position, 'leverage', 1.0) - fees = getattr(position, 'fees', 0.0) - - # If current_price is not provided, try to get it - if current_price is None: - symbol = getattr(position, 'symbol', None) - if symbol: - current_price = self._get_current_price(symbol) - else: - return 0.0 - - # Calculate P&L based on position side - if side == 'LONG': - pnl = (current_price - entry_price) * size * leverage - elif side == 'SHORT': - pnl = (entry_price - current_price) * size * leverage - else: - pnl = 0.0 - - # Subtract fees for net P&L - net_pnl = pnl - fees - - return net_pnl - - except Exception as e: - logger.error(f"Error calculating verified P&L: {e}") - return 0.0 - - def _get_current_price(self, symbol): - """Get current price for a symbol with fallbacks""" - try: - # Try to get from trading executor - if hasattr(self.trading_executor, 'current_prices') and symbol in self.trading_executor.current_prices: - return self.trading_executor.current_prices[symbol] - - # Try to get from data provider - if hasattr(self.trading_executor, 'data_provider'): - data_provider = self.trading_executor.data_provider - if hasattr(data_provider, 'get_current_price'): - price = data_provider.get_current_price(symbol) - if price and price > 0: - return price - - # Try to get from COB data - if hasattr(self.trading_executor, 'latest_cob_data') and symbol in self.trading_executor.latest_cob_data: - cob_data = self.trading_executor.latest_cob_data[symbol] - if hasattr(cob_data, 'stats') and 'mid_price' in cob_data.stats: - return cob_data.stats['mid_price'] - - # Default fallback - return 0.0 - - except Exception as e: - logger.error(f"Error getting current price for {symbol}: {e}") - return 0.0 - - def _refresh_price(self, symbol): - """Force a price refresh for a symbol""" - try: - # Try to refresh from data provider - if hasattr(self.trading_executor, 'data_provider'): - data_provider = self.trading_executor.data_provider - if hasattr(data_provider, 'fetch_current_price'): - price = data_provider.fetch_current_price(symbol) - if price and price > 0: - # Update trading executor price - if hasattr(self.trading_executor, 'current_prices'): - self.trading_executor.current_prices[symbol] = price - - # Update timestamp - self.last_price_update[symbol] = datetime.now() - - logger.info(f"Refreshed price for {symbol}: ${price:.2f}") - return True - - logger.warning(f"Failed to refresh price for {symbol}") - return False - - except Exception as e: - logger.error(f"Error refreshing price for {symbol}: {e}") - return False - - def _ensure_position_reset(self, symbol): - """Ensure position is properly reset before new entry""" - try: - # Check if we have an active position - if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions: - # Position exists, check if it's valid - position = self.trading_executor.positions[symbol] - if position and getattr(position, 'active', False): - logger.warning(f"Position already active for {symbol}, cannot enter new position") - return False - - # Check reset flag - if symbol in self.position_reset_flags and not self.position_reset_flags[symbol]: - # Force position cleanup - if hasattr(self.trading_executor, 'positions'): - self.trading_executor.positions.pop(symbol, None) - - logger.info(f"Forced position reset for {symbol}") - self.position_reset_flags[symbol] = True - - return True - - except Exception as e: - logger.error(f"Error ensuring position reset for {symbol}: {e}") - return False - - def get_trade_history(self, symbol=None): - """Get verified trade history""" - if symbol: - return self.trade_history.get(symbol, []) - return self.trade_history - - def get_price_update_status(self): - """Get price update status for all symbols""" - status = {} - current_time = datetime.now() - - for symbol, timestamp in self.last_price_update.items(): - time_since_update = (current_time - timestamp).total_seconds() - status[symbol] = { - 'last_update': timestamp, - 'seconds_ago': time_since_update, - 'is_fresh': time_since_update <= self.price_update_threshold - } - - return status \ No newline at end of file diff --git a/core/training_data_collector.py b/core/training_data_collector.py deleted file mode 100644 index 3e43c6d..0000000 --- a/core/training_data_collector.py +++ /dev/null @@ -1,795 +0,0 @@ -""" -Comprehensive Training Data Collection System - -This module implements a robust training data collection system that: -1. Captures all model inputs with validation and completeness checks -2. Stores training data packages with future outcome validation -3. Detects rapid price changes for high-value training examples -4. Enables replay and retraining on most profitable setups -5. Maintains data integrity and traceability - -Key Features: -- Real-time data package creation with all model inputs -- Future outcome validation (profitable vs unprofitable predictions) -- Rapid price change detection for premium training examples -- Comprehensive data validation and completeness verification -- Backpropagation data storage for gradient replay -- Training episode profitability tracking and ranking -""" - -import asyncio -import json -import logging -import numpy as np -import pandas as pd -import pickle -import torch -from datetime import datetime, timedelta -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Any, Callable -from dataclasses import dataclass, field, asdict -from collections import deque -import hashlib -import threading -from concurrent.futures import ThreadPoolExecutor - -logger = logging.getLogger(__name__) - -@dataclass -class ModelInputPackage: - """Complete package of all model inputs at a specific timestamp""" - timestamp: datetime - symbol: str - - # Market data inputs - ohlcv_data: Dict[str, pd.DataFrame] # {timeframe: DataFrame} - tick_data: List[Dict[str, Any]] # Raw tick data - cob_data: Dict[str, Any] # Consolidated Order Book data - technical_indicators: Dict[str, float] # All technical indicators - pivot_points: List[Dict[str, Any]] # Detected pivot points - - # Model-specific inputs - cnn_features: np.ndarray # CNN input features - rl_state: np.ndarray # RL state representation - orchestrator_context: Dict[str, Any] # Orchestrator context - - # Cross-model inputs (outputs from other models) - cnn_predictions: Optional[Dict[str, Any]] = None - rl_predictions: Optional[Dict[str, Any]] = None - orchestrator_decision: Optional[Dict[str, Any]] = None - - # Data validation - data_hash: str = "" - completeness_score: float = 0.0 - validation_flags: Dict[str, bool] = field(default_factory=dict) - - def __post_init__(self): - """Calculate data hash and completeness after initialization""" - self.data_hash = self._calculate_hash() - self.completeness_score = self._calculate_completeness() - self.validation_flags = self._validate_data() - - def _calculate_hash(self) -> str: - """Calculate hash for data integrity verification""" - try: - # Create a string representation of all data - data_str = f"{self.timestamp}_{self.symbol}" - data_str += f"_{len(self.ohlcv_data)}_{len(self.tick_data)}" - data_str += f"_{self.cnn_features.shape if self.cnn_features is not None else 'None'}" - data_str += f"_{self.rl_state.shape if self.rl_state is not None else 'None'}" - - return hashlib.md5(data_str.encode()).hexdigest() - except Exception as e: - logger.warning(f"Error calculating data hash: {e}") - return "invalid_hash" - - def _calculate_completeness(self) -> float: - """Calculate completeness score (0.0 to 1.0)""" - try: - total_fields = 10 # Total expected data fields - complete_fields = 0 - - # Check each required field - if self.ohlcv_data and len(self.ohlcv_data) > 0: - complete_fields += 1 - if self.tick_data and len(self.tick_data) > 0: - complete_fields += 1 - if self.cob_data and len(self.cob_data) > 0: - complete_fields += 1 - if self.technical_indicators and len(self.technical_indicators) > 0: - complete_fields += 1 - if self.pivot_points and len(self.pivot_points) > 0: - complete_fields += 1 - if self.cnn_features is not None and self.cnn_features.size > 0: - complete_fields += 1 - if self.rl_state is not None and self.rl_state.size > 0: - complete_fields += 1 - if self.orchestrator_context and len(self.orchestrator_context) > 0: - complete_fields += 1 - if self.cnn_predictions is not None: - complete_fields += 1 - if self.rl_predictions is not None: - complete_fields += 1 - - return complete_fields / total_fields - except Exception as e: - logger.warning(f"Error calculating completeness: {e}") - return 0.0 - - def _validate_data(self) -> Dict[str, bool]: - """Validate data integrity and consistency""" - flags = {} - - try: - # Validate timestamp - flags['valid_timestamp'] = isinstance(self.timestamp, datetime) - - # Validate OHLCV data - flags['valid_ohlcv'] = ( - self.ohlcv_data is not None and - len(self.ohlcv_data) > 0 and - all(isinstance(df, pd.DataFrame) for df in self.ohlcv_data.values()) - ) - - # Validate feature arrays - flags['valid_cnn_features'] = ( - self.cnn_features is not None and - isinstance(self.cnn_features, np.ndarray) and - self.cnn_features.size > 0 - ) - - flags['valid_rl_state'] = ( - self.rl_state is not None and - isinstance(self.rl_state, np.ndarray) and - self.rl_state.size > 0 - ) - - # Validate data consistency - flags['data_consistent'] = self.completeness_score > 0.7 - - except Exception as e: - logger.warning(f"Error validating data: {e}") - flags['validation_error'] = True - - return flags - -@dataclass -class TrainingOutcome: - """Future outcome validation for training data""" - input_package_hash: str - timestamp: datetime - symbol: str - - # Price movement outcomes - price_change_1m: float - price_change_5m: float - price_change_15m: float - price_change_1h: float - - # Profitability metrics - max_profit_potential: float - max_loss_potential: float - optimal_entry_price: float - optimal_exit_price: float - optimal_holding_time: timedelta - - # Classification labels - is_profitable: bool - profitability_score: float # 0.0 to 1.0 - risk_reward_ratio: float - - # Rapid price change detection - is_rapid_change: bool - change_velocity: float # Price change per minute - volatility_spike: bool - - # Validation - outcome_validated: bool = False - validation_timestamp: datetime = field(default_factory=datetime.now) - -@dataclass -class TrainingEpisode: - """Complete training episode with inputs, predictions, and outcomes""" - episode_id: str - input_package: ModelInputPackage - model_predictions: Dict[str, Any] # Predictions from all models - actual_outcome: TrainingOutcome - - # Training metadata - episode_type: str # 'normal', 'rapid_change', 'high_profit' - profitability_rank: float # Ranking among all episodes - training_priority: float # Priority for replay training - - # Backpropagation data storage - gradient_data: Optional[Dict[str, torch.Tensor]] = None - loss_components: Optional[Dict[str, float]] = None - model_states: Optional[Dict[str, Any]] = None - - # Episode statistics - created_timestamp: datetime = field(default_factory=datetime.now) - last_trained_timestamp: Optional[datetime] = None - training_count: int = 0 - - def calculate_training_priority(self) -> float: - """Calculate training priority based on profitability and characteristics""" - try: - priority = 0.0 - - # Base priority from profitability - if self.actual_outcome.is_profitable: - priority += self.actual_outcome.profitability_score * 0.4 - - # Bonus for rapid changes (high learning value) - if self.actual_outcome.is_rapid_change: - priority += 0.3 - - # Bonus for high risk-reward ratio - if self.actual_outcome.risk_reward_ratio > 2.0: - priority += 0.2 - - # Bonus for data completeness - priority += self.input_package.completeness_score * 0.1 - - # Penalty for frequent training (avoid overfitting) - if self.training_count > 5: - priority *= 0.8 - - return min(priority, 1.0) - - except Exception as e: - logger.warning(f"Error calculating training priority: {e}") - return 0.0 - -class RapidChangeDetector: - """Detects rapid price changes for high-value training examples""" - - def __init__(self, - velocity_threshold: float = 0.5, # % per minute - volatility_multiplier: float = 3.0, - lookback_minutes: int = 5): - self.velocity_threshold = velocity_threshold - self.volatility_multiplier = volatility_multiplier - self.lookback_minutes = lookback_minutes - - # Price history for change detection - self.price_history: Dict[str, deque] = {} - self.volatility_baseline: Dict[str, float] = {} - - def add_price_point(self, symbol: str, timestamp: datetime, price: float): - """Add new price point for change detection""" - if symbol not in self.price_history: - self.price_history[symbol] = deque(maxlen=self.lookback_minutes * 60) # 1 second resolution - self.volatility_baseline[symbol] = 0.0 - - self.price_history[symbol].append((timestamp, price)) - self._update_volatility_baseline(symbol) - - def detect_rapid_change(self, symbol: str) -> Tuple[bool, float, bool]: - """ - Detect rapid price changes - - Returns: - (is_rapid_change, change_velocity, volatility_spike) - """ - if symbol not in self.price_history or len(self.price_history[symbol]) < 60: - return False, 0.0, False - - try: - prices = list(self.price_history[symbol]) - - # Calculate recent velocity (last minute) - recent_prices = prices[-60:] # Last 60 seconds - if len(recent_prices) < 2: - return False, 0.0, False - - start_price = recent_prices[0][1] - end_price = recent_prices[-1][1] - time_diff = (recent_prices[-1][0] - recent_prices[0][0]).total_seconds() / 60.0 # minutes - - if time_diff <= 0: - return False, 0.0, False - - # Calculate velocity (% change per minute) - velocity = abs((end_price - start_price) / start_price * 100) / time_diff - - # Check for rapid change - is_rapid = velocity > self.velocity_threshold - - # Check for volatility spike - current_volatility = self._calculate_current_volatility(symbol) - baseline_volatility = self.volatility_baseline.get(symbol, 0.0) - volatility_spike = ( - baseline_volatility > 0 and - current_volatility > baseline_volatility * self.volatility_multiplier - ) - - return is_rapid, velocity, volatility_spike - - except Exception as e: - logger.warning(f"Error detecting rapid change for {symbol}: {e}") - return False, 0.0, False - - def _update_volatility_baseline(self, symbol: str): - """Update volatility baseline for the symbol""" - try: - if len(self.price_history[symbol]) < 120: # Need at least 2 minutes of data - return - - # Calculate rolling volatility over longer period - prices = [p[1] for p in list(self.price_history[symbol])[-300:]] # Last 5 minutes - if len(prices) < 2: - return - - # Calculate standard deviation of price changes - price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))] - volatility = np.std(price_changes) * 100 # Convert to percentage - - # Update baseline with exponential moving average - alpha = 0.1 - if self.volatility_baseline[symbol] == 0: - self.volatility_baseline[symbol] = volatility - else: - self.volatility_baseline[symbol] = ( - alpha * volatility + (1 - alpha) * self.volatility_baseline[symbol] - ) - - except Exception as e: - logger.warning(f"Error updating volatility baseline for {symbol}: {e}") - - def _calculate_current_volatility(self, symbol: str) -> float: - """Calculate current volatility for the symbol""" - try: - if len(self.price_history[symbol]) < 60: - return 0.0 - - # Use last minute of data - recent_prices = [p[1] for p in list(self.price_history[symbol])[-60:]] - if len(recent_prices) < 2: - return 0.0 - - price_changes = [abs(recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1] - for i in range(1, len(recent_prices))] - return np.std(price_changes) * 100 - - except Exception as e: - logger.warning(f"Error calculating current volatility for {symbol}: {e}") - return 0.0 - -class TrainingDataCollector: - """Main training data collection system""" - - def __init__(self, - storage_dir: str = "training_data", - max_episodes_per_symbol: int = 10000, - outcome_validation_delay: timedelta = timedelta(hours=1)): - - self.storage_dir = Path(storage_dir) - self.storage_dir.mkdir(parents=True, exist_ok=True) - - self.max_episodes_per_symbol = max_episodes_per_symbol - self.outcome_validation_delay = outcome_validation_delay - - # Data storage - self.training_episodes: Dict[str, List[TrainingEpisode]] = {} # {symbol: episodes} - self.pending_outcomes: Dict[str, List[ModelInputPackage]] = {} # Awaiting outcome validation - - # Rapid change detection - self.rapid_change_detector = RapidChangeDetector() - - # Data validation and statistics - self.collection_stats = { - 'total_episodes': 0, - 'profitable_episodes': 0, - 'rapid_change_episodes': 0, - 'validation_errors': 0, - 'data_completeness_avg': 0.0 - } - - # Background processing - self.is_collecting = False - self.collection_thread = None - self.outcome_validation_thread = None - - # Thread safety - self.data_lock = threading.Lock() - - logger.info(f"Training Data Collector initialized") - logger.info(f"Storage directory: {self.storage_dir}") - logger.info(f"Max episodes per symbol: {self.max_episodes_per_symbol}") - - def start_collection(self): - """Start the training data collection system""" - if self.is_collecting: - logger.warning("Training data collection already running") - return - - self.is_collecting = True - - # Start outcome validation thread - self.outcome_validation_thread = threading.Thread( - target=self._outcome_validation_worker, - daemon=True - ) - self.outcome_validation_thread.start() - - logger.info("Training data collection started") - - def stop_collection(self): - """Stop the training data collection system""" - self.is_collecting = False - - if self.outcome_validation_thread: - self.outcome_validation_thread.join(timeout=5) - - logger.info("Training data collection stopped") - - def collect_training_data(self, - symbol: str, - ohlcv_data: Dict[str, pd.DataFrame], - tick_data: List[Dict[str, Any]], - cob_data: Dict[str, Any], - technical_indicators: Dict[str, float], - pivot_points: List[Dict[str, Any]], - cnn_features: np.ndarray, - rl_state: np.ndarray, - orchestrator_context: Dict[str, Any], - model_predictions: Dict[str, Any] = None) -> str: - """ - Collect comprehensive training data package - - Returns: - episode_id for tracking - """ - try: - # Create input package - input_package = ModelInputPackage( - timestamp=datetime.now(), - symbol=symbol, - ohlcv_data=ohlcv_data, - tick_data=tick_data, - cob_data=cob_data, - technical_indicators=technical_indicators, - pivot_points=pivot_points, - cnn_features=cnn_features, - rl_state=rl_state, - orchestrator_context=orchestrator_context - ) - - # Validate data completeness - if input_package.completeness_score < 0.5: - logger.warning(f"Low data completeness for {symbol}: {input_package.completeness_score:.2f}") - self.collection_stats['validation_errors'] += 1 - return None - - # Check for rapid price changes - current_price = self._extract_current_price(ohlcv_data) - if current_price: - self.rapid_change_detector.add_price_point(symbol, input_package.timestamp, current_price) - - # Add to pending outcomes for future validation - with self.data_lock: - if symbol not in self.pending_outcomes: - self.pending_outcomes[symbol] = [] - - self.pending_outcomes[symbol].append(input_package) - - # Limit pending outcomes to prevent memory issues - if len(self.pending_outcomes[symbol]) > 1000: - self.pending_outcomes[symbol] = self.pending_outcomes[symbol][-500:] - - # Generate episode ID - episode_id = f"{symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}" - - # Update statistics - self.collection_stats['total_episodes'] += 1 - self.collection_stats['data_completeness_avg'] = ( - (self.collection_stats['data_completeness_avg'] * (self.collection_stats['total_episodes'] - 1) + - input_package.completeness_score) / self.collection_stats['total_episodes'] - ) - - logger.debug(f"Collected training data for {symbol}: {episode_id}") - logger.debug(f"Data completeness: {input_package.completeness_score:.2f}") - - return episode_id - - except Exception as e: - logger.error(f"Error collecting training data for {symbol}: {e}") - self.collection_stats['validation_errors'] += 1 - return None - - def _extract_current_price(self, ohlcv_data: Dict[str, pd.DataFrame]) -> Optional[float]: - """Extract current price from OHLCV data""" - try: - # Try to get price from shortest timeframe first - for timeframe in ['1s', '1m', '5m', '15m', '1h']: - if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty: - return float(ohlcv_data[timeframe]['close'].iloc[-1]) - return None - except Exception as e: - logger.warning(f"Error extracting current price: {e}") - return None - - def _outcome_validation_worker(self): - """Background worker for validating training outcomes""" - logger.info("Outcome validation worker started") - - while self.is_collecting: - try: - self._validate_pending_outcomes() - threading.Event().wait(60) # Check every minute - - except Exception as e: - logger.error(f"Error in outcome validation worker: {e}") - threading.Event().wait(30) # Wait before retrying - - logger.info("Outcome validation worker stopped") - - def _validate_pending_outcomes(self): - """Validate outcomes for pending training data""" - current_time = datetime.now() - - with self.data_lock: - for symbol in list(self.pending_outcomes.keys()): - if symbol not in self.pending_outcomes: - continue - - validated_packages = [] - remaining_packages = [] - - for package in self.pending_outcomes[symbol]: - # Check if enough time has passed for outcome validation - if current_time - package.timestamp >= self.outcome_validation_delay: - outcome = self._calculate_training_outcome(package) - if outcome: - self._create_training_episode(package, outcome) - validated_packages.append(package) - else: - remaining_packages.append(package) - else: - remaining_packages.append(package) - - # Update pending outcomes - self.pending_outcomes[symbol] = remaining_packages - - if validated_packages: - logger.info(f"Validated {len(validated_packages)} outcomes for {symbol}") - - def _calculate_training_outcome(self, input_package: ModelInputPackage) -> Optional[TrainingOutcome]: - """Calculate training outcome based on future price movements""" - try: - # This would typically fetch recent price data to calculate outcomes - # For now, we'll create a placeholder implementation - - # Extract base price from input package - base_price = self._extract_current_price(input_package.ohlcv_data) - if not base_price: - return None - - # Simulate outcome calculation (in real implementation, fetch actual future prices) - # This is where you would integrate with your data provider to get actual outcomes - - # Check for rapid change - is_rapid, velocity, volatility_spike = self.rapid_change_detector.detect_rapid_change( - input_package.symbol - ) - - # Create outcome (placeholder values - replace with actual calculation) - outcome = TrainingOutcome( - input_package_hash=input_package.data_hash, - timestamp=input_package.timestamp, - symbol=input_package.symbol, - price_change_1m=0.0, # Calculate from actual future data - price_change_5m=0.0, - price_change_15m=0.0, - price_change_1h=0.0, - max_profit_potential=0.0, - max_loss_potential=0.0, - optimal_entry_price=base_price, - optimal_exit_price=base_price, - optimal_holding_time=timedelta(minutes=5), - is_profitable=False, # Determine from actual outcomes - profitability_score=0.0, - risk_reward_ratio=1.0, - is_rapid_change=is_rapid, - change_velocity=velocity, - volatility_spike=volatility_spike, - outcome_validated=True - ) - - return outcome - - except Exception as e: - logger.error(f"Error calculating training outcome: {e}") - return None - - def _create_training_episode(self, input_package: ModelInputPackage, outcome: TrainingOutcome): - """Create complete training episode""" - try: - episode_id = f"{input_package.symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}" - - # Determine episode type - episode_type = 'normal' - if outcome.is_rapid_change: - episode_type = 'rapid_change' - self.collection_stats['rapid_change_episodes'] += 1 - elif outcome.profitability_score > 0.8: - episode_type = 'high_profit' - - if outcome.is_profitable: - self.collection_stats['profitable_episodes'] += 1 - - # Create training episode - episode = TrainingEpisode( - episode_id=episode_id, - input_package=input_package, - model_predictions={}, # Will be filled when models make predictions - actual_outcome=outcome, - episode_type=episode_type, - profitability_rank=0.0, # Will be calculated later - training_priority=0.0 - ) - - # Calculate training priority - episode.training_priority = episode.calculate_training_priority() - - # Store episode - symbol = input_package.symbol - if symbol not in self.training_episodes: - self.training_episodes[symbol] = [] - - self.training_episodes[symbol].append(episode) - - # Limit episodes per symbol - if len(self.training_episodes[symbol]) > self.max_episodes_per_symbol: - # Keep highest priority episodes - self.training_episodes[symbol].sort(key=lambda x: x.training_priority, reverse=True) - self.training_episodes[symbol] = self.training_episodes[symbol][:self.max_episodes_per_symbol] - - # Save episode to disk - self._save_episode_to_disk(episode) - - logger.debug(f"Created training episode: {episode_id}") - logger.debug(f"Episode type: {episode_type}, Priority: {episode.training_priority:.3f}") - - except Exception as e: - logger.error(f"Error creating training episode: {e}") - - def _save_episode_to_disk(self, episode: TrainingEpisode): - """Save training episode to disk for persistence""" - try: - symbol_dir = self.storage_dir / episode.input_package.symbol - symbol_dir.mkdir(parents=True, exist_ok=True) - - # Save episode data - episode_file = symbol_dir / f"{episode.episode_id}.pkl" - with open(episode_file, 'wb') as f: - pickle.dump(episode, f) - - # Save episode metadata for quick access - metadata = { - 'episode_id': episode.episode_id, - 'timestamp': episode.input_package.timestamp.isoformat(), - 'episode_type': episode.episode_type, - 'training_priority': episode.training_priority, - 'profitability_score': episode.actual_outcome.profitability_score, - 'is_profitable': episode.actual_outcome.is_profitable, - 'is_rapid_change': episode.actual_outcome.is_rapid_change, - 'data_completeness': episode.input_package.completeness_score - } - - metadata_file = symbol_dir / f"{episode.episode_id}_metadata.json" - with open(metadata_file, 'w') as f: - json.dump(metadata, f, indent=2) - - except Exception as e: - logger.error(f"Error saving episode to disk: {e}") - - def get_high_priority_episodes(self, - symbol: str, - limit: int = 100, - min_priority: float = 0.5) -> List[TrainingEpisode]: - """Get high-priority training episodes for replay training""" - try: - if symbol not in self.training_episodes: - return [] - - # Filter and sort by priority - high_priority = [ - ep for ep in self.training_episodes[symbol] - if ep.training_priority >= min_priority - ] - - high_priority.sort(key=lambda x: x.training_priority, reverse=True) - - return high_priority[:limit] - - except Exception as e: - logger.error(f"Error getting high priority episodes for {symbol}: {e}") - return [] - - def get_collection_statistics(self) -> Dict[str, Any]: - """Get comprehensive collection statistics""" - stats = self.collection_stats.copy() - - # Add per-symbol statistics - stats['episodes_per_symbol'] = { - symbol: len(episodes) - for symbol, episodes in self.training_episodes.items() - } - - # Add pending outcomes count - stats['pending_outcomes'] = { - symbol: len(packages) - for symbol, packages in self.pending_outcomes.items() - } - - # Calculate profitability rate - if stats['total_episodes'] > 0: - stats['profitability_rate'] = stats['profitable_episodes'] / stats['total_episodes'] - stats['rapid_change_rate'] = stats['rapid_change_episodes'] / stats['total_episodes'] - else: - stats['profitability_rate'] = 0.0 - stats['rapid_change_rate'] = 0.0 - - return stats - - def validate_data_integrity(self) -> Dict[str, Any]: - """Comprehensive data integrity validation""" - validation_results = { - 'total_episodes_checked': 0, - 'hash_mismatches': 0, - 'completeness_issues': 0, - 'validation_flag_failures': 0, - 'corrupted_episodes': [], - 'integrity_score': 1.0 - } - - try: - for symbol, episodes in self.training_episodes.items(): - for episode in episodes: - validation_results['total_episodes_checked'] += 1 - - # Check data hash - expected_hash = episode.input_package._calculate_hash() - if expected_hash != episode.input_package.data_hash: - validation_results['hash_mismatches'] += 1 - validation_results['corrupted_episodes'].append(episode.episode_id) - - # Check completeness - if episode.input_package.completeness_score < 0.7: - validation_results['completeness_issues'] += 1 - - # Check validation flags - if not episode.input_package.validation_flags.get('data_consistent', False): - validation_results['validation_flag_failures'] += 1 - - # Calculate integrity score - total_issues = ( - validation_results['hash_mismatches'] + - validation_results['completeness_issues'] + - validation_results['validation_flag_failures'] - ) - - if validation_results['total_episodes_checked'] > 0: - validation_results['integrity_score'] = 1.0 - ( - total_issues / validation_results['total_episodes_checked'] - ) - - logger.info(f"Data integrity validation completed") - logger.info(f"Integrity score: {validation_results['integrity_score']:.3f}") - - except Exception as e: - logger.error(f"Error during data integrity validation: {e}") - validation_results['validation_error'] = str(e) - - return validation_results - -# Global instance for easy access -training_data_collector = None - -def get_training_data_collector() -> TrainingDataCollector: - """Get global training data collector instance""" - global training_data_collector - if training_data_collector is None: - training_data_collector = TrainingDataCollector() - return training_data_collector \ No newline at end of file diff --git a/dataprovider_realtime.py b/dataprovider_realtime.py deleted file mode 100644 index 2072041..0000000 --- a/dataprovider_realtime.py +++ /dev/null @@ -1,2490 +0,0 @@ -# import asyncio -# import json -# import logging - -# # Fix PIL import issue that causes plotly JSON serialization errors -# import os -# os.environ['MPLBACKEND'] = 'Agg' # Use non-interactive backend -# try: -# # Try to fix PIL import issue -# import PIL.Image -# # Disable PIL in plotly to prevent circular import issues -# import plotly.io as pio -# pio.kaleido.scope.default_format = "png" -# except ImportError: -# pass -# except Exception: -# # Suppress any PIL-related errors during import -# pass - -# from typing import Dict, List, Optional, Tuple, Union -# import websockets -# import plotly.graph_objects as go -# from plotly.subplots import make_subplots -# import dash -# from dash import html, dcc -# from dash.dependencies import Input, Output -# import pandas as pd -# import numpy as np -# from collections import deque -# import time -# from threading import Thread -# import requests -# import os -# from datetime import datetime, timedelta -# import pytz -# import tzlocal -# import threading -# import random -# import dash_bootstrap_components as dbc -# import uuid -# import ta -# from sklearn.preprocessing import MinMaxScaler -# import re -# import psutil -# import gc -# import websocket - -# # Import psycopg2 with error handling -# try: -# import psycopg2 -# PSYCOPG2_AVAILABLE = True -# except ImportError: -# PSYCOPG2_AVAILABLE = False -# psycopg2 = None - -# # TimescaleDB configuration from environment variables -# TIMESCALEDB_ENABLED = os.environ.get('TIMESCALEDB_ENABLED', '1') == '1' and PSYCOPG2_AVAILABLE -# TIMESCALEDB_HOST = os.environ.get('TIMESCALEDB_HOST', '192.168.0.10') -# TIMESCALEDB_PORT = int(os.environ.get('TIMESCALEDB_PORT', '5432')) -# TIMESCALEDB_USER = os.environ.get('TIMESCALEDB_USER', 'postgres') -# TIMESCALEDB_PASSWORD = os.environ.get('TIMESCALEDB_PASSWORD', 'timescaledbpass') -# TIMESCALEDB_DB = os.environ.get('TIMESCALEDB_DB', 'candles') - -# class TimescaleDBHandler: -# """Handler for TimescaleDB operations for candle storage and retrieval""" - -# def __init__(self): -# """Initialize TimescaleDB connection if enabled""" -# self.enabled = TIMESCALEDB_ENABLED -# self.conn = None - -# if not self.enabled: -# if not PSYCOPG2_AVAILABLE: -# print("psycopg2 module not available. TimescaleDB integration disabled.") -# return - -# try: -# # Connect to TimescaleDB -# self.conn = psycopg2.connect( -# host=TIMESCALEDB_HOST, -# port=TIMESCALEDB_PORT, -# user=TIMESCALEDB_USER, -# password=TIMESCALEDB_PASSWORD, -# dbname=TIMESCALEDB_DB -# ) -# print(f"Connected to TimescaleDB at {TIMESCALEDB_HOST}:{TIMESCALEDB_PORT}") - -# # Ensure the candles table exists -# self._ensure_table() - -# print("TimescaleDB integration initialized successfully") -# except Exception as e: -# print(f"Error connecting to TimescaleDB: {str(e)}") -# self.enabled = False -# self.conn = None - -# def _ensure_table(self): -# """Ensure the candles table exists with TimescaleDB hypertable""" -# if not self.conn: -# return - -# try: -# with self.conn.cursor() as cur: -# # Create the candles table if it doesn't exist -# cur.execute(''' -# CREATE TABLE IF NOT EXISTS candles ( -# symbol TEXT, -# interval TEXT, -# timestamp TIMESTAMPTZ, -# open DOUBLE PRECISION, -# high DOUBLE PRECISION, -# low DOUBLE PRECISION, -# close DOUBLE PRECISION, -# volume DOUBLE PRECISION, -# PRIMARY KEY (symbol, interval, timestamp) -# ); -# ''') - -# # Check if the table is already a hypertable -# cur.execute(''' -# SELECT EXISTS ( -# SELECT 1 FROM timescaledb_information.hypertables -# WHERE hypertable_name = 'candles' -# ); -# ''') -# is_hypertable = cur.fetchone()[0] - -# # Convert to hypertable if not already done -# if not is_hypertable: -# cur.execute(''' -# SELECT create_hypertable('candles', 'timestamp', -# if_not_exists => TRUE, -# migrate_data => TRUE -# ); -# ''') - -# self.conn.commit() -# print("TimescaleDB table structure verified") -# except Exception as e: -# print(f"Error setting up TimescaleDB tables: {str(e)}") -# self.enabled = False - -# def upsert_candle(self, symbol, interval, candle): -# """Insert or update a candle in TimescaleDB""" -# if not self.enabled or not self.conn: -# return False - -# try: -# with self.conn.cursor() as cur: -# cur.execute(''' -# INSERT INTO candles ( -# symbol, interval, timestamp, -# open, high, low, close, volume -# ) -# VALUES (%s, %s, %s, %s, %s, %s, %s, %s) -# ON CONFLICT (symbol, interval, timestamp) -# DO UPDATE SET -# open = EXCLUDED.open, -# high = EXCLUDED.high, -# low = EXCLUDED.low, -# close = EXCLUDED.close, -# volume = EXCLUDED.volume -# ''', ( -# symbol, interval, candle['timestamp'], -# candle['open'], candle['high'], candle['low'], -# candle['close'], candle['volume'] -# )) -# self.conn.commit() -# return True -# except Exception as e: -# print(f"Error upserting candle to TimescaleDB: {str(e)}") -# # Try to reconnect on error -# try: -# self.conn = psycopg2.connect( -# host=TIMESCALEDB_HOST, -# port=TIMESCALEDB_PORT, -# user=TIMESCALEDB_USER, -# password=TIMESCALEDB_PASSWORD, -# dbname=TIMESCALEDB_DB -# ) -# except: -# pass -# return False - -# def fetch_candles(self, symbol, interval, limit=1000): -# """Fetch candles from TimescaleDB""" -# if not self.enabled or not self.conn: -# return [] - -# try: -# with self.conn.cursor() as cur: -# cur.execute(''' -# SELECT timestamp, open, high, low, close, volume -# FROM candles -# WHERE symbol = %s AND interval = %s -# ORDER BY timestamp DESC -# LIMIT %s -# ''', (symbol, interval, limit)) - -# rows = cur.fetchall() - -# # Convert to list of dictionaries (ordered from oldest to newest) -# candles = [] -# for row in reversed(rows): # Reverse to get oldest first -# candle = { -# 'timestamp': row[0], -# 'open': row[1], -# 'high': row[2], -# 'low': row[3], -# 'close': row[4], -# 'volume': row[5] -# } -# candles.append(candle) - -# return candles -# except Exception as e: -# print(f"Error fetching candles from TimescaleDB: {str(e)}") -# # Try to reconnect on error -# try: -# self.conn = psycopg2.connect( -# host=TIMESCALEDB_HOST, -# port=TIMESCALEDB_PORT, -# user=TIMESCALEDB_USER, -# password=TIMESCALEDB_PASSWORD, -# dbname=TIMESCALEDB_DB -# ) -# except: -# pass -# return [] - -# class BinanceHistoricalData: -# """ -# Class for fetching historical price data from Binance. -# """ -# def __init__(self): -# self.base_url = "https://api.binance.com/api/v3" -# self.cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache') -# if not os.path.exists(self.cache_dir): -# os.makedirs(self.cache_dir) -# # Timestamp of last data update -# self.last_update = None - -# def get_historical_candles(self, symbol, interval_seconds=3600, limit=1000): -# """ -# Fetch historical candles from Binance API. - -# Args: -# symbol (str): Trading pair symbol (e.g., "BTC/USDT") -# interval_seconds (int): Timeframe in seconds (e.g., 3600 for 1h) -# limit (int): Number of candles to fetch - -# Returns: -# pd.DataFrame: DataFrame with OHLCV data -# """ -# # Convert interval_seconds to Binance interval format -# interval_map = { -# 1: "1s", -# 60: "1m", -# 300: "5m", -# 900: "15m", -# 1800: "30m", -# 3600: "1h", -# 14400: "4h", -# 86400: "1d" -# } - -# interval = interval_map.get(interval_seconds, "1h") - -# # Format symbol for Binance API (remove slash and make uppercase) -# formatted_symbol = symbol.replace("/", "").upper() - -# # Check if we have cached data first -# cache_file = self._get_cache_filename(formatted_symbol, interval) -# cached_data = self._load_from_cache(formatted_symbol, interval) - -# # If we have cached data that's recent enough, use it -# if cached_data is not None and len(cached_data) >= limit: -# cache_age_minutes = (datetime.now() - self.last_update).total_seconds() / 60 if self.last_update else 60 -# if cache_age_minutes < 15: # Only use cache if it's less than 15 minutes old -# logger.info(f"Using cached historical data for {symbol} ({interval})") -# return cached_data - -# try: -# # Build URL for klines endpoint -# url = f"{self.base_url}/klines" -# params = { -# "symbol": formatted_symbol, -# "interval": interval, -# "limit": limit -# } - -# # Make the request -# response = requests.get(url, params=params) -# response.raise_for_status() - -# # Parse the response -# data = response.json() - -# # Create dataframe -# df = pd.DataFrame(data, columns=[ -# "timestamp", "open", "high", "low", "close", "volume", -# "close_time", "quote_asset_volume", "number_of_trades", -# "taker_buy_base_asset_volume", "taker_buy_quote_asset_volume", "ignore" -# ]) - -# # Convert timestamp to datetime -# df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms") - -# # Convert price columns to float -# for col in ["open", "high", "low", "close", "volume"]: -# df[col] = df[col].astype(float) - -# # Sort by timestamp -# df = df.sort_values("timestamp") - -# # Save to cache for future use -# self._save_to_cache(df, formatted_symbol, interval) -# self.last_update = datetime.now() - -# logger.info(f"Fetched {len(df)} candles for {symbol} ({interval})") -# return df - -# except Exception as e: -# logger.error(f"Error fetching historical data from Binance: {str(e)}") -# # Return cached data if we have it, even if it's not enough -# if cached_data is not None: -# logger.warning(f"Using cached data instead (may be incomplete)") -# return cached_data -# # Return empty dataframe on error -# return pd.DataFrame() - -# def _get_cache_filename(self, symbol, interval): -# """Get filename for cache file""" -# return os.path.join(self.cache_dir, f"{symbol}_{interval}_candles.csv") - -# def _load_from_cache(self, symbol, interval): -# """Load candles from cache file""" -# try: -# cache_file = self._get_cache_filename(symbol, interval) -# if os.path.exists(cache_file): -# # For 1s interval, check if the cache is recent (less than 10 minutes old) -# if interval == "1s" or interval == 1: -# file_mod_time = datetime.fromtimestamp(os.path.getmtime(cache_file)) -# time_diff = (datetime.now() - file_mod_time).total_seconds() / 60 -# if time_diff > 10: -# logger.info("1s cache is older than 10 minutes, skipping load") -# return None -# logger.info(f"Using recent 1s cache (age: {time_diff:.1f} minutes)") - -# df = pd.read_csv(cache_file) -# df["timestamp"] = pd.to_datetime(df["timestamp"]) -# logger.info(f"Loaded {len(df)} candles from cache: {cache_file}") -# return df -# except Exception as e: -# logger.error(f"Error loading cached data: {str(e)}") -# return None - -# def _save_to_cache(self, df, symbol, interval): -# """Save candles to cache file""" -# try: -# cache_file = self._get_cache_filename(symbol, interval) -# df.to_csv(cache_file, index=False) -# logger.info(f"Saved {len(df)} candles to cache: {cache_file}") -# return True -# except Exception as e: -# logger.error(f"Error saving to cache: {str(e)}") -# return False - -# def get_recent_trades(self, symbol, limit=1000): -# """Get recent trades for a symbol""" -# formatted_symbol = symbol.replace("/", "") - -# try: -# url = f"{self.base_url}/trades" -# params = { -# "symbol": formatted_symbol, -# "limit": limit -# } - -# response = requests.get(url, params=params) -# response.raise_for_status() - -# data = response.json() - -# # Create dataframe -# df = pd.DataFrame(data) -# df["time"] = pd.to_datetime(df["time"], unit="ms") -# df["price"] = df["price"].astype(float) -# df["qty"] = df["qty"].astype(float) - -# return df - -# except Exception as e: -# logger.error(f"Error fetching recent trades: {str(e)}") -# return pd.DataFrame() - -# class MultiTimeframeDataInterface: -# """ -# Enhanced Data Interface supporting: -# - Multiple trading pairs -# - Multiple timeframes per pair (1s, 1m, 1h, 1d + custom) -# - Technical indicators -# - Cross-timeframe normalization -# - Real-time data updates -# """ - -# def __init__(self, symbol=None, timeframes=None, data_dir="data"): -# """ -# Initialize the data interface. - -# Args: -# symbol (str): Trading pair symbol (e.g., "BTC/USDT") -# timeframes (list): List of timeframes to use (e.g., ['1m', '5m', '1h', '4h', '1d']) -# data_dir (str): Directory to store/load datasets -# """ -# self.symbol = symbol -# self.timeframes = timeframes or ['1h', '4h', '1d'] -# self.data_dir = data_dir -# self.scalers = {} # Store scalers for each timeframe - -# # Initialize the historical data fetcher -# self.historical_data = BinanceHistoricalData() - -# # Create data directory if it doesn't exist -# os.makedirs(self.data_dir, exist_ok=True) - -# # Initialize empty dataframes for each timeframe -# self.dataframes = {tf: None for tf in self.timeframes} - -# # Store timestamps of last updates per timeframe -# self.last_updates = {tf: None for tf in self.timeframes} - -# # Timeframe mapping (string to seconds) -# self.timeframe_to_seconds = { -# '1s': 1, -# '1m': 60, -# '5m': 300, -# '15m': 900, -# '30m': 1800, -# '1h': 3600, -# '4h': 14400, -# '1d': 86400 -# } - -# logger.info(f"MultiTimeframeDataInterface initialized for {symbol} with timeframes {timeframes}") - -# def get_data(self, timeframe='1h', n_candles=1000, refresh=False, add_indicators=True): -# """ -# Fetch historical price data for a given timeframe with optional indicators. - -# Args: -# timeframe (str): Timeframe to fetch data for -# n_candles (int): Number of candles to fetch -# refresh (bool): Force refresh of the data -# add_indicators (bool): Whether to add technical indicators - -# Returns: -# pd.DataFrame: DataFrame with OHLCV data and indicators -# """ -# # Check if we need to refresh -# current_time = datetime.now() - -# if (not refresh and -# self.dataframes[timeframe] is not None and -# self.last_updates[timeframe] is not None and -# (current_time - self.last_updates[timeframe]).total_seconds() < 60): -# #logger.info(f"Using cached data for {self.symbol} {timeframe}") -# return self.dataframes[timeframe] - -# interval_seconds = self.timeframe_to_seconds.get(timeframe, 3600) - -# # Fetch data -# df = self.historical_data.get_historical_candles( -# symbol=self.symbol, -# interval_seconds=interval_seconds, -# limit=n_candles -# ) - -# if df is None or df.empty: -# logger.error(f"No data available for {self.symbol} {timeframe}") -# return None - -# # Add indicators if requested -# if add_indicators: -# df = self.add_indicators(df) - -# # Store in cache -# self.dataframes[timeframe] = df -# self.last_updates[timeframe] = current_time - -# logger.info(f"Fetched and processed {len(df)} candles for {self.symbol} {timeframe}") -# return df - -# def add_indicators(self, df): -# """ -# Add comprehensive technical indicators to the dataframe. - -# Args: -# df (pd.DataFrame): DataFrame with OHLCV data - -# Returns: -# pd.DataFrame: DataFrame with added technical indicators -# """ -# # Make a copy to avoid modifying the original -# df_copy = df.copy() - -# # Basic price indicators -# df_copy['returns'] = df_copy['close'].pct_change() -# df_copy['log_returns'] = np.log(df_copy['close'] / df_copy['close'].shift(1)) - -# # Moving Averages -# df_copy['sma_7'] = ta.trend.sma_indicator(df_copy['close'], window=7) -# df_copy['sma_25'] = ta.trend.sma_indicator(df_copy['close'], window=25) -# df_copy['sma_99'] = ta.trend.sma_indicator(df_copy['close'], window=99) -# df_copy['ema_9'] = ta.trend.ema_indicator(df_copy['close'], window=9) -# df_copy['ema_21'] = ta.trend.ema_indicator(df_copy['close'], window=21) - -# # MACD -# macd = ta.trend.MACD(df_copy['close']) -# df_copy['macd'] = macd.macd() -# df_copy['macd_signal'] = macd.macd_signal() -# df_copy['macd_diff'] = macd.macd_diff() - -# # RSI -# df_copy['rsi'] = ta.momentum.rsi(df_copy['close'], window=14) - -# # Bollinger Bands -# bollinger = ta.volatility.BollingerBands(df_copy['close']) -# df_copy['bb_high'] = bollinger.bollinger_hband() -# df_copy['bb_low'] = bollinger.bollinger_lband() -# df_copy['bb_pct'] = bollinger.bollinger_pband() - -# # Stochastic Oscillator -# stoch = ta.momentum.StochasticOscillator(df_copy['high'], df_copy['low'], df_copy['close']) -# df_copy['stoch_k'] = stoch.stoch() -# df_copy['stoch_d'] = stoch.stoch_signal() - -# # ATR - Average True Range -# df_copy['atr'] = ta.volatility.average_true_range(df_copy['high'], df_copy['low'], df_copy['close'], window=14) - -# # Money Flow Index -# df_copy['mfi'] = ta.volume.money_flow_index(df_copy['high'], df_copy['low'], df_copy['close'], df_copy['volume'], window=14) - -# # OBV - On-Balance Volume -# df_copy['obv'] = ta.volume.on_balance_volume(df_copy['close'], df_copy['volume']) - -# # Ichimoku Cloud -# ichimoku = ta.trend.IchimokuIndicator(df_copy['high'], df_copy['low']) -# df_copy['ichimoku_a'] = ichimoku.ichimoku_a() -# df_copy['ichimoku_b'] = ichimoku.ichimoku_b() -# df_copy['ichimoku_base'] = ichimoku.ichimoku_base_line() -# df_copy['ichimoku_conv'] = ichimoku.ichimoku_conversion_line() - -# # ADX - Average Directional Index -# adx = ta.trend.ADXIndicator(df_copy['high'], df_copy['low'], df_copy['close']) -# df_copy['adx'] = adx.adx() -# df_copy['adx_pos'] = adx.adx_pos() -# df_copy['adx_neg'] = adx.adx_neg() - -# # VWAP - Volume Weighted Average Price (intraday) -# # Custom calculation since TA library doesn't include VWAP -# df_copy['vwap'] = (df_copy['volume'] * (df_copy['high'] + df_copy['low'] + df_copy['close']) / 3).cumsum() / df_copy['volume'].cumsum() - -# # Fill NaN values -# df_copy = df_copy.fillna(method='bfill').fillna(0) - -# return df_copy - -# def get_multi_timeframe_data(self, timeframes=None, n_candles=1000, refresh=False, add_indicators=True): -# """ -# Fetch data for multiple timeframes. - -# Args: -# timeframes (list): List of timeframes to fetch -# n_candles (int): Number of candles to fetch for each timeframe -# refresh (bool): Force refresh of the data -# add_indicators (bool): Whether to add technical indicators - -# Returns: -# dict: Dictionary of dataframes indexed by timeframe -# """ -# if timeframes is None: -# timeframes = self.timeframes - -# result = {} - -# for tf in timeframes: -# # For higher timeframes, we need fewer candles -# tf_candles = n_candles -# if tf == '4h': -# tf_candles = max(250, n_candles // 4) -# elif tf == '1d': -# tf_candles = max(100, n_candles // 24) - -# df = self.get_data(timeframe=tf, n_candles=tf_candles, refresh=refresh, add_indicators=add_indicators) -# if df is not None and not df.empty: -# result[tf] = df - -# return result - -# def prepare_training_data(self, window_size=20, train_ratio=0.8, refresh=False): -# """ -# Prepare training data from multiple timeframes. - -# Args: -# window_size (int): Size of the sliding window -# train_ratio (float): Ratio of data to use for training -# refresh (bool): Whether to refresh the data - -# Returns: -# tuple: (X_train, y_train, X_val, y_val, train_prices, val_prices) -# """ -# # Get data for all timeframes -# data_dict = self.get_multi_timeframe_data(refresh=refresh) - -# if not data_dict: -# logger.error("Failed to fetch data for any timeframe") -# return None, None, None, None, None, None - -# # Align all dataframes by timestamp -# all_dfs = list(data_dict.values()) -# min_date = max([df['timestamp'].min() for df in all_dfs]) -# max_date = min([df['timestamp'].max() for df in all_dfs]) - -# aligned_dfs = {} -# for tf, df in data_dict.items(): -# aligned_df = df[(df['timestamp'] >= min_date) & (df['timestamp'] <= max_date)] -# aligned_dfs[tf] = aligned_df - -# # Choose the lowest timeframe as the reference for time alignment -# reference_tf = min(self.timeframes, key=lambda x: self.timeframe_to_seconds.get(x, 3600)) -# reference_df = aligned_dfs[reference_tf] - -# # Create sliding windows for each timeframe -# X_dict = {} -# for tf, df in aligned_dfs.items(): -# # Drop timestamp and create numeric features -# features = df.drop('timestamp', axis=1).values - -# # Ensure the feature array is 3D: [samples, window, features] -# X = np.array([features[i:i+window_size] for i in range(len(features)-window_size)]) -# X_dict[tf] = X - -# # Create target labels based on future price movements -# reference_prices = reference_df['close'].values -# future_prices = reference_prices[window_size:] -# current_prices = reference_prices[window_size-1:-1] - -# # Calculate returns -# returns = (future_prices - current_prices) / current_prices - -# # Create labels: 0=SELL, 1=HOLD, 2=BUY -# threshold = 0.0005 # 0.05% threshold -# y = np.zeros(len(returns), dtype=int) -# y[returns > threshold] = 2 # BUY -# y[returns < -threshold] = 0 # SELL -# y[(returns >= -threshold) & (returns <= threshold)] = 1 # HOLD - -# # Split into training and validation sets -# split_idx = int(len(y) * train_ratio) - -# X_train_dict = {tf: X[:split_idx] for tf, X in X_dict.items()} -# X_val_dict = {tf: X[split_idx:] for tf, X in X_dict.items()} - -# y_train = y[:split_idx] -# y_val = y[split_idx:] - -# train_prices = reference_prices[window_size-1:window_size-1+split_idx] -# val_prices = reference_prices[window_size-1+split_idx:window_size-1+len(y)] - -# logger.info(f"Prepared training data - Train: {len(y_train)}, Val: {len(y_val)}") - -# return X_train_dict, y_train, X_val_dict, y_val, train_prices, val_prices - -# def normalize_data(self, data_dict, fit=True): -# """ -# Normalize data across all timeframes. - -# Args: -# data_dict (dict): Dictionary of data arrays by timeframe -# fit (bool): Whether to fit new scalers or use existing ones - -# Returns: -# dict: Dictionary of normalized data arrays -# """ -# result = {} - -# for tf, data in data_dict.items(): -# # For 3D data [samples, window, features] -# if len(data.shape) == 3: -# samples, window, features = data.shape -# reshaped = data.reshape(-1, features) - -# if fit or tf not in self.scalers: -# self.scalers[tf] = MinMaxScaler() -# normalized = self.scalers[tf].fit_transform(reshaped) -# else: -# normalized = self.scalers[tf].transform(reshaped) - -# result[tf] = normalized.reshape(samples, window, features) - -# # For 2D data [samples, features] -# elif len(data.shape) == 2: -# if fit or tf not in self.scalers: -# self.scalers[tf] = MinMaxScaler() -# result[tf] = self.scalers[tf].fit_transform(data) -# else: -# result[tf] = self.scalers[tf].transform(data) - -# return result - -# def get_realtime_features(self, timeframes=None, window_size=20): -# """ -# Get the most recent data for real-time prediction. - -# Args: -# timeframes (list): List of timeframes to use -# window_size (int): Size of the sliding window - -# Returns: -# dict: Dictionary of feature arrays for the latest window -# """ -# if timeframes is None: -# timeframes = self.timeframes - -# # Get fresh data -# data_dict = self.get_multi_timeframe_data(timeframes=timeframes, refresh=True) - -# result = {} -# for tf, df in data_dict.items(): -# if len(df) < window_size: -# logger.warning(f"Not enough data for {tf} (need {window_size}, got {len(df)})") -# continue - -# # Get the latest window -# latest_data = df.tail(window_size).drop('timestamp', axis=1).values - -# # Add extra dimension to match model input shape [1, window_size, features] -# result[tf] = latest_data.reshape(1, window_size, -1) - -# # Apply normalization using existing scalers -# if self.scalers: -# result = self.normalize_data(result, fit=False) - -# return result - -# def calculate_pnl(self, predictions, prices, position_size=1.0, fee_rate=0.0002): -# """ -# Calculate PnL and win rate from predictions. - -# Args: -# predictions (np.ndarray): Array of predicted actions (0=SELL, 1=HOLD, 2=BUY) -# prices (np.ndarray): Array of prices -# position_size (float): Size of each position -# fee_rate (float): Trading fee rate (default: 0.0002 for 0.02% per trade) - -# Returns: -# tuple: (total_pnl, win_rate, trades) -# """ -# if len(predictions) < 2 or len(prices) < 2: -# return 0.0, 0.0, [] - -# # Ensure arrays are the same length -# min_len = min(len(predictions), len(prices)-1) -# actions = predictions[:min_len] - -# pnl = 0.0 -# wins = 0 -# trades = [] - -# for i in range(min_len): -# current_price = prices[i] -# next_price = prices[i+1] -# action = actions[i] - -# # Skip HOLD actions -# if action == 1: -# continue - -# price_change = (next_price - current_price) / current_price - -# if action == 2: # BUY -# # Calculate raw PnL -# raw_pnl = price_change * position_size - -# # Calculate fees (entry and exit) -# entry_fee = position_size * fee_rate -# exit_fee = position_size * (1 + price_change) * fee_rate -# total_fees = entry_fee + exit_fee - -# # Net PnL after fees -# trade_pnl = raw_pnl - total_fees - -# trade_type = 'BUY' -# is_win = trade_pnl > 0 -# elif action == 0: # SELL -# # Calculate raw PnL -# raw_pnl = -price_change * position_size - -# # Calculate fees (entry and exit) -# entry_fee = position_size * fee_rate -# exit_fee = position_size * (1 - price_change) * fee_rate -# total_fees = entry_fee + exit_fee - -# # Net PnL after fees -# trade_pnl = raw_pnl - total_fees - -# trade_type = 'SELL' -# is_win = trade_pnl > 0 -# else: -# continue - -# pnl += trade_pnl -# wins += int(is_win) - -# trades.append({ -# 'type': trade_type, -# 'entry': float(current_price), # Ensure serializable -# 'exit': float(next_price), -# 'raw_pnl': float(raw_pnl), -# 'fees': float(total_fees), -# 'pnl': float(trade_pnl), -# 'win': bool(is_win), -# 'timestamp': datetime.now().isoformat() # Add timestamp -# }) - -# win_rate = wins / len(trades) if trades else 0.0 - -# return float(pnl), float(win_rate), trades - -# # Configure logging with more detailed format -# logging.basicConfig( -# level=logging.INFO, # Changed to DEBUG for more detailed logs -# format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', -# handlers=[ -# logging.StreamHandler(), -# logging.FileHandler('realtime_chart.log') -# ] -# ) -# logger = logging.getLogger(__name__) - -# # Neural Network integration (conditional import) -# NN_ENABLED = os.environ.get('ENABLE_NN_MODELS', '0') == '1' -# nn_orchestrator = None -# nn_inference_thread = None - -# if NN_ENABLED: -# try: -# import sys -# # Add project root to sys.path if needed -# project_root = os.path.dirname(os.path.abspath(__file__)) -# if project_root not in sys.path: -# sys.path.append(project_root) - -# from NN.main import NeuralNetworkOrchestrator -# logger.info("Neural Network module enabled") -# except ImportError as e: -# logger.warning(f"Failed to import Neural Network module, disabling NN features: {str(e)}") -# NN_ENABLED = False - -# # NN utility functions -# def setup_neural_network(): -# """Initialize the neural network components if enabled""" -# global nn_orchestrator, NN_ENABLED - -# if not NN_ENABLED: -# return False - -# try: -# # Get configuration from environment variables or use defaults -# symbol = os.environ.get('NN_SYMBOL', 'ETH/USDT') -# timeframes = os.environ.get('NN_TIMEFRAMES', '1m,5m,1h,4h,1d').split(',') -# output_size = int(os.environ.get('NN_OUTPUT_SIZE', '3')) # 3 for BUY/HOLD/SELL - -# # Configure the orchestrator -# config = { -# 'symbol': symbol, -# 'timeframes': timeframes, -# 'window_size': int(os.environ.get('NN_WINDOW_SIZE', '20')), -# 'n_features': 5, # OHLCV -# 'output_size': output_size, -# 'model_dir': 'NN/models/saved', -# 'data_dir': 'NN/data' -# } - -# # Initialize the orchestrator -# logger.info(f"Initializing Neural Network Orchestrator with config: {config}") -# nn_orchestrator = NeuralNetworkOrchestrator(config) - -# # Load the model -# model_loaded = nn_orchestrator.load_model() -# if not model_loaded: -# logger.warning("Failed to load neural network model. Using untrained model.") - -# return model_loaded -# except Exception as e: -# logger.error(f"Error setting up neural network: {str(e)}") -# NN_ENABLED = False -# return False - -# def start_nn_inference_thread(interval_seconds): -# """Start a background thread to periodically run inference with the neural network""" -# global nn_inference_thread - -# if not NN_ENABLED or nn_orchestrator is None: -# logger.warning("Cannot start inference thread - Neural Network not enabled or initialized") -# return False - -# def inference_worker(): -# """Worker function for the inference thread""" -# model_type = os.environ.get('NN_MODEL_TYPE', 'cnn') -# timeframe = os.environ.get('NN_TIMEFRAME', '1h') - -# logger.info(f"Starting neural network inference thread with {interval_seconds}s interval") -# logger.info(f"Using model type: {model_type}, timeframe: {timeframe}") - -# # Wait a bit for charts to initialize -# time.sleep(5) - -# # Track active charts -# active_charts = [] - -# while True: -# try: -# # Find active charts if we don't have them yet -# if not active_charts and 'charts' in globals(): -# active_charts = globals()['charts'] -# logger.info(f"Found {len(active_charts)} active charts for NN signals") - -# # Run inference -# result = nn_orchestrator.run_inference_pipeline( -# model_type=model_type, -# timeframe=timeframe -# ) - -# if result: -# # Log the result -# logger.info(f"Neural network inference result: {result}") - -# # Add signal to charts -# if active_charts: -# try: -# if 'action' in result: -# action = result['action'] -# timestamp = datetime.fromisoformat(result['timestamp'].replace('Z', '+00:00')) - -# # Get probability if available -# probability = None -# if 'probability' in result: -# probability = result['probability'] -# elif 'probabilities' in result: -# probability = result['probabilities'].get(action, None) - -# # Add signal to each chart -# for chart in active_charts: -# if hasattr(chart, 'add_nn_signal'): -# chart.add_nn_signal(action, timestamp, probability) -# except Exception as e: -# logger.error(f"Error adding NN signal to chart: {str(e)}") -# import traceback -# logger.error(traceback.format_exc()) - -# # Sleep for the interval -# time.sleep(interval_seconds) - -# except Exception as e: -# logger.error(f"Error in inference thread: {str(e)}") -# import traceback -# logger.error(traceback.format_exc()) -# time.sleep(5) # Wait a bit before retrying - -# # Create and start the thread -# nn_inference_thread = threading.Thread(target=inference_worker, daemon=True) -# nn_inference_thread.start() - -# return True - -# # Try to get local timezone, default to Sofia/EET if not available -# try: -# local_timezone = tzlocal.get_localzone() -# # Get timezone name safely -# try: -# tz_name = str(local_timezone) -# # Handle case where it might be zoneinfo.ZoneInfo object instead of pytz timezone -# if hasattr(local_timezone, 'zone'): -# tz_name = local_timezone.zone -# elif hasattr(local_timezone, 'key'): -# tz_name = local_timezone.key -# else: -# tz_name = str(local_timezone) -# except: -# tz_name = "Local" -# logger.info(f"Detected local timezone: {local_timezone} ({tz_name})") -# except Exception as e: -# logger.warning(f"Could not detect local timezone: {str(e)}. Defaulting to Sofia/EET") -# local_timezone = pytz.timezone('Europe/Sofia') -# tz_name = "Europe/Sofia" - -# def convert_to_local_time(timestamp): -# """Convert timestamp to local timezone""" -# try: -# if isinstance(timestamp, pd.Timestamp): -# dt = timestamp.to_pydatetime() -# elif isinstance(timestamp, np.datetime64): -# dt = pd.Timestamp(timestamp).to_pydatetime() -# elif isinstance(timestamp, str): -# dt = pd.to_datetime(timestamp).to_pydatetime() -# else: -# dt = timestamp - -# # If datetime is naive (no timezone), assume it's UTC -# if dt.tzinfo is None: -# dt = dt.replace(tzinfo=pytz.UTC) - -# # Convert to local timezone -# local_dt = dt.astimezone(local_timezone) -# return local_dt -# except Exception as e: -# logger.error(f"Error converting timestamp to local time: {str(e)}") -# return timestamp - -# # Initialize TimescaleDB handler - only once, at module level -# timescaledb_handler = TimescaleDBHandler() if TIMESCALEDB_ENABLED else None - -# class TickStorage: -# def __init__(self, symbol, timeframes=None, use_timescaledb=False): -# """Initialize the tick storage for a specific symbol""" -# self.symbol = symbol -# self.timeframes = timeframes or ["1s", "5m", "15m", "1h", "4h", "1d"] -# self.ticks = [] -# self.candles = {tf: [] for tf in self.timeframes} -# self.current_candle = {tf: None for tf in self.timeframes} -# self.last_candle_timestamp = {tf: None for tf in self.timeframes} -# self.cache_dir = os.path.join(os.getcwd(), "cache", symbol.replace("/", "")) -# self.cache_path = os.path.join(self.cache_dir, f"{symbol.replace('/', '')}_ticks.json") # Add missing cache_path -# self.use_timescaledb = use_timescaledb -# self.max_ticks = 10000 # Maximum number of ticks to store in memory - -# # Create cache directory if it doesn't exist -# os.makedirs(self.cache_dir, exist_ok=True) - -# logger.info(f"Creating new tick storage for {symbol} with timeframes {self.timeframes}") -# logger.info(f"Cache directory: {self.cache_dir}") -# logger.info(f"Cache file: {self.cache_path}") - -# if use_timescaledb: -# print(f"TickStorage: TimescaleDB integration is ENABLED for {symbol}") -# else: -# logger.info(f"TickStorage: TimescaleDB integration is DISABLED for {symbol}") - -# def _save_to_cache(self): -# """Save ticks to a cache file""" -# try: -# # Only save the latest 5000 ticks to avoid giant files -# ticks_to_save = self.ticks[-5000:] if len(self.ticks) > 5000 else self.ticks - -# # Convert pandas Timestamps to ISO strings for JSON serialization -# serializable_ticks = [] -# for tick in ticks_to_save: -# serializable_tick = tick.copy() -# if isinstance(tick['timestamp'], pd.Timestamp): -# serializable_tick['timestamp'] = tick['timestamp'].isoformat() -# elif hasattr(tick['timestamp'], 'isoformat'): -# serializable_tick['timestamp'] = tick['timestamp'].isoformat() -# else: -# # Keep as is if it's already a string or number -# serializable_tick['timestamp'] = tick['timestamp'] -# serializable_ticks.append(serializable_tick) - -# with open(self.cache_path, 'w') as f: -# json.dump(serializable_ticks, f) -# logger.debug(f"Saved {len(serializable_ticks)} ticks to cache") -# except Exception as e: -# logger.error(f"Error saving ticks to cache: {e}") - -# def _load_from_cache(self): -# """Load ticks from cache if available""" -# if os.path.exists(self.cache_path): -# try: -# # Check if the cache file is recent (< 10 minutes old) -# cache_age = time.time() - os.path.getmtime(self.cache_path) -# if cache_age > 600: # 10 minutes in seconds -# logger.warning(f"Cache file is {cache_age:.1f} seconds old (>10 min). Not using it.") -# return False - -# with open(self.cache_path, 'r') as f: -# cached_ticks = json.load(f) - -# if cached_ticks: -# # Convert ISO strings back to pandas Timestamps -# processed_ticks = [] -# for tick in cached_ticks: -# processed_tick = tick.copy() -# if isinstance(tick['timestamp'], str): -# try: -# processed_tick['timestamp'] = pd.Timestamp(tick['timestamp']) -# except: -# # If parsing fails, use current time -# processed_tick['timestamp'] = pd.Timestamp.now() -# else: -# # Convert to pandas Timestamp if it's a number (milliseconds) -# processed_tick['timestamp'] = pd.Timestamp(tick['timestamp'], unit='ms') -# processed_ticks.append(processed_tick) - -# self.ticks = processed_ticks -# logger.info(f"Loaded {len(cached_ticks)} ticks from cache") -# return True -# except Exception as e: -# logger.error(f"Error loading ticks from cache: {e}") -# return False - -# def add_tick(self, tick=None, price=None, volume=None, timestamp=None): -# """ -# Add a tick to the storage and update candles for all timeframes - -# Args: -# tick (dict, optional): A tick object containing price, quantity and timestamp -# price (float, optional): Price of the tick (used in older interface) -# volume (float, optional): Volume of the tick (used in older interface) -# timestamp (datetime, optional): Timestamp of the tick (used in older interface) -# """ -# # Handle tick as a dict or separate parameters for backward compatibility -# if tick is not None and isinstance(tick, dict): -# # Using the new interface with a tick object -# price = tick['price'] -# volume = tick.get('quantity', 0) -# timestamp = tick['timestamp'] -# elif price is not None: -# # Using the old interface with separate parameters -# # Convert datetime to pd.Timestamp if needed -# if timestamp is not None and not isinstance(timestamp, pd.Timestamp): -# timestamp = pd.Timestamp(timestamp) -# else: -# logger.error("Invalid tick: must provide either a tick dict or price") -# return - -# # Ensure timestamp is a pandas Timestamp -# if not isinstance(timestamp, pd.Timestamp): -# if isinstance(timestamp, (int, float)): -# # Assume it's milliseconds -# timestamp = pd.Timestamp(timestamp, unit='ms') -# else: -# # Try to parse as string or datetime -# timestamp = pd.Timestamp(timestamp) - -# # Create tick object with consistent pandas Timestamp -# tick_obj = { -# 'price': float(price), -# 'quantity': float(volume) if volume is not None else 0.0, -# 'timestamp': timestamp -# } - -# # Add to the list of ticks -# self.ticks.append(tick_obj) - -# # Limit the number of ticks to avoid memory issues -# if len(self.ticks) > self.max_ticks: -# self.ticks = self.ticks[-self.max_ticks:] - -# # Update candles for all timeframes -# for timeframe in self.timeframes: -# if timeframe == "1s": -# self._update_1s_candle(tick_obj) -# else: -# self._update_candles_for_timeframe(timeframe, tick_obj) - -# # Cache to disk periodically -# self._try_cache_ticks() - -# def _update_1s_candle(self, tick): -# """Update the 1-second candle with the new tick""" -# # Get timestamp for the start of the current second -# tick_timestamp = tick['timestamp'] -# candle_timestamp = pd.Timestamp(int(tick_timestamp.timestamp() // 1 * 1_000_000_000)) - -# # Check if we need to create a new candle -# if self.current_candle["1s"] is None or self.current_candle["1s"]["timestamp"] != candle_timestamp: -# # If we have a current candle, finalize it and add to candles list -# if self.current_candle["1s"] is not None: -# # Add the completed candle to the list -# self.candles["1s"].append(self.current_candle["1s"]) - -# # Limit the number of stored candles to prevent memory issues -# if len(self.candles["1s"]) > 3600: # Keep last hour of 1s candles -# self.candles["1s"] = self.candles["1s"][-3600:] - -# # Store in TimescaleDB if enabled -# if self.use_timescaledb: -# timescaledb_handler.upsert_candle( -# self.symbol, "1s", self.current_candle["1s"] -# ) - -# # Log completed candle for debugging -# logger.debug(f"Completed 1s candle: {self.current_candle['1s']['timestamp']} - Close: {self.current_candle['1s']['close']}") - -# # Create a new candle -# self.current_candle["1s"] = { -# "timestamp": candle_timestamp, -# "open": float(tick["price"]), -# "high": float(tick["price"]), -# "low": float(tick["price"]), -# "close": float(tick["price"]), -# "volume": float(tick["quantity"]) if "quantity" in tick else 0.0 -# } - -# # Update last candle timestamp -# self.last_candle_timestamp["1s"] = candle_timestamp -# logger.debug(f"Created new 1s candle at {candle_timestamp}") -# else: -# # Update the current candle -# current = self.current_candle["1s"] -# price = float(tick["price"]) - -# # Update high and low -# if price > current["high"]: -# current["high"] = price -# if price < current["low"]: -# current["low"] = price - -# # Update close price and add volume -# current["close"] = price -# current["volume"] += float(tick["quantity"]) if "quantity" in tick else 0.0 - -# def _update_candles_for_timeframe(self, timeframe, tick): -# """Update candles for a specific timeframe""" -# # Skip 1s as it's handled separately -# if timeframe == "1s": -# return - -# # Convert timeframe to seconds -# timeframe_seconds = self._timeframe_to_seconds(timeframe) - -# # Get the timestamp truncated to the timeframe interval -# # e.g., for a 5m candle, the timestamp should be truncated to the nearest 5-minute mark -# # Convert timestamp to datetime if it's not already -# tick_timestamp = tick['timestamp'] -# if isinstance(tick_timestamp, pd.Timestamp): -# ts = tick_timestamp -# else: -# ts = pd.Timestamp(tick_timestamp) - -# # Truncate timestamp to nearest timeframe interval -# timestamp = pd.Timestamp( -# int(ts.timestamp() // timeframe_seconds * timeframe_seconds * 1_000_000_000) -# ) - -# # Get the current candle for this timeframe -# current_candle = self.current_candle[timeframe] - -# # If we have no current candle or the timestamp is different (new candle) -# if current_candle is None or current_candle['timestamp'] != timestamp: -# # If we have a current candle, add it to the candles list -# if current_candle: -# self.candles[timeframe].append(current_candle) - -# # Save to TimescaleDB if enabled -# if self.use_timescaledb: -# timescaledb_handler.upsert_candle(self.symbol, timeframe, current_candle) - -# # Create a new candle -# current_candle = { -# 'timestamp': timestamp, -# 'open': tick['price'], -# 'high': tick['price'], -# 'low': tick['price'], -# 'close': tick['price'], -# 'volume': tick.get('quantity', 0) -# } - -# # Update current candle -# self.current_candle[timeframe] = current_candle -# self.last_candle_timestamp[timeframe] = timestamp - -# else: -# # Update existing candle -# current_candle['high'] = max(current_candle['high'], tick['price']) -# current_candle['low'] = min(current_candle['low'], tick['price']) -# current_candle['close'] = tick['price'] -# current_candle['volume'] += tick.get('quantity', 0) - -# # Limit the number of candles to avoid memory issues -# max_candles = 1000 -# if len(self.candles[timeframe]) > max_candles: -# self.candles[timeframe] = self.candles[timeframe][-max_candles:] - -# def _timeframe_to_seconds(self, timeframe): -# """Convert a timeframe string (e.g., '1m', '1h') to seconds""" -# if timeframe == "1s": -# return 1 - -# try: -# # Extract the number and unit -# match = re.match(r'(\d+)([smhdw])', timeframe) -# if not match: -# return None - -# num, unit = match.groups() -# num = int(num) - -# # Convert to seconds -# if unit == 's': -# return num -# elif unit == 'm': -# return num * 60 -# elif unit == 'h': -# return num * 3600 -# elif unit == 'd': -# return num * 86400 -# elif unit == 'w': -# return num * 604800 - -# return None -# except: -# return None - -# def get_candles(self, timeframe, limit=None): -# """Get candles for a given timeframe""" -# if timeframe in self.candles: -# candles = self.candles[timeframe] - -# # Add the current candle if it exists and isn't None -# if timeframe in self.current_candle and self.current_candle[timeframe] is not None: -# # Make a copy of the current candle -# current_candle_copy = self.current_candle[timeframe].copy() - -# # Check if the current candle is newer than the last candle in the list -# if not candles or current_candle_copy["timestamp"] > candles[-1]["timestamp"]: -# candles = candles + [current_candle_copy] - -# # Apply limit if provided -# if limit and len(candles) > limit: -# return candles[-limit:] -# return candles -# return [] - -# def get_last_price(self): -# """Get the last known price""" -# if self.ticks: -# return float(self.ticks[-1]["price"]) -# return None - -# def load_historical_data(self, symbol, limit=1000): -# """Load historical data for all timeframes""" -# logger.info(f"Starting historical data load for {symbol} with limit {limit}") - -# # Clear existing data -# self.ticks = [] -# self.candles = {tf: [] for tf in self.timeframes} -# self.current_candle = {tf: None for tf in self.timeframes} - -# # Try to load ticks from cache first -# logger.info("Attempting to load from cache...") -# cache_loaded = self._load_from_cache() -# if cache_loaded: -# logger.info("Successfully loaded data from cache") -# else: -# logger.info("No valid cache data found") - -# # Check if we have TimescaleDB enabled -# if self.use_timescaledb and timescaledb_handler and timescaledb_handler.enabled: -# logger.info("Attempting to fetch historical data from TimescaleDB") -# loaded_from_db = False - -# # Load candles for each timeframe from TimescaleDB -# for tf in self.timeframes: -# try: -# candles = timescaledb_handler.fetch_candles(symbol, tf, limit) -# if candles: -# self.candles[tf] = candles -# loaded_from_db = True -# logger.info(f"Loaded {len(candles)} {tf} candles from TimescaleDB") -# else: -# logger.info(f"No {tf} candles found in TimescaleDB") -# except Exception as e: -# logger.error(f"Error loading {tf} candles from TimescaleDB: {str(e)}") - -# if loaded_from_db: -# logger.info("Successfully loaded historical data from TimescaleDB") -# return True -# else: -# logger.info("TimescaleDB not available or disabled") - -# # If no TimescaleDB data and no cache, we need to get from Binance API -# if not cache_loaded: -# logger.info("Loading data from Binance API...") -# # Create a BinanceHistoricalData instance -# historical_data = BinanceHistoricalData() - -# # Load data for each timeframe -# success_count = 0 -# for tf in self.timeframes: -# if tf != "1s": # Skip 1s since we'll generate it from ticks -# try: -# logger.info(f"Fetching {tf} candles for {symbol}...") -# df = historical_data.get_historical_candles(symbol, self._timeframe_to_seconds(tf), limit) -# if df is not None and not df.empty: -# logger.info(f"Loaded {len(df)} {tf} candles from Binance API") - -# # Convert to our candle format and store -# candles = [] -# for _, row in df.iterrows(): -# candle = { -# 'timestamp': row['timestamp'], -# 'open': row['open'], -# 'high': row['high'], -# 'low': row['low'], -# 'close': row['close'], -# 'volume': row['volume'] -# } -# candles.append(candle) - -# # Also save to TimescaleDB if enabled -# if self.use_timescaledb and timescaledb_handler and timescaledb_handler.enabled: -# timescaledb_handler.upsert_candle(symbol, tf, candle) - -# self.candles[tf] = candles -# success_count += 1 -# else: -# logger.warning(f"No data returned for {tf} candles") -# except Exception as e: -# logger.error(f"Error loading {tf} candles: {str(e)}") -# import traceback -# logger.error(traceback.format_exc()) - -# logger.info(f"Successfully loaded {success_count} timeframes from Binance API") - -# # For 1s, load from API if possible or compute from first available timeframe -# if "1s" in self.timeframes: -# logger.info("Loading 1s candles...") -# # Try to get 1s data from Binance -# try: -# df_1s = historical_data.get_historical_candles(symbol, 1, 300) # Only need recent 1s data -# if df_1s is not None and not df_1s.empty: -# logger.info(f"Loaded {len(df_1s)} recent 1s candles from Binance API") - -# # Convert to our candle format and store -# candles_1s = [] -# for _, row in df_1s.iterrows(): -# candle = { -# 'timestamp': row['timestamp'], -# 'open': row['open'], -# 'high': row['high'], -# 'low': row['low'], -# 'close': row['close'], -# 'volume': row['volume'] -# } -# candles_1s.append(candle) - -# # Also save to TimescaleDB if enabled -# if self.use_timescaledb and timescaledb_handler and timescaledb_handler.enabled: -# timescaledb_handler.upsert_candle(symbol, "1s", candle) - -# self.candles["1s"] = candles_1s -# except Exception as e: -# logger.error(f"Error loading 1s candles: {str(e)}") - -# # If 1s data not available or failed to load, approximate from 1m data -# if not self.candles.get("1s"): -# logger.info("1s data not available, trying to approximate from 1m data...") -# # If 1s data not available, we can approximate from 1m data -# if "1m" in self.timeframes and self.candles["1m"]: -# # For demonstration, just use the 1m candles as placeholders for 1s -# # In a real implementation, you might want more sophisticated interpolation -# logger.info("Using 1m candles as placeholders for 1s timeframe") -# self.candles["1s"] = [] - -# # Take the most recent 5 minutes of 1m candles -# recent_1m = self.candles["1m"][-5:] if self.candles["1m"] else [] -# logger.info(f"Creating 1s approximations from {len(recent_1m)} 1m candles") -# for candle_1m in recent_1m: -# # Create 60 1s candles for each 1m candle -# ts_base = candle_1m["timestamp"].timestamp() -# for i in range(60): -# # Create a 1s candle with interpolated values -# candle_1s = { -# 'timestamp': pd.Timestamp(int((ts_base + i) * 1_000_000_000)), -# 'open': candle_1m['open'], -# 'high': candle_1m['high'], -# 'low': candle_1m['low'], -# 'close': candle_1m['close'], -# 'volume': candle_1m['volume'] / 60.0 # Distribute volume evenly -# } -# self.candles["1s"].append(candle_1s) - -# # Also save to TimescaleDB if enabled -# if self.use_timescaledb and timescaledb_handler and timescaledb_handler.enabled: -# timescaledb_handler.upsert_candle(symbol, "1s", candle_1s) - -# logger.info(f"Created {len(self.candles['1s'])} approximated 1s candles") -# else: -# logger.warning("No 1m data available to approximate 1s candles from") - -# # Set the last candle of each timeframe as the current candle -# for tf in self.timeframes: -# if self.candles[tf]: -# self.current_candle[tf] = self.candles[tf][-1].copy() -# self.last_candle_timestamp[tf] = self.current_candle[tf]["timestamp"] -# logger.debug(f"Set current candle for {tf}: {self.current_candle[tf]['timestamp']}") - -# # If we loaded ticks from cache, rebuild candles -# if cache_loaded: -# logger.info("Rebuilding candles from cached ticks...") -# # Clear candles -# self.candles = {tf: [] for tf in self.timeframes} -# self.current_candle = {tf: None for tf in self.timeframes} - -# # Process each tick to rebuild the candles -# for tick in self.ticks: -# for tf in self.timeframes: -# if tf == "1s": -# self._update_1s_candle(tick) -# else: -# self._update_candles_for_timeframe(tf, tick) - -# logger.info("Finished rebuilding candles from ticks") - -# # Log final results -# for tf in self.timeframes: -# count = len(self.candles[tf]) -# logger.info(f"Final {tf} candle count: {count}") - -# has_data = cache_loaded or any(self.candles[tf] for tf in self.timeframes) -# logger.info(f"Historical data loading completed. Has data: {has_data}") -# return has_data - -# def _try_cache_ticks(self): -# """Try to save ticks to cache periodically""" -# # Only save to cache every 100 ticks to avoid excessive disk I/O -# if len(self.ticks) % 100 == 0: -# try: -# self._save_to_cache() -# except Exception as e: -# # Don't spam logs with cache errors, just log once every 1000 ticks -# if len(self.ticks) % 1000 == 0: -# logger.warning(f"Cache save failed at {len(self.ticks)} ticks: {str(e)}") -# pass # Continue even if cache fails - -# class Position: -# """Represents a trading position""" - -# def __init__(self, action, entry_price, amount, timestamp=None, trade_id=None, fee_rate=0.0002): -# self.action = action -# self.entry_price = entry_price -# self.amount = amount -# self.entry_timestamp = timestamp or datetime.now() -# self.exit_timestamp = None -# self.exit_price = None -# self.pnl = None -# self.is_open = True -# self.trade_id = trade_id or str(uuid.uuid4())[:8] -# self.fee_rate = fee_rate -# self.paid_fee = entry_price * amount * fee_rate # Calculate entry fee - -# def close(self, exit_price, exit_timestamp=None): -# """Close an open position""" -# self.exit_price = exit_price -# self.exit_timestamp = exit_timestamp or datetime.now() -# self.is_open = False - -# # Calculate P&L -# if self.action == "BUY": -# price_diff = self.exit_price - self.entry_price -# # Calculate fee for exit trade -# exit_fee = exit_price * self.amount * self.fee_rate -# self.paid_fee += exit_fee # Add exit fee to total paid fee -# self.pnl = (price_diff * self.amount) - self.paid_fee -# else: # SELL -# price_diff = self.entry_price - self.exit_price -# # Calculate fee for exit trade -# exit_fee = exit_price * self.amount * self.fee_rate -# self.paid_fee += exit_fee # Add exit fee to total paid fee -# self.pnl = (price_diff * self.amount) - self.paid_fee - -# return self.pnl - -# class RealTimeChart: -# def __init__(self, app=None, symbol='BTCUSDT', timeframe='1m', standalone=True, chart_title=None, -# run_signal_interpreter=False, debug_mode=False, historical_candles=None, -# extended_hours=False, enable_logging=True, agent=None, trading_env=None, -# max_memory_usage=90, memory_check_interval=10, tick_update_interval=0.5, -# chart_update_interval=1, performance_monitoring=False, show_volume=True, -# show_indicators=True, custom_trades=None, port=8050, height=900, width=1200, -# positions_callback=None, allow_synthetic_data=True, tick_storage=None): -# """Initialize a real-time chart with support for multiple indicators and backtesting.""" - -# # Store parameters -# self.symbol = symbol -# self.timeframe = timeframe -# self.debug_mode = debug_mode -# self.standalone = standalone -# self.chart_title = chart_title or f"{symbol} Real-Time Chart" -# self.extended_hours = extended_hours -# self.enable_logging = enable_logging -# self.run_signal_interpreter = run_signal_interpreter -# self.historical_candles = historical_candles -# self.performance_monitoring = performance_monitoring -# self.max_memory_usage = max_memory_usage -# self.memory_check_interval = memory_check_interval -# self.tick_update_interval = tick_update_interval -# self.chart_update_interval = chart_update_interval -# self.show_volume = show_volume -# self.show_indicators = show_indicators -# self.custom_trades = custom_trades -# self.port = port -# self.height = height -# self.width = width -# self.positions_callback = positions_callback -# self.allow_synthetic_data = allow_synthetic_data - -# # Initialize interval store -# self.interval_store = {'interval': 1} # Default to 1s timeframe - -# # Initialize trading components -# self.agent = agent -# self.trading_env = trading_env - -# # Initialize button styles for timeframe selection -# self.button_style = { -# 'background': '#343a40', -# 'color': 'white', -# 'border': 'none', -# 'padding': '10px 20px', -# 'margin': '0 5px', -# 'borderRadius': '4px', -# 'cursor': 'pointer' -# } - -# self.active_button_style = { -# 'background': '#007bff', -# 'color': 'white', -# 'border': 'none', -# 'padding': '10px 20px', -# 'margin': '0 5px', -# 'borderRadius': '4px', -# 'cursor': 'pointer', -# 'fontWeight': 'bold' -# } - -# # Initialize color schemes -# self.colors = { -# 'background': '#1e1e1e', -# 'text': '#ffffff', -# 'grid': '#333333', -# 'candle_up': '#26a69a', -# 'candle_down': '#ef5350', -# 'volume_up': 'rgba(38, 166, 154, 0.3)', -# 'volume_down': 'rgba(239, 83, 80, 0.3)', -# 'ma': '#ffeb3b', -# 'ema': '#29b6f6', -# 'bollinger_bands': '#ff9800', -# 'trades_buy': '#00e676', -# 'trades_sell': '#ff1744' -# } - -# # Initialize data storage -# self.all_trades = [] # Store trades -# self.positions = [] # Store open positions -# self.latest_price = 0.0 -# self.latest_volume = 0.0 -# self.latest_timestamp = datetime.now() -# self.current_balance = 100.0 # Starting balance -# self.accumulative_pnl = 0.0 # Accumulated profit/loss - -# # Initialize trade rate counter -# self.trade_count = 0 -# self.start_time = time.time() -# self.trades_per_second = 0 -# self.trades_per_minute = 0 -# self.trades_per_hour = 0 - -# # Initialize trade rate tracking variables -# self.trade_times = [] # Store timestamps of recent trades for rate calculation -# self.last_trade_rate_calculation = datetime.now() -# self.trade_rate = {"per_second": 0, "per_minute": 0, "per_hour": 0} - -# # Initialize interactive components -# self.app = app - -# # Create a new app if not provided -# if self.app is None and standalone: -# self.app = dash.Dash( -# __name__, -# external_stylesheets=[dbc.themes.DARKLY], -# suppress_callback_exceptions=True -# ) - -# # Initialize tick storage if not provided -# if tick_storage is None: -# # Check if TimescaleDB integration is enabled -# use_timescaledb = TIMESCALEDB_ENABLED and timescaledb_handler is not None - -# # Create a new tick storage -# self.tick_storage = TickStorage( -# symbol=symbol, -# timeframes=["1s", "1m", "5m", "15m", "1h", "4h", "1d"], -# use_timescaledb=use_timescaledb -# ) - -# # Load historical data immediately for cold start -# logger.info(f"Loading historical data for {symbol} during chart initialization") -# try: -# data_loaded = self.tick_storage.load_historical_data(symbol) -# if data_loaded: -# logger.info(f"Successfully loaded historical data for {symbol}") -# # Log what we have -# for tf in ["1s", "1m", "5m", "15m", "1h"]: -# candle_count = len(self.tick_storage.candles.get(tf, [])) -# logger.info(f" {tf}: {candle_count} candles") -# else: -# logger.warning(f"Failed to load historical data for {symbol}") -# except Exception as e: -# logger.error(f"Error loading historical data during initialization: {str(e)}") -# import traceback -# logger.error(traceback.format_exc()) -# else: -# self.tick_storage = tick_storage - -# # Create layout and callbacks if app is provided -# if self.app is not None: -# # Create the layout -# self.app.layout = self._create_layout() - -# # Register callbacks -# self._setup_callbacks() - -# # Log initialization -# if self.enable_logging: -# logger.info(f"RealTimeChart initialized: {self.symbol} ({self.timeframe}) ") - -# def _create_layout(self): -# return html.Div([ -# # Header section with title and current price -# html.Div([ -# html.H1(f"{self.symbol} Real-Time Chart", className="display-4"), - -# # Current price ticker -# html.Div([ -# html.H4("Current Price:", style={"display": "inline-block", "marginRight": "10px"}), -# html.H3(id="current-price", style={"display": "inline-block", "color": "#17a2b8"}), -# html.Div([ -# html.H5("Balance:", style={"display": "inline-block", "marginRight": "10px", "marginLeft": "30px"}), -# html.H5(id="current-balance", style={"display": "inline-block", "color": "#28a745"}), -# ], style={"display": "inline-block", "marginLeft": "40px"}), -# html.Div([ -# html.H5("Accumulated PnL:", style={"display": "inline-block", "marginRight": "10px", "marginLeft": "30px"}), -# html.H5(id="accumulated-pnl", style={"display": "inline-block", "color": "#ffc107"}), -# ], style={"display": "inline-block", "marginLeft": "40px"}), - -# # Add trade rate display -# html.Div([ -# html.H5("Trade Rate:", style={"display": "inline-block", "marginRight": "10px", "marginLeft": "30px"}), -# html.Span([ -# html.Span(id="trade-rate-second", style={"color": "#ff7f0e"}), -# html.Span("/s, "), -# html.Span(id="trade-rate-minute", style={"color": "#ff7f0e"}), -# html.Span("/m, "), -# html.Span(id="trade-rate-hour", style={"color": "#ff7f0e"}), -# html.Span("/h") -# ], style={"display": "inline-block"}), -# ], style={"display": "inline-block", "marginLeft": "40px"}), -# ], style={"textAlign": "center", "margin": "20px 0"}), -# ], style={"textAlign": "center", "marginBottom": "20px"}), - -# # Add interval component for periodic updates -# dcc.Interval( -# id='interval-component', -# interval=500, # in milliseconds -# n_intervals=0 -# ), - -# # Add timeframe selection buttons -# html.Div([ -# html.Button('1s', id='btn-1s', n_clicks=0, style=self.active_button_style), -# html.Button('5s', id='btn-5s', n_clicks=0, style=self.button_style), -# html.Button('15s', id='btn-15s', n_clicks=0, style=self.button_style), -# html.Button('1m', id='btn-1m', n_clicks=0, style=self.button_style), -# html.Button('5m', id='btn-5m', n_clicks=0, style=self.button_style), -# html.Button('15m', id='btn-15m', n_clicks=0, style=self.button_style), -# html.Button('1h', id='btn-1h', n_clicks=0, style=self.button_style), -# ], style={"textAlign": "center", "marginBottom": "20px"}), - -# # Store for the selected timeframe -# dcc.Store(id='interval-store', data={'interval': 1}), - -# # Chart content (without wrapper div to avoid callback issues) -# dcc.Graph(id='live-chart', style={"height": "600px"}), -# dcc.Graph(id='secondary-charts', style={"height": "500px"}), -# html.Div(id='positions-list') -# ]) - -# def _create_chart_and_controls(self): -# """Create the chart and controls for the dashboard.""" -# try: -# # Get selected interval from the dashboard (default to 1s if not available) -# interval_seconds = 1 -# if hasattr(self, 'interval_store') and self.interval_store: -# interval_seconds = self.interval_store.get('interval', 1) - -# # Create chart components -# chart_div = html.Div([ -# # Update chart with data for the selected interval -# dcc.Graph( -# id='live-chart', -# figure=self._update_main_chart(interval_seconds), -# style={"height": "600px"} -# ), - -# # Update secondary charts -# dcc.Graph( -# id='secondary-charts', -# figure=self._update_secondary_charts(), -# style={"height": "500px"} -# ), - -# # Update positions list -# html.Div( -# id='positions-list', -# children=self._get_position_list_rows() -# ) -# ]) - -# return chart_div - -# except Exception as e: -# logger.error(f"Error creating chart and controls: {str(e)}") -# import traceback -# logger.error(traceback.format_exc()) -# # Return a simple error message as fallback -# return html.Div(f"Error loading chart: {str(e)}", style={"color": "red", "padding": "20px"}) - -# def _setup_callbacks(self): -# """Setup Dash callbacks for the real-time chart""" -# if self.app is None: -# return - -# try: -# # Update chart with all components based on interval -# @self.app.callback( -# [ -# Output('live-chart', 'figure'), -# Output('secondary-charts', 'figure'), -# Output('positions-list', 'children'), -# Output('current-price', 'children'), -# Output('current-balance', 'children'), -# Output('accumulated-pnl', 'children'), -# Output('trade-rate-second', 'children'), -# Output('trade-rate-minute', 'children'), -# Output('trade-rate-hour', 'children') -# ], -# [ -# Input('interval-component', 'n_intervals'), -# Input('interval-store', 'data') -# ] -# ) -# def update_all(n_intervals, interval_data): -# """Update all chart components""" -# try: -# # Get selected interval -# interval_seconds = interval_data.get('interval', 1) if interval_data else 1 - -# # Update main chart - limit data for performance -# main_chart = self._update_main_chart(interval_seconds) - -# # Update secondary charts - limit data for performance -# secondary_charts = self._update_secondary_charts() - -# # Update positions list -# positions_list = self._get_position_list_rows() - -# # Update current price and balance -# current_price = f"${self.latest_price:.2f}" if self.latest_price else "Error" -# current_balance = f"${self.current_balance:.2f}" -# accumulated_pnl = f"${self.accumulative_pnl:.2f}" - -# # Calculate trade rates -# trade_rate = self._calculate_trade_rate() -# trade_rate_second = f"{trade_rate['per_second']:.1f}" -# trade_rate_minute = f"{trade_rate['per_minute']:.1f}" -# trade_rate_hour = f"{trade_rate['per_hour']:.1f}" - -# return (main_chart, secondary_charts, positions_list, -# current_price, current_balance, accumulated_pnl, -# trade_rate_second, trade_rate_minute, trade_rate_hour) - -# except Exception as e: -# logger.error(f"Error in update_all callback: {str(e)}") -# # Return empty/error states -# import plotly.graph_objects as go -# empty_fig = go.Figure() -# empty_fig.add_annotation(text="Chart Loading...", xref="paper", yref="paper", x=0.5, y=0.5) - -# return (empty_fig, empty_fig, [], "Loading...", "$0.00", "$0.00", "0.0", "0.0", "0.0") - -# # Timeframe selection callbacks -# @self.app.callback( -# [Output('interval-store', 'data'), -# Output('btn-1s', 'style'), Output('btn-5s', 'style'), Output('btn-15s', 'style'), -# Output('btn-1m', 'style'), Output('btn-5m', 'style'), Output('btn-15m', 'style'), -# Output('btn-1h', 'style')], -# [Input('btn-1s', 'n_clicks'), Input('btn-5s', 'n_clicks'), Input('btn-15s', 'n_clicks'), -# Input('btn-1m', 'n_clicks'), Input('btn-5m', 'n_clicks'), Input('btn-15m', 'n_clicks'), -# Input('btn-1h', 'n_clicks')] -# ) -# def update_timeframe(n1s, n5s, n15s, n1m, n5m, n15m, n1h): -# """Update selected timeframe based on button clicks""" -# ctx = dash.callback_context -# if not ctx.triggered: -# # Default to 1s -# styles = [self.active_button_style] + [self.button_style] * 6 -# return {'interval': 1}, *styles - -# button_id = ctx.triggered[0]['prop_id'].split('.')[0] - -# # Map button to interval seconds -# interval_map = { -# 'btn-1s': 1, 'btn-5s': 5, 'btn-15s': 15, -# 'btn-1m': 60, 'btn-5m': 300, 'btn-15m': 900, 'btn-1h': 3600 -# } - -# selected_interval = interval_map.get(button_id, 1) - -# # Create styles - active for selected, normal for others -# button_names = ['btn-1s', 'btn-5s', 'btn-15s', 'btn-1m', 'btn-5m', 'btn-15m', 'btn-1h'] -# styles = [] -# for name in button_names: -# if name == button_id: -# styles.append(self.active_button_style) -# else: -# styles.append(self.button_style) - -# return {'interval': selected_interval}, *styles - -# logger.info("Dash callbacks registered successfully") - -# except Exception as e: -# logger.error(f"Error setting up callbacks: {str(e)}") -# import traceback -# logger.error(traceback.format_exc()) - -# def _calculate_trade_rate(self): -# """Calculate trading rate per second, minute, and hour""" -# try: -# now = datetime.now() -# current_time = time.time() - -# # Filter trades within different time windows -# trades_last_second = sum(1 for trade_time in self.trade_times if current_time - trade_time <= 1) -# trades_last_minute = sum(1 for trade_time in self.trade_times if current_time - trade_time <= 60) -# trades_last_hour = sum(1 for trade_time in self.trade_times if current_time - trade_time <= 3600) - -# return { -# "per_second": trades_last_second, -# "per_minute": trades_last_minute, -# "per_hour": trades_last_hour -# } -# except Exception as e: -# logger.warning(f"Error calculating trade rate: {str(e)}") -# return {"per_second": 0.0, "per_minute": 0.0, "per_hour": 0.0} - -# def _update_secondary_charts(self): -# """Create secondary charts for volume and indicators""" -# try: -# # Create subplots for secondary charts -# fig = make_subplots( -# rows=2, cols=1, -# subplot_titles=['Volume', 'Technical Indicators'], -# shared_xaxes=True, -# vertical_spacing=0.1, -# row_heights=[0.3, 0.7] -# ) - -# # Get latest candles (limit for performance) -# candles = self.tick_storage.candles.get("1m", [])[-100:] # Last 100 candles for performance - -# if not candles: -# fig.add_annotation(text="No data available", xref="paper", yref="paper", x=0.5, y=0.5) -# fig.update_layout( -# title="Secondary Charts", -# template="plotly_dark", -# height=400 -# ) -# return fig - -# # Extract data -# timestamps = [candle['timestamp'] for candle in candles] -# volumes = [candle['volume'] for candle in candles] -# closes = [candle['close'] for candle in candles] - -# # Volume chart -# colors = ['#26a69a' if i == 0 or closes[i] >= closes[i-1] else '#ef5350' for i in range(len(closes))] -# fig.add_trace( -# go.Bar( -# x=timestamps, -# y=volumes, -# name='Volume', -# marker_color=colors, -# showlegend=False -# ), -# row=1, col=1 -# ) - -# # Technical indicators -# if len(closes) >= 20: -# # Simple moving average -# sma_20 = pd.Series(closes).rolling(window=20).mean() -# fig.add_trace( -# go.Scatter( -# x=timestamps, -# y=sma_20, -# name='SMA 20', -# line=dict(color='#ffeb3b', width=2) -# ), -# row=2, col=1 -# ) - -# # RSI calculation -# if len(closes) >= 14: -# rsi = self._calculate_rsi(closes, 14) -# fig.add_trace( -# go.Scatter( -# x=timestamps, -# y=rsi, -# name='RSI', -# line=dict(color='#29b6f6', width=2), -# yaxis='y3' -# ), -# row=2, col=1 -# ) - -# # Update layout -# fig.update_layout( -# title="Volume & Technical Indicators", -# template="plotly_dark", -# height=400, -# showlegend=True, -# legend=dict(x=0, y=1, bgcolor='rgba(0,0,0,0)') -# ) - -# # Update y-axes -# fig.update_yaxes(title="Volume", row=1, col=1) -# fig.update_yaxes(title="Price", row=2, col=1) - -# return fig - -# except Exception as e: -# logger.error(f"Error creating secondary charts: {str(e)}") -# # Return empty figure on error -# fig = go.Figure() -# fig.add_annotation(text=f"Error: {str(e)}", xref="paper", yref="paper", x=0.5, y=0.5) -# fig.update_layout(template="plotly_dark", height=400) -# return fig - -# def _calculate_rsi(self, prices, period=14): -# """Calculate RSI indicator""" -# try: -# prices = pd.Series(prices) -# delta = prices.diff() -# gain = (delta.where(delta > 0, 0)).rolling(window=period).mean() -# loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() -# rs = gain / loss -# rsi = 100 - (100 / (1 + rs)) -# return rsi.fillna(50).tolist() # Fill NaN with neutral RSI value -# except Exception: -# return [50] * len(prices) # Return neutral RSI on error - -# def _get_position_list_rows(self): -# """Get list of current positions for display""" -# try: -# if not self.positions: -# return [html.Div("No open positions", style={"color": "#888", "padding": "10px"})] - -# rows = [] -# for i, position in enumerate(self.positions): -# try: -# # Calculate current PnL -# current_pnl = (self.latest_price - position.entry_price) * position.amount -# if position.action.upper() == 'SELL': -# current_pnl = -current_pnl - -# # Create position row -# row = html.Div([ -# html.Span(f"#{i+1}: ", style={"fontWeight": "bold"}), -# html.Span(f"{position.action.upper()} ", -# style={"color": "#00e676" if position.action.upper() == "BUY" else "#ff1744"}), -# html.Span(f"{position.amount:.4f} @ ${position.entry_price:.2f} "), -# html.Span(f"PnL: ${current_pnl:.2f}", -# style={"color": "#00e676" if current_pnl >= 0 else "#ff1744"}) -# ], style={"padding": "5px", "borderBottom": "1px solid #333"}) - -# rows.append(row) -# except Exception as e: -# logger.warning(f"Error formatting position {i}: {str(e)}") - -# return rows - -# except Exception as e: -# logger.error(f"Error getting position list: {str(e)}") -# return [html.Div("Error loading positions", style={"color": "red", "padding": "10px"})] - -# def add_trade(self, action, price, amount, timestamp=None, trade_id=None): -# """Add a trade to the chart and update tracking""" -# try: -# if timestamp is None: -# timestamp = datetime.now() - -# # Create trade record -# trade = { -# 'id': trade_id or str(uuid.uuid4()), -# 'action': action.upper(), -# 'price': float(price), -# 'amount': float(amount), -# 'timestamp': timestamp, -# 'value': float(price) * float(amount) -# } - -# # Add to trades list -# self.all_trades.append(trade) - -# # Update trade rate tracking -# self.trade_times.append(time.time()) -# # Keep only last hour of trade times -# cutoff_time = time.time() - 3600 -# self.trade_times = [t for t in self.trade_times if t > cutoff_time] - -# # Update positions -# if action.upper() in ['BUY', 'SELL']: -# position = Position( -# action=action.upper(), -# entry_price=float(price), -# amount=float(amount), -# timestamp=timestamp, -# trade_id=trade['id'] -# ) -# self.positions.append(position) - -# # Update balance and PnL -# if action.upper() == 'BUY': -# self.current_balance -= trade['value'] -# else: # SELL -# self.current_balance += trade['value'] - -# # Calculate PnL for this trade -# if len(self.all_trades) > 1: -# # Simple PnL calculation - more sophisticated logic could be added -# last_opposite_trades = [t for t in reversed(self.all_trades[:-1]) -# if t['action'] != action.upper()] -# if last_opposite_trades: -# last_trade = last_opposite_trades[0] -# if action.upper() == 'SELL': -# pnl = (float(price) - last_trade['price']) * float(amount) -# else: # BUY -# pnl = (last_trade['price'] - float(price)) * float(amount) -# self.accumulative_pnl += pnl - -# logger.info(f"Added trade: {action.upper()} {amount} @ ${price:.2f}") - -# except Exception as e: -# logger.error(f"Error adding trade: {str(e)}") - -# def _get_interval_key(self, interval_seconds): -# """Convert interval seconds to timeframe key""" -# if interval_seconds <= 1: -# return "1s" -# elif interval_seconds <= 5: -# return "5s" if "5s" in self.tick_storage.timeframes else "1s" -# elif interval_seconds <= 15: -# return "15s" if "15s" in self.tick_storage.timeframes else "1m" -# elif interval_seconds <= 60: -# return "1m" -# elif interval_seconds <= 300: -# return "5m" -# elif interval_seconds <= 900: -# return "15m" -# elif interval_seconds <= 3600: -# return "1h" -# elif interval_seconds <= 14400: -# return "4h" -# else: -# return "1d" - -# def _update_main_chart(self, interval_seconds): -# """Update the main chart for the specified interval""" -# try: -# # Convert interval seconds to timeframe key -# interval_key = self._get_interval_key(interval_seconds) - -# # Get candles for this timeframe (limit to last 100 for performance) -# candles = self.tick_storage.candles.get(interval_key, [])[-100:] - -# if not candles: -# logger.warning(f"No candle data available for {interval_key}") -# # Return empty figure with a message -# fig = go.Figure() -# fig.add_annotation( -# text=f"No data available for {interval_key}", -# xref="paper", yref="paper", -# x=0.5, y=0.5, -# showarrow=False, -# font=dict(size=16, color="white") -# ) -# fig.update_layout( -# title=f"{self.symbol} - {interval_key} Chart", -# template="plotly_dark", -# height=600 -# ) -# return fig - -# # Extract data from candles -# timestamps = [candle['timestamp'] for candle in candles] -# opens = [candle['open'] for candle in candles] -# highs = [candle['high'] for candle in candles] -# lows = [candle['low'] for candle in candles] -# closes = [candle['close'] for candle in candles] -# volumes = [candle['volume'] for candle in candles] - -# # Create candlestick chart -# fig = go.Figure() - -# # Add candlestick trace -# fig.add_trace(go.Candlestick( -# x=timestamps, -# open=opens, -# high=highs, -# low=lows, -# close=closes, -# name="Price", -# increasing_line_color='#26a69a', -# decreasing_line_color='#ef5350', -# increasing_fillcolor='#26a69a', -# decreasing_fillcolor='#ef5350' -# )) - -# # Add trade markers if we have trades -# if self.all_trades: -# # Filter trades to match the current timeframe window -# start_time = timestamps[0] if timestamps else datetime.now() - timedelta(hours=1) -# end_time = timestamps[-1] if timestamps else datetime.now() - -# filtered_trades = [ -# trade for trade in self.all_trades -# if start_time <= trade['timestamp'] <= end_time -# ] - -# if filtered_trades: -# buy_trades = [t for t in filtered_trades if t['action'] == 'BUY'] -# sell_trades = [t for t in filtered_trades if t['action'] == 'SELL'] - -# # Add BUY markers -# if buy_trades: -# fig.add_trace(go.Scatter( -# x=[t['timestamp'] for t in buy_trades], -# y=[t['price'] for t in buy_trades], -# mode='markers', -# marker=dict( -# symbol='triangle-up', -# size=12, -# color='#00e676', -# line=dict(color='white', width=1) -# ), -# name='BUY', -# text=[f"BUY {t['amount']:.4f} @ ${t['price']:.2f}" for t in buy_trades], -# hovertemplate='%{text}
Time: %{x}' -# )) - -# # Add SELL markers -# if sell_trades: -# fig.add_trace(go.Scatter( -# x=[t['timestamp'] for t in sell_trades], -# y=[t['price'] for t in sell_trades], -# mode='markers', -# marker=dict( -# symbol='triangle-down', -# size=12, -# color='#ff1744', -# line=dict(color='white', width=1) -# ), -# name='SELL', -# text=[f"SELL {t['amount']:.4f} @ ${t['price']:.2f}" for t in sell_trades], -# hovertemplate='%{text}
Time: %{x}' -# )) - -# # Add moving averages if we have enough data -# if len(closes) >= 20: -# # 20-period SMA -# sma_20 = pd.Series(closes).rolling(window=20).mean() -# fig.add_trace(go.Scatter( -# x=timestamps, -# y=sma_20, -# name='SMA 20', -# line=dict(color='#ffeb3b', width=1), -# opacity=0.7 -# )) - -# if len(closes) >= 50: -# # 50-period SMA -# sma_50 = pd.Series(closes).rolling(window=50).mean() -# fig.add_trace(go.Scatter( -# x=timestamps, -# y=sma_50, -# name='SMA 50', -# line=dict(color='#ff9800', width=1), -# opacity=0.7 -# )) - -# # Update layout -# fig.update_layout( -# title=f"{self.symbol} - {interval_key} Chart ({len(candles)} candles)", -# template="plotly_dark", -# height=600, -# xaxis_title="Time", -# yaxis_title="Price ($)", -# legend=dict( -# yanchor="top", -# y=0.99, -# xanchor="left", -# x=0.01, -# bgcolor="rgba(0,0,0,0.5)" -# ), -# hovermode='x unified', -# dragmode='pan' -# ) - -# # Remove range slider for better performance -# fig.update_layout(xaxis_rangeslider_visible=False) - -# # Update the latest price -# if closes: -# self.latest_price = closes[-1] -# self.latest_timestamp = timestamps[-1] - -# return fig - -# except Exception as e: -# logger.error(f"Error updating main chart: {str(e)}") -# import traceback -# logger.error(traceback.format_exc()) - -# # Return error figure -# fig = go.Figure() -# fig.add_annotation( -# text=f"Chart Error: {str(e)}", -# xref="paper", yref="paper", -# x=0.5, y=0.5, -# showarrow=False, -# font=dict(size=16, color="red") -# ) -# fig.update_layout( -# title="Chart Error", -# template="plotly_dark", -# height=600 -# ) -# return fig - -# def set_trading_env(self, trading_env): -# """Set the trading environment to monitor for new trades""" -# self.trading_env = trading_env -# if hasattr(trading_env, 'add_trade_callback'): -# trading_env.add_trade_callback(self.add_trade) -# logger.info("Trading environment integrated with chart") - -# def set_agent(self, agent): -# """Set the agent to monitor for trading decisions""" -# self.agent = agent -# logger.info("Agent integrated with chart") - -# def update_from_env(self, env_data): -# """Update chart data from trading environment""" -# try: -# if 'latest_price' in env_data: -# self.latest_price = env_data['latest_price'] - -# if 'balance' in env_data: -# self.current_balance = env_data['balance'] - -# if 'pnl' in env_data: -# self.accumulative_pnl = env_data['pnl'] - -# if 'trades' in env_data: -# # Add any new trades -# for trade in env_data['trades']: -# if trade not in self.all_trades: -# self.add_trade( -# action=trade.get('action', 'HOLD'), -# price=trade.get('price', self.latest_price), -# amount=trade.get('amount', 0.1), -# timestamp=trade.get('timestamp', datetime.now()), -# trade_id=trade.get('id') -# ) -# except Exception as e: -# logger.error(f"Error updating from environment: {str(e)}") - -# def get_latest_data(self): -# """Get the latest data for external systems""" -# return { -# 'latest_price': self.latest_price, -# 'latest_volume': self.latest_volume, -# 'latest_timestamp': self.latest_timestamp, -# 'current_balance': self.current_balance, -# 'accumulative_pnl': self.accumulative_pnl, -# 'positions': len(self.positions), -# 'trade_count': len(self.all_trades), -# 'trade_rate': self._calculate_trade_rate() -# } - -# async def start_websocket(self): -# """Start the websocket connection for real-time data""" -# try: -# logger.info("Starting websocket connection for real-time data") - -# # Start the websocket data fetching -# websocket_url = "wss://stream.binance.com:9443/ws/ethusdt@ticker" - -# async def websocket_handler(): -# """Handle websocket connection and data updates""" -# try: -# async with websockets.connect(websocket_url) as websocket: -# logger.info(f"WebSocket connected for {self.symbol}") -# message_count = 0 - -# async for message in websocket: -# try: -# data = json.loads(message) - -# # Update tick storage with new price data -# tick = { -# 'price': float(data['c']), # Current price -# 'volume': float(data['v']), # Volume -# 'timestamp': pd.Timestamp.now() -# } - -# self.tick_storage.add_tick(tick) - -# # Update chart's latest price and volume -# self.latest_price = float(data['c']) -# self.latest_volume = float(data['v']) -# self.latest_timestamp = pd.Timestamp.now() - -# message_count += 1 - -# # Log periodic updates -# if message_count % 100 == 0: -# logger.info(f"Received message #{message_count}") -# logger.info(f"Processed {message_count} ticks, current price: ${self.latest_price:.2f}") - -# # Log candle counts -# candle_count = len(self.tick_storage.candles.get("1s", [])) -# logger.info(f"Current 1s candles count: {candle_count}") - -# except json.JSONDecodeError as e: -# logger.warning(f"Failed to parse websocket message: {str(e)}") -# except Exception as e: -# logger.error(f"Error processing websocket message: {str(e)}") - -# except websockets.exceptions.ConnectionClosed: -# logger.warning("WebSocket connection closed") -# except Exception as e: -# logger.error(f"WebSocket error: {str(e)}") - -# # Start the websocket handler in the background -# await websocket_handler() - -# except Exception as e: -# logger.error(f"Error starting websocket: {str(e)}") -# import traceback -# logger.error(traceback.format_exc()) - -# def run(self, host='127.0.0.1', port=8050, debug=False): -# """Run the Dash app""" -# try: -# if self.app is None: -# logger.error("No Dash app instance available") -# return - -# logger.info("="*60) -# logger.info("🔗 ACCESS WEB UI AT: http://localhost:8050/") -# logger.info("📊 View live trading data and charts in your browser") -# logger.info("="*60) - -# # Run the app - FIXED: Updated for newer Dash versions -# self.app.run( -# host=host, -# port=port, -# debug=debug, -# use_reloader=False, # Disable reloader to avoid conflicts -# threaded=True # Enable threading for better performance -# ) -# except Exception as e: -# logger.error(f"Error running Dash app: {str(e)}") -# import traceback -# logger.error(traceback.format_exc()) diff --git a/debug/test_fixed_issues.py b/debug/test_fixed_issues.py deleted file mode 100644 index e4bc8f6..0000000 --- a/debug/test_fixed_issues.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify that both model prediction and trading statistics issues are fixed -""" - -import sys -import os -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider -from core.trading_executor import TradingExecutor -import asyncio -import logging - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -async def test_model_predictions(): - """Test that model predictions are working correctly""" - - logger.info("=" * 60) - logger.info("TESTING MODEL PREDICTIONS") - logger.info("=" * 60) - - # Initialize components - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - - # Check model registration - logger.info("1. Checking model registration...") - models = orchestrator.model_registry.get_all_models() - logger.info(f" Registered models: {list(models.keys()) if models else 'None'}") - - # Test making a decision - logger.info("2. Testing trading decision generation...") - decision = await orchestrator.make_trading_decision('ETH/USDT') - - if decision: - logger.info(f" ✅ Decision generated: {decision.action} (confidence: {decision.confidence:.3f})") - logger.info(f" ✅ Reasoning: {decision.reasoning}") - return True - else: - logger.error(" ❌ No decision generated") - return False - -def test_trading_statistics(): - """Test that trading statistics calculations are working correctly""" - - logger.info("=" * 60) - logger.info("TESTING TRADING STATISTICS") - logger.info("=" * 60) - - # Initialize trading executor - trading_executor = TradingExecutor() - - # Check if we have any trades - trade_history = trading_executor.get_trade_history() - logger.info(f"1. Current trade history: {len(trade_history)} trades") - - # Get daily stats - daily_stats = trading_executor.get_daily_stats() - logger.info("2. Daily statistics from trading executor:") - logger.info(f" Total trades: {daily_stats.get('total_trades', 0)}") - logger.info(f" Winning trades: {daily_stats.get('winning_trades', 0)}") - logger.info(f" Losing trades: {daily_stats.get('losing_trades', 0)}") - logger.info(f" Win rate: {daily_stats.get('win_rate', 0.0) * 100:.1f}%") - logger.info(f" Avg winning trade: ${daily_stats.get('avg_winning_trade', 0.0):.2f}") - logger.info(f" Avg losing trade: ${daily_stats.get('avg_losing_trade', 0.0):.2f}") - logger.info(f" Total P&L: ${daily_stats.get('total_pnl', 0.0):.2f}") - - # Simulate some trades if we don't have any - if daily_stats.get('total_trades', 0) == 0: - logger.info("3. No trades found - simulating some test trades...") - - # Add some mock trades to the trade history - from core.trading_executor import TradeRecord - from datetime import datetime - - # Add a winning trade - winning_trade = TradeRecord( - symbol='ETH/USDT', - side='LONG', - quantity=0.01, - entry_price=2500.0, - exit_price=2550.0, - entry_time=datetime.now(), - exit_time=datetime.now(), - pnl=0.50, # $0.50 profit - fees=0.01, - confidence=0.8 - ) - trading_executor.trade_history.append(winning_trade) - - # Add a losing trade - losing_trade = TradeRecord( - symbol='ETH/USDT', - side='LONG', - quantity=0.01, - entry_price=2500.0, - exit_price=2480.0, - entry_time=datetime.now(), - exit_time=datetime.now(), - pnl=-0.20, # $0.20 loss - fees=0.01, - confidence=0.7 - ) - trading_executor.trade_history.append(losing_trade) - - # Get updated stats - daily_stats = trading_executor.get_daily_stats() - logger.info(" Updated statistics after adding test trades:") - logger.info(f" Total trades: {daily_stats.get('total_trades', 0)}") - logger.info(f" Winning trades: {daily_stats.get('winning_trades', 0)}") - logger.info(f" Losing trades: {daily_stats.get('losing_trades', 0)}") - logger.info(f" Win rate: {daily_stats.get('win_rate', 0.0) * 100:.1f}%") - logger.info(f" Avg winning trade: ${daily_stats.get('avg_winning_trade', 0.0):.2f}") - logger.info(f" Avg losing trade: ${daily_stats.get('avg_losing_trade', 0.0):.2f}") - logger.info(f" Total P&L: ${daily_stats.get('total_pnl', 0.0):.2f}") - - # Verify calculations - expected_win_rate = 1/2 # 1 win out of 2 trades = 50% - expected_avg_win = 0.50 - expected_avg_loss = -0.20 - - actual_win_rate = daily_stats.get('win_rate', 0.0) - actual_avg_win = daily_stats.get('avg_winning_trade', 0.0) - actual_avg_loss = daily_stats.get('avg_losing_trade', 0.0) - - logger.info("4. Verifying calculations:") - logger.info(f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {actual_win_rate*100:.1f}% ✅" if abs(actual_win_rate - expected_win_rate) < 0.01 else f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {actual_win_rate*100:.1f}% ❌") - logger.info(f" Avg win: Expected ${expected_avg_win:.2f}, Got ${actual_avg_win:.2f} ✅" if abs(actual_avg_win - expected_avg_win) < 0.01 else f" Avg win: Expected ${expected_avg_win:.2f}, Got ${actual_avg_win:.2f} ❌") - logger.info(f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${actual_avg_loss:.2f} ✅" if abs(actual_avg_loss - expected_avg_loss) < 0.01 else f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${actual_avg_loss:.2f} ❌") - - return True - - return True - -async def main(): - """Run all tests""" - - logger.info("🚀 STARTING COMPREHENSIVE FIXES TEST") - logger.info("Testing both model prediction fixes and trading statistics fixes") - - # Test model predictions - prediction_success = await test_model_predictions() - - # Test trading statistics - stats_success = test_trading_statistics() - - logger.info("=" * 60) - logger.info("TEST SUMMARY") - logger.info("=" * 60) - logger.info(f"Model Predictions: {'✅ FIXED' if prediction_success else '❌ STILL BROKEN'}") - logger.info(f"Trading Statistics: {'✅ FIXED' if stats_success else '❌ STILL BROKEN'}") - - if prediction_success and stats_success: - logger.info("🎉 ALL ISSUES FIXED! The system should now work correctly.") - else: - logger.error("❌ Some issues remain. Check the logs above for details.") - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/debug/test_trading_fixes.py b/debug/test_trading_fixes.py deleted file mode 100644 index 79eca49..0000000 --- a/debug/test_trading_fixes.py +++ /dev/null @@ -1,250 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify trading fixes: -1. Position sizes with leverage -2. ETH-only trading -3. Correct win rate calculations -4. Meaningful P&L values -""" - -import sys -import os -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - -from core.trading_executor import TradingExecutor -from core.trading_executor import TradeRecord -from datetime import datetime -import logging - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_position_sizing(): - """Test that position sizing now includes leverage and meaningful amounts""" - - logger.info("=" * 60) - logger.info("TESTING POSITION SIZING WITH LEVERAGE") - logger.info("=" * 60) - - # Initialize trading executor - trading_executor = TradingExecutor() - - # Test position calculation - confidence = 0.8 - current_price = 2500.0 # ETH price - - position_value = trading_executor._calculate_position_size(confidence, current_price) - quantity = position_value / current_price - - logger.info(f"1. Position calculation test:") - logger.info(f" Confidence: {confidence}") - logger.info(f" ETH Price: ${current_price}") - logger.info(f" Position Value: ${position_value:.2f}") - logger.info(f" Quantity: {quantity:.6f} ETH") - - # Check if position is meaningful - if position_value > 1000: # Should be >$1000 with 10x leverage - logger.info(" ✅ Position size is meaningful (>$1000)") - else: - logger.error(f" ❌ Position size too small: ${position_value:.2f}") - - # Test different confidence levels - logger.info("2. Testing different confidence levels:") - for conf in [0.2, 0.5, 0.8, 1.0]: - pos_val = trading_executor._calculate_position_size(conf, current_price) - qty = pos_val / current_price - logger.info(f" Confidence {conf}: ${pos_val:.2f} ({qty:.6f} ETH)") - -def test_eth_only_restriction(): - """Test that only ETH trades are allowed""" - - logger.info("=" * 60) - logger.info("TESTING ETH-ONLY TRADING RESTRICTION") - logger.info("=" * 60) - - trading_executor = TradingExecutor() - - # Test ETH trade (should be allowed) - logger.info("1. Testing ETH/USDT trade (should be allowed):") - eth_allowed = trading_executor._check_safety_conditions('ETH/USDT', 'BUY') - logger.info(f" ETH/USDT allowed: {'✅ YES' if eth_allowed else '❌ NO'}") - - # Test BTC trade (should be blocked) - logger.info("2. Testing BTC/USDT trade (should be blocked):") - btc_allowed = trading_executor._check_safety_conditions('BTC/USDT', 'BUY') - logger.info(f" BTC/USDT allowed: {'❌ YES (ERROR!)' if btc_allowed else '✅ NO (CORRECT)'}") - -def test_win_rate_calculation(): - """Test that win rate calculations are correct""" - - logger.info("=" * 60) - logger.info("TESTING WIN RATE CALCULATIONS") - logger.info("=" * 60) - - trading_executor = TradingExecutor() - - # Clear existing trades - trading_executor.trade_history = [] - - # Add test trades with meaningful P&L - logger.info("1. Adding test trades with meaningful P&L:") - - # Add 3 winning trades - for i in range(3): - winning_trade = TradeRecord( - symbol='ETH/USDT', - side='LONG', - quantity=1.0, - entry_price=2500.0, - exit_price=2550.0, - entry_time=datetime.now(), - exit_time=datetime.now(), - pnl=50.0, # $50 profit with leverage - fees=1.0, - confidence=0.8, - hold_time_seconds=30.0 # 30 second hold - ) - trading_executor.trade_history.append(winning_trade) - logger.info(f" Added winning trade #{i+1}: +$50.00 (30s hold)") - - # Add 2 losing trades - for i in range(2): - losing_trade = TradeRecord( - symbol='ETH/USDT', - side='LONG', - quantity=1.0, - entry_price=2500.0, - exit_price=2475.0, - entry_time=datetime.now(), - exit_time=datetime.now(), - pnl=-25.0, # $25 loss with leverage - fees=1.0, - confidence=0.7, - hold_time_seconds=15.0 # 15 second hold - ) - trading_executor.trade_history.append(losing_trade) - logger.info(f" Added losing trade #{i+1}: -$25.00 (15s hold)") - - # Get statistics - stats = trading_executor.get_daily_stats() - - logger.info("2. Calculated statistics:") - logger.info(f" Total trades: {stats['total_trades']}") - logger.info(f" Winning trades: {stats['winning_trades']}") - logger.info(f" Losing trades: {stats['losing_trades']}") - logger.info(f" Win rate: {stats['win_rate']*100:.1f}%") - logger.info(f" Avg winning trade: ${stats['avg_winning_trade']:.2f}") - logger.info(f" Avg losing trade: ${stats['avg_losing_trade']:.2f}") - logger.info(f" Total P&L: ${stats['total_pnl']:.2f}") - - # Verify calculations - expected_win_rate = 3/5 # 3 wins out of 5 trades = 60% - expected_avg_win = 50.0 - expected_avg_loss = -25.0 - - logger.info("3. Verification:") - win_rate_ok = abs(stats['win_rate'] - expected_win_rate) < 0.01 - avg_win_ok = abs(stats['avg_winning_trade'] - expected_avg_win) < 0.01 - avg_loss_ok = abs(stats['avg_losing_trade'] - expected_avg_loss) < 0.01 - - logger.info(f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {stats['win_rate']*100:.1f}% {'✅' if win_rate_ok else '❌'}") - logger.info(f" Avg win: Expected ${expected_avg_win:.2f}, Got ${stats['avg_winning_trade']:.2f} {'✅' if avg_win_ok else '❌'}") - logger.info(f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${stats['avg_losing_trade']:.2f} {'✅' if avg_loss_ok else '❌'}") - - return win_rate_ok and avg_win_ok and avg_loss_ok - -def test_new_features(): - """Test new features: hold time, leverage, percentage-based sizing""" - - logger.info("=" * 60) - logger.info("TESTING NEW FEATURES") - logger.info("=" * 60) - - trading_executor = TradingExecutor() - - # Test account info - account_info = trading_executor.get_account_info() - logger.info(f"1. Account Information:") - logger.info(f" Account Balance: ${account_info['account_balance']:.2f}") - logger.info(f" Leverage: {account_info['leverage']:.0f}x") - logger.info(f" Trading Mode: {account_info['trading_mode']}") - logger.info(f" Position Sizing: {account_info['position_sizing']['base_percent']:.1f}% base") - - # Test leverage setting - logger.info("2. Testing leverage control:") - old_leverage = trading_executor.get_leverage() - logger.info(f" Current leverage: {old_leverage:.0f}x") - - success = trading_executor.set_leverage(100.0) - new_leverage = trading_executor.get_leverage() - logger.info(f" Set to 100x: {'✅ SUCCESS' if success and new_leverage == 100.0 else '❌ FAILED'}") - - # Reset leverage - trading_executor.set_leverage(old_leverage) - - # Test percentage-based position sizing - logger.info("3. Testing percentage-based position sizing:") - confidence = 0.8 - eth_price = 2500.0 - - position_value = trading_executor._calculate_position_size(confidence, eth_price) - account_balance = trading_executor._get_account_balance_for_sizing() - base_percent = trading_executor.mexc_config.get('base_position_percent', 5.0) - leverage = trading_executor.get_leverage() - - expected_base = account_balance * (base_percent / 100.0) * confidence - expected_leveraged = expected_base * leverage - - logger.info(f" Account: ${account_balance:.2f}") - logger.info(f" Base %: {base_percent:.1f}%") - logger.info(f" Confidence: {confidence:.1f}") - logger.info(f" Leverage: {leverage:.0f}x") - logger.info(f" Expected base: ${expected_base:.2f}") - logger.info(f" Expected leveraged: ${expected_leveraged:.2f}") - logger.info(f" Actual: ${position_value:.2f}") - - sizing_ok = abs(position_value - expected_leveraged) < 0.01 - logger.info(f" Percentage sizing: {'✅ CORRECT' if sizing_ok else '❌ INCORRECT'}") - - return sizing_ok - -def main(): - """Run all tests""" - - logger.info("🚀 TESTING TRADING FIXES AND NEW FEATURES") - logger.info("Testing position sizing, ETH-only trading, win rate calculations, and new features") - - # Test position sizing - test_position_sizing() - - # Test ETH-only restriction - test_eth_only_restriction() - - # Test win rate calculation - calculation_success = test_win_rate_calculation() - - # Test new features - features_success = test_new_features() - - logger.info("=" * 60) - logger.info("TEST SUMMARY") - logger.info("=" * 60) - logger.info(f"Position Sizing: ✅ Updated with percentage-based leverage") - logger.info(f"ETH-Only Trading: ✅ Configured in config") - logger.info(f"Win Rate Calculation: {'✅ FIXED' if calculation_success else '❌ STILL BROKEN'}") - logger.info(f"New Features: {'✅ WORKING' if features_success else '❌ ISSUES FOUND'}") - - if calculation_success and features_success: - logger.info("🎉 ALL FEATURES WORKING! Now you should see:") - logger.info(" - Percentage-based position sizing (2-20% of account)") - logger.info(" - 50x leverage (adjustable in UI)") - logger.info(" - Hold time in seconds for each trade") - logger.info(" - Total fees in trading statistics") - logger.info(" - Only ETH/USDT trades") - logger.info(" - Correct win rate calculations") - else: - logger.error("❌ Some issues remain. Check the logs above for details.") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/debug/trade_audit.py b/debug/trade_audit.py deleted file mode 100644 index 04efd8b..0000000 --- a/debug/trade_audit.py +++ /dev/null @@ -1,344 +0,0 @@ -#!/usr/bin/env python3 -""" -Trade Audit Tool - -This tool analyzes trade data to identify potential issues with: -- Duplicate entry prices -- Rapid consecutive trades -- P&L calculation accuracy -- Position tracking problems - -Usage: - python debug/trade_audit.py [--trades-file path/to/trades.json] -""" - -import argparse -import json -import pandas as pd -import numpy as np -from datetime import datetime, timedelta -import matplotlib.pyplot as plt -import os -import sys -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - -def parse_trade_time(time_str): - """Parse trade time string to datetime object""" - try: - # Try HH:MM:SS format - return datetime.strptime(time_str, "%H:%M:%S") - except ValueError: - try: - # Try full datetime format - return datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S") - except ValueError: - # Return as is if parsing fails - return time_str - -def load_trades_from_file(file_path): - """Load trades from JSON file""" - try: - with open(file_path, 'r') as f: - return json.load(f) - except FileNotFoundError: - print(f"Error: File {file_path} not found") - return [] - except json.JSONDecodeError: - print(f"Error: File {file_path} is not valid JSON") - return [] - -def load_trades_from_dashboard_cache(): - """Load trades from dashboard cache file if available""" - cache_paths = [ - "cache/dashboard_trades.json", - "cache/closed_trades.json", - "data/trades_history.json" - ] - - for path in cache_paths: - if os.path.exists(path): - print(f"Loading trades from cache: {path}") - return load_trades_from_file(path) - - print("No trade cache files found") - return [] - -def parse_trade_data(trades_data): - """Parse trade data into a pandas DataFrame for analysis""" - parsed_trades = [] - - for trade in trades_data: - # Handle different trade data formats - parsed_trade = {} - - # Time field might be named entry_time or time - if 'entry_time' in trade: - parsed_trade['time'] = parse_trade_time(trade['entry_time']) - elif 'time' in trade: - parsed_trade['time'] = parse_trade_time(trade['time']) - else: - parsed_trade['time'] = None - - # Side might be named side or action - parsed_trade['side'] = trade.get('side', trade.get('action', 'UNKNOWN')) - - # Size might be named size or quantity - parsed_trade['size'] = float(trade.get('size', trade.get('quantity', 0))) - - # Entry and exit prices - parsed_trade['entry_price'] = float(trade.get('entry_price', trade.get('entry', 0))) - parsed_trade['exit_price'] = float(trade.get('exit_price', trade.get('exit', 0))) - - # Hold time in seconds - parsed_trade['hold_time'] = float(trade.get('hold_time_seconds', trade.get('hold', 0))) - - # P&L and fees - parsed_trade['pnl'] = float(trade.get('pnl', 0)) - parsed_trade['fees'] = float(trade.get('fees', 0)) - - # Calculate expected P&L for verification - if parsed_trade['side'] == 'LONG' or parsed_trade['side'] == 'BUY': - expected_pnl = (parsed_trade['exit_price'] - parsed_trade['entry_price']) * parsed_trade['size'] - else: # SHORT or SELL - expected_pnl = (parsed_trade['entry_price'] - parsed_trade['exit_price']) * parsed_trade['size'] - - parsed_trade['expected_pnl'] = expected_pnl - parsed_trade['pnl_difference'] = parsed_trade['pnl'] - expected_pnl - - parsed_trades.append(parsed_trade) - - # Convert to DataFrame - if parsed_trades: - df = pd.DataFrame(parsed_trades) - return df - else: - return pd.DataFrame() - -def analyze_trades(df): - """Analyze trades for potential issues""" - if df.empty: - print("No trades to analyze") - return - - print(f"\n{'='*50}") - print("TRADE AUDIT RESULTS") - print(f"{'='*50}") - print(f"Total trades analyzed: {len(df)}") - - # Check for duplicate entry prices - entry_price_counts = df['entry_price'].value_counts() - duplicate_entries = entry_price_counts[entry_price_counts > 1] - - print(f"\n{'='*20} DUPLICATE ENTRY PRICES {'='*20}") - if not duplicate_entries.empty: - print(f"Found {len(duplicate_entries)} prices with multiple entries:") - for price, count in duplicate_entries.items(): - print(f" ${price:.2f}: {count} trades") - - # Analyze the duplicate entry trades in more detail - for price in duplicate_entries.index: - duplicate_df = df[df['entry_price'] == price].copy() - duplicate_df['time_diff'] = duplicate_df['time'].diff().dt.total_seconds() - - print(f"\nDetailed analysis for entry price ${price:.2f}:") - print(f" Time gaps between consecutive trades:") - for i, (_, row) in enumerate(duplicate_df.iterrows()): - if i > 0: # Skip first row as it has no previous trade - time_diff = row['time_diff'] - if pd.notna(time_diff): - print(f" {row['time'].strftime('%H:%M:%S')}: {time_diff:.0f} seconds after previous trade") - else: - print("No duplicate entry prices found") - - # Check for rapid consecutive trades - df = df.sort_values('time') - df['time_since_last'] = df['time'].diff().dt.total_seconds() - - rapid_trades = df[df['time_since_last'] < 30].copy() - - print(f"\n{'='*20} RAPID CONSECUTIVE TRADES {'='*20}") - if not rapid_trades.empty: - print(f"Found {len(rapid_trades)} trades executed within 30 seconds of previous trade:") - for _, row in rapid_trades.iterrows(): - if pd.notna(row['time_since_last']): - print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} ${row['size']:.2f} @ ${row['entry_price']:.2f} - {row['time_since_last']:.0f}s after previous") - else: - print("No rapid consecutive trades found") - - # Check for P&L calculation accuracy - pnl_diff = df[abs(df['pnl_difference']) > 0.01].copy() - - print(f"\n{'='*20} P&L CALCULATION ISSUES {'='*20}") - if not pnl_diff.empty: - print(f"Found {len(pnl_diff)} trades with P&L calculation discrepancies:") - for _, row in pnl_diff.iterrows(): - print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} - Reported: ${row['pnl']:.2f}, Expected: ${row['expected_pnl']:.2f}, Diff: ${row['pnl_difference']:.2f}") - else: - print("No P&L calculation issues found") - - # Check for side distribution - side_counts = df['side'].value_counts() - - print(f"\n{'='*20} TRADE SIDE DISTRIBUTION {'='*20}") - for side, count in side_counts.items(): - print(f" {side}: {count} trades ({count/len(df)*100:.1f}%)") - - # Check for hold time distribution - print(f"\n{'='*20} HOLD TIME DISTRIBUTION {'='*20}") - print(f" Min hold time: {df['hold_time'].min():.0f} seconds") - print(f" Max hold time: {df['hold_time'].max():.0f} seconds") - print(f" Avg hold time: {df['hold_time'].mean():.0f} seconds") - print(f" Median hold time: {df['hold_time'].median():.0f} seconds") - - # Hold time buckets - hold_buckets = [0, 30, 60, 120, 300, 600, 1800, 3600, float('inf')] - hold_labels = ['0-30s', '30-60s', '1-2m', '2-5m', '5-10m', '10-30m', '30-60m', '60m+'] - - df['hold_bucket'] = pd.cut(df['hold_time'], bins=hold_buckets, labels=hold_labels) - hold_dist = df['hold_bucket'].value_counts().sort_index() - - for bucket, count in hold_dist.items(): - print(f" {bucket}: {count} trades ({count/len(df)*100:.1f}%)") - - # Generate summary statistics - print(f"\n{'='*20} TRADE PERFORMANCE SUMMARY {'='*20}") - winning_trades = df[df['pnl'] > 0] - losing_trades = df[df['pnl'] < 0] - - print(f" Win rate: {len(winning_trades)/len(df)*100:.1f}% ({len(winning_trades)}W/{len(losing_trades)}L)") - print(f" Avg win: ${winning_trades['pnl'].mean():.2f}") - print(f" Avg loss: ${abs(losing_trades['pnl'].mean()):.2f}") - print(f" Total P&L: ${df['pnl'].sum():.2f}") - print(f" Total fees: ${df['fees'].sum():.2f}") - print(f" Net P&L: ${(df['pnl'].sum() - df['fees'].sum()):.2f}") - - # Plot entry price distribution - plt.figure(figsize=(10, 6)) - plt.hist(df['entry_price'], bins=20, alpha=0.7) - plt.title('Entry Price Distribution') - plt.xlabel('Entry Price ($)') - plt.ylabel('Number of Trades') - plt.grid(True, alpha=0.3) - plt.savefig('debug/entry_price_distribution.png') - - # Plot P&L distribution - plt.figure(figsize=(10, 6)) - plt.hist(df['pnl'], bins=20, alpha=0.7) - plt.title('P&L Distribution') - plt.xlabel('P&L ($)') - plt.ylabel('Number of Trades') - plt.grid(True, alpha=0.3) - plt.savefig('debug/pnl_distribution.png') - - print(f"\n{'='*20} AUDIT COMPLETE {'='*20}") - print("Plots saved to debug/entry_price_distribution.png and debug/pnl_distribution.png") - -def analyze_manual_trades(trades_data): - """Analyze manually provided trade data""" - # Parse the trade data into a structured format - parsed_trades = [] - - for line in trades_data.strip().split('\n'): - if not line or line.startswith('from last session') or line.startswith('Recent Closed Trades') or line.startswith('Trading Performance'): - continue - - if line.startswith('Win Rate:'): - # This is the summary line, skip it - continue - - try: - # Parse trade line format: Time Side Size Entry Exit Hold P&L Fees - parts = line.split('$') - - time_side = parts[0].strip().split() - time = time_side[0] - side = time_side[1] - - size = float(parts[1].split()[0]) - entry = float(parts[2].split()[0]) - exit = float(parts[3].split()[0]) - - # The hold time and P&L are in the last parts - remaining = parts[3].split() - hold = int(remaining[1]) - pnl = float(parts[4].split()[0]) - - # Fees might be in a different format - if len(parts) > 5: - fees = float(parts[5].strip()) - else: - fees = 0.0 - - parsed_trade = { - 'time': parse_trade_time(time), - 'side': side, - 'size': size, - 'entry_price': entry, - 'exit_price': exit, - 'hold_time': hold, - 'pnl': pnl, - 'fees': fees - } - - # Calculate expected P&L - if side == 'LONG' or side == 'BUY': - expected_pnl = (exit - entry) * size - else: # SHORT or SELL - expected_pnl = (entry - exit) * size - - parsed_trade['expected_pnl'] = expected_pnl - parsed_trade['pnl_difference'] = pnl - expected_pnl - - parsed_trades.append(parsed_trade) - - except Exception as e: - print(f"Error parsing trade line: {line}") - print(f"Error details: {e}") - - # Convert to DataFrame - if parsed_trades: - df = pd.DataFrame(parsed_trades) - return df - else: - return pd.DataFrame() - -def main(): - parser = argparse.ArgumentParser(description='Trade Audit Tool') - parser.add_argument('--trades-file', type=str, help='Path to trades JSON file') - parser.add_argument('--manual-trades', type=str, help='Path to text file with manually entered trades') - args = parser.parse_args() - - # Create debug directory if it doesn't exist - os.makedirs('debug', exist_ok=True) - - if args.trades_file: - trades_data = load_trades_from_file(args.trades_file) - df = parse_trade_data(trades_data) - elif args.manual_trades: - try: - with open(args.manual_trades, 'r') as f: - manual_trades = f.read() - df = analyze_manual_trades(manual_trades) - except Exception as e: - print(f"Error reading manual trades file: {e}") - df = pd.DataFrame() - else: - # Try to load from dashboard cache - trades_data = load_trades_from_dashboard_cache() - if trades_data: - df = parse_trade_data(trades_data) - else: - print("No trade data provided. Use --trades-file or --manual-trades") - return - - if not df.empty: - analyze_trades(df) - else: - print("No valid trade data to analyze") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/debug_training_methods.py b/debug_training_methods.py deleted file mode 100644 index f7e5813..0000000 --- a/debug_training_methods.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 -""" -Debug Training Methods - -This script checks what training methods are available on each model. -""" - -import asyncio -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -async def debug_training_methods(): - """Debug the available training methods on each model""" - print("=== Debugging Training Methods ===") - - # Initialize orchestrator - print("1. Initializing orchestrator...") - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider=data_provider) - - # Wait for initialization - await asyncio.sleep(2) - - print("\n2. Checking available training methods on each model:") - - for model_name, model_interface in orchestrator.model_registry.models.items(): - print(f"\n--- {model_name} ---") - print(f"Interface type: {type(model_interface).__name__}") - - # Get underlying model - underlying_model = getattr(model_interface, 'model', None) - if underlying_model: - print(f"Underlying model type: {type(underlying_model).__name__}") - else: - print("No underlying model found") - continue - - # Check for training methods - training_methods = [] - for method in ['train_on_outcome', 'add_experience', 'remember', 'replay', 'add_training_sample', 'train', 'train_with_reward', 'update_loss']: - if hasattr(underlying_model, method): - training_methods.append(method) - - print(f"Available training methods: {training_methods}") - - # Check for specific attributes - attributes = [] - for attr in ['memory', 'batch_size', 'training_data']: - if hasattr(underlying_model, attr): - attr_value = getattr(underlying_model, attr) - if attr == 'memory' and hasattr(attr_value, '__len__'): - attributes.append(f"{attr}(len={len(attr_value)})") - elif attr == 'training_data' and hasattr(attr_value, '__len__'): - attributes.append(f"{attr}(len={len(attr_value)})") - else: - attributes.append(f"{attr}={attr_value}") - - print(f"Relevant attributes: {attributes}") - - # Check if it's an RL agent - if hasattr(underlying_model, 'act') and hasattr(underlying_model, 'remember'): - print("✅ Detected as RL Agent") - elif hasattr(underlying_model, 'predict') and hasattr(underlying_model, 'add_training_sample'): - print("✅ Detected as CNN Model") - else: - print("❓ Unknown model type") - - print("\n3. Testing a simple training attempt:") - - # Get a prediction first - predictions = await orchestrator._get_all_predictions('ETH/USDT') - print(f"Got {len(predictions)} predictions") - - # Try to trigger training for each model - for model_name in orchestrator.model_registry.models.keys(): - print(f"\nTesting training for {model_name}...") - try: - await orchestrator._trigger_immediate_training_for_model(model_name, 'ETH/USDT') - print(f"✅ Training attempt completed for {model_name}") - except Exception as e: - print(f"❌ Training failed for {model_name}: {e}") - -if __name__ == "__main__": - asyncio.run(debug_training_methods()) \ No newline at end of file diff --git a/docs/exchanges/bybit/examples.py b/docs/exchanges/bybit/examples.py deleted file mode 100644 index 4827533..0000000 --- a/docs/exchanges/bybit/examples.py +++ /dev/null @@ -1,233 +0,0 @@ -""" -Bybit Integration Examples -Based on official pybit library documentation and examples -""" - -import os -from pybit.unified_trading import HTTP -import logging - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def create_bybit_session(testnet=True): - """Create a Bybit HTTP session. - - Args: - testnet (bool): Use testnet if True, live if False - - Returns: - HTTP: Bybit session object - """ - api_key = os.getenv('BYBIT_API_KEY') - api_secret = os.getenv('BYBIT_API_SECRET') - - if not api_key or not api_secret: - raise ValueError("BYBIT_API_KEY and BYBIT_API_SECRET must be set in environment") - - session = HTTP( - testnet=testnet, - api_key=api_key, - api_secret=api_secret, - ) - - logger.info(f"Created Bybit session (testnet: {testnet})") - return session - -def get_account_info(session): - """Get account information and balances.""" - try: - # Get account info - account_info = session.get_wallet_balance(accountType="UNIFIED") - logger.info(f"Account info: {account_info}") - - return account_info - except Exception as e: - logger.error(f"Error getting account info: {e}") - return None - -def get_ticker_info(session, symbol="BTCUSDT"): - """Get ticker information for a symbol. - - Args: - session: Bybit HTTP session - symbol: Trading symbol (default: BTCUSDT) - """ - try: - ticker = session.get_tickers(category="linear", symbol=symbol) - logger.info(f"Ticker for {symbol}: {ticker}") - return ticker - except Exception as e: - logger.error(f"Error getting ticker for {symbol}: {e}") - return None - -def get_orderbook(session, symbol="BTCUSDT", limit=25): - """Get orderbook for a symbol. - - Args: - session: Bybit HTTP session - symbol: Trading symbol - limit: Number of price levels to return - """ - try: - orderbook = session.get_orderbook( - category="linear", - symbol=symbol, - limit=limit - ) - logger.info(f"Orderbook for {symbol}: {orderbook}") - return orderbook - except Exception as e: - logger.error(f"Error getting orderbook for {symbol}: {e}") - return None - -def place_limit_order(session, symbol="BTCUSDT", side="Buy", qty="0.001", price="50000"): - """Place a limit order. - - Args: - session: Bybit HTTP session - symbol: Trading symbol - side: "Buy" or "Sell" - qty: Order quantity as string - price: Order price as string - """ - try: - order = session.place_order( - category="linear", - symbol=symbol, - side=side, - orderType="Limit", - qty=qty, - price=price, - timeInForce="GTC" # Good Till Cancelled - ) - logger.info(f"Placed order: {order}") - return order - except Exception as e: - logger.error(f"Error placing order: {e}") - return None - -def place_market_order(session, symbol="BTCUSDT", side="Buy", qty="0.001"): - """Place a market order. - - Args: - session: Bybit HTTP session - symbol: Trading symbol - side: "Buy" or "Sell" - qty: Order quantity as string - """ - try: - order = session.place_order( - category="linear", - symbol=symbol, - side=side, - orderType="Market", - qty=qty - ) - logger.info(f"Placed market order: {order}") - return order - except Exception as e: - logger.error(f"Error placing market order: {e}") - return None - -def get_open_orders(session, symbol=None): - """Get open orders. - - Args: - session: Bybit HTTP session - symbol: Trading symbol (optional, gets all if None) - """ - try: - params = {"category": "linear", "openOnly": True} - if symbol: - params["symbol"] = symbol - - orders = session.get_open_orders(**params) - logger.info(f"Open orders: {orders}") - return orders - except Exception as e: - logger.error(f"Error getting open orders: {e}") - return None - -def cancel_order(session, symbol, order_id): - """Cancel an order. - - Args: - session: Bybit HTTP session - symbol: Trading symbol - order_id: Order ID to cancel - """ - try: - result = session.cancel_order( - category="linear", - symbol=symbol, - orderId=order_id - ) - logger.info(f"Cancelled order {order_id}: {result}") - return result - except Exception as e: - logger.error(f"Error cancelling order {order_id}: {e}") - return None - -def get_position(session, symbol="BTCUSDT"): - """Get position information. - - Args: - session: Bybit HTTP session - symbol: Trading symbol - """ - try: - positions = session.get_positions( - category="linear", - symbol=symbol - ) - logger.info(f"Position for {symbol}: {positions}") - return positions - except Exception as e: - logger.error(f"Error getting position for {symbol}: {e}") - return None - -def get_trade_history(session, symbol="BTCUSDT", limit=50): - """Get trade history. - - Args: - session: Bybit HTTP session - symbol: Trading symbol - limit: Number of trades to return - """ - try: - trades = session.get_executions( - category="linear", - symbol=symbol, - limit=limit - ) - logger.info(f"Trade history for {symbol}: {trades}") - return trades - except Exception as e: - logger.error(f"Error getting trade history for {symbol}: {e}") - return None - -# Example usage -if __name__ == "__main__": - # Create session (testnet by default) - session = create_bybit_session(testnet=True) - - # Get account info - account_info = get_account_info(session) - - # Get ticker - ticker = get_ticker_info(session, "BTCUSDT") - - # Get orderbook - orderbook = get_orderbook(session, "BTCUSDT") - - # Get open orders - open_orders = get_open_orders(session) - - # Get position - position = get_position(session, "BTCUSDT") - - # Note: Uncomment below to actually place orders (use with caution) - # order = place_limit_order(session, "BTCUSDT", "Buy", "0.001", "30000") - # market_order = place_market_order(session, "BTCUSDT", "Buy", "0.001") \ No newline at end of file diff --git a/example_usage_simplified_data_provider.py b/example_usage_simplified_data_provider.py deleted file mode 100644 index a3a3d14..0000000 --- a/example_usage_simplified_data_provider.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3 -""" -Example usage of the simplified data provider -""" - -import time -import logging -from core.data_provider import DataProvider - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def main(): - """Demonstrate the simplified data provider usage""" - - # Initialize data provider (starts automatic maintenance) - logger.info("Initializing DataProvider...") - dp = DataProvider() - - # Wait for initial data load (happens automatically in background) - logger.info("Waiting for initial data load...") - time.sleep(15) # Give it time to load data - - # Example 1: Get cached historical data (no API calls) - logger.info("\n=== Example 1: Getting Historical Data ===") - eth_1m_data = dp.get_historical_data('ETH/USDT', '1m', limit=50) - if eth_1m_data is not None: - logger.info(f"ETH/USDT 1m data: {len(eth_1m_data)} candles") - logger.info(f"Latest candle: {eth_1m_data.iloc[-1]['close']}") - - # Example 2: Get current prices - logger.info("\n=== Example 2: Current Prices ===") - eth_price = dp.get_current_price('ETH/USDT') - btc_price = dp.get_current_price('BTC/USDT') - logger.info(f"ETH current price: ${eth_price}") - logger.info(f"BTC current price: ${btc_price}") - - # Example 3: Check cache status - logger.info("\n=== Example 3: Cache Status ===") - cache_summary = dp.get_cached_data_summary() - for symbol in cache_summary['cached_data']: - logger.info(f"\n{symbol}:") - for timeframe, info in cache_summary['cached_data'][symbol].items(): - if 'candle_count' in info and info['candle_count'] > 0: - logger.info(f" {timeframe}: {info['candle_count']} candles, latest: ${info['latest_price']}") - else: - logger.info(f" {timeframe}: {info.get('status', 'no data')}") - - # Example 4: Multiple timeframe data - logger.info("\n=== Example 4: Multiple Timeframes ===") - for tf in ['1s', '1m', '1h', '1d']: - data = dp.get_historical_data('ETH/USDT', tf, limit=5) - if data is not None and not data.empty: - logger.info(f"ETH {tf}: {len(data)} candles, range: ${data['close'].min():.2f} - ${data['close'].max():.2f}") - - # Example 5: Health check - logger.info("\n=== Example 5: Health Check ===") - health = dp.health_check() - logger.info(f"Data maintenance active: {health['data_maintenance_active']}") - logger.info(f"Symbols: {health['symbols']}") - logger.info(f"Timeframes: {health['timeframes']}") - - # Example 6: Wait and show automatic updates - logger.info("\n=== Example 6: Automatic Updates ===") - logger.info("Waiting 30 seconds to show automatic data updates...") - - # Get initial timestamp - initial_data = dp.get_historical_data('ETH/USDT', '1s', limit=1) - initial_time = initial_data.index[-1] if initial_data is not None else None - - time.sleep(30) - - # Check if data was updated - updated_data = dp.get_historical_data('ETH/USDT', '1s', limit=1) - updated_time = updated_data.index[-1] if updated_data is not None else None - - if initial_time and updated_time and updated_time > initial_time: - logger.info(f"✅ Data automatically updated! New timestamp: {updated_time}") - else: - logger.info("⏳ Data update in progress...") - - # Clean shutdown - logger.info("\n=== Shutting Down ===") - dp.stop_automatic_data_maintenance() - logger.info("DataProvider stopped successfully") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/kill_stale_processes.py b/kill_stale_processes.py deleted file mode 100644 index de74b8c..0000000 --- a/kill_stale_processes.py +++ /dev/null @@ -1,331 +0,0 @@ -""" -Kill Stale Processes - -This script identifies and kills stale Python processes that might be causing -the dashboard startup freeze. It looks for: -1. Hanging dashboard processes -2. Stale COB data collection threads -3. Matplotlib GUI processes -4. Blocked network connections - -Usage: - python kill_stale_processes.py -""" - -import os -import sys -import psutil -import signal -import time -from datetime import datetime - -def find_python_processes(): - """Find all Python processes""" - python_processes = [] - - try: - for proc in psutil.process_iter(['pid', 'name', 'cmdline', 'create_time', 'status']): - try: - if proc.info['name'] and 'python' in proc.info['name'].lower(): - # Get command line to identify dashboard processes - cmdline = ' '.join(proc.info['cmdline']) if proc.info['cmdline'] else '' - - python_processes.append({ - 'pid': proc.info['pid'], - 'name': proc.info['name'], - 'cmdline': cmdline, - 'create_time': proc.info['create_time'], - 'status': proc.info['status'], - 'process': proc - }) - except (psutil.NoSuchProcess, psutil.AccessDenied): - continue - - except Exception as e: - print(f"Error finding Python processes: {e}") - - return python_processes - -def identify_dashboard_processes(python_processes): - """Identify processes related to the dashboard""" - dashboard_processes = [] - - dashboard_keywords = [ - 'clean_dashboard', - 'run_clean_dashboard', - 'dashboard', - 'trading', - 'cob_data', - 'orchestrator', - 'data_provider' - ] - - for proc_info in python_processes: - cmdline = proc_info['cmdline'].lower() - - # Check if this is a dashboard-related process - is_dashboard = any(keyword in cmdline for keyword in dashboard_keywords) - - if is_dashboard: - dashboard_processes.append(proc_info) - - return dashboard_processes - -def identify_stale_processes(python_processes): - """Identify potentially stale processes""" - stale_processes = [] - current_time = time.time() - - for proc_info in python_processes: - try: - proc = proc_info['process'] - - # Check if process is in a problematic state - if proc_info['status'] in ['zombie', 'stopped']: - stale_processes.append({ - **proc_info, - 'reason': f"Process status: {proc_info['status']}" - }) - continue - - # Check if process has been running for a very long time without activity - age_hours = (current_time - proc_info['create_time']) / 3600 - if age_hours > 24: # Running for more than 24 hours - try: - # Check CPU usage - cpu_percent = proc.cpu_percent(interval=1) - if cpu_percent < 0.1: # Very low CPU usage - stale_processes.append({ - **proc_info, - 'reason': f"Old process ({age_hours:.1f}h) with low CPU usage ({cpu_percent:.1f}%)" - }) - except: - pass - - # Check for processes with high memory usage but no activity - try: - memory_info = proc.memory_info() - memory_mb = memory_info.rss / 1024 / 1024 - - if memory_mb > 500: # More than 500MB - cpu_percent = proc.cpu_percent(interval=1) - if cpu_percent < 0.1: - stale_processes.append({ - **proc_info, - 'reason': f"High memory usage ({memory_mb:.1f}MB) with low CPU usage ({cpu_percent:.1f}%)" - }) - except: - pass - - except (psutil.NoSuchProcess, psutil.AccessDenied): - continue - - return stale_processes - -def kill_process_safely(proc_info, force=False): - """Kill a process safely""" - try: - proc = proc_info['process'] - pid = proc_info['pid'] - - print(f"Attempting to {'force kill' if force else 'terminate'} PID {pid}: {proc_info['name']}") - - if force: - # Force kill - if os.name == 'nt': # Windows - os.system(f"taskkill /F /PID {pid}") - else: # Unix/Linux - os.kill(pid, signal.SIGKILL) - else: - # Graceful termination - proc.terminate() - - # Wait for termination - try: - proc.wait(timeout=5) - print(f"✅ Process {pid} terminated gracefully") - return True - except psutil.TimeoutExpired: - print(f"⚠️ Process {pid} didn't terminate gracefully, will force kill") - return False - - print(f"✅ Process {pid} killed") - return True - - except (psutil.NoSuchProcess, psutil.AccessDenied) as e: - print(f"⚠️ Could not kill process {proc_info['pid']}: {e}") - return False - except Exception as e: - print(f"❌ Error killing process {proc_info['pid']}: {e}") - return False - -def check_port_usage(): - """Check if dashboard port is in use""" - try: - import socket - - # Check if port 8050 is in use - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - result = sock.connect_ex(('localhost', 8050)) - sock.close() - - if result == 0: - print("⚠️ Port 8050 is in use") - - # Find process using the port - for conn in psutil.net_connections(): - if conn.laddr.port == 8050: - try: - proc = psutil.Process(conn.pid) - print(f" Port 8050 used by PID {conn.pid}: {proc.name()}") - return conn.pid - except: - pass - else: - print("✅ Port 8050 is available") - return None - - except Exception as e: - print(f"Error checking port usage: {e}") - return None - -def main(): - """Main function""" - print("🔍 Stale Process Killer") - print("=" * 50) - - try: - # Step 1: Find all Python processes - print("🔍 Finding Python processes...") - python_processes = find_python_processes() - print(f"Found {len(python_processes)} Python processes") - - # Step 2: Identify dashboard processes - print("\n🎯 Identifying dashboard processes...") - dashboard_processes = identify_dashboard_processes(python_processes) - - if dashboard_processes: - print(f"Found {len(dashboard_processes)} dashboard-related processes:") - for proc in dashboard_processes: - age_hours = (time.time() - proc['create_time']) / 3600 - print(f" PID {proc['pid']}: {proc['name']} (age: {age_hours:.1f}h, status: {proc['status']})") - print(f" Command: {proc['cmdline'][:100]}...") - else: - print("No dashboard processes found") - - # Step 3: Check port usage - print("\n🌐 Checking port usage...") - port_pid = check_port_usage() - - # Step 4: Identify stale processes - print("\n🕵️ Identifying stale processes...") - stale_processes = identify_stale_processes(python_processes) - - if stale_processes: - print(f"Found {len(stale_processes)} potentially stale processes:") - for proc in stale_processes: - print(f" PID {proc['pid']}: {proc['name']} - {proc['reason']}") - else: - print("No stale processes identified") - - # Step 5: Ask user what to do - if dashboard_processes or stale_processes or port_pid: - print("\n🤔 What would you like to do?") - print("1. Kill all dashboard processes") - print("2. Kill only stale processes") - print("3. Kill process using port 8050") - print("4. Kill all identified processes") - print("5. Show process details and exit") - print("6. Exit without killing anything") - - try: - choice = input("\nEnter your choice (1-6): ").strip() - - if choice == '1': - # Kill dashboard processes - print("\n🔫 Killing dashboard processes...") - for proc in dashboard_processes: - if not kill_process_safely(proc): - kill_process_safely(proc, force=True) - - elif choice == '2': - # Kill stale processes - print("\n🔫 Killing stale processes...") - for proc in stale_processes: - if not kill_process_safely(proc): - kill_process_safely(proc, force=True) - - elif choice == '3': - # Kill process using port 8050 - if port_pid: - print(f"\n🔫 Killing process using port 8050 (PID {port_pid})...") - try: - proc = psutil.Process(port_pid) - proc_info = { - 'pid': port_pid, - 'name': proc.name(), - 'process': proc - } - if not kill_process_safely(proc_info): - kill_process_safely(proc_info, force=True) - except: - print(f"❌ Could not kill process {port_pid}") - else: - print("No process found using port 8050") - - elif choice == '4': - # Kill all identified processes - print("\n🔫 Killing all identified processes...") - all_processes = dashboard_processes + stale_processes - if port_pid: - try: - proc = psutil.Process(port_pid) - all_processes.append({ - 'pid': port_pid, - 'name': proc.name(), - 'process': proc - }) - except: - pass - - for proc in all_processes: - if not kill_process_safely(proc): - kill_process_safely(proc, force=True) - - elif choice == '5': - # Show details - print("\n📋 Process Details:") - all_processes = dashboard_processes + stale_processes - for proc in all_processes: - print(f"\nPID {proc['pid']}: {proc['name']}") - print(f" Status: {proc['status']}") - print(f" Command: {proc['cmdline']}") - print(f" Created: {datetime.fromtimestamp(proc['create_time'])}") - - elif choice == '6': - print("👋 Exiting without killing processes") - - else: - print("❌ Invalid choice") - - except KeyboardInterrupt: - print("\n👋 Cancelled by user") - else: - print("\n✅ No problematic processes found") - - print("\n" + "=" * 50) - print("💡 After killing processes, you can try:") - print(" python run_lightweight_dashboard.py") - print(" or") - print(" python fix_startup_freeze.py") - - return True - - except Exception as e: - print(f"❌ Error in main function: {e}") - return False - -if __name__ == "__main__": - success = main() - if not success: - sys.exit(1) \ No newline at end of file diff --git a/launch_training.py b/launch_training.py deleted file mode 100644 index 5eb92e6..0000000 --- a/launch_training.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -Launch training with optimized short-term models only -""" - -import os -import sys -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import load_config -from core.training import TrainingManager -from core.models import OptimizedShortTermModel - -def main(): - """Main training function using only optimized models""" - config = load_config() - - # Initialize model - model = OptimizedShortTermModel() - - # Load best model if exists - best_model_path = config.model_paths.get('ticks_model') - if os.path.exists(best_model_path): - model.load_state_dict(torch.load(best_model_path)) - - # Initialize training - trainer = TrainingManager( - model=model, - config=config, - use_ticks=True, - use_realtime=True - ) - - # Start training - trainer.train() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/main.py b/main.py deleted file mode 100644 index ff3a43d..0000000 --- a/main.py +++ /dev/null @@ -1,458 +0,0 @@ -#!/usr/bin/env python3 -""" -Streamlined Trading System - Web Dashboard + Training - -Integrated system with both training loop and web dashboard: -- Training Pipeline: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution -- Web Dashboard: Real-time monitoring and control interface -- 2-Action System: BUY/SELL with intelligent position management -- Always invested approach with smart risk/reward setup detection - -Usage: - python main.py [--symbol ETH/USDT] [--port 8050] -""" - -import os -# Fix OpenMP library conflicts before importing other modules -os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' -os.environ['OMP_NUM_THREADS'] = '4' - -import asyncio -import argparse -import logging -import sys -from pathlib import Path -from threading import Thread -import time -from safe_logging import setup_safe_logging - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import get_config, setup_logging, Config -from core.data_provider import DataProvider - -# Import checkpoint management -from utils.checkpoint_manager import get_checkpoint_manager -from utils.training_integration import get_training_integration - -logger = logging.getLogger(__name__) - -async def run_web_dashboard(): - """Run the streamlined web dashboard with 2-action system and always-invested approach""" - try: - logger.info("Starting Streamlined Trading Dashboard...") - logger.info("2-Action System: BUY/SELL with intelligent position management") - logger.info("Always Invested Approach: Smart risk/reward setup detection") - logger.info("Integrated Training Pipeline: Live data -> Models -> Trading") - - # Get configuration - config = get_config() - - # Initialize core components for streamlined pipeline - from core.standardized_data_provider import StandardizedDataProvider - from core.orchestrator import TradingOrchestrator - from core.trading_executor import TradingExecutor - - # Create standardized data provider (validated BaseDataInput, pivots, COB) - data_provider = StandardizedDataProvider() - - # Start real-time streaming for BOM caching - try: - await data_provider.start_real_time_streaming() - logger.info("[SUCCESS] Real-time data streaming started for BOM caching") - except Exception as e: - logger.warning(f"[WARNING] Real-time streaming failed: {e}") - - # Verify data connection with retry mechanism - logger.info("[DATA] Verifying live data connection...") - symbol = config.get('symbols', ['ETH/USDT'])[0] - - # Wait for data provider to initialize and fetch initial data - max_retries = 10 - retry_delay = 2 - - for attempt in range(max_retries): - test_df = data_provider.get_historical_data(symbol, '1m', limit=10) - if test_df is not None and len(test_df) > 0: - logger.info("[SUCCESS] Data connection verified") - logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation") - break - else: - if attempt < max_retries - 1: - logger.info(f"[DATA] Waiting for data provider to initialize... (attempt {attempt + 1}/{max_retries})") - await asyncio.sleep(retry_delay) - else: - logger.warning("[WARNING] Data connection verification failed, but continuing with system startup") - logger.warning("The system will attempt to fetch data as needed during operation") - - # Load model registry for integrated pipeline - try: - from models import get_model_registry - model_registry = {} # Use simple dict for now - logger.info("[MODELS] Model registry initialized for training") - except ImportError: - model_registry = {} - logger.warning("Model registry not available, using empty registry") - - # Initialize checkpoint management - checkpoint_manager = get_checkpoint_manager() - training_integration = get_training_integration() - logger.info("Checkpoint management initialized for training pipeline") - - # Create unified orchestrator with full ML pipeline - orchestrator = TradingOrchestrator( - data_provider=data_provider, - enhanced_rl_training=True, - model_registry={} - ) - logger.info("Unified Trading Orchestrator initialized with full ML pipeline") - logger.info("Data Bus -> Models (DQN + CNN + COB) -> Decision Model -> Trading Signals") - - # Checkpoint management will be handled in the training loop - logger.info("Checkpoint management will be initialized in training loop") - - # Unified orchestrator includes COB integration as part of data bus - logger.info("COB Integration available - feeds into unified data bus") - - # Create trading executor for live execution - trading_executor = TradingExecutor() - - # Start the training and monitoring loop - logger.info(f"Starting Enhanced Training Pipeline") - logger.info("Live Data Processing: ENABLED") - logger.info("COB Integration: ENABLED (Real-time market microstructure)") - logger.info("Integrated CNN Training: ENABLED") - logger.info("Integrated RL Training: ENABLED") - logger.info("Real-time Indicators & Pivots: ENABLED") - logger.info("Live Trading Execution: ENABLED") - logger.info("2-Action System: BUY/SELL with position intelligence") - logger.info("Always Invested: Different thresholds for entry/exit") - logger.info("Pipeline: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution") - logger.info("Starting training loop...") - - # Start the training loop - logger.info("About to start training loop...") - await start_training_loop(orchestrator, trading_executor) - - except Exception as e: - logger.error(f"Error in streamlined dashboard: {e}") - logger.error("Training stopped") - import traceback - logger.error(traceback.format_exc()) - -def start_web_ui(port=8051): - """Start the main TradingDashboard UI in a separate thread""" - try: - logger.info("=" * 50) - logger.info("Starting Main Trading Dashboard UI...") - logger.info(f"Trading Dashboard: http://127.0.0.1:{port}") - logger.info("COB Integration: ENABLED (Real-time order book visualization)") - logger.info("=" * 50) - - # Import and create the Clean Trading Dashboard - from web.clean_dashboard import CleanTradingDashboard - from core.standardized_data_provider import StandardizedDataProvider - from core.orchestrator import TradingOrchestrator - from core.trading_executor import TradingExecutor - - # Initialize components for the dashboard - config = get_config() - data_provider = StandardizedDataProvider() - - # Start real-time streaming for BOM caching (non-blocking) - try: - import threading - def start_streaming(): - import asyncio - asyncio.run(data_provider.start_real_time_streaming()) - - streaming_thread = threading.Thread(target=start_streaming, daemon=True) - streaming_thread.start() - logger.info("[SUCCESS] Real-time streaming thread started for dashboard") - except Exception as e: - logger.warning(f"[WARNING] Dashboard streaming setup failed: {e}") - - # Load model registry for enhanced features - try: - from models import get_model_registry - model_registry = {} # Use simple dict for now - except ImportError: - model_registry = {} - - # Initialize checkpoint management for dashboard - dashboard_checkpoint_manager = get_checkpoint_manager() - dashboard_training_integration = get_training_integration() - - # Create unified orchestrator for the dashboard - dashboard_orchestrator = TradingOrchestrator( - data_provider=data_provider, - enhanced_rl_training=True, - model_registry={} - ) - - trading_executor = TradingExecutor("config.yaml") - - # Create the clean trading dashboard with enhanced features - dashboard = CleanTradingDashboard( - data_provider=data_provider, - orchestrator=dashboard_orchestrator, - trading_executor=trading_executor - ) - - logger.info("Clean Trading Dashboard created successfully") - logger.info("Features: Live trading, COB visualization, ML pipeline monitoring, Position management") - logger.info("✅ Unified orchestrator with decision-making model and checkpoint management") - - # Run the dashboard server (COB integration will start automatically) - dashboard.run_server(host='127.0.0.1', port=port, debug=False) - - except Exception as e: - logger.error(f"Error starting main trading dashboard UI: {e}") - import traceback - logger.error(traceback.format_exc()) - -async def start_training_loop(orchestrator, trading_executor): - """Start the main training and monitoring loop with checkpoint management""" - logger.info("=" * 70) - logger.info("STARTING ENHANCED TRAINING LOOP WITH COB INTEGRATION") - logger.info("=" * 70) - - logger.info("Training loop function entered successfully") - - # Initialize checkpoint management for training loop - checkpoint_manager = get_checkpoint_manager() - training_integration = get_training_integration() - - # Training statistics for checkpoint management - training_stats = { - 'iteration_count': 0, - 'total_decisions': 0, - 'successful_trades': 0, - 'best_performance': 0.0, - 'last_checkpoint_iteration': 0 - } - - try: - # Start real-time processing (Basic orchestrator doesn't have this method) - logger.info("Checking for real-time processing capabilities...") - try: - if hasattr(orchestrator, 'start_realtime_processing'): - logger.info("Starting real-time processing...") - await orchestrator.start_realtime_processing() - logger.info("Real-time processing started") - else: - logger.info("Basic orchestrator - no real-time processing method available") - except Exception as e: - logger.warning(f"Real-time processing not available: {e}") - - logger.info("About to enter main training loop...") - - # Main training loop - iteration = 0 - while True: - iteration += 1 - training_stats['iteration_count'] = iteration - - logger.info(f"Training iteration {iteration}") - - # Make trading decisions using Basic orchestrator (single symbol method) - decisions = {} - symbols = ['ETH/USDT'] # Focus on ETH only for training - - for symbol in symbols: - try: - decision = await orchestrator.make_trading_decision(symbol) - decisions[symbol] = decision - except Exception as e: - logger.warning(f"Error making decision for {symbol}: {e}") - decisions[symbol] = None - - # Process decisions and collect training metrics - iteration_decisions = 0 - iteration_performance = 0.0 - - # Log decisions and performance - for symbol, decision in decisions.items(): - if decision: - iteration_decisions += 1 - logger.info(f"{symbol}: {decision.action} (confidence: {decision.confidence:.3f})") - - # Track performance for checkpoint management - iteration_performance += decision.confidence - - # Execute if confidence is high enough - if decision.confidence > 0.7: - logger.info(f"Executing {symbol}: {decision.action}") - training_stats['successful_trades'] += 1 - # trading_executor.execute_action(decision) - - # Update training statistics - training_stats['total_decisions'] += iteration_decisions - if iteration_performance > training_stats['best_performance']: - training_stats['best_performance'] = iteration_performance - - # Save checkpoint every 50 iterations or when performance improves significantly - should_save_checkpoint = ( - iteration % 50 == 0 or # Regular interval - iteration_performance > training_stats['best_performance'] * 1.1 or # 10% improvement - iteration - training_stats['last_checkpoint_iteration'] >= 100 # Force save every 100 iterations - ) - - if should_save_checkpoint: - try: - # Create performance metrics for checkpoint - performance_metrics = { - 'avg_confidence': iteration_performance / max(iteration_decisions, 1), - 'success_rate': training_stats['successful_trades'] / max(training_stats['total_decisions'], 1), - 'total_decisions': training_stats['total_decisions'], - 'iteration': iteration - } - - # Save orchestrator state (if it has models) - if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent: - saved = orchestrator.rl_agent.save_checkpoint(iteration_performance) - if saved: - logger.info(f"✅ RL Agent checkpoint saved at iteration {iteration}") - - if hasattr(orchestrator, 'cnn_model') and orchestrator.cnn_model: - # Simulate CNN checkpoint save - logger.info(f"✅ CNN Model training state saved at iteration {iteration}") - - if hasattr(orchestrator, 'extrema_trainer') and orchestrator.extrema_trainer: - saved = orchestrator.extrema_trainer.save_checkpoint() - if saved: - logger.info(f"✅ ExtremaTrainer checkpoint saved at iteration {iteration}") - - training_stats['last_checkpoint_iteration'] = iteration - logger.info(f"📊 Checkpoint management completed for iteration {iteration}") - - except Exception as e: - logger.warning(f"Checkpoint saving failed at iteration {iteration}: {e}") - - # Log performance metrics every 10 iterations - if iteration % 10 == 0: - metrics = orchestrator.get_performance_metrics() - logger.info(f"Performance metrics: {metrics}") - - # Log training statistics - logger.info(f"Training stats: {training_stats}") - - # Log checkpoint statistics - checkpoint_stats = checkpoint_manager.get_checkpoint_stats() - logger.info(f"Checkpoints: {checkpoint_stats['total_checkpoints']} total, " - f"{checkpoint_stats['total_size_mb']:.2f} MB") - - # Log COB integration status (Basic orchestrator doesn't have COB features) - symbols = getattr(orchestrator, 'symbols', ['ETH/USDT']) - if hasattr(orchestrator, 'latest_cob_features'): - for symbol in symbols: - cob_features = orchestrator.latest_cob_features.get(symbol) - cob_state = orchestrator.latest_cob_state.get(symbol) - if cob_features is not None: - logger.info(f"{symbol} COB: CNN features {cob_features.shape}, DQN state {cob_state.shape if cob_state is not None else 'None'}") - else: - logger.debug("Basic orchestrator - no COB integration features available") - - # Sleep between iterations - await asyncio.sleep(5) # 5 second intervals - - except KeyboardInterrupt: - logger.info("Training interrupted by user") - except Exception as e: - logger.error(f"Error in training loop: {e}") - import traceback - logger.error(traceback.format_exc()) - finally: - # Save final checkpoints before shutdown - try: - logger.info("Saving final checkpoints before shutdown...") - - if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent: - orchestrator.rl_agent.save_checkpoint(0.0, force_save=True) - logger.info("✅ Final RL Agent checkpoint saved") - - if hasattr(orchestrator, 'extrema_trainer') and orchestrator.extrema_trainer: - orchestrator.extrema_trainer.save_checkpoint(force_save=True) - logger.info("✅ Final ExtremaTrainer checkpoint saved") - - # Log final checkpoint statistics - final_stats = checkpoint_manager.get_checkpoint_stats() - logger.info(f"📊 Final checkpoint stats: {final_stats['total_checkpoints']} checkpoints, " - f"{final_stats['total_size_mb']:.2f} MB total") - - except Exception as e: - logger.warning(f"Error saving final checkpoints: {e}") - - # Stop real-time processing (Basic orchestrator doesn't have these methods) - try: - if hasattr(orchestrator, 'stop_realtime_processing'): - await orchestrator.stop_realtime_processing() - except Exception as e: - logger.warning(f"Error stopping real-time processing: {e}") - - try: - if hasattr(orchestrator, 'stop_cob_integration'): - await orchestrator.stop_cob_integration() - except Exception as e: - logger.warning(f"Error stopping COB integration: {e}") - logger.info("Training loop stopped with checkpoint management") - -async def main(): - """Main entry point with both training loop and web dashboard""" - parser = argparse.ArgumentParser(description='Streamlined Trading System - Training + Web Dashboard') - parser.add_argument('--symbol', type=str, default='ETH/USDT', - help='Primary trading symbol (default: ETH/USDT)') - parser.add_argument('--port', type=int, default=8050, - help='Web dashboard port (default: 8050)') - parser.add_argument('--debug', action='store_true', - help='Enable debug mode') - - args = parser.parse_args() - - # Setup logging and ensure directories exist - Path("logs").mkdir(exist_ok=True) - Path("NN/models/saved").mkdir(parents=True, exist_ok=True) - setup_safe_logging() - - try: - logger.info("=" * 70) - logger.info("STREAMLINED TRADING SYSTEM - TRAINING + MAIN DASHBOARD") - logger.info(f"Primary Symbol: {args.symbol}") - logger.info(f"Training Port: {args.port}") - logger.info(f"Main Trading Dashboard: http://127.0.0.1:{args.port}") - logger.info("2-Action System: BUY/SELL with intelligent position management") - logger.info("Always Invested: Learning to spot high risk/reward setups") - logger.info("Flow: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution") - logger.info("Main Dashboard: Live trading, RL monitoring, Position management") - logger.info("🔄 Checkpoint Management: Automatic training state persistence") - # logger.info("📊 W&B Integration: Optional experiment tracking") - logger.info("💾 Model Rotation: Keep best 5 checkpoints per model") - logger.info("=" * 70) - - # Start main trading dashboard UI in a separate thread - web_thread = Thread(target=lambda: start_web_ui(args.port), daemon=True) - web_thread.start() - logger.info("Main trading dashboard UI thread started") - - # Give web UI time to start - await asyncio.sleep(2) - - # Run the training loop (this will run indefinitely) - await run_web_dashboard() - - logger.info("[SUCCESS] Operation completed successfully!") - - except KeyboardInterrupt: - logger.info("System shutdown requested by user") - except Exception as e: - logger.error(f"Fatal error: {e}") - import traceback - logger.error(traceback.format_exc()) - return 1 - - return 0 - -if __name__ == "__main__": - sys.exit(asyncio.run(main())) \ No newline at end of file diff --git a/main_clean.py b/main_clean.py deleted file mode 100644 index 3965ef2..0000000 --- a/main_clean.py +++ /dev/null @@ -1,133 +0,0 @@ -#!/usr/bin/env python3 -""" -Clean Main Entry Point for Enhanced Trading Dashboard - -This is the main entry point that safely launches the clean dashboard -with proper error handling and optimized settings. -""" - -import os -import sys -import logging -import argparse -from typing import Optional - -# Add project root to path -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - -# Import core components -try: - from core.config import setup_logging - from core.data_provider import DataProvider - from core.orchestrator import TradingOrchestrator - from core.trading_executor import TradingExecutor - from web.clean_dashboard import create_clean_dashboard -except ImportError as e: - print(f"Error importing core modules: {e}") - sys.exit(1) - -logger = logging.getLogger(__name__) - -def create_safe_orchestrator() -> Optional[TradingOrchestrator]: - """Create orchestrator with safe CNN model handling""" - try: - # Create orchestrator with basic configuration (uses correct constructor parameters) - orchestrator = TradingOrchestrator( - enhanced_rl_training=False # Disable problematic training initially - ) - - logger.info("Trading orchestrator created successfully") - return orchestrator - - except Exception as e: - logger.error(f"Error creating orchestrator: {e}") - logger.info("Continuing without orchestrator - dashboard will run in view-only mode") - return None - -def create_safe_trading_executor() -> Optional[TradingExecutor]: - """Create trading executor with safe configuration""" - try: - # TradingExecutor only accepts config_path parameter - trading_executor = TradingExecutor(config_path="config.yaml") - - logger.info("Trading executor created successfully") - return trading_executor - - except Exception as e: - logger.error(f"Error creating trading executor: {e}") - logger.info("Continuing without trading executor - dashboard will be view-only") - return None - -def main(): - """Main entry point for clean dashboard""" - parser = argparse.ArgumentParser(description='Enhanced Trading Dashboard') - parser.add_argument('--port', type=int, default=8050, help='Dashboard port (default: 8050)') - parser.add_argument('--host', type=str, default='127.0.0.1', help='Dashboard host (default: 127.0.0.1)') - parser.add_argument('--debug', action='store_true', help='Enable debug mode') - parser.add_argument('--no-training', action='store_true', help='Disable ML training for stability') - - args = parser.parse_args() - - # Setup logging - try: - setup_logging() - logger.info("================================================================================") - logger.info("CLEAN ENHANCED TRADING DASHBOARD") - logger.info("================================================================================") - logger.info(f"Starting on http://{args.host}:{args.port}") - logger.info("Features: Real-time Charts, Trading Interface, Model Monitoring") - logger.info("================================================================================") - except Exception as e: - print(f"Error setting up logging: {e}") - # Continue without logging setup - - # Set environment variables for optimization - os.environ['ENABLE_REALTIME_CHARTS'] = '1' - if not args.no_training: - os.environ['ENABLE_NN_MODELS'] = '1' - - try: - # Create data provider - logger.info("Initializing data provider...") - data_provider = DataProvider(symbols=['ETH/USDT', 'BTC/USDT']) - - # Create orchestrator (with safe CNN handling) - logger.info("Initializing trading orchestrator...") - orchestrator = create_safe_orchestrator() - - # Create trading executor - logger.info("Initializing trading executor...") - trading_executor = create_safe_trading_executor() - - # Create and run dashboard - logger.info("Creating clean dashboard...") - dashboard = create_clean_dashboard( - data_provider=data_provider, - orchestrator=orchestrator, - trading_executor=trading_executor - ) - - # Start the dashboard server - logger.info(f"Starting dashboard server on http://{args.host}:{args.port}") - dashboard.run_server( - host=args.host, - port=args.port, - debug=args.debug - ) - - except KeyboardInterrupt: - logger.info("Dashboard stopped by user") - except Exception as e: - logger.error(f"Error running dashboard: {e}") - - # Try to provide helpful error message - if "model.fit" in str(e) or "CNN" in str(e): - logger.error("CNN model training error detected. Try running with --no-training flag") - logger.error("Command: python main_clean.py --no-training") - - sys.exit(1) - finally: - logger.info("Clean dashboard shutdown complete") - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/migrate_existing_models.py b/migrate_existing_models.py deleted file mode 100644 index 0950b37..0000000 --- a/migrate_existing_models.py +++ /dev/null @@ -1,204 +0,0 @@ -#!/usr/bin/env python3 -""" -Migrate Existing Models to Checkpoint System - -This script migrates existing model files to the new checkpoint system -and creates proper database metadata entries. -""" - -import os -import shutil -import logging -from datetime import datetime -from pathlib import Path -from utils.database_manager import get_database_manager, CheckpointMetadata -from utils.checkpoint_manager import save_checkpoint -from utils.text_logger import get_text_logger - -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def migrate_existing_models(): - """Migrate existing models to checkpoint system""" - print("=== Migrating Existing Models to Checkpoint System ===") - - db_manager = get_database_manager() - text_logger = get_text_logger() - - # Define model migrations - migrations = [ - { - 'model_name': 'enhanced_cnn', - 'model_type': 'cnn', - 'source_file': 'models/enhanced_cnn/ETH_USDT_cnn.pth', - 'performance_metrics': {'loss': 0.0187, 'accuracy': 0.92}, - 'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True} - }, - { - 'model_name': 'dqn_agent', - 'model_type': 'rl', - 'source_file': 'models/enhanced_rl/ETH_USDT_dqn_policy.pth', - 'performance_metrics': {'loss': 0.0234, 'reward': 145.2}, - 'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True, 'type': 'policy'} - }, - { - 'model_name': 'dqn_agent_target', - 'model_type': 'rl', - 'source_file': 'models/enhanced_rl/ETH_USDT_dqn_target.pth', - 'performance_metrics': {'loss': 0.0234, 'reward': 145.2}, - 'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True, 'type': 'target'} - } - ] - - migrated_count = 0 - - for migration in migrations: - source_path = Path(migration['source_file']) - - if not source_path.exists(): - logger.warning(f"Source file not found: {source_path}") - continue - - try: - # Create checkpoint directory - checkpoint_dir = Path("models/checkpoints") / migration['model_name'] - checkpoint_dir.mkdir(parents=True, exist_ok=True) - - # Create checkpoint filename - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - checkpoint_id = f"{migration['model_name']}_{timestamp}" - checkpoint_file = checkpoint_dir / f"{checkpoint_id}.pt" - - # Copy model file to checkpoint location - shutil.copy2(source_path, checkpoint_file) - logger.info(f"Copied {source_path} -> {checkpoint_file}") - - # Calculate file size - file_size_mb = checkpoint_file.stat().st_size / (1024 * 1024) - - # Create checkpoint metadata - metadata = CheckpointMetadata( - checkpoint_id=checkpoint_id, - model_name=migration['model_name'], - model_type=migration['model_type'], - timestamp=datetime.now(), - performance_metrics=migration['performance_metrics'], - training_metadata=migration['training_metadata'], - file_path=str(checkpoint_file), - file_size_mb=file_size_mb, - is_active=True - ) - - # Save to database - if db_manager.save_checkpoint_metadata(metadata): - logger.info(f"Saved checkpoint metadata: {checkpoint_id}") - - # Log to text file - text_logger.log_checkpoint_event( - model_name=migration['model_name'], - event_type="MIGRATED", - checkpoint_id=checkpoint_id, - details=f"from {source_path}, size={file_size_mb:.1f}MB" - ) - - migrated_count += 1 - else: - logger.error(f"Failed to save checkpoint metadata: {checkpoint_id}") - - except Exception as e: - logger.error(f"Failed to migrate {migration['model_name']}: {e}") - - print(f"\nMigration completed: {migrated_count} models migrated") - - # Show current checkpoint status - print("\n=== Current Checkpoint Status ===") - for model_name in ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target']: - checkpoints = db_manager.list_checkpoints(model_name) - if checkpoints: - print(f"{model_name}: {len(checkpoints)} checkpoints") - for checkpoint in checkpoints[:2]: # Show first 2 - print(f" - {checkpoint.checkpoint_id} ({checkpoint.file_size_mb:.1f}MB)") - else: - print(f"{model_name}: No checkpoints") - -def verify_checkpoint_system(): - """Verify the checkpoint system is working""" - print("\n=== Verifying Checkpoint System ===") - - db_manager = get_database_manager() - - # Test loading checkpoints - for model_name in ['dqn_agent', 'enhanced_cnn']: - metadata = db_manager.get_best_checkpoint_metadata(model_name) - if metadata: - file_exists = Path(metadata.file_path).exists() - print(f"{model_name}: ✅ Metadata found, File exists: {file_exists}") - if file_exists: - print(f" -> {metadata.checkpoint_id} ({metadata.file_size_mb:.1f}MB)") - else: - print(f" -> ERROR: File missing: {metadata.file_path}") - else: - print(f"{model_name}: ❌ No checkpoint metadata found") - -def create_test_checkpoint(): - """Create a test checkpoint to verify saving works""" - print("\n=== Testing Checkpoint Saving ===") - - try: - import torch - import torch.nn as nn - - # Create a simple test model - class TestModel(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(10, 1) - - def forward(self, x): - return self.linear(x) - - test_model = TestModel() - - # Save using the checkpoint system - from utils.checkpoint_manager import save_checkpoint - - result = save_checkpoint( - model=test_model, - model_name="test_model", - model_type="test", - performance_metrics={"loss": 0.1, "accuracy": 0.95}, - training_metadata={"test": True, "created": datetime.now().isoformat()} - ) - - if result: - print(f"✅ Test checkpoint saved successfully: {result.checkpoint_id}") - - # Verify it exists - db_manager = get_database_manager() - metadata = db_manager.get_best_checkpoint_metadata("test_model") - if metadata and Path(metadata.file_path).exists(): - print(f"✅ Test checkpoint verified: {metadata.file_path}") - - # Clean up test checkpoint - Path(metadata.file_path).unlink() - print("🧹 Test checkpoint cleaned up") - else: - print("❌ Test checkpoint verification failed") - else: - print("❌ Test checkpoint saving failed") - - except Exception as e: - print(f"❌ Test checkpoint creation failed: {e}") - -def main(): - """Main migration process""" - migrate_existing_models() - verify_checkpoint_system() - create_test_checkpoint() - - print("\n=== Migration Complete ===") - print("The checkpoint system should now work properly!") - print("Existing models have been migrated and the system is ready for use.") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/model_manager.py b/model_manager.py deleted file mode 100644 index b09ddfc..0000000 --- a/model_manager.py +++ /dev/null @@ -1,558 +0,0 @@ -""" -Enhanced Model Management System for Trading Dashboard - -This system provides: -- Automatic cleanup of old model checkpoints -- Best model tracking with performance metrics -- Configurable retention policies -- Startup model loading -- Performance-based model selection -""" - -import os -import json -import shutil -import logging -import torch -import glob -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Any -from dataclasses import dataclass, asdict -from pathlib import Path -import numpy as np - -logger = logging.getLogger(__name__) - -@dataclass -class ModelMetrics: - """Performance metrics for model evaluation""" - accuracy: float = 0.0 - profit_factor: float = 0.0 - win_rate: float = 0.0 - sharpe_ratio: float = 0.0 - max_drawdown: float = 0.0 - total_trades: int = 0 - avg_trade_duration: float = 0.0 - confidence_score: float = 0.0 - - def get_composite_score(self) -> float: - """Calculate composite performance score""" - # Weighted composite score - weights = { - 'profit_factor': 0.3, - 'sharpe_ratio': 0.25, - 'win_rate': 0.2, - 'accuracy': 0.15, - 'confidence_score': 0.1 - } - - # Normalize values to 0-1 range - normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0 - normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1 - normalized_win_rate = self.win_rate - normalized_accuracy = self.accuracy - normalized_confidence = self.confidence_score - - # Apply penalties for poor performance - drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown - - score = ( - weights['profit_factor'] * normalized_pf + - weights['sharpe_ratio'] * normalized_sharpe + - weights['win_rate'] * normalized_win_rate + - weights['accuracy'] * normalized_accuracy + - weights['confidence_score'] * normalized_confidence - ) * drawdown_penalty - - return min(max(score, 0), 1) - -@dataclass -class ModelInfo: - """Complete model information and metadata""" - model_type: str # 'cnn', 'rl', 'transformer' - model_name: str - file_path: str - creation_time: datetime - last_updated: datetime - file_size_mb: float - metrics: ModelMetrics - training_episodes: int = 0 - model_version: str = "1.0" - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary for JSON serialization""" - data = asdict(self) - data['creation_time'] = self.creation_time.isoformat() - data['last_updated'] = self.last_updated.isoformat() - return data - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo': - """Create from dictionary""" - data['creation_time'] = datetime.fromisoformat(data['creation_time']) - data['last_updated'] = datetime.fromisoformat(data['last_updated']) - data['metrics'] = ModelMetrics(**data['metrics']) - return cls(**data) - -class ModelManager: - """Enhanced model management system""" - - def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None): - self.base_dir = Path(base_dir) - self.config = config or self._get_default_config() - - # Model directories - self.models_dir = self.base_dir / "models" - self.nn_models_dir = self.base_dir / "NN" / "models" - self.registry_file = self.models_dir / "model_registry.json" - self.best_models_dir = self.models_dir / "best_models" - - # Create directories - self.best_models_dir.mkdir(parents=True, exist_ok=True) - - # Model registry - self.model_registry: Dict[str, ModelInfo] = {} - self._load_registry() - - logger.info(f"Model Manager initialized - Base: {self.base_dir}") - logger.info(f"Retention policy: Keep {self.config['max_models_per_type']} best models per type") - - def _get_default_config(self) -> Dict[str, Any]: - """Get default configuration""" - return { - 'max_models_per_type': 3, # Keep top 3 models per type - 'max_total_models': 10, # Maximum total models to keep - 'cleanup_frequency_hours': 24, # Cleanup every 24 hours - 'min_performance_threshold': 0.3, # Minimum composite score - 'max_checkpoint_age_days': 7, # Delete checkpoints older than 7 days - 'auto_cleanup_enabled': True, - 'backup_before_cleanup': True, - 'model_size_limit_mb': 100, # Individual model size limit - 'total_storage_limit_gb': 5.0 # Total storage limit - } - - def _load_registry(self): - """Load model registry from file""" - try: - if self.registry_file.exists(): - with open(self.registry_file, 'r') as f: - data = json.load(f) - self.model_registry = { - k: ModelInfo.from_dict(v) for k, v in data.items() - } - logger.info(f"Loaded {len(self.model_registry)} models from registry") - else: - logger.info("No existing model registry found") - except Exception as e: - logger.error(f"Error loading model registry: {e}") - self.model_registry = {} - - def _save_registry(self): - """Save model registry to file""" - try: - self.models_dir.mkdir(parents=True, exist_ok=True) - with open(self.registry_file, 'w') as f: - data = {k: v.to_dict() for k, v in self.model_registry.items()} - json.dump(data, f, indent=2, default=str) - logger.info(f"Saved registry with {len(self.model_registry)} models") - except Exception as e: - logger.error(f"Error saving model registry: {e}") - - def cleanup_all_existing_models(self, confirm: bool = False) -> Dict[str, Any]: - """ - Clean up all existing model files and prepare for 2-action system training - - Args: - confirm: If True, perform the cleanup. If False, return what would be cleaned - - Returns: - Dict with cleanup statistics - """ - cleanup_stats = { - 'files_found': 0, - 'files_deleted': 0, - 'directories_cleaned': 0, - 'space_freed_mb': 0.0, - 'errors': [] - } - - # Model file patterns for both 2-action and legacy 3-action systems - model_patterns = [ - "**/*.pt", "**/*.pth", "**/*.h5", "**/*.pkl", "**/*.joblib", "**/*.model", - "**/checkpoint_*", "**/model_*", "**/cnn_*", "**/dqn_*", "**/rl_*" - ] - - # Directories to clean - model_directories = [ - "models/saved", - "NN/models/saved", - "NN/models/saved/checkpoints", - "NN/models/saved/realtime_checkpoints", - "NN/models/saved/realtime_ticks_checkpoints", - "model_backups" - ] - - try: - # Scan for files to be cleaned - for directory in model_directories: - dir_path = Path(self.base_dir) / directory - if dir_path.exists(): - for pattern in model_patterns: - for file_path in dir_path.glob(pattern): - if file_path.is_file(): - cleanup_stats['files_found'] += 1 - file_size = file_path.stat().st_size / (1024 * 1024) # MB - cleanup_stats['space_freed_mb'] += file_size - - if confirm: - try: - file_path.unlink() - cleanup_stats['files_deleted'] += 1 - logger.info(f"Deleted model file: {file_path}") - except Exception as e: - cleanup_stats['errors'].append(f"Failed to delete {file_path}: {e}") - - # Clean up empty checkpoint directories - for directory in model_directories: - dir_path = Path(self.base_dir) / directory - if dir_path.exists(): - for subdir in dir_path.rglob("*"): - if subdir.is_dir() and not any(subdir.iterdir()): - if confirm: - try: - subdir.rmdir() - cleanup_stats['directories_cleaned'] += 1 - logger.info(f"Removed empty directory: {subdir}") - except Exception as e: - cleanup_stats['errors'].append(f"Failed to remove directory {subdir}: {e}") - - if confirm: - # Clear the registry for fresh start with 2-action system - self.model_registry = { - 'models': {}, - 'metadata': { - 'last_updated': datetime.now().isoformat(), - 'total_models': 0, - 'system_type': '2_action', # Mark as 2-action system - 'action_space': ['SELL', 'BUY'], - 'version': '2.0' - } - } - self._save_registry() - - logger.info("=" * 60) - logger.info("MODEL CLEANUP COMPLETED - 2-ACTION SYSTEM READY") - logger.info(f"Files deleted: {cleanup_stats['files_deleted']}") - logger.info(f"Space freed: {cleanup_stats['space_freed_mb']:.2f} MB") - logger.info(f"Directories cleaned: {cleanup_stats['directories_cleaned']}") - logger.info("Registry reset for 2-action system (BUY/SELL)") - logger.info("Ready for fresh training with intelligent position management") - logger.info("=" * 60) - else: - logger.info("=" * 60) - logger.info("MODEL CLEANUP PREVIEW - 2-ACTION SYSTEM MIGRATION") - logger.info(f"Files to delete: {cleanup_stats['files_found']}") - logger.info(f"Space to free: {cleanup_stats['space_freed_mb']:.2f} MB") - logger.info("Run with confirm=True to perform cleanup") - logger.info("=" * 60) - - except Exception as e: - cleanup_stats['errors'].append(f"Cleanup error: {e}") - logger.error(f"Error during model cleanup: {e}") - - return cleanup_stats - - def register_model(self, model_path: str, model_type: str, metrics: Optional[ModelMetrics] = None) -> str: - """ - Register a new model in the 2-action system - - Args: - model_path: Path to the model file - model_type: Type of model ('cnn', 'rl', 'transformer') - metrics: Performance metrics - - Returns: - str: Unique model name/ID - """ - if not Path(model_path).exists(): - raise FileNotFoundError(f"Model file not found: {model_path}") - - # Generate unique model name - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - model_name = f"{model_type}_2action_{timestamp}" - - # Get file info - file_path = Path(model_path) - file_size_mb = file_path.stat().st_size / (1024 * 1024) - - # Default metrics for 2-action system - if metrics is None: - metrics = ModelMetrics( - accuracy=0.0, - profit_factor=1.0, - win_rate=0.5, - sharpe_ratio=0.0, - max_drawdown=0.0, - confidence_score=0.5 - ) - - # Create model info - model_info = ModelInfo( - model_type=model_type, - model_name=model_name, - file_path=str(file_path.absolute()), - creation_time=datetime.now(), - last_updated=datetime.now(), - file_size_mb=file_size_mb, - metrics=metrics, - model_version="2.0" # 2-action system version - ) - - # Add to registry - self.model_registry['models'][model_name] = model_info.to_dict() - self.model_registry['metadata']['total_models'] = len(self.model_registry['models']) - self.model_registry['metadata']['last_updated'] = datetime.now().isoformat() - self.model_registry['metadata']['system_type'] = '2_action' - self.model_registry['metadata']['action_space'] = ['SELL', 'BUY'] - - self._save_registry() - - # Cleanup old models if necessary - self._cleanup_models_by_type(model_type) - - logger.info(f"Registered 2-action model: {model_name}") - logger.info(f"Model type: {model_type}, Size: {file_size_mb:.2f} MB") - logger.info(f"Performance score: {metrics.get_composite_score():.4f}") - - return model_name - - def _should_keep_model(self, model_info: ModelInfo) -> bool: - """Determine if model should be kept based on performance""" - score = model_info.metrics.get_composite_score() - - # Check minimum threshold - if score < self.config['min_performance_threshold']: - return False - - # Check size limit - if model_info.file_size_mb > self.config['model_size_limit_mb']: - logger.warning(f"Model too large: {model_info.file_size_mb:.1f}MB > {self.config['model_size_limit_mb']}MB") - return False - - # Check if better than existing models of same type - existing_models = self.get_models_by_type(model_info.model_type) - if len(existing_models) >= self.config['max_models_per_type']: - # Find worst performing model - worst_model = min(existing_models.values(), key=lambda m: m.metrics.get_composite_score()) - if score <= worst_model.metrics.get_composite_score(): - return False - - return True - - def _cleanup_models_by_type(self, model_type: str): - """Cleanup old models of specific type, keeping only the best ones""" - models_of_type = self.get_models_by_type(model_type) - max_keep = self.config['max_models_per_type'] - - if len(models_of_type) <= max_keep: - return - - # Sort by performance score - sorted_models = sorted( - models_of_type.items(), - key=lambda x: x[1].metrics.get_composite_score(), - reverse=True - ) - - # Keep only the best models - models_to_keep = sorted_models[:max_keep] - models_to_remove = sorted_models[max_keep:] - - for model_name, model_info in models_to_remove: - try: - # Remove file - model_path = Path(model_info.file_path) - if model_path.exists(): - model_path.unlink() - - # Remove from registry - del self.model_registry[model_name] - - logger.info(f"Removed old model: {model_name} (Score: {model_info.metrics.get_composite_score():.3f})") - - except Exception as e: - logger.error(f"Error removing model {model_name}: {e}") - - def get_models_by_type(self, model_type: str) -> Dict[str, ModelInfo]: - """Get all models of a specific type""" - return { - name: info for name, info in self.model_registry.items() - if info.model_type == model_type - } - - def get_best_model(self, model_type: str) -> Optional[ModelInfo]: - """Get the best performing model of a specific type""" - models_of_type = self.get_models_by_type(model_type) - - if not models_of_type: - return None - - return max(models_of_type.values(), key=lambda m: m.metrics.get_composite_score()) - - def load_best_models(self) -> Dict[str, Any]: - """Load the best models for each type""" - loaded_models = {} - - for model_type in ['cnn', 'rl', 'transformer']: - best_model = self.get_best_model(model_type) - - if best_model: - try: - model_path = Path(best_model.file_path) - if model_path.exists(): - # Load the model - model_data = torch.load(model_path, map_location='cpu') - loaded_models[model_type] = { - 'model': model_data, - 'info': best_model, - 'path': str(model_path) - } - logger.info(f"Loaded best {model_type} model: {best_model.model_name} " - f"(Score: {best_model.metrics.get_composite_score():.3f})") - else: - logger.warning(f"Best {model_type} model file not found: {model_path}") - except Exception as e: - logger.error(f"Error loading {model_type} model: {e}") - else: - logger.info(f"No {model_type} model available") - - return loaded_models - - def update_model_performance(self, model_name: str, metrics: ModelMetrics): - """Update performance metrics for a model""" - if model_name in self.model_registry: - self.model_registry[model_name].metrics = metrics - self.model_registry[model_name].last_updated = datetime.now() - self._save_registry() - - logger.info(f"Updated metrics for {model_name}: Score {metrics.get_composite_score():.3f}") - else: - logger.warning(f"Model {model_name} not found in registry") - - def get_storage_stats(self) -> Dict[str, Any]: - """Get storage usage statistics""" - total_size_mb = 0 - model_count = 0 - - for model_info in self.model_registry.values(): - total_size_mb += model_info.file_size_mb - model_count += 1 - - # Check actual storage usage - actual_size_mb = 0 - if self.best_models_dir.exists(): - actual_size_mb = sum( - f.stat().st_size for f in self.best_models_dir.rglob('*') if f.is_file() - ) / 1024 / 1024 - - return { - 'total_models': model_count, - 'registered_size_mb': total_size_mb, - 'actual_size_mb': actual_size_mb, - 'storage_limit_gb': self.config['total_storage_limit_gb'], - 'utilization_percent': (actual_size_mb / 1024) / self.config['total_storage_limit_gb'] * 100, - 'models_by_type': { - model_type: len(self.get_models_by_type(model_type)) - for model_type in ['cnn', 'rl', 'transformer'] - } - } - - def get_model_leaderboard(self) -> List[Dict[str, Any]]: - """Get model performance leaderboard""" - leaderboard = [] - - for model_name, model_info in self.model_registry.items(): - leaderboard.append({ - 'name': model_name, - 'type': model_info.model_type, - 'score': model_info.metrics.get_composite_score(), - 'profit_factor': model_info.metrics.profit_factor, - 'win_rate': model_info.metrics.win_rate, - 'sharpe_ratio': model_info.metrics.sharpe_ratio, - 'size_mb': model_info.file_size_mb, - 'age_days': (datetime.now() - model_info.creation_time).days, - 'last_updated': model_info.last_updated.strftime('%Y-%m-%d %H:%M') - }) - - # Sort by score - leaderboard.sort(key=lambda x: x['score'], reverse=True) - - return leaderboard - - def cleanup_checkpoints(self) -> Dict[str, Any]: - """Clean up old checkpoint files""" - cleanup_summary = { - 'deleted_files': 0, - 'freed_space_mb': 0, - 'errors': [] - } - - cutoff_date = datetime.now() - timedelta(days=self.config['max_checkpoint_age_days']) - - # Search for checkpoint files - checkpoint_patterns = [ - "**/checkpoint_*.pt", - "**/model_*.pt", - "**/*checkpoint*", - "**/epoch_*.pt" - ] - - for pattern in checkpoint_patterns: - for file_path in self.base_dir.rglob(pattern): - if "best_models" not in str(file_path) and file_path.is_file(): - try: - file_time = datetime.fromtimestamp(file_path.stat().st_mtime) - if file_time < cutoff_date: - size_mb = file_path.stat().st_size / 1024 / 1024 - file_path.unlink() - cleanup_summary['deleted_files'] += 1 - cleanup_summary['freed_space_mb'] += size_mb - except Exception as e: - error_msg = f"Error deleting checkpoint {file_path}: {e}" - logger.error(error_msg) - cleanup_summary['errors'].append(error_msg) - - if cleanup_summary['deleted_files'] > 0: - logger.info(f"Checkpoint cleanup: Deleted {cleanup_summary['deleted_files']} files, " - f"freed {cleanup_summary['freed_space_mb']:.1f}MB") - - return cleanup_summary - -def create_model_manager() -> ModelManager: - """Create and initialize the global model manager""" - return ModelManager() - -# Example usage -if __name__ == "__main__": - # Configure logging - logging.basicConfig(level=logging.INFO) - - # Create model manager - manager = ModelManager() - - # Clean up all existing models (with confirmation) - print("WARNING: This will delete ALL existing models!") - print("Type 'CONFIRM' to proceed:") - user_input = input().strip() - - if user_input == "CONFIRM": - cleanup_result = manager.cleanup_all_existing_models(confirm=True) - print(f"\nCleanup complete:") - print(f"- Deleted {cleanup_result['files_deleted']} files") - print(f"- Freed {cleanup_result['space_freed_mb']:.1f}MB of space") - print(f"- Cleaned {cleanup_result['directories_cleaned']} directories") - - if cleanup_result['errors']: - print(f"- {len(cleanup_result['errors'])} errors occurred") - else: - print("Cleanup cancelled") \ No newline at end of file diff --git a/position_sync_enhancement.py b/position_sync_enhancement.py deleted file mode 100644 index 2020d94..0000000 --- a/position_sync_enhancement.py +++ /dev/null @@ -1,193 +0,0 @@ -#!/usr/bin/env python3 -""" -Position Sync Enhancement - Fix P&L and Win Rate Calculation - -This script enhances the position synchronization and P&L calculation -to properly account for leverage in the trading system. -""" - -import os -import sys -import logging -from pathlib import Path -from datetime import datetime - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import get_config, setup_logging -from core.trading_executor import TradingExecutor, TradeRecord - -# Setup logging -setup_logging() -logger = logging.getLogger(__name__) - -def analyze_trade_records(): - """Analyze trade records for P&L calculation issues""" - logger.info("Analyzing trade records for P&L calculation issues...") - - # Initialize trading executor - trading_executor = TradingExecutor() - - # Get trade records - trade_records = trading_executor.trade_records - - if not trade_records: - logger.warning("No trade records found.") - return - - logger.info(f"Found {len(trade_records)} trade records.") - - # Analyze P&L calculation - total_pnl = 0.0 - total_gross_pnl = 0.0 - total_fees = 0.0 - winning_trades = 0 - losing_trades = 0 - breakeven_trades = 0 - - for trade in trade_records: - # Calculate correct P&L with leverage - entry_value = trade.entry_price * trade.quantity - exit_value = trade.exit_price * trade.quantity - - if trade.side == 'LONG': - gross_pnl = (exit_value - entry_value) * trade.leverage - else: # SHORT - gross_pnl = (entry_value - exit_value) * trade.leverage - - # Calculate fees - fees = (entry_value + exit_value) * 0.001 # 0.1% fee on both entry and exit - - # Calculate net P&L - net_pnl = gross_pnl - fees - - # Compare with stored values - pnl_diff = abs(net_pnl - trade.pnl) - if pnl_diff > 0.01: # More than 1 cent difference - logger.warning(f"P&L calculation issue detected for trade {trade.entry_time}:") - logger.warning(f" Stored P&L: ${trade.pnl:.2f}") - logger.warning(f" Calculated P&L: ${net_pnl:.2f}") - logger.warning(f" Difference: ${pnl_diff:.2f}") - logger.warning(f" Leverage used: {trade.leverage}x") - - # Update statistics - total_pnl += net_pnl - total_gross_pnl += gross_pnl - total_fees += fees - - if net_pnl > 0.01: # More than 1 cent profit - winning_trades += 1 - elif net_pnl < -0.01: # More than 1 cent loss - losing_trades += 1 - else: - breakeven_trades += 1 - - # Calculate win rate - total_trades = winning_trades + losing_trades + breakeven_trades - win_rate = (winning_trades / total_trades * 100) if total_trades > 0 else 0.0 - - logger.info("\nTrade Analysis Results:") - logger.info(f" Total trades: {total_trades}") - logger.info(f" Winning trades: {winning_trades}") - logger.info(f" Losing trades: {losing_trades}") - logger.info(f" Breakeven trades: {breakeven_trades}") - logger.info(f" Win rate: {win_rate:.1f}%") - logger.info(f" Total P&L: ${total_pnl:.2f}") - logger.info(f" Total gross P&L: ${total_gross_pnl:.2f}") - logger.info(f" Total fees: ${total_fees:.2f}") - - # Check for leverage issues - leverage_issues = False - for trade in trade_records: - if trade.leverage <= 1.0: - leverage_issues = True - logger.warning(f"Low leverage detected: {trade.leverage}x for trade at {trade.entry_time}") - - if leverage_issues: - logger.warning("\nLeverage issues detected. Consider fixing the leverage calculation.") - logger.info("Recommended fix: Ensure leverage is properly set in the trading executor.") - else: - logger.info("\nNo leverage issues detected.") - -def fix_leverage_calculation(): - """Fix leverage calculation in the trading executor""" - logger.info("Fixing leverage calculation in the trading executor...") - - # Initialize trading executor - trading_executor = TradingExecutor() - - # Get current leverage - current_leverage = trading_executor.current_leverage - logger.info(f"Current leverage setting: {current_leverage}x") - - # Check if leverage is properly set - if current_leverage <= 1: - logger.warning("Leverage is set too low. Updating to 20x...") - trading_executor.current_leverage = 20 - logger.info(f"Updated leverage to {trading_executor.current_leverage}x") - else: - logger.info("Leverage is already set correctly.") - - # Update trade records with correct leverage - updated_count = 0 - for i, trade in enumerate(trading_executor.trade_records): - if trade.leverage <= 1.0: - # Create updated trade record - updated_trade = TradeRecord( - symbol=trade.symbol, - side=trade.side, - quantity=trade.quantity, - entry_price=trade.entry_price, - exit_price=trade.exit_price, - entry_time=trade.entry_time, - exit_time=trade.exit_time, - pnl=trade.pnl, - fees=trade.fees, - confidence=trade.confidence, - hold_time_seconds=trade.hold_time_seconds, - leverage=trading_executor.current_leverage, # Use current leverage setting - position_size_usd=trade.position_size_usd, - gross_pnl=trade.gross_pnl, - net_pnl=trade.net_pnl - ) - - # Recalculate P&L with correct leverage - entry_value = updated_trade.entry_price * updated_trade.quantity - exit_value = updated_trade.exit_price * updated_trade.quantity - - if updated_trade.side == 'LONG': - updated_trade.gross_pnl = (exit_value - entry_value) * updated_trade.leverage - else: # SHORT - updated_trade.gross_pnl = (entry_value - exit_value) * updated_trade.leverage - - # Recalculate fees - updated_trade.fees = (entry_value + exit_value) * 0.001 # 0.1% fee on both entry and exit - - # Recalculate net P&L - updated_trade.net_pnl = updated_trade.gross_pnl - updated_trade.fees - updated_trade.pnl = updated_trade.net_pnl - - # Update trade record - trading_executor.trade_records[i] = updated_trade - updated_count += 1 - - logger.info(f"Updated {updated_count} trade records with correct leverage.") - - # Save updated trade records - # Note: This is a placeholder. In a real implementation, you would need to - # persist the updated trade records to storage. - logger.info("Changes will take effect on next dashboard restart.") - - return updated_count > 0 - -if __name__ == "__main__": - logger.info("=" * 70) - logger.info("POSITION SYNC ENHANCEMENT") - logger.info("=" * 70) - - if len(sys.argv) > 1 and sys.argv[1] == 'fix': - fix_leverage_calculation() - else: - analyze_trade_records() \ No newline at end of file diff --git a/read_logs.py b/read_logs.py deleted file mode 100644 index e120f51..0000000 --- a/read_logs.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python -""" -Log Reader Utility - -This script provides a convenient way to read and filter log files during -development. -""" - -import os -import sys -import time -import argparse -from datetime import datetime - -def parse_args(): - """Parse command line arguments""" - parser = argparse.ArgumentParser(description='Read and filter log files') - parser.add_argument('--file', type=str, help='Log file to read (defaults to most recent .log file)') - parser.add_argument('--tail', type=int, default=50, help='Number of lines to show from the end') - parser.add_argument('--follow', '-f', action='store_true', help='Follow the file as it grows') - parser.add_argument('--filter', type=str, help='Only show lines containing this string') - parser.add_argument('--list', action='store_true', help='List all log files sorted by modification time') - return parser.parse_args() - -def get_most_recent_log(): - """Find the most recently modified log file""" - log_files = [f for f in os.listdir('.') if f.endswith('.log')] - if not log_files: - print("No log files found in current directory.") - sys.exit(1) - - # Sort by modification time (newest first) - log_files.sort(key=lambda x: os.path.getmtime(x), reverse=True) - return log_files[0] - -def list_log_files(): - """List all log files sorted by modification time""" - log_files = [f for f in os.listdir('.') if f.endswith('.log')] - if not log_files: - print("No log files found in current directory.") - sys.exit(1) - - # Sort by modification time (newest first) - log_files.sort(key=lambda x: os.path.getmtime(x), reverse=True) - - print(f"{'LAST MODIFIED':<20} {'SIZE':<10} FILENAME") - print("-" * 60) - for log_file in log_files: - mtime = datetime.fromtimestamp(os.path.getmtime(log_file)) - size = os.path.getsize(log_file) - size_str = f"{size / 1024:.1f} KB" if size > 1024 else f"{size} B" - print(f"{mtime.strftime('%Y-%m-%d %H:%M:%S'):<20} {size_str:<10} {log_file}") - -def read_log_tail(file_path, num_lines, filter_text=None): - """Read the last N lines of a file""" - try: - with open(file_path, 'r', encoding='utf-8') as f: - # Read all lines (inefficient but simple) - lines = f.readlines() - - # Filter if needed - if filter_text: - lines = [line for line in lines if filter_text in line] - - # Get the last N lines - last_lines = lines[-num_lines:] if len(lines) > num_lines else lines - return last_lines - except Exception as e: - print(f"Error reading file: {str(e)}") - sys.exit(1) - -def follow_log(file_path, filter_text=None): - """Follow the log file as it grows (like tail -f)""" - try: - with open(file_path, 'r', encoding='utf-8') as f: - # Go to the end of the file - f.seek(0, 2) - - while True: - line = f.readline() - if line: - if not filter_text or filter_text in line: - # Remove newlines at the end to avoid double spacing - print(line.rstrip()) - else: - time.sleep(0.1) # Sleep briefly to avoid consuming CPU - except KeyboardInterrupt: - print("\nLog reading stopped.") - except Exception as e: - print(f"Error following file: {str(e)}") - sys.exit(1) - -def main(): - """Main function""" - args = parse_args() - - # List all log files if requested - if args.list: - list_log_files() - return - - # Determine which file to read - file_path = args.file - if not file_path: - file_path = get_most_recent_log() - print(f"Reading most recent log file: {file_path}") - - # Follow mode (like tail -f) - if args.follow: - print(f"Following {file_path} (Press Ctrl+C to stop)...") - # First print the tail - for line in read_log_tail(file_path, args.tail, args.filter): - print(line.rstrip()) - print("-" * 80) - print("Waiting for new content...") - # Then follow - follow_log(file_path, args.filter) - else: - # Just print the tail - for line in read_log_tail(file_path, args.tail, args.filter): - print(line.rstrip()) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/reset_db_manager.py b/reset_db_manager.py deleted file mode 100644 index cec1b0a..0000000 --- a/reset_db_manager.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python3 -""" -Script to reset the database manager instance to trigger migration in running system -""" - -import sys -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from utils.database_manager import reset_database_manager -import logging - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def main(): - """Reset the database manager to trigger migration""" - try: - logger.info("Resetting database manager to trigger migration...") - reset_database_manager() - logger.info("✅ Database manager reset successfully!") - logger.info("The migration will run automatically on the next database access.") - return True - except Exception as e: - logger.error(f"❌ Failed to reset database manager: {e}") - return False - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/reset_models_and_fix_mapping.py b/reset_models_and_fix_mapping.py deleted file mode 100644 index 59871b2..0000000 --- a/reset_models_and_fix_mapping.py +++ /dev/null @@ -1,204 +0,0 @@ -#!/usr/bin/env python3 -""" -Reset Models and Fix Action Mapping - -This script: -1. Deletes existing model files -2. Creates new model files with consistent action mapping -3. Updates action mapping in key files -""" - -import os -import shutil -import logging -import sys -import torch -import numpy as np -from datetime import datetime - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def ensure_directory(directory): - """Ensure directory exists""" - if not os.path.exists(directory): - os.makedirs(directory) - logger.info(f"Created directory: {directory}") - -def delete_directory_contents(directory): - """Delete all files in a directory""" - if os.path.exists(directory): - for filename in os.listdir(directory): - file_path = os.path.join(directory, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - logger.info(f"Deleted: {file_path}") - except Exception as e: - logger.error(f"Failed to delete {file_path}. Reason: {e}") - -def create_backup_directory(): - """Create a backup directory with timestamp""" - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - backup_dir = f"models/backup_{timestamp}" - ensure_directory(backup_dir) - return backup_dir - -def backup_models(): - """Backup existing models""" - backup_dir = create_backup_directory() - - # List of model directories to backup - model_dirs = [ - "models/enhanced_rl", - "models/enhanced_cnn", - "models/realtime_rl_cob", - "models/rl", - "models/cnn" - ] - - for model_dir in model_dirs: - if os.path.exists(model_dir): - dest_dir = os.path.join(backup_dir, os.path.basename(model_dir)) - ensure_directory(dest_dir) - - # Copy files - for filename in os.listdir(model_dir): - file_path = os.path.join(model_dir, filename) - if os.path.isfile(file_path): - shutil.copy2(file_path, dest_dir) - logger.info(f"Backed up: {file_path} to {dest_dir}") - - return backup_dir - -def initialize_dqn_model(): - """Initialize a new DQN model with consistent action mapping""" - try: - # Import necessary modules - sys.path.append(os.path.dirname(os.path.abspath(__file__))) - from NN.models.dqn_agent import DQNAgent - - # Define state shape for BTC and ETH - state_shape = (100,) # Default feature dimension - - # Create models directory - ensure_directory("models/enhanced_rl") - - # Initialize DQN with 3 actions (BUY=0, SELL=1, HOLD=2) - dqn_btc = DQNAgent( - state_shape=state_shape, - n_actions=3, # BUY=0, SELL=1, HOLD=2 - learning_rate=0.001, - epsilon=0.5, # Start with moderate exploration - epsilon_min=0.01, - epsilon_decay=0.995, - model_name="BTC_USDT_dqn" - ) - - dqn_eth = DQNAgent( - state_shape=state_shape, - n_actions=3, # BUY=0, SELL=1, HOLD=2 - learning_rate=0.001, - epsilon=0.5, # Start with moderate exploration - epsilon_min=0.01, - epsilon_decay=0.995, - model_name="ETH_USDT_dqn" - ) - - # Save initial models - torch.save(dqn_btc.policy_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_policy.pth") - torch.save(dqn_btc.target_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_target.pth") - torch.save(dqn_eth.policy_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_policy.pth") - torch.save(dqn_eth.target_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_target.pth") - - logger.info("Initialized new DQN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)") - return True - except Exception as e: - logger.error(f"Failed to initialize DQN models: {e}") - return False - -def initialize_cnn_model(): - """Initialize a new CNN model with consistent action mapping""" - try: - # Import necessary modules - sys.path.append(os.path.dirname(os.path.abspath(__file__))) - from NN.models.enhanced_cnn import EnhancedCNN - - # Define input dimension and number of actions - input_dim = 100 # Default feature dimension - n_actions = 3 # BUY=0, SELL=1, HOLD=2 - - # Create models directory - ensure_directory("models/enhanced_cnn") - - # Initialize CNN models for BTC and ETH - cnn_btc = EnhancedCNN(input_dim, n_actions) - cnn_eth = EnhancedCNN(input_dim, n_actions) - - # Save initial models - torch.save(cnn_btc.state_dict(), "models/enhanced_cnn/BTC_USDT_cnn.pth") - torch.save(cnn_eth.state_dict(), "models/enhanced_cnn/ETH_USDT_cnn.pth") - - logger.info("Initialized new CNN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)") - return True - except Exception as e: - logger.error(f"Failed to initialize CNN models: {e}") - return False - -def initialize_realtime_rl_model(): - """Initialize a new realtime RL model with consistent action mapping""" - try: - # Create models directory - ensure_directory("models/realtime_rl_cob") - - # Create empty model files to ensure directory is not empty - with open("models/realtime_rl_cob/README.txt", "w") as f: - f.write("Realtime RL COB models will be saved here.\n") - f.write("Action mapping: BUY=0, SELL=1, HOLD=2\n") - - logger.info("Initialized realtime RL model directory") - return True - except Exception as e: - logger.error(f"Failed to initialize realtime RL models: {e}") - return False - -def main(): - """Main function to reset models and fix action mapping""" - logger.info("Starting model reset and action mapping fix") - - # Backup existing models - backup_dir = backup_models() - logger.info(f"Backed up existing models to {backup_dir}") - - # Delete existing model files - model_dirs = [ - "models/enhanced_rl", - "models/enhanced_cnn", - "models/realtime_rl_cob" - ] - - for model_dir in model_dirs: - delete_directory_contents(model_dir) - logger.info(f"Deleted contents of {model_dir}") - - # Initialize new models with consistent action mapping - dqn_success = initialize_dqn_model() - cnn_success = initialize_cnn_model() - rl_success = initialize_realtime_rl_model() - - if dqn_success and cnn_success and rl_success: - logger.info("Successfully reset models and fixed action mapping") - logger.info("New action mapping: BUY=0, SELL=1, HOLD=2") - else: - logger.error("Failed to reset models and fix action mapping") - - logger.info("Model reset complete") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/run_clean_dashboard.py b/run_clean_dashboard.py index 40202b5..d1b38b3 100644 --- a/run_clean_dashboard.py +++ b/run_clean_dashboard.py @@ -1,325 +1,26 @@ -#!/usr/bin/env python3 -""" -Run Clean Trading Dashboard with Full Training Pipeline -Integrated system with both training loop and clean web dashboard -""" - -import os -# Fix OpenMP library conflicts before importing other modules -os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' -os.environ['OMP_NUM_THREADS'] = '4' - -# Fix matplotlib backend issue - set non-interactive backend before any imports -import matplotlib -matplotlib.use('Agg') # Use non-interactive Agg backend - -import asyncio import logging -import sys -import platform -from safe_logging import setup_safe_logging -import threading -import time -from pathlib import Path - -# Windows-specific async event loop configuration -if platform.system() == "Windows": - # Use ProactorEventLoop on Windows for better I/O handling - asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import get_config, setup_logging -from core.data_provider import DataProvider - -# Import checkpoint management -from utils.checkpoint_manager import get_checkpoint_manager -from utils.training_integration import get_training_integration - -# Setup logging -setup_safe_logging() -logger = logging.getLogger(__name__) - -async def start_training_pipeline(orchestrator, trading_executor): - """Start the training pipeline in the background with comprehensive error handling""" - logger.info("=" * 70) - logger.info("STARTING TRAINING PIPELINE WITH CLEAN DASHBOARD") - logger.info("=" * 70) - - # Set up async exception handler - def handle_async_exception(loop, context): - """Handle uncaught async exceptions""" - exception = context.get('exception') - if exception: - logger.error(f"Uncaught async exception: {exception}") - logger.error(f"Context: {context}") - else: - logger.error(f"Async error: {context.get('message', 'Unknown error')}") - - # Get current event loop and set exception handler - loop = asyncio.get_running_loop() - loop.set_exception_handler(handle_async_exception) - - # Initialize checkpoint management - checkpoint_manager = get_checkpoint_manager() - training_integration = get_training_integration() - - # Training statistics - training_stats = { - 'iteration_count': 0, - 'total_decisions': 0, - 'successful_trades': 0, - 'best_performance': 0.0, - 'last_checkpoint_iteration': 0 - } - - try: - # Start real-time processing with error handling - try: - if hasattr(orchestrator, 'start_realtime_processing'): - await orchestrator.start_realtime_processing() - logger.info("Real-time processing started") - except Exception as e: - logger.error(f"Error starting real-time processing: {e}") - - # Start COB integration with error handling - try: - if hasattr(orchestrator, 'start_cob_integration'): - await orchestrator.start_cob_integration() - logger.info("COB integration started - 5-minute data matrix active") - else: - logger.info("COB integration not available") - except Exception as e: - logger.error(f"Error starting COB integration: {e}") - - # Main training loop - iteration = 0 - last_checkpoint_time = time.time() - - while True: - try: - iteration += 1 - training_stats['iteration_count'] = iteration - - # Get symbols to process - symbols = orchestrator.symbols if hasattr(orchestrator, 'symbols') else ['ETH/USDT'] - - # Process each symbol - for symbol in symbols: - try: - # Make trading decision (this triggers model training) - decision = await orchestrator.make_trading_decision(symbol) - if decision: - training_stats['total_decisions'] += 1 - logger.debug(f"[{symbol}] Decision: {decision.action} @ {decision.confidence:.1%}") - - except Exception as e: - logger.warning(f"Error processing {symbol}: {e}") - - # Status logging every 100 iterations - if iteration % 100 == 0: - current_time = time.time() - elapsed = current_time - last_checkpoint_time - - logger.info(f"[TRAINING] Iteration {iteration}, Decisions: {training_stats['total_decisions']}, Time: {elapsed:.1f}s") - - # Models will save their own checkpoints when performance improves - training_stats['last_checkpoint_iteration'] = iteration - last_checkpoint_time = current_time - - # Brief pause to prevent overwhelming the system - await asyncio.sleep(0.1) # 100ms between iterations - - except Exception as e: - logger.error(f"Training loop error: {e}") - await asyncio.sleep(5) # Wait longer on error - - except Exception as e: - logger.error(f"Training pipeline error: {e}") - import traceback - logger.error(traceback.format_exc()) - -def start_clean_dashboard_with_training(): - """Start clean dashboard with full training pipeline""" - try: - logger.info("=" * 80) - logger.info("CLEAN TRADING DASHBOARD + FULL TRAINING PIPELINE") - logger.info("=" * 80) - logger.info("Features: Real-time Training, COB Integration, Clean UI") - logger.info("Universal Data Stream: ENABLED") - logger.info("Neural Decision Fusion: ENABLED") - logger.info("COB Integration: ENABLED") - logger.info("GPU Training: ENABLED") - logger.info("TensorBoard Integration: ENABLED") - logger.info("Multi-symbol: ETH/USDT, BTC/USDT") - - # Get port from environment or use default - dashboard_port = int(os.environ.get('DASHBOARD_PORT', '8051')) - tensorboard_port = int(os.environ.get('TENSORBOARD_PORT', '6006')) - logger.info(f"Dashboard: http://127.0.0.1:{dashboard_port}") - logger.info(f"TensorBoard: http://127.0.0.1:{tensorboard_port}") - logger.info("=" * 80) - - # Check environment variables - enable_universal_stream = os.environ.get('ENABLE_UNIVERSAL_DATA_STREAM', '1') == '1' - enable_nn_fusion = os.environ.get('ENABLE_NN_DECISION_FUSION', '1') == '1' - enable_cob = os.environ.get('ENABLE_COB_INTEGRATION', '1') == '1' - - logger.info(f"Universal Data Stream: {'ENABLED' if enable_universal_stream else 'DISABLED'}") - logger.info(f"Neural Decision Fusion: {'ENABLED' if enable_nn_fusion else 'DISABLED'}") - logger.info(f"COB Integration: {'ENABLED' if enable_cob else 'DISABLED'}") - - # Get configuration - config = get_config() - - # Initialize core components with standardized versions - from core.standardized_data_provider import StandardizedDataProvider - from core.orchestrator import TradingOrchestrator - from core.trading_executor import TradingExecutor - - # Create standardized data provider - data_provider = StandardizedDataProvider() - logger.info("StandardizedDataProvider created with BaseDataInput support") - - # Create enhanced orchestrator with standardized data provider - orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True) - logger.info("Enhanced Trading Orchestrator created with COB integration") - - # Create trading executor - trading_executor = TradingExecutor(config_path="config.yaml") - logger.info(f"Creating trading executor with {trading_executor.primary_name} configuration...") - - - # Connect trading executor to orchestrator - orchestrator.trading_executor = trading_executor - logger.info("Trading Executor connected to Orchestrator") - - # Initialize system resource monitoring - from utils.system_monitor import start_system_monitoring - system_monitor = start_system_monitoring() - - # Set up cleanup callback for memory management - def cleanup_callback(): - """Custom cleanup for memory management""" - try: - # Clear orchestrator caches - if hasattr(orchestrator, 'recent_decisions'): - for symbol in orchestrator.recent_decisions: - if len(orchestrator.recent_decisions[symbol]) > 50: - orchestrator.recent_decisions[symbol] = orchestrator.recent_decisions[symbol][-25:] - - # Clear data provider caches - if hasattr(data_provider, 'clear_old_data'): - data_provider.clear_old_data() - - logger.info("Custom memory cleanup completed") - except Exception as e: - logger.error(f"Error in custom cleanup: {e}") - - system_monitor.set_callbacks(cleanup=cleanup_callback) - logger.info("System resource monitoring started with memory cleanup") - - # Import clean dashboard - from web.clean_dashboard import create_clean_dashboard - - # Create clean dashboard - logger.info("Creating clean dashboard...") - dashboard = create_clean_dashboard(data_provider, orchestrator, trading_executor) - logger.info("Clean Trading Dashboard created") - - # Add memory cleanup method to dashboard - def cleanup_dashboard_memory(): - """Clean up dashboard memory caches""" - try: - if hasattr(dashboard, 'recent_decisions'): - dashboard.recent_decisions = dashboard.recent_decisions[-50:] # Keep last 50 - if hasattr(dashboard, 'closed_trades'): - dashboard.closed_trades = dashboard.closed_trades[-100:] # Keep last 100 - if hasattr(dashboard, 'tick_cache'): - dashboard.tick_cache = dashboard.tick_cache[-1000:] # Keep last 1000 - logger.debug("Dashboard memory cleanup completed") - except Exception as e: - logger.error(f"Error in dashboard memory cleanup: {e}") - - # Set cleanup method on dashboard - dashboard.cleanup_memory = cleanup_dashboard_memory - - # Start training pipeline in background thread with enhanced error handling - def training_worker(): - """Run training pipeline in background with comprehensive error handling""" - try: - asyncio.run(start_training_pipeline(orchestrator, trading_executor)) - except KeyboardInterrupt: - logger.info("Training worker stopped by user") - except Exception as e: - logger.error(f"Training worker error: {e}") - import traceback - logger.error(f"Training worker traceback: {traceback.format_exc()}") - # Don't exit - let main thread handle restart - - training_thread = threading.Thread(target=training_worker, daemon=True) - training_thread.start() - logger.info("Training pipeline started in background with error handling") - - # Wait a moment for training to initialize - time.sleep(3) - - # Start TensorBoard in background - from web.tensorboard_integration import get_tensorboard_integration - tensorboard_port = int(os.environ.get('TENSORBOARD_PORT', '6006')) - tensorboard_integration = get_tensorboard_integration(log_dir="runs", port=tensorboard_port) - - # Start TensorBoard server - tensorboard_started = tensorboard_integration.start_tensorboard(open_browser=False) - if tensorboard_started: - logger.info(f"TensorBoard started at {tensorboard_integration.get_tensorboard_url()}") - else: - logger.warning("Failed to start TensorBoard - training metrics will not be visualized") - - # Start dashboard server with error handling (this blocks) - logger.info("Starting Clean Dashboard Server with error handling...") - try: - dashboard.run_server(host='127.0.0.1', port=dashboard_port, debug=False) - except Exception as e: - logger.error(f"Dashboard server error: {e}") - import traceback - logger.error(f"Dashboard server traceback: {traceback.format_exc()}") - raise # Re-raise to trigger main error handling - - except KeyboardInterrupt: - logger.info("System stopped by user") - # Stop TensorBoard - try: - tensorboard_integration = get_tensorboard_integration() - tensorboard_integration.stop_tensorboard() - except: - pass - except Exception as e: - logger.error(f"Error running clean dashboard with training: {e}") - import traceback - traceback.print_exc() - sys.exit(1) +from core.config import setup_logging, get_config +from core.trading_executor import TradingExecutor +from core.orchestrator import TradingOrchestrator +from core.standardized_data_provider import StandardizedDataProvider +from web.clean_dashboard import CleanTradingDashboard def main(): - """Main function with comprehensive error handling""" - try: - start_clean_dashboard_with_training() - except KeyboardInterrupt: - logger.info("Dashboard stopped by user (Ctrl+C)") - sys.exit(0) - except Exception as e: - logger.error(f"Critical error in main: {e}") - import traceback - logger.error(traceback.format_exc()) - sys.exit(1) + setup_logging() + cfg = get_config() + + data_provider = StandardizedDataProvider() + trading_executor = TradingExecutor() + orchestrator = TradingOrchestrator(data_provider=data_provider) + + dashboard = CleanTradingDashboard( + data_provider=data_provider, + orchestrator=orchestrator, + trading_executor=trading_executor + ) + logging.getLogger(__name__).info("Starting Clean Trading Dashboard at http://127.0.0.1:8050") + dashboard.run_server(host='127.0.0.1', port=8050, debug=False) + +if __name__ == '__main__': + main() -if __name__ == "__main__": - # Ensure logging is flushed on exit - import atexit - def flush_logs(): - logging.shutdown() - atexit.register(flush_logs) - - main() \ No newline at end of file diff --git a/run_continuous_training.py b/run_continuous_training.py deleted file mode 100644 index 86c5c69..0000000 --- a/run_continuous_training.py +++ /dev/null @@ -1,501 +0,0 @@ -#!/usr/bin/env python3 -""" -Continuous Full Training System (RL + CNN) - -This system runs continuous training for both RL and CNN models using the enhanced -DataProvider for consistent data streaming to both models and the dashboard. - -Features: -- Single DataProvider instance for all data needs -- Continuous RL training with real-time market data -- CNN training with perfect move detection -- Real-time performance monitoring -- Automatic model checkpointing -- Integration with live trading dashboard -""" - -import asyncio -import logging -import time -import signal -import sys -from datetime import datetime, timedelta -from threading import Thread, Event -from typing import Dict, Any - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler('logs/continuous_training.log'), - logging.StreamHandler() - ] -) -logger = logging.getLogger(__name__) - -# Import our components -from core.config import get_config -from core.data_provider import DataProvider, MarketTick -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard - -# Import checkpoint management -from utils.checkpoint_manager import get_checkpoint_manager -from utils.training_integration import get_training_integration - -class ContinuousTrainingSystem: - """Comprehensive continuous training system for RL + CNN models""" - - def __init__(self): - """Initialize the continuous training system""" - self.config = get_config() - - # Single DataProvider instance for all data needs - self.data_provider = DataProvider( - symbols=['ETH/USDT', 'BTC/USDT'], - timeframes=['1s', '1m', '1h', '1d'] - ) - - # Enhanced orchestrator for AI trading - self.orchestrator = EnhancedTradingOrchestrator(self.data_provider) - - # Dashboard for monitoring - self.dashboard = None - - # Training control - self.running = False - self.shutdown_event = Event() - - # Checkpoint management - self.checkpoint_manager = get_checkpoint_manager() - self.training_integration = get_training_integration() - - # Performance tracking - self.training_stats = { - 'start_time': None, - 'rl_training_cycles': 0, - 'cnn_training_cycles': 0, - 'perfect_moves_detected': 0, - 'total_ticks_processed': 0, - 'models_saved': 0, - 'last_checkpoint': None, - 'best_rl_reward': float('-inf'), - 'best_cnn_accuracy': 0.0 - } - - # Training intervals - self.rl_training_interval = 300 # 5 minutes - self.cnn_training_interval = 600 # 10 minutes - self.checkpoint_interval = 1800 # 30 minutes - - logger.info("Continuous Training System initialized with checkpoint management") - logger.info(f"RL training interval: {self.rl_training_interval}s") - logger.info(f"CNN training interval: {self.cnn_training_interval}s") - logger.info(f"Checkpoint interval: {self.checkpoint_interval}s") - - async def start(self, run_dashboard: bool = True): - """Start the continuous training system""" - logger.info("Starting Continuous Training System...") - self.running = True - self.training_stats['start_time'] = datetime.now() - - try: - # Start DataProvider streaming - logger.info("Starting DataProvider real-time streaming...") - await self.data_provider.start_real_time_streaming() - - # Subscribe to tick data for training - subscriber_id = self.data_provider.subscribe_to_ticks( - callback=self._handle_training_tick, - symbols=['ETH/USDT', 'BTC/USDT'], - subscriber_name="ContinuousTraining" - ) - logger.info(f"Subscribed to training tick stream: {subscriber_id}") - - # Start training threads - training_tasks = [ - asyncio.create_task(self._rl_training_loop()), - asyncio.create_task(self._cnn_training_loop()), - asyncio.create_task(self._checkpoint_loop()), - asyncio.create_task(self._monitoring_loop()) - ] - - # Start dashboard if requested - if run_dashboard: - dashboard_task = asyncio.create_task(self._run_dashboard()) - training_tasks.append(dashboard_task) - - logger.info("All training components started successfully") - - # Wait for shutdown signal - await self._wait_for_shutdown() - - except Exception as e: - logger.error(f"Error in continuous training system: {e}") - raise - finally: - await self.stop() - - def _handle_training_tick(self, tick: MarketTick): - """Handle incoming tick data for training""" - try: - self.training_stats['total_ticks_processed'] += 1 - - # Process tick through orchestrator for RL training - if self.orchestrator and hasattr(self.orchestrator, 'process_tick'): - self.orchestrator.process_tick(tick) - - # Log every 1000 ticks - if self.training_stats['total_ticks_processed'] % 1000 == 0: - logger.info(f"Processed {self.training_stats['total_ticks_processed']} training ticks") - - except Exception as e: - logger.warning(f"Error processing training tick: {e}") - - async def _rl_training_loop(self): - """Continuous RL training loop""" - logger.info("Starting RL training loop...") - - while self.running: - try: - start_time = time.time() - - # Perform RL training cycle - if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent: - logger.info("Starting RL training cycle...") - - # Get recent market data for training - training_data = self._prepare_rl_training_data() - - if training_data is not None: - # Train RL agent - training_results = await self._train_rl_agent(training_data) - - if training_results: - self.training_stats['rl_training_cycles'] += 1 - logger.info(f"RL training cycle {self.training_stats['rl_training_cycles']} completed") - logger.info(f"Training results: {training_results}") - else: - logger.warning("No training data available for RL agent") - - # Wait for next training cycle - elapsed = time.time() - start_time - sleep_time = max(0, self.rl_training_interval - elapsed) - await asyncio.sleep(sleep_time) - - except Exception as e: - logger.error(f"Error in RL training loop: {e}") - await asyncio.sleep(60) # Wait before retrying - - async def _cnn_training_loop(self): - """Continuous CNN training loop""" - logger.info("Starting CNN training loop...") - - while self.running: - try: - start_time = time.time() - - # Perform CNN training cycle - if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model: - logger.info("Starting CNN training cycle...") - - # Detect perfect moves for CNN training - perfect_moves = self._detect_perfect_moves() - - if perfect_moves: - self.training_stats['perfect_moves_detected'] += len(perfect_moves) - - # Train CNN with perfect moves - training_results = await self._train_cnn_model(perfect_moves) - - if training_results: - self.training_stats['cnn_training_cycles'] += 1 - logger.info(f"CNN training cycle {self.training_stats['cnn_training_cycles']} completed") - logger.info(f"Perfect moves processed: {len(perfect_moves)}") - else: - logger.info("No perfect moves detected for CNN training") - - # Wait for next training cycle - elapsed = time.time() - start_time - sleep_time = max(0, self.cnn_training_interval - elapsed) - await asyncio.sleep(sleep_time) - - except Exception as e: - logger.error(f"Error in CNN training loop: {e}") - await asyncio.sleep(60) # Wait before retrying - - async def _checkpoint_loop(self): - """Automatic model checkpointing loop""" - logger.info("Starting checkpoint loop...") - - while self.running: - try: - await asyncio.sleep(self.checkpoint_interval) - - logger.info("Creating model checkpoints...") - - # Save RL model - if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent: - rl_checkpoint = await self._save_rl_checkpoint() - if rl_checkpoint: - logger.info(f"RL checkpoint saved: {rl_checkpoint}") - - # Save CNN model - if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model: - cnn_checkpoint = await self._save_cnn_checkpoint() - if cnn_checkpoint: - logger.info(f"CNN checkpoint saved: {cnn_checkpoint}") - - self.training_stats['models_saved'] += 1 - self.training_stats['last_checkpoint'] = datetime.now() - - except Exception as e: - logger.error(f"Error in checkpoint loop: {e}") - - async def _monitoring_loop(self): - """System monitoring and performance tracking loop""" - logger.info("Starting monitoring loop...") - - while self.running: - try: - await asyncio.sleep(300) # Monitor every 5 minutes - - # Log system statistics - uptime = datetime.now() - self.training_stats['start_time'] - - logger.info("=== CONTINUOUS TRAINING SYSTEM STATUS ===") - logger.info(f"Uptime: {uptime}") - logger.info(f"RL training cycles: {self.training_stats['rl_training_cycles']}") - logger.info(f"CNN training cycles: {self.training_stats['cnn_training_cycles']}") - logger.info(f"Perfect moves detected: {self.training_stats['perfect_moves_detected']}") - logger.info(f"Total ticks processed: {self.training_stats['total_ticks_processed']}") - logger.info(f"Models saved: {self.training_stats['models_saved']}") - - # DataProvider statistics - if hasattr(self.data_provider, 'get_subscriber_stats'): - subscriber_stats = self.data_provider.get_subscriber_stats() - logger.info(f"Active subscribers: {subscriber_stats.get('active_subscribers', 0)}") - logger.info(f"Total ticks distributed: {subscriber_stats.get('distribution_stats', {}).get('total_ticks_distributed', 0)}") - - # Orchestrator performance - if hasattr(self.orchestrator, 'get_performance_metrics'): - perf_metrics = self.orchestrator.get_performance_metrics() - logger.info(f"Orchestrator performance: {perf_metrics}") - - logger.info("==========================================") - - except Exception as e: - logger.error(f"Error in monitoring loop: {e}") - - async def _run_dashboard(self): - """Run the dashboard in a separate thread""" - try: - logger.info("Starting live trading dashboard...") - - def run_dashboard(): - self.dashboard = RealTimeScalpingDashboard( - data_provider=self.data_provider, - orchestrator=self.orchestrator - ) - self.dashboard.run(host='127.0.0.1', port=8051, debug=False) - - dashboard_thread = Thread(target=run_dashboard, daemon=True) - dashboard_thread.start() - - logger.info("Dashboard started at http://127.0.0.1:8051") - - # Keep dashboard thread alive - while self.running: - await asyncio.sleep(10) - - except Exception as e: - logger.error(f"Error running dashboard: {e}") - - def _prepare_rl_training_data(self) -> Dict[str, Any]: - """Prepare training data for RL agent""" - try: - # Get recent market data from DataProvider - eth_data = self.data_provider.get_latest_candles('ETH/USDT', '1m', limit=1000) - btc_data = self.data_provider.get_latest_candles('BTC/USDT', '1m', limit=1000) - - if eth_data is not None and not eth_data.empty: - return { - 'eth_data': eth_data, - 'btc_data': btc_data, - 'timestamp': datetime.now() - } - - return None - - except Exception as e: - logger.error(f"Error preparing RL training data: {e}") - return None - - def _detect_perfect_moves(self) -> list: - """Detect perfect trading moves for CNN training""" - try: - # Get recent tick data - recent_ticks = self.data_provider.get_recent_ticks('ETHUSDT', count=500) - - if not recent_ticks: - return [] - - # Simple perfect move detection (can be enhanced) - perfect_moves = [] - - for i in range(1, len(recent_ticks) - 1): - prev_tick = recent_ticks[i-1] - curr_tick = recent_ticks[i] - next_tick = recent_ticks[i+1] - - # Detect significant price movements - price_change = (next_tick.price - curr_tick.price) / curr_tick.price - - if abs(price_change) > 0.001: # 0.1% movement - perfect_moves.append({ - 'timestamp': curr_tick.timestamp, - 'price': curr_tick.price, - 'action': 'BUY' if price_change > 0 else 'SELL', - 'confidence': min(abs(price_change) * 100, 1.0) - }) - - return perfect_moves[-10:] # Return last 10 perfect moves - - except Exception as e: - logger.error(f"Error detecting perfect moves: {e}") - return [] - - async def _train_rl_agent(self, training_data: Dict[str, Any]) -> Dict[str, Any]: - """Train the RL agent with market data""" - try: - # Placeholder for RL training logic - # This would integrate with the actual RL agent - - logger.info("Training RL agent with market data...") - - # Simulate training time - await asyncio.sleep(1) - - return { - 'loss': 0.05, - 'reward': 0.75, - 'episodes': 100 - } - - except Exception as e: - logger.error(f"Error training RL agent: {e}") - return None - - async def _train_cnn_model(self, perfect_moves: list) -> Dict[str, Any]: - """Train the CNN model with perfect moves""" - try: - # Placeholder for CNN training logic - # This would integrate with the actual CNN model - - logger.info(f"Training CNN model with {len(perfect_moves)} perfect moves...") - - # Simulate training time - await asyncio.sleep(2) - - return { - 'accuracy': 0.92, - 'loss': 0.08, - 'perfect_moves_processed': len(perfect_moves) - } - - except Exception as e: - logger.error(f"Error training CNN model: {e}") - return None - - async def _save_rl_checkpoint(self) -> str: - """Save RL model checkpoint""" - try: - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - checkpoint_path = f"models/rl/checkpoint_rl_{timestamp}.pt" - - # Placeholder for actual model saving - logger.info(f"Saving RL checkpoint to {checkpoint_path}") - - return checkpoint_path - - except Exception as e: - logger.error(f"Error saving RL checkpoint: {e}") - return None - - async def _save_cnn_checkpoint(self) -> str: - """Save CNN model checkpoint""" - try: - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - checkpoint_path = f"models/cnn/checkpoint_cnn_{timestamp}.pt" - - # Placeholder for actual model saving - logger.info(f"Saving CNN checkpoint to {checkpoint_path}") - - return checkpoint_path - - except Exception as e: - logger.error(f"Error saving CNN checkpoint: {e}") - return None - - async def _wait_for_shutdown(self): - """Wait for shutdown signal""" - def signal_handler(signum, frame): - logger.info(f"Received signal {signum}, shutting down...") - self.shutdown_event.set() - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - # Wait for shutdown event - while not self.shutdown_event.is_set(): - await asyncio.sleep(1) - - async def stop(self): - """Stop the continuous training system""" - logger.info("Stopping Continuous Training System...") - self.running = False - - try: - # Stop DataProvider streaming - if self.data_provider: - await self.data_provider.stop_real_time_streaming() - - # Final checkpoint - logger.info("Creating final checkpoints...") - await self._save_rl_checkpoint() - await self._save_cnn_checkpoint() - - # Log final statistics - uptime = datetime.now() - self.training_stats['start_time'] - logger.info("=== FINAL TRAINING STATISTICS ===") - logger.info(f"Total uptime: {uptime}") - logger.info(f"RL training cycles: {self.training_stats['rl_training_cycles']}") - logger.info(f"CNN training cycles: {self.training_stats['cnn_training_cycles']}") - logger.info(f"Perfect moves detected: {self.training_stats['perfect_moves_detected']}") - logger.info(f"Total ticks processed: {self.training_stats['total_ticks_processed']}") - logger.info(f"Models saved: {self.training_stats['models_saved']}") - logger.info("=================================") - - except Exception as e: - logger.error(f"Error during shutdown: {e}") - - logger.info("Continuous Training System stopped") - -async def main(): - """Main entry point""" - logger.info("Starting Continuous Full Training System (RL + CNN)") - - # Create and start the training system - training_system = ContinuousTrainingSystem() - - try: - await training_system.start(run_dashboard=True) - except KeyboardInterrupt: - logger.info("Interrupted by user") - except Exception as e: - logger.error(f"Fatal error: {e}") - sys.exit(1) - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/run_crash_safe_dashboard.py b/run_crash_safe_dashboard.py deleted file mode 100644 index 2c228cb..0000000 --- a/run_crash_safe_dashboard.py +++ /dev/null @@ -1,269 +0,0 @@ -#!/usr/bin/env python3 -""" -Crash-Safe Dashboard Runner - -This runner is designed to prevent crashes by: -1. Isolating imports with try/except blocks -2. Minimal initialization -3. Graceful error handling -4. No complex training loops -5. Safe component loading -""" - -import os -import sys -import logging -import traceback -from pathlib import Path - -# Fix environment issues before any imports -os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' -os.environ['OMP_NUM_THREADS'] = '1' # Minimal threads -os.environ['MPLBACKEND'] = 'Agg' - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -# Setup basic logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -# Reduce noise from other loggers -logging.getLogger('werkzeug').setLevel(logging.ERROR) -logging.getLogger('dash').setLevel(logging.ERROR) -logging.getLogger('matplotlib').setLevel(logging.ERROR) - -class CrashSafeDashboard: - """Crash-safe dashboard with minimal dependencies""" - - def __init__(self): - """Initialize with safe error handling""" - self.components = {} - self.dashboard_app = None - self.initialization_errors = [] - - logger.info("Initializing crash-safe dashboard...") - - def safe_import(self, module_name, class_name=None): - """Safely import modules with error handling""" - try: - if class_name: - module = __import__(module_name, fromlist=[class_name]) - return getattr(module, class_name) - else: - return __import__(module_name) - except Exception as e: - error_msg = f"Failed to import {module_name}.{class_name if class_name else ''}: {e}" - logger.error(error_msg) - self.initialization_errors.append(error_msg) - return None - - def initialize_core_components(self): - """Initialize core components safely""" - logger.info("Initializing core components...") - - # Try to import and initialize config - try: - from core.config import get_config, setup_logging - setup_logging() - self.components['config'] = get_config() - logger.info("✓ Config loaded") - except Exception as e: - logger.error(f"✗ Config failed: {e}") - self.initialization_errors.append(f"Config: {e}") - - # Try to initialize data provider - try: - from core.data_provider import DataProvider - self.components['data_provider'] = DataProvider() - logger.info("✓ Data provider initialized") - except Exception as e: - logger.error(f"✗ Data provider failed: {e}") - self.initialization_errors.append(f"Data provider: {e}") - - # Try to initialize trading executor - try: - from core.trading_executor import TradingExecutor - self.components['trading_executor'] = TradingExecutor() - logger.info("✓ Trading executor initialized") - except Exception as e: - logger.error(f"✗ Trading executor failed: {e}") - self.initialization_errors.append(f"Trading executor: {e}") - - # Try to initialize orchestrator (WITHOUT training to avoid crashes) - try: - from core.orchestrator import TradingOrchestrator - self.components['orchestrator'] = TradingOrchestrator( - data_provider=self.components.get('data_provider'), - enhanced_rl_training=False # DISABLED to prevent crashes - ) - logger.info("✓ Orchestrator initialized (training disabled)") - except Exception as e: - logger.error(f"✗ Orchestrator failed: {e}") - self.initialization_errors.append(f"Orchestrator: {e}") - - def create_minimal_dashboard(self): - """Create minimal dashboard without complex features""" - try: - import dash - from dash import html, dcc - - # Create minimal Dash app - self.dashboard_app = dash.Dash(__name__) - - # Create simple layout - self.dashboard_app.layout = html.Div([ - html.H1("Trading Dashboard - Safe Mode", style={'textAlign': 'center'}), - html.Hr(), - - # Status section - html.Div([ - html.H3("System Status"), - html.Div(id="system-status", children=self._get_system_status()), - ], style={'margin': '20px'}), - - # Error section - html.Div([ - html.H3("Initialization Status"), - html.Div(id="init-status", children=self._get_init_status()), - ], style={'margin': '20px'}), - - # Simple refresh interval - dcc.Interval( - id='interval-component', - interval=5000, # Update every 5 seconds - n_intervals=0 - ) - ]) - - # Add simple callback - @self.dashboard_app.callback( - [dash.dependencies.Output('system-status', 'children'), - dash.dependencies.Output('init-status', 'children')], - [dash.dependencies.Input('interval-component', 'n_intervals')] - ) - def update_status(n): - try: - return self._get_system_status(), self._get_init_status() - except Exception as e: - logger.error(f"Callback error: {e}") - return f"Callback error: {e}", "Error in callback" - - logger.info("✓ Minimal dashboard created") - return True - - except Exception as e: - logger.error(f"✗ Dashboard creation failed: {e}") - logger.error(traceback.format_exc()) - return False - - def _get_system_status(self): - """Get system status for display""" - try: - status_items = [] - - # Check components - for name, component in self.components.items(): - if component is not None: - status_items.append(html.P(f"✓ {name.replace('_', ' ').title()}: OK", - style={'color': 'green'})) - else: - status_items.append(html.P(f"✗ {name.replace('_', ' ').title()}: Failed", - style={'color': 'red'})) - - # Add timestamp - status_items.append(html.P(f"Last update: {datetime.now().strftime('%H:%M:%S')}", - style={'color': 'gray', 'fontSize': '12px'})) - - return status_items - - except Exception as e: - return [html.P(f"Status error: {e}", style={'color': 'red'})] - - def _get_init_status(self): - """Get initialization status for display""" - try: - if not self.initialization_errors: - return [html.P("✓ All components initialized successfully", style={'color': 'green'})] - - error_items = [html.P("⚠️ Some components failed to initialize:", style={'color': 'orange'})] - - for error in self.initialization_errors: - error_items.append(html.P(f"• {error}", style={'color': 'red', 'fontSize': '12px'})) - - return error_items - - except Exception as e: - return [html.P(f"Init status error: {e}", style={'color': 'red'})] - - def run(self, port=8051): - """Run the crash-safe dashboard""" - try: - logger.info("=" * 60) - logger.info("CRASH-SAFE DASHBOARD") - logger.info("=" * 60) - logger.info("Mode: Safe mode with minimal features") - logger.info("Training: Completely disabled") - logger.info("Focus: System stability and basic monitoring") - logger.info("=" * 60) - - # Initialize components - self.initialize_core_components() - - # Create dashboard - if not self.create_minimal_dashboard(): - logger.error("Failed to create dashboard") - return False - - # Report initialization status - if self.initialization_errors: - logger.warning(f"Dashboard starting with {len(self.initialization_errors)} component failures") - for error in self.initialization_errors: - logger.warning(f" - {error}") - else: - logger.info("All components initialized successfully") - - # Start dashboard - logger.info(f"Starting dashboard on http://127.0.0.1:{port}") - logger.info("Press Ctrl+C to stop") - - self.dashboard_app.run_server( - host='127.0.0.1', - port=port, - debug=False, - use_reloader=False, - threaded=True - ) - - return True - - except KeyboardInterrupt: - logger.info("Dashboard stopped by user") - return True - except Exception as e: - logger.error(f"Dashboard failed: {e}") - logger.error(traceback.format_exc()) - return False - -def main(): - """Main function with comprehensive error handling""" - try: - dashboard = CrashSafeDashboard() - success = dashboard.run() - - if success: - logger.info("Dashboard completed successfully") - else: - logger.error("Dashboard failed") - - except Exception as e: - logger.error(f"Fatal error: {e}") - logger.error(traceback.format_exc()) - sys.exit(1) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/run_enhanced_rl_training.py b/run_enhanced_rl_training.py deleted file mode 100644 index 743d924..0000000 --- a/run_enhanced_rl_training.py +++ /dev/null @@ -1,525 +0,0 @@ -#!/usr/bin/env python3 -""" -Enhanced RL Training Launcher with Real Data Integration - -This script launches the comprehensive RL training system that uses: -- Real-time tick data (300s window for momentum detection) -- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d) -- BTC reference data for correlation -- CNN hidden features and predictions -- Williams Market Structure pivot points -- Market microstructure analysis - -The RL model will receive ~13,400 features instead of the previous ~100 basic features. -Training metrics are automatically logged to TensorBoard for visualization. -""" - -import asyncio -import logging -import time -import signal -import sys -from datetime import datetime, timedelta -from pathlib import Path -from typing import Dict, List, Optional - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler('enhanced_rl_training.log'), - logging.StreamHandler(sys.stdout) - ] -) - -logger = logging.getLogger(__name__) - -# Import our enhanced components -from core.config import get_config -from core.data_provider import DataProvider -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from training.enhanced_rl_trainer import EnhancedRLTrainer -from training.enhanced_rl_state_builder import EnhancedRLStateBuilder -from training.williams_market_structure import WilliamsMarketStructure -from training.cnn_rl_bridge import CNNRLBridge -from utils.tensorboard_logger import TensorBoardLogger - -class EnhancedRLTrainingSystem: - """Comprehensive RL training system with real data integration""" - - def __init__(self): - """Initialize the enhanced RL training system""" - self.config = get_config() - self.running = False - self.data_provider = None - self.orchestrator = None - self.rl_trainer = None - - # Performance tracking - self.training_stats = { - 'training_sessions': 0, - 'total_experiences': 0, - 'avg_state_size': 0, - 'data_quality_score': 0.0, - 'last_training_time': None - } - - # Initialize TensorBoard logger - experiment_name = f"enhanced_rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - self.tb_logger = TensorBoardLogger( - log_dir="runs", - experiment_name=experiment_name, - enabled=True - ) - - logger.info("Enhanced RL Training System initialized") - logger.info(f"TensorBoard logging enabled for experiment: {experiment_name}") - logger.info("Features:") - logger.info("- Real-time tick data processing (300s window)") - logger.info("- Multi-timeframe OHLCV analysis (1s, 1m, 1h, 1d)") - logger.info("- BTC correlation analysis") - logger.info("- CNN feature integration") - logger.info("- Williams Market Structure pivot points") - logger.info("- ~13,400 feature state vector (vs previous ~100)") - -# async def initialize(self): -# """Initialize all components""" -# try: -# logger.info("Initializing enhanced RL training components...") - -# # Initialize data provider with real-time streaming -# logger.info("Setting up data provider with real-time streaming...") -# self.data_provider = DataProvider( -# symbols=self.config.symbols, -# timeframes=self.config.timeframes -# ) - -# # Start real-time data streaming -# await self.data_provider.start_real_time_streaming() -# logger.info("Real-time data streaming started") - -# # Wait for initial data collection -# logger.info("Collecting initial market data...") -# await asyncio.sleep(30) # Allow 30 seconds for data collection - -# # Initialize enhanced orchestrator -# logger.info("Initializing enhanced orchestrator...") -# self.orchestrator = EnhancedTradingOrchestrator(self.data_provider) - -# # Initialize enhanced RL trainer with comprehensive state building -# logger.info("Initializing enhanced RL trainer...") -# self.rl_trainer = EnhancedRLTrainer( -# config=self.config, -# orchestrator=self.orchestrator -# ) - -# # Verify data availability -# data_status = await self._verify_data_availability() -# if not data_status['has_sufficient_data']: -# logger.warning("Insufficient data detected. Continuing with limited training.") -# logger.warning(f"Data status: {data_status}") -# else: -# logger.info("Sufficient data available for comprehensive RL training") -# logger.info(f"Tick data: {data_status['tick_count']} ticks") -# logger.info(f"OHLCV data: {data_status['ohlcv_bars']} bars") - -# self.running = True -# logger.info("Enhanced RL training system initialized successfully") - -# except Exception as e: -# logger.error(f"Error during initialization: {e}") -# raise - -# async def _verify_data_availability(self) -> Dict[str, any]: -# """Verify that we have sufficient data for training""" -# try: -# data_status = { -# 'has_sufficient_data': False, -# 'tick_count': 0, -# 'ohlcv_bars': 0, -# 'symbols_with_data': [], -# 'missing_data': [] -# } - -# for symbol in self.config.symbols: -# # Check tick data -# recent_ticks = self.data_provider.get_recent_ticks(symbol, count=100) -# tick_count = len(recent_ticks) - -# # Check OHLCV data -# ohlcv_bars = 0 -# for timeframe in ['1s', '1m', '1h', '1d']: -# try: -# df = self.data_provider.get_historical_data( -# symbol=symbol, -# timeframe=timeframe, -# limit=50, -# refresh=True -# ) -# if df is not None and not df.empty: -# ohlcv_bars += len(df) -# except Exception as e: -# logger.warning(f"Error checking {timeframe} data for {symbol}: {e}") - -# data_status['tick_count'] += tick_count -# data_status['ohlcv_bars'] += ohlcv_bars - -# if tick_count >= 50 and ohlcv_bars >= 100: -# data_status['symbols_with_data'].append(symbol) -# else: -# data_status['missing_data'].append(f"{symbol}: {tick_count} ticks, {ohlcv_bars} bars") - -# # Consider data sufficient if we have at least one symbol with good data -# data_status['has_sufficient_data'] = len(data_status['symbols_with_data']) > 0 - -# return data_status - -# except Exception as e: -# logger.error(f"Error verifying data availability: {e}") -# return {'has_sufficient_data': False, 'error': str(e)} - -# async def run_training_loop(self): -# """Run the main training loop with real data""" -# logger.info("Starting enhanced RL training loop...") - -# training_cycle = 0 -# last_state_size_log = time.time() - -# try: -# while self.running: -# training_cycle += 1 -# cycle_start_time = time.time() - -# logger.info(f"Training cycle {training_cycle} started") - -# # Get comprehensive market states with real data -# market_states = await self._get_comprehensive_market_states() - -# if not market_states: -# logger.warning("No market states available. Waiting for data...") -# await asyncio.sleep(60) -# continue - -# # Train RL agents with comprehensive states -# training_results = await self._train_rl_agents(market_states) - -# # Update performance tracking -# self._update_training_stats(training_results, market_states) - -# # Log training progress -# cycle_duration = time.time() - cycle_start_time -# logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s") - -# # Log state size periodically -# if time.time() - last_state_size_log > 300: # Every 5 minutes -# self._log_state_size_info(market_states) -# last_state_size_log = time.time() - -# # Save models periodically -# if training_cycle % 10 == 0: -# await self._save_training_progress() - -# # Wait before next training cycle -# await asyncio.sleep(300) # Train every 5 minutes - -# except Exception as e: -# logger.error(f"Error in training loop: {e}") -# raise - -# async def _get_comprehensive_market_states(self) -> Dict[str, any]: -# """Get comprehensive market states with all required data""" -# try: -# # Get market states from orchestrator -# universal_stream = self.orchestrator.universal_adapter.get_universal_stream() -# market_states = await self.orchestrator._get_all_market_states_universal(universal_stream) - -# # Verify data quality -# quality_score = self._calculate_data_quality(market_states) -# self.training_stats['data_quality_score'] = quality_score - -# if quality_score < 0.5: -# logger.warning(f"Low data quality detected: {quality_score:.2f}") - -# return market_states - -# except Exception as e: -# logger.error(f"Error getting comprehensive market states: {e}") -# return {} - -# def _calculate_data_quality(self, market_states: Dict[str, any]) -> float: -# """Calculate data quality score based on available data""" -# try: -# if not market_states: -# return 0.0 - -# total_score = 0.0 -# total_symbols = len(market_states) - -# for symbol, state in market_states.items(): -# symbol_score = 0.0 - -# # Score based on tick data availability -# if hasattr(state, 'raw_ticks') and state.raw_ticks: -# tick_score = min(len(state.raw_ticks) / 100, 1.0) # Max score for 100+ ticks -# symbol_score += tick_score * 0.3 - -# # Score based on OHLCV data availability -# if hasattr(state, 'ohlcv_data') and state.ohlcv_data: -# ohlcv_score = len(state.ohlcv_data) / 4.0 # Max score for all 4 timeframes -# symbol_score += min(ohlcv_score, 1.0) * 0.4 - -# # Score based on CNN features -# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features: -# symbol_score += 0.15 - -# # Score based on pivot points -# if hasattr(state, 'pivot_points') and state.pivot_points: -# symbol_score += 0.15 - -# total_score += symbol_score - -# return total_score / total_symbols if total_symbols > 0 else 0.0 - -# except Exception as e: -# logger.warning(f"Error calculating data quality: {e}") -# return 0.5 # Default to medium quality - - async def _train_rl_agents(self, market_states: Dict[str, any]) -> Dict[str, any]: - """Train RL agents with comprehensive market states""" - try: - training_results = { - 'symbols_trained': [], - 'total_experiences': 0, - 'avg_state_size': 0, - 'training_errors': [], - 'losses': {}, - 'rewards': {} - } - - for symbol, market_state in market_states.items(): - try: - # Convert market state to comprehensive RL state - rl_state = self.rl_trainer._market_state_to_rl_state(market_state) - - if rl_state is not None and len(rl_state) > 0: - # Record state size - state_size = len(rl_state) - training_results['avg_state_size'] += state_size - - # Log state size to TensorBoard - self.tb_logger.log_scalar( - f'State/{symbol}/Size', - state_size, - self.training_stats['training_sessions'] - ) - - # Simulate trading action for experience generation - # In real implementation, this would be actual trading decisions - action = self._simulate_trading_action(symbol, rl_state) - - # Generate reward based on market outcome - reward = self._calculate_training_reward(symbol, market_state, action) - - # Store reward for TensorBoard logging - training_results['rewards'][symbol] = reward - - # Log action and reward to TensorBoard - self.tb_logger.log_scalars(f'Actions/{symbol}', { - 'action': action, - 'reward': reward - }, self.training_stats['training_sessions']) - - # Add experience to RL agent - agent = self.rl_trainer.agents.get(symbol) - if agent: - # Create next state (would be actual next market state in real scenario) - next_state = rl_state # Simplified for now - - agent.remember( - state=rl_state, - action=action, - reward=reward, - next_state=next_state, - done=False - ) - - # Train agent if enough experiences - if len(agent.replay_buffer) >= agent.batch_size: - loss = agent.replay() - if loss is not None: - logger.debug(f"Agent {symbol} training loss: {loss:.4f}") - - # Store loss for TensorBoard logging - training_results['losses'][symbol] = loss - - # Log loss to TensorBoard - self.tb_logger.log_scalar( - f'Training/{symbol}/Loss', - loss, - self.training_stats['training_sessions'] - ) - - training_results['symbols_trained'].append(symbol) - training_results['total_experiences'] += 1 - - except Exception as e: - error_msg = f"Error training {symbol}: {e}" - logger.warning(error_msg) - training_results['training_errors'].append(error_msg) - - # Calculate average state size - if len(training_results['symbols_trained']) > 0: - training_results['avg_state_size'] /= len(training_results['symbols_trained']) - - # Log overall training metrics to TensorBoard - self.tb_logger.log_scalars('Training/Overall', { - 'symbols_trained': len(training_results['symbols_trained']), - 'experiences': training_results['total_experiences'], - 'avg_state_size': training_results['avg_state_size'], - 'errors': len(training_results['training_errors']) - }, self.training_stats['training_sessions']) - - return training_results - - except Exception as e: - logger.error(f"Error training RL agents: {e}") - return {'error': str(e)} - -# def _simulate_trading_action(self, symbol: str, rl_state) -> int: -# """Simulate trading action for training (would be real decision in production)""" -# # Simple simulation based on state features -# if len(rl_state) > 100: -# # Use momentum features to decide action -# momentum_features = rl_state[:100] # First 100 features assumed to be momentum -# avg_momentum = sum(momentum_features) / len(momentum_features) - -# if avg_momentum > 0.6: -# return 1 # BUY -# elif avg_momentum < 0.4: -# return 2 # SELL -# else: -# return 0 # HOLD -# else: -# return 0 # HOLD as default - -# def _calculate_training_reward(self, symbol: str, market_state, action: int) -> float: -# """Calculate training reward based on market state and action""" -# try: -# # Simple reward calculation based on market conditions -# base_reward = 0.0 - -# # Reward based on volatility alignment -# if hasattr(market_state, 'volatility'): -# if action == 0 and market_state.volatility > 0.02: # HOLD in high volatility -# base_reward += 0.1 -# elif action != 0 and market_state.volatility < 0.01: # Trade in low volatility -# base_reward += 0.1 - -# # Reward based on trend alignment -# if hasattr(market_state, 'trend_strength'): -# if action == 1 and market_state.trend_strength > 0.6: # BUY in uptrend -# base_reward += 0.2 -# elif action == 2 and market_state.trend_strength < 0.4: # SELL in downtrend -# base_reward += 0.2 - -# return base_reward - -# except Exception as e: -# logger.warning(f"Error calculating reward for {symbol}: {e}") -# return 0.0 - -# def _update_training_stats(self, training_results: Dict[str, any], market_states: Dict[str, any]): -# """Update training statistics""" -# self.training_stats['training_sessions'] += 1 -# self.training_stats['total_experiences'] += training_results.get('total_experiences', 0) -# self.training_stats['avg_state_size'] = training_results.get('avg_state_size', 0) -# self.training_stats['last_training_time'] = datetime.now() - -# # Log statistics periodically -# if self.training_stats['training_sessions'] % 10 == 0: -# logger.info("Training Statistics:") -# logger.info(f" Sessions: {self.training_stats['training_sessions']}") -# logger.info(f" Total Experiences: {self.training_stats['total_experiences']}") -# logger.info(f" Avg State Size: {self.training_stats['avg_state_size']:.0f}") -# logger.info(f" Data Quality: {self.training_stats['data_quality_score']:.2f}") - -# def _log_state_size_info(self, market_states: Dict[str, any]): -# """Log information about state sizes for debugging""" -# for symbol, state in market_states.items(): -# info = [] - -# if hasattr(state, 'raw_ticks'): -# info.append(f"ticks: {len(state.raw_ticks)}") - -# if hasattr(state, 'ohlcv_data'): -# total_bars = sum(len(bars) for bars in state.ohlcv_data.values()) -# info.append(f"OHLCV bars: {total_bars}") - -# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features: -# info.append("CNN features: available") - -# if hasattr(state, 'pivot_points') and state.pivot_points: -# info.append("pivot points: available") - -# logger.info(f"{symbol} state data: {', '.join(info)}") - -# async def _save_training_progress(self): -# """Save training progress and models""" -# try: -# if self.rl_trainer: -# self.rl_trainer._save_all_models() -# logger.info("Training progress saved") -# except Exception as e: -# logger.error(f"Error saving training progress: {e}") - -# async def shutdown(self): -# """Graceful shutdown""" -# logger.info("Shutting down enhanced RL training system...") -# self.running = False - -# # Save final state -# await self._save_training_progress() - -# # Stop data provider -# if self.data_provider: -# await self.data_provider.stop_real_time_streaming() - -# logger.info("Enhanced RL training system shutdown complete") - -# async def main(): -# """Main function to run enhanced RL training""" -# system = None - -# def signal_handler(signum, frame): -# logger.info("Received shutdown signal") -# if system: -# asyncio.create_task(system.shutdown()) - -# # Set up signal handlers -# signal.signal(signal.SIGINT, signal_handler) -# signal.signal(signal.SIGTERM, signal_handler) - -# try: -# # Create and initialize the training system -# system = EnhancedRLTrainingSystem() -# await system.initialize() - -# logger.info("Enhanced RL Training System is now running...") -# logger.info("The RL model now receives ~13,400 features instead of ~100!") -# logger.info("Press Ctrl+C to stop") - -# # Run the training loop -# await system.run_training_loop() - -# except KeyboardInterrupt: -# logger.info("Training interrupted by user") -# except Exception as e: -# logger.error(f"Error in main training loop: {e}") -# raise -# finally: -# if system: -# await system.shutdown() - -# if __name__ == "__main__": -# asyncio.run(main()) \ No newline at end of file diff --git a/run_enhanced_training_dashboard.py b/run_enhanced_training_dashboard.py deleted file mode 100644 index 3422636..0000000 --- a/run_enhanced_training_dashboard.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python3 -""" -Run Dashboard with Enhanced Training System Enabled - -This script starts the trading dashboard with the enhanced real-time -training system automatically enabled and running. -""" - -import sys -import os -import asyncio -import logging -from datetime import datetime - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider -from core.trading_executor import TradingExecutor -from web.clean_dashboard import create_clean_dashboard - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -async def main(): - """Start dashboard with enhanced training enabled""" - try: - logger.info("=" * 70) - logger.info("STARTING DASHBOARD WITH ENHANCED TRAINING SYSTEM") - logger.info("=" * 70) - - # 1. Initialize components with enhanced training - logger.info("1. Initializing components...") - data_provider = DataProvider() - trading_executor = TradingExecutor() - - # 2. Create orchestrator with enhanced training ENABLED - logger.info("2. Creating orchestrator with enhanced training...") - orchestrator = TradingOrchestrator( - data_provider=data_provider, - enhanced_rl_training=True # 🔥 THIS ENABLES ENHANCED TRAINING - ) - - # 3. Verify enhanced training is available - logger.info("3. Verifying enhanced training system...") - if orchestrator.enhanced_training_system: - logger.info("✅ Enhanced training system available") - logger.info(f" - Training enabled: {orchestrator.training_enabled}") - - # 4. Start enhanced training - logger.info("4. Starting enhanced training system...") - start_result = orchestrator.start_enhanced_training() - if start_result: - logger.info("✅ Enhanced training started successfully") - else: - logger.warning("⚠️ Enhanced training start failed") - else: - logger.warning("⚠️ Enhanced training system not available") - - # 5. Create dashboard - logger.info("5. Creating dashboard...") - dashboard = create_clean_dashboard( - data_provider=data_provider, - orchestrator=orchestrator, - trading_executor=trading_executor - ) - - # 6. Connect training system to dashboard - logger.info("6. Connecting training system to dashboard...") - orchestrator.set_training_dashboard(dashboard) - - # 7. Start dashboard - logger.info("7. Starting dashboard...") - logger.info("🎉 Dashboard with enhanced training is now running!") - logger.info(" - Enhanced training: ENABLED") - logger.info(" - Real-time learning: ACTIVE") - logger.info(" - Dashboard URL: http://127.0.0.1:8051") - - # Keep running - await asyncio.sleep(3600) # Run for 1 hour - - except KeyboardInterrupt: - logger.info("Dashboard stopped by user") - except Exception as e: - logger.error(f"Error starting dashboard: {e}") - import traceback - logger.error(traceback.format_exc()) - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/run_integrated_rl_cob_dashboard.py b/run_integrated_rl_cob_dashboard.py deleted file mode 100644 index 2620117..0000000 --- a/run_integrated_rl_cob_dashboard.py +++ /dev/null @@ -1,510 +0,0 @@ -#!/usr/bin/env python3 -""" -Integrated Real-time RL COB Trading System with Dashboard - -This script starts both: -1. RealtimeRLCOBTrader - 1B parameter RL model with real-time training -2. COB Dashboard - Real-time visualization with RL predictions - -The RL predictions are integrated into the dashboard for live visualization. -""" - -import asyncio -import logging -import signal -import sys -import json -import os -from datetime import datetime -from typing import Dict, Any -from aiohttp import web - -# Local imports -from core.realtime_rl_cob_trader import RealtimeRLCOBTrader, PredictionResult -from core.trading_executor import TradingExecutor -from core.config import load_config -from web.cob_realtime_dashboard import COBDashboardServer - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler('logs/integrated_rl_cob_system.log'), - logging.StreamHandler(sys.stdout) - ] -) - -logger = logging.getLogger(__name__) - -class IntegratedRLCOBSystem: - """ - Integrated Real-time RL COB Trading System with Dashboard - """ - - def __init__(self, config_path: str = "config.yaml"): - """Initialize integrated system with configuration""" - self.config = load_config(config_path) - self.trader = None - self.dashboard = None - self.trading_executor = None - self.running = False - - # RL prediction storage for dashboard - self.rl_predictions: Dict[str, list] = {} - self.prediction_history: Dict[str, list] = {} - - # Setup signal handlers for graceful shutdown - signal.signal(signal.SIGINT, self._signal_handler) - signal.signal(signal.SIGTERM, self._signal_handler) - - logger.info("IntegratedRLCOBSystem initialized") - - def _signal_handler(self, signum, frame): - """Handle shutdown signals""" - logger.info(f"Received signal {signum}, initiating graceful shutdown...") - self.running = False - - async def start(self): - """Start the integrated RL COB trading system with dashboard""" - try: - logger.info("=" * 60) - logger.info("INTEGRATED RL COB SYSTEM STARTING") - logger.info("Real-time RL Trading + Dashboard") - logger.info("=" * 60) - - # Initialize trading executor - await self._initialize_trading_executor() - - # Initialize RL trader with prediction callback - await self._initialize_rl_trader() - - # Initialize dashboard with RL integration - await self._initialize_dashboard() - - # Start the integrated system - await self._start_integrated_system() - - # Run main loop - await self._run_main_loop() - - except Exception as e: - logger.error(f"Critical error in integrated system: {e}") - raise - finally: - await self.stop() - - async def _initialize_trading_executor(self): - """Initialize the trading executor""" - logger.info("Initializing Trading Executor...") - - # Get trading configuration - trading_config = self.config.get('trading', {}) - mexc_config = self.config.get('mexc', {}) - - # Determine if we should run in simulation mode - simulation_mode = mexc_config.get('simulation_mode', True) - - if simulation_mode: - logger.info("Running in SIMULATION mode - no real trades will be executed") - else: - logger.warning("Running in LIVE TRADING mode - real money at risk!") - - # Add safety confirmation for live trading - confirmation = input("Type 'CONFIRM_LIVE_TRADING' to proceed with live trading: ") - if confirmation != 'CONFIRM_LIVE_TRADING': - logger.info("Live trading not confirmed, switching to simulation mode") - simulation_mode = True - - # Initialize trading executor with config path - self.trading_executor = TradingExecutor("config.yaml") - - logger.info(f"Trading Executor initialized in {'SIMULATION' if simulation_mode else 'LIVE'} mode") - - async def _initialize_rl_trader(self): - """Initialize the RL trader with prediction callbacks""" - logger.info("Initializing Real-time RL COB Trader...") - - # Get RL configuration - rl_config = self.config.get('realtime_rl', {}) - - # Trading symbols - symbols = rl_config.get('symbols', ['BTC/USDT', 'ETH/USDT']) - - # Initialize prediction storage - for symbol in symbols: - self.rl_predictions[symbol] = [] - self.prediction_history[symbol] = [] - - # RL parameters - inference_interval_ms = rl_config.get('inference_interval_ms', 200) - min_confidence_threshold = rl_config.get('min_confidence_threshold', 0.7) - required_confident_predictions = rl_config.get('required_confident_predictions', 3) - model_checkpoint_dir = rl_config.get('model_checkpoint_dir', 'models/realtime_rl_cob') - - # Initialize RL trader - self.trader = RealtimeRLCOBTrader( - symbols=symbols, - trading_executor=self.trading_executor, - model_checkpoint_dir=model_checkpoint_dir, - inference_interval_ms=inference_interval_ms, - min_confidence_threshold=min_confidence_threshold, - required_confident_predictions=required_confident_predictions - ) - - # Monkey-patch the trader to capture predictions - original_add_signal = self.trader._add_signal - def enhanced_add_signal(symbol: str, prediction: PredictionResult): - # Call original method - original_add_signal(symbol, prediction) - # Capture prediction for dashboard - self._on_rl_prediction(symbol, prediction) - - self.trader._add_signal = enhanced_add_signal - - logger.info(f"RL Trader initialized for symbols: {symbols}") - logger.info(f"Inference interval: {inference_interval_ms}ms") - logger.info(f"Confidence threshold: {min_confidence_threshold}") - logger.info(f"Required predictions: {required_confident_predictions}") - - def _on_rl_prediction(self, symbol: str, prediction: PredictionResult): - """Handle RL predictions for dashboard integration""" - try: - # Convert prediction to dashboard format - prediction_data = { - 'timestamp': prediction.timestamp.isoformat(), - 'direction': prediction.predicted_direction, # 0=DOWN, 1=SIDEWAYS, 2=UP - 'confidence': prediction.confidence, - 'predicted_change': prediction.predicted_change, - 'direction_text': ['DOWN', 'SIDEWAYS', 'UP'][prediction.predicted_direction], - 'color': ['red', 'gray', 'green'][prediction.predicted_direction] - } - - # Add to current predictions (for live display) - self.rl_predictions[symbol].append(prediction_data) - if len(self.rl_predictions[symbol]) > 100: # Keep last 100 - self.rl_predictions[symbol] = self.rl_predictions[symbol][-100:] - - # Add to history (for chart overlay) - self.prediction_history[symbol].append(prediction_data) - if len(self.prediction_history[symbol]) > 1000: # Keep last 1000 - self.prediction_history[symbol] = self.prediction_history[symbol][-1000:] - - logger.debug(f"Captured RL prediction for {symbol}: {prediction.predicted_direction} " - f"(confidence: {prediction.confidence:.3f})") - - except Exception as e: - logger.error(f"Error capturing RL prediction: {e}") - - async def _initialize_dashboard(self): - """Initialize the COB dashboard with RL integration""" - logger.info("Initializing COB Dashboard with RL Integration...") - - # Get dashboard configuration - dashboard_config = self.config.get('dashboard', {}) - host = dashboard_config.get('host', 'localhost') - port = dashboard_config.get('port', 8053) - - # Create enhanced dashboard server - self.dashboard = EnhancedCOBDashboardServer( - host=host, - port=port, - rl_system=self # Pass reference to get predictions - ) - - logger.info(f"COB Dashboard initialized at http://{host}:{port}") - - async def _start_integrated_system(self): - """Start the complete integrated system""" - logger.info("Starting Integrated RL COB System...") - - # Start RL trader first (this initializes COB integration) - await self.trader.start() - logger.info("RL Trader started") - - # Start dashboard (uses same COB integration) - await self.dashboard.start() - logger.info("COB Dashboard started") - - self.running = True - - logger.info("INTEGRATED SYSTEM FULLY OPERATIONAL!") - logger.info("1B parameter RL model: ACTIVE") - logger.info("Real-time COB data: STREAMING") - logger.info("Signal accumulation: ACTIVE") - logger.info("Live predictions: VISIBLE IN DASHBOARD") - logger.info("Continuous training: ACTIVE") - logger.info(f"Dashboard URL: http://{self.dashboard.host}:{self.dashboard.port}") - - async def _run_main_loop(self): - """Main monitoring and statistics loop""" - logger.info("Starting integrated system monitoring...") - - last_stats_time = datetime.now() - stats_interval = 60 # Print stats every 60 seconds - - while self.running: - try: - # Sleep for a bit - await asyncio.sleep(10) - - # Print periodic statistics - current_time = datetime.now() - if (current_time - last_stats_time).total_seconds() >= stats_interval: - await self._print_integrated_stats() - last_stats_time = current_time - - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"Error in main loop: {e}") - await asyncio.sleep(5) - - logger.info("Integrated system monitoring stopped") - - async def _print_integrated_stats(self): - """Print comprehensive integrated system statistics""" - try: - logger.info("=" * 80) - logger.info("INTEGRATED RL COB SYSTEM STATISTICS") - logger.info("=" * 80) - - # RL Trader Statistics - if self.trader: - rl_stats = self.trader.get_performance_stats() - logger.info("\n🤖 RL TRADER PERFORMANCE:") - - for symbol in self.trader.symbols: - training_stats = rl_stats.get('training_stats', {}).get(symbol, {}) - inference_stats = rl_stats.get('inference_stats', {}).get(symbol, {}) - signal_stats = rl_stats.get('signal_stats', {}).get(symbol, {}) - - logger.info(f"\n 📈 {symbol}:") - logger.info(f" Predictions: {training_stats.get('total_predictions', 0)}") - logger.info(f" Success Rate: {signal_stats.get('success_rate', 0):.1%}") - logger.info(f" Avg Inference: {inference_stats.get('average_inference_time_ms', 0):.1f}ms") - logger.info(f" Current Signals: {signal_stats.get('current_signals', 0)}") - - # RL prediction stats for dashboard - recent_predictions = len(self.rl_predictions.get(symbol, [])) - total_predictions = len(self.prediction_history.get(symbol, [])) - logger.info(f" Dashboard Predictions: {recent_predictions} recent, {total_predictions} total") - - # Dashboard Statistics - if self.dashboard: - logger.info(f"\nDASHBOARD STATISTICS:") - logger.info(f" Active Connections: {len(self.dashboard.websocket_connections)}") - logger.info(f" Server Status: {'RUNNING' if self.dashboard.site else 'STOPPED'}") - logger.info(f" URL: http://{self.dashboard.host}:{self.dashboard.port}") - - # Trading Executor Statistics - if self.trading_executor: - positions = self.trading_executor.get_positions() - trade_history = self.trading_executor.get_trade_history() - - logger.info(f"\n💰 TRADING STATISTICS:") - logger.info(f" Active Positions: {len(positions)}") - logger.info(f" Total Trades: {len(trade_history)}") - - if trade_history: - total_pnl = sum(trade.pnl for trade in trade_history) - profitable_trades = sum(1 for trade in trade_history if trade.pnl > 0) - win_rate = (profitable_trades / len(trade_history)) * 100 - - logger.info(f" Total P&L: ${total_pnl:.2f}") - logger.info(f" Win Rate: {win_rate:.1f}%") - - logger.info("=" * 80) - - except Exception as e: - logger.error(f"Error printing integrated stats: {e}") - - async def stop(self): - """Stop the integrated system gracefully""" - if not self.running: - return - - logger.info("Stopping Integrated RL COB System...") - - self.running = False - - # Stop dashboard - if self.dashboard: - await self.dashboard.stop() - logger.info("Dashboard stopped") - - # Stop RL trader - if self.trader: - await self.trader.stop() - logger.info("RL Trader stopped") - - logger.info("🏁 Integrated system stopped successfully") - - def get_rl_predictions(self, symbol: str) -> Dict[str, Any]: - """Get RL predictions for dashboard display""" - return { - 'recent_predictions': self.rl_predictions.get(symbol, []), - 'prediction_history': self.prediction_history.get(symbol, []), - 'total_predictions': len(self.prediction_history.get(symbol, [])), - 'recent_count': len(self.rl_predictions.get(symbol, [])) - } - -class EnhancedCOBDashboardServer(COBDashboardServer): - """Enhanced COB Dashboard with RL prediction integration""" - - def __init__(self, host: str = 'localhost', port: int = 8053, rl_system: IntegratedRLCOBSystem = None): - super().__init__(host, port) - self.rl_system = rl_system - - # Add RL prediction routes - self._setup_rl_routes() - - logger.info("Enhanced COB Dashboard with RL predictions initialized") - - async def serve_dashboard(self, request): - """Serve the enhanced dashboard HTML with RL predictions""" - try: - # Read the enhanced dashboard HTML - dashboard_path = os.path.join(os.path.dirname(__file__), 'enhanced_cob_dashboard.html') - - if os.path.exists(dashboard_path): - with open(dashboard_path, 'r', encoding='utf-8') as f: - html_content = f.read() - return web.Response(text=html_content, content_type='text/html') - else: - # Fallback to basic dashboard - logger.warning("Enhanced dashboard HTML not found, using basic dashboard") - return await super().serve_dashboard(request) - - except Exception as e: - logger.error(f"Error serving enhanced dashboard: {e}") - return web.Response(text="Dashboard error", status=500) - - def _setup_rl_routes(self): - """Setup additional routes for RL predictions""" - self.app.router.add_get('/api/rl-predictions/{symbol}', self.get_rl_predictions) - self.app.router.add_get('/api/rl-status', self.get_rl_status) - - async def get_rl_predictions(self, request): - """Get RL predictions for a symbol""" - try: - symbol = request.match_info['symbol'] - symbol = symbol.replace('%2F', '/') - - if symbol not in self.symbols: - return web.json_response({ - 'error': f'Symbol {symbol} not supported' - }, status=400) - - if not self.rl_system: - return web.json_response({ - 'error': 'RL system not available' - }, status=503) - - predictions = self.rl_system.get_rl_predictions(symbol) - - return web.json_response({ - 'symbol': symbol, - 'timestamp': datetime.now().isoformat(), - 'predictions': predictions - }) - - except Exception as e: - logger.error(f"Error getting RL predictions: {e}") - return web.json_response({ - 'error': str(e) - }, status=500) - - async def get_rl_status(self, request): - """Get RL system status""" - try: - if not self.rl_system or not self.rl_system.trader: - return web.json_response({ - 'status': 'inactive', - 'error': 'RL system not available' - }) - - rl_stats = self.rl_system.trader.get_performance_stats() - - status = { - 'status': 'active', - 'symbols': self.rl_system.trader.symbols, - 'model_info': rl_stats.get('model_info', {}), - 'inference_interval_ms': self.rl_system.trader.inference_interval_ms, - 'confidence_threshold': self.rl_system.trader.min_confidence_threshold, - 'required_predictions': self.rl_system.trader.required_confident_predictions, - 'device': str(self.rl_system.trader.device), - 'running': self.rl_system.trader.running - } - - return web.json_response(status) - - except Exception as e: - logger.error(f"Error getting RL status: {e}") - return web.json_response({ - 'status': 'error', - 'error': str(e) - }, status=500) - - async def _broadcast_cob_update(self, symbol: str, data: Dict): - """Enhanced COB update broadcast with RL predictions""" - try: - # Get RL predictions if available - rl_data = {} - if self.rl_system: - rl_predictions = self.rl_system.get_rl_predictions(symbol) - rl_data = { - 'rl_predictions': rl_predictions.get('recent_predictions', [])[-10:], # Last 10 - 'prediction_count': rl_predictions.get('total_predictions', 0) - } - - # Enhanced data with RL predictions - enhanced_data = { - **data, - 'rl_data': rl_data - } - - # Broadcast to all WebSocket connections - message = json.dumps({ - 'type': 'cob_update', - 'symbol': symbol, - 'timestamp': datetime.now().isoformat(), - 'data': enhanced_data - }, default=str) - - # Send to all connected clients - disconnected = [] - for ws in self.websocket_connections: - try: - await ws.send_str(message) - except Exception as e: - logger.debug(f"WebSocket send failed: {e}") - disconnected.append(ws) - - # Remove disconnected clients - for ws in disconnected: - self.websocket_connections.discard(ws) - - except Exception as e: - logger.error(f"Error broadcasting enhanced COB update: {e}") - -async def main(): - """Main entry point for integrated RL COB system""" - # Create logs directory - os.makedirs('logs', exist_ok=True) - - # Initialize and start integrated system - system = IntegratedRLCOBSystem() - - try: - await system.start() - except KeyboardInterrupt: - logger.info("Received keyboard interrupt") - except Exception as e: - logger.error(f"System error: {e}") - raise - finally: - await system.stop() - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/run_mexc_browser.py b/run_mexc_browser.py deleted file mode 100644 index 76eb3dc..0000000 --- a/run_mexc_browser.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -""" -One-Click MEXC Browser Launcher - -Simply run this script to start capturing MEXC futures trading requests. -""" - -import sys -import os - -# Add project root to path -project_root = os.path.dirname(os.path.abspath(__file__)) -sys.path.insert(0, project_root) - -def main(): - """Launch MEXC browser automation""" - print("🚀 MEXC Futures Request Interceptor") - print("=" * 50) - print("This will automatically:") - print("✅ Install ChromeDriver") - print("✅ Open MEXC futures page") - print("✅ Capture all API requests") - print("✅ Extract session cookies") - print("✅ Save data to JSON files") - print("\nRequirements will be installed automatically if missing.") - - try: - # First try to run the auto browser directly - from core.mexc_webclient.auto_browser import main as run_auto_browser - run_auto_browser() - - except ImportError as e: - print(f"\n⚠️ Import error: {e}") - print("Installing requirements first...") - - # Try to install requirements and run setup - try: - from setup_mexc_browser import main as setup_main - setup_main() - except ImportError: - print("❌ Could not find setup script") - print("Please run: pip install selenium webdriver-manager") - - except Exception as e: - print(f"❌ Error: {e}") - print("\nTroubleshooting:") - print("1. Make sure you have Chrome browser installed") - print("2. Check your internet connection") - print("3. Try running: pip install selenium webdriver-manager") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/run_optimized_cob_system.py b/run_optimized_cob_system.py deleted file mode 100644 index 405ef11..0000000 --- a/run_optimized_cob_system.py +++ /dev/null @@ -1,451 +0,0 @@ -#!/usr/bin/env python3 -""" -Optimized COB System - Eliminates Redundant Implementations - -This optimized script runs both the COB dashboard and 1B RL trading system -in a single process with shared data sources to eliminate redundancies: - -BEFORE (Redundant): -- Dashboard: Own COBIntegration instance -- RL Trader: Own COBIntegration instance -- Training: Own COBIntegration instance -= 3x WebSocket connections, 3x order book processing - -AFTER (Optimized): -- Shared COBIntegration instance -- Single WebSocket connection per exchange -- Shared order book processing and caching -= 1x connections, 1x processing, shared memory - -Resource savings: ~60% memory, ~70% network bandwidth -""" - -import asyncio -import logging -import signal -import sys -import json -import os -from datetime import datetime -from typing import Dict, Any, List, Optional -from aiohttp import web -import threading - -# Local imports -from core.cob_integration import COBIntegration -from core.data_provider import DataProvider -from core.trading_executor import TradingExecutor -from core.config import load_config -from web.cob_realtime_dashboard import COBDashboardServer - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler('logs/optimized_cob_system.log'), - logging.StreamHandler(sys.stdout) - ] -) - -logger = logging.getLogger(__name__) - -class OptimizedCOBSystem: - """ - Optimized COB System - Single COB instance shared across all components - """ - - def __init__(self, config_path: str = "config.yaml"): - """Initialize optimized system with shared resources""" - self.config = load_config(config_path) - self.running = False - - # Shared components (eliminate redundancy) - self.data_provider = DataProvider() - self.shared_cob_integration: Optional[COBIntegration] = None - self.trading_executor: Optional[TradingExecutor] = None - - # Dashboard using shared COB - self.dashboard_server: Optional[COBDashboardServer] = None - - # Performance tracking - self.performance_stats = { - 'start_time': None, - 'cob_updates_processed': 0, - 'dashboard_connections': 0, - 'memory_saved_mb': 0 - } - - # Setup signal handlers - signal.signal(signal.SIGINT, self._signal_handler) - signal.signal(signal.SIGTERM, self._signal_handler) - - logger.info("OptimizedCOBSystem initialized - Eliminating redundant implementations") - - def _signal_handler(self, signum, frame): - """Handle shutdown signals""" - logger.info(f"Received signal {signum}, initiating graceful shutdown...") - self.running = False - - async def start(self): - """Start the optimized COB system""" - try: - logger.info("=" * 70) - logger.info("🚀 OPTIMIZED COB SYSTEM STARTING") - logger.info("=" * 70) - logger.info("Eliminating redundant COB implementations...") - logger.info("Single shared COB integration for all components") - logger.info("=" * 70) - - # Initialize shared components - await self._initialize_shared_components() - - # Initialize dashboard with shared COB - await self._initialize_optimized_dashboard() - - # Start the integrated system - await self._start_optimized_system() - - # Run main monitoring loop - await self._run_optimized_loop() - - except Exception as e: - logger.error(f"Critical error in optimized system: {e}") - import traceback - logger.error(traceback.format_exc()) - raise - finally: - await self.stop() - - async def _initialize_shared_components(self): - """Initialize shared components (eliminates redundancy)""" - logger.info("1. Initializing shared COB integration...") - - # Single COB integration instance for entire system - self.shared_cob_integration = COBIntegration( - data_provider=self.data_provider, - symbols=['BTC/USDT', 'ETH/USDT'] - ) - - # Start the shared COB integration - await self.shared_cob_integration.start() - - logger.info("2. Initializing trading executor...") - - # Trading executor configuration - trading_config = self.config.get('trading', {}) - mexc_config = self.config.get('mexc', {}) - simulation_mode = mexc_config.get('simulation_mode', True) - - self.trading_executor = TradingExecutor() - - logger.info("✅ Shared components initialized") - logger.info(f" Single COB integration: {len(self.shared_cob_integration.symbols)} symbols") - logger.info(f" Trading mode: {'SIMULATION' if simulation_mode else 'LIVE'}") - - async def _initialize_optimized_dashboard(self): - """Initialize dashboard that uses shared COB (no redundant instance)""" - logger.info("3. Initializing optimized dashboard...") - - # Create dashboard and replace its COB with our shared one - self.dashboard_server = COBDashboardServer(host='localhost', port=8053) - - # Replace the dashboard's COB integration with our shared one - self.dashboard_server.cob_integration = self.shared_cob_integration - - logger.info("✅ Optimized dashboard initialized with shared COB") - - async def _start_optimized_system(self): - """Start the optimized system with shared resources""" - logger.info("4. Starting optimized system...") - - self.running = True - self.performance_stats['start_time'] = datetime.now() - - # Start dashboard server with shared COB - await self.dashboard_server.start() - - # Estimate memory savings - # Start RL trader - await self.rl_trader.start() - - # Estimate memory savings - estimated_savings = self._calculate_memory_savings() - self.performance_stats['memory_saved_mb'] = estimated_savings - - logger.info("🚀 Optimized COB System started successfully!") - logger.info(f"💾 Estimated memory savings: {estimated_savings:.0f} MB") - logger.info(f"🌐 Dashboard: http://localhost:8053") - logger.info(f"🤖 RL Training: Active with 1B parameters") - logger.info(f"📊 Shared COB: Single integration for all components") - logger.info("🔄 System Status: OPTIMIZED - No redundant implementations") - - def _calculate_memory_savings(self) -> float: - """Calculate estimated memory savings from eliminating redundancy""" - # Estimates based on typical COB memory usage - cob_integration_memory_mb = 512 # Order books, caches, connections - websocket_connection_memory_mb = 64 # Per exchange connection - - # Before: 3 separate COB integrations (dashboard + RL trader + training) - before_memory = 3 * cob_integration_memory_mb + 3 * websocket_connection_memory_mb - - # After: 1 shared COB integration - after_memory = 1 * cob_integration_memory_mb + 1 * websocket_connection_memory_mb - - savings = before_memory - after_memory - return savings - - async def _run_optimized_loop(self): - """Main optimized monitoring loop""" - logger.info("Starting optimized monitoring loop...") - - last_stats_time = datetime.now() - stats_interval = 60 # Print stats every 60 seconds - - while self.running: - try: - # Sleep for a bit - await asyncio.sleep(10) - - # Update performance stats - self._update_performance_stats() - - # Print periodic statistics - current_time = datetime.now() - if (current_time - last_stats_time).total_seconds() >= stats_interval: - await self._print_optimized_stats() - last_stats_time = current_time - - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"Error in optimized loop: {e}") - await asyncio.sleep(5) - - logger.info("Optimized monitoring loop stopped") - - def _update_performance_stats(self): - """Update performance statistics""" - try: - # Get stats from shared COB integration - if self.shared_cob_integration: - cob_stats = self.shared_cob_integration.get_statistics() - self.performance_stats['cob_updates_processed'] = cob_stats.get('total_signals', {}).get('BTC/USDT', 0) - - # Get stats from dashboard - if self.dashboard_server: - dashboard_stats = self.dashboard_server.get_stats() - self.performance_stats['dashboard_connections'] = dashboard_stats.get('active_connections', 0) - - # Get stats from RL trader - if self.rl_trader: - rl_stats = self.rl_trader.get_stats() - self.performance_stats['rl_predictions'] = rl_stats.get('total_predictions', 0) - - # Get stats from trading executor - if self.trading_executor: - trade_history = self.trading_executor.get_trade_history() - self.performance_stats['trades_executed'] = len(trade_history) - - except Exception as e: - logger.warning(f"Error updating performance stats: {e}") - - async def _print_optimized_stats(self): - """Print comprehensive optimized system statistics""" - try: - stats = self.performance_stats - uptime = (datetime.now() - stats['start_time']).total_seconds() if stats['start_time'] else 0 - - logger.info("=" * 80) - logger.info("🚀 OPTIMIZED COB SYSTEM PERFORMANCE STATISTICS") - logger.info("=" * 80) - - logger.info("📊 Resource Optimization:") - logger.info(f" Memory Saved: {stats['memory_saved_mb']:.0f} MB") - logger.info(f" Uptime: {uptime:.0f} seconds") - logger.info(f" COB Updates: {stats['cob_updates_processed']}") - - logger.info("\n🌐 Dashboard Statistics:") - logger.info(f" Active Connections: {stats['dashboard_connections']}") - logger.info(f" Server Status: {'RUNNING' if self.dashboard_server else 'STOPPED'}") - - logger.info("\n🤖 RL Trading Statistics:") - logger.info(f" Total Predictions: {stats['rl_predictions']}") - logger.info(f" Trades Executed: {stats['trades_executed']}") - logger.info(f" Trainer Status: {'ACTIVE' if self.rl_trader else 'STOPPED'}") - - # Shared COB statistics - if self.shared_cob_integration: - cob_stats = self.shared_cob_integration.get_statistics() - logger.info("\n📈 Shared COB Integration:") - logger.info(f" Active Exchanges: {', '.join(cob_stats.get('active_exchanges', []))}") - logger.info(f" Streaming: {cob_stats.get('is_streaming', False)}") - logger.info(f" CNN Callbacks: {cob_stats.get('cnn_callbacks', 0)}") - logger.info(f" DQN Callbacks: {cob_stats.get('dqn_callbacks', 0)}") - logger.info(f" Dashboard Callbacks: {cob_stats.get('dashboard_callbacks', 0)}") - - logger.info("=" * 80) - logger.info("✅ OPTIMIZATION STATUS: Redundancy eliminated, shared resources active") - - except Exception as e: - logger.error(f"Error printing optimized stats: {e}") - - async def stop(self): - """Stop the optimized system gracefully""" - if not self.running: - return - - logger.info("Stopping Optimized COB System...") - - self.running = False - - # Stop RL trader - if self.rl_trader: - await self.rl_trader.stop() - logger.info("✅ RL Trader stopped") - - # Stop dashboard - if self.dashboard_server: - await self.dashboard_server.stop() - logger.info("✅ Dashboard stopped") - - # Stop shared COB integration (last, as others depend on it) - if self.shared_cob_integration: - await self.shared_cob_integration.stop() - logger.info("✅ Shared COB integration stopped") - - # Print final optimization report - await self._print_final_optimization_report() - - logger.info("Optimized COB System stopped successfully") - - async def _print_final_optimization_report(self): - """Print final optimization report""" - stats = self.performance_stats - uptime = (datetime.now() - stats['start_time']).total_seconds() if stats['start_time'] else 0 - - logger.info("\n📊 FINAL OPTIMIZATION REPORT:") - logger.info(f" Total Runtime: {uptime:.0f} seconds") - logger.info(f" Memory Saved: {stats['memory_saved_mb']:.0f} MB") - logger.info(f" COB Updates Processed: {stats['cob_updates_processed']}") - logger.info(f" RL Predictions Made: {stats['rl_predictions']}") - logger.info(f" Trades Executed: {stats['trades_executed']}") - logger.info(" ✅ Redundant implementations eliminated") - logger.info(" ✅ Shared COB integration successful") - - -# Simplified components that use shared COB (no redundant integrations) - -class EnhancedCOBDashboard(COBDashboardServer): - """Enhanced dashboard that uses shared COB integration""" - - def __init__(self, host: str = 'localhost', port: int = 8053, - shared_cob: COBIntegration = None, performance_tracker: Dict = None): - # Initialize parent without creating new COB integration - self.shared_cob = shared_cob - self.performance_tracker = performance_tracker or {} - super().__init__(host, port) - - # Use shared COB instead of creating new one - self.cob_integration = shared_cob - logger.info("Enhanced dashboard using shared COB integration (no redundancy)") - - def get_stats(self) -> Dict[str, Any]: - """Get dashboard statistics""" - return { - 'active_connections': len(self.websocket_connections), - 'using_shared_cob': self.shared_cob is not None, - 'server_running': self.runner is not None - } - -class OptimizedRLTrader: - """Optimized RL trader that uses shared COB integration""" - - def __init__(self, symbols: List[str], shared_cob: COBIntegration, - trading_executor: TradingExecutor, performance_tracker: Dict = None): - self.symbols = symbols - self.shared_cob = shared_cob - self.trading_executor = trading_executor - self.performance_tracker = performance_tracker or {} - self.running = False - - # Subscribe to shared COB updates instead of creating new integration - self.subscription_id = None - self.prediction_count = 0 - - logger.info("Optimized RL trader using shared COB integration (no redundancy)") - - async def start(self): - """Start RL trader with shared COB""" - self.running = True - - # Subscribe to shared COB updates - self.subscription_id = self.shared_cob.add_dqn_callback(self._on_cob_update) - - # Start prediction loop - asyncio.create_task(self._prediction_loop()) - - logger.info("Optimized RL trader started with shared COB subscription") - - async def stop(self): - """Stop RL trader""" - self.running = False - logger.info("Optimized RL trader stopped") - - async def _on_cob_update(self, symbol: str, data: Dict): - """Handle COB updates from shared integration""" - try: - # Process RL prediction using shared data - self.prediction_count += 1 - - # Simple prediction logic (placeholder) - confidence = 0.75 # Example confidence - - if self.prediction_count % 100 == 0: - logger.info(f"RL Prediction #{self.prediction_count} for {symbol} (confidence: {confidence:.2f})") - - except Exception as e: - logger.error(f"Error in RL update: {e}") - - async def _prediction_loop(self): - """Main prediction loop""" - while self.running: - try: - # RL model inference would go here - await asyncio.sleep(0.2) # 200ms inference interval - except Exception as e: - logger.error(f"Error in prediction loop: {e}") - await asyncio.sleep(1) - - def get_stats(self) -> Dict[str, Any]: - """Get RL trader statistics""" - return { - 'total_predictions': self.prediction_count, - 'using_shared_cob': self.shared_cob is not None, - 'subscription_active': self.subscription_id is not None - } - - -async def main(): - """Main entry point for optimized COB system""" - try: - # Create logs directory - os.makedirs('logs', exist_ok=True) - - # Initialize and start optimized system - system = OptimizedCOBSystem() - await system.start() - - except KeyboardInterrupt: - logger.info("Received keyboard interrupt, shutting down...") - except Exception as e: - logger.error(f"Critical error: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - # Set event loop policy for Windows compatibility - if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'): - asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) - - asyncio.run(main()) \ No newline at end of file diff --git a/run_realtime_rl_cob_trader.py b/run_realtime_rl_cob_trader.py deleted file mode 100644 index 57e425e..0000000 --- a/run_realtime_rl_cob_trader.py +++ /dev/null @@ -1,324 +0,0 @@ -#!/usr/bin/env python3 -""" -Real-time RL COB Trader Launcher - -Launch script for the real-time reinforcement learning trader that: -1. Uses COB data for training a 1B parameter model -2. Performs inference every 200ms -3. Accumulates confident signals for trade execution -4. Trains continuously in real-time based on outcomes - -This script provides a complete trading system integration. -""" - -import asyncio -import logging -import signal -import sys -import json -import os -from datetime import datetime -from typing import Dict, Any, Optional - -# Local imports -from core.realtime_rl_cob_trader import RealtimeRLCOBTrader -from core.trading_executor import TradingExecutor -from core.config import load_config - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler('logs/realtime_rl_cob_trader.log'), - logging.StreamHandler(sys.stdout) - ] -) - -logger = logging.getLogger(__name__) - -class RealtimeRLCOBTraderLauncher: - """ - Launcher for Real-time RL COB Trader system - """ - - def __init__(self, config_path: str = "config.yaml"): - """Initialize launcher with configuration""" - self.config_path = config_path - self.config = load_config(config_path) - self.trader: Optional[RealtimeRLCOBTrader] = None - self.trading_executor: Optional[TradingExecutor] = None - self.running = False - - # Setup signal handlers for graceful shutdown - signal.signal(signal.SIGINT, self._signal_handler) - signal.signal(signal.SIGTERM, self._signal_handler) - - logger.info("RealtimeRLCOBTraderLauncher initialized") - - def _signal_handler(self, signum, frame): - """Handle shutdown signals""" - logger.info(f"Received signal {signum}, initiating graceful shutdown...") - self.running = False - - async def start(self): - """Start the real-time RL COB trading system""" - try: - logger.info("=" * 60) - logger.info("REAL-TIME RL COB TRADER SYSTEM STARTING") - logger.info("=" * 60) - - # Initialize trading executor - await self._initialize_trading_executor() - - # Initialize RL trader - await self._initialize_rl_trader() - - # Start the trading system - await self._start_trading_system() - - # Run main loop - await self._run_main_loop() - - except Exception as e: - logger.error(f"Critical error in trader launcher: {e}") - raise - finally: - await self.stop() - - async def _initialize_trading_executor(self): - """Initialize the trading executor""" - logger.info("Initializing Trading Executor...") - - # Get trading configuration - trading_config = self.config.get('trading', {}) - mexc_config = self.config.get('mexc', {}) - - # Determine if we should run in simulation mode - simulation_mode = mexc_config.get('simulation_mode', True) - - if simulation_mode: - logger.info("Running in SIMULATION mode - no real trades will be executed") - else: - logger.warning("Running in LIVE TRADING mode - real money at risk!") - - # Add safety confirmation for live trading - confirmation = input("Type 'CONFIRM_LIVE_TRADING' to proceed with live trading: ") - if confirmation != 'CONFIRM_LIVE_TRADING': - logger.info("Live trading not confirmed, switching to simulation mode") - simulation_mode = True - - # Initialize trading executor - self.trading_executor = TradingExecutor(self.config_path) - - logger.info(f"Trading Executor initialized in {'SIMULATION' if simulation_mode else 'LIVE'} mode") - - async def _initialize_rl_trader(self): - """Initialize the RL trader""" - logger.info("Initializing Real-time RL COB Trader...") - - # Get RL configuration - rl_config = self.config.get('realtime_rl', {}) - - # Trading symbols - symbols = rl_config.get('symbols', ['BTC/USDT', 'ETH/USDT']) - - # RL parameters - inference_interval_ms = rl_config.get('inference_interval_ms', 200) - min_confidence_threshold = rl_config.get('min_confidence_threshold', 0.7) - required_confident_predictions = rl_config.get('required_confident_predictions', 3) - model_checkpoint_dir = rl_config.get('model_checkpoint_dir', 'models/realtime_rl_cob') - - # Initialize RL trader - if self.trading_executor is None: - raise RuntimeError("Trading executor not initialized") - - self.trader = RealtimeRLCOBTrader( - symbols=symbols, - trading_executor=self.trading_executor, - model_checkpoint_dir=model_checkpoint_dir, - inference_interval_ms=inference_interval_ms, - min_confidence_threshold=min_confidence_threshold, - required_confident_predictions=required_confident_predictions - ) - - logger.info(f"RL Trader initialized for symbols: {symbols}") - logger.info(f"Inference interval: {inference_interval_ms}ms") - logger.info(f"Confidence threshold: {min_confidence_threshold}") - logger.info(f"Required predictions: {required_confident_predictions}") - - async def _start_trading_system(self): - """Start the complete trading system""" - logger.info("Starting Real-time RL COB Trading System...") - - # Start RL trader (this will start COB integration internally) - if self.trader is None: - raise RuntimeError("RL trader not initialized") - await self.trader.start() - - self.running = True - - logger.info("✅ Real-time RL COB Trading System started successfully!") - logger.info("🔥 1B parameter model training and inference active") - logger.info("📊 COB data streaming and processing") - logger.info("🎯 Signal accumulation and trade execution ready") - logger.info("⚡ Real-time training on prediction outcomes") - - async def _run_main_loop(self): - """Main monitoring and statistics loop""" - logger.info("Starting main monitoring loop...") - - last_stats_time = datetime.now() - stats_interval = 60 # Print stats every 60 seconds - - while self.running: - try: - # Sleep for a bit - await asyncio.sleep(10) - - # Print periodic statistics - current_time = datetime.now() - if (current_time - last_stats_time).total_seconds() >= stats_interval: - await self._print_performance_stats() - last_stats_time = current_time - - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"Error in main loop: {e}") - await asyncio.sleep(5) - - logger.info("Main monitoring loop stopped") - - async def _print_performance_stats(self): - """Print comprehensive performance statistics""" - try: - if not self.trader: - return - - stats = self.trader.get_performance_stats() - - logger.info("=" * 80) - logger.info("🔥 REAL-TIME RL COB TRADER PERFORMANCE STATISTICS") - logger.info("=" * 80) - - # Model information - logger.info("📊 Model Information:") - for symbol, model_info in stats.get('model_info', {}).items(): - total_params = model_info.get('total_parameters', 0) - logger.info(f" {symbol}: {total_params:,} parameters ({total_params/1e9:.2f}B)") - - # Training statistics - logger.info("\n🧠 Training Statistics:") - for symbol, training_stats in stats.get('training_stats', {}).items(): - total_preds = training_stats.get('total_predictions', 0) - successful_preds = training_stats.get('successful_predictions', 0) - success_rate = (successful_preds / max(1, total_preds)) * 100 - avg_loss = training_stats.get('average_loss', 0.0) - training_steps = training_stats.get('total_training_steps', 0) - last_training = training_stats.get('last_training_time') - - logger.info(f" {symbol}:") - logger.info(f" Predictions: {total_preds} (Success: {success_rate:.1f}%)") - logger.info(f" Training Steps: {training_steps}") - logger.info(f" Average Loss: {avg_loss:.6f}") - if last_training: - logger.info(f" Last Training: {last_training}") - - # Inference statistics - logger.info("\n⚡ Inference Statistics:") - for symbol, inference_stats in stats.get('inference_stats', {}).items(): - total_inferences = inference_stats.get('total_inferences', 0) - avg_time = inference_stats.get('average_inference_time_ms', 0.0) - last_inference = inference_stats.get('last_inference_time') - - logger.info(f" {symbol}:") - logger.info(f" Total Inferences: {total_inferences}") - logger.info(f" Average Time: {avg_time:.1f}ms") - if last_inference: - logger.info(f" Last Inference: {last_inference}") - - # Signal statistics - logger.info("\n🎯 Signal Accumulation:") - for symbol, signal_stats in stats.get('signal_stats', {}).items(): - current_signals = signal_stats.get('current_signals', 0) - confidence_sum = signal_stats.get('confidence_sum', 0.0) - success_rate = signal_stats.get('success_rate', 0.0) * 100 - - logger.info(f" {symbol}:") - logger.info(f" Current Signals: {current_signals}") - logger.info(f" Confidence Sum: {confidence_sum:.2f}") - logger.info(f" Historical Success Rate: {success_rate:.1f}%") - - # Trading executor statistics - if self.trading_executor: - positions = self.trading_executor.get_positions() - trade_history = self.trading_executor.get_trade_history() - - logger.info("\n💰 Trading Statistics:") - logger.info(f" Active Positions: {len(positions)}") - logger.info(f" Total Trades: {len(trade_history)}") - - if trade_history: - # Calculate P&L statistics - total_pnl = sum(trade.pnl for trade in trade_history) - profitable_trades = sum(1 for trade in trade_history if trade.pnl > 0) - win_rate = (profitable_trades / len(trade_history)) * 100 - - logger.info(f" Total P&L: ${total_pnl:.2f}") - logger.info(f" Win Rate: {win_rate:.1f}%") - - # Show active positions - if positions: - logger.info("\n📍 Active Positions:") - for symbol, position in positions.items(): - logger.info(f" {symbol}: {position.side} {position.quantity:.6f} @ ${position.entry_price:.2f}") - - logger.info("=" * 80) - - except Exception as e: - logger.error(f"Error printing performance stats: {e}") - - async def stop(self): - """Stop the trading system gracefully""" - if not self.running: - return - - logger.info("Stopping Real-time RL COB Trading System...") - - self.running = False - - # Stop RL trader - if self.trader: - await self.trader.stop() - logger.info("✅ RL Trader stopped") - - # Print final statistics - if self.trader: - logger.info("\n📊 Final Performance Summary:") - await self._print_performance_stats() - - logger.info("Real-time RL COB Trading System stopped successfully") - -async def main(): - """Main entry point""" - try: - # Create logs directory if it doesn't exist - os.makedirs('logs', exist_ok=True) - - # Initialize and start launcher - launcher = RealtimeRLCOBTraderLauncher() - await launcher.start() - - except KeyboardInterrupt: - logger.info("Received keyboard interrupt, shutting down...") - except Exception as e: - logger.error(f"Critical error: {e}") - raise - -if __name__ == "__main__": - # Set event loop policy for Windows compatibility - if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'): - asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) - - asyncio.run(main()) \ No newline at end of file diff --git a/run_simple_dashboard.py b/run_simple_dashboard.py deleted file mode 100644 index 9bf4909..0000000 --- a/run_simple_dashboard.py +++ /dev/null @@ -1,218 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple Dashboard Runner - Fixed version for testing -""" - -import os -import sys -import logging -import time -import threading -from pathlib import Path - -# Fix OpenMP library conflicts -os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' -os.environ['OMP_NUM_THREADS'] = '4' - -# Fix matplotlib backend -import matplotlib -matplotlib.use('Agg') - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def create_simple_dashboard(): - """Create a simple working dashboard""" - try: - import dash - from dash import html, dcc, Input, Output - import plotly.graph_objs as go - import pandas as pd - import numpy as np - from datetime import datetime, timedelta - - # Create Dash app - app = dash.Dash(__name__) - - # Simple layout - app.layout = html.Div([ - html.H1("Trading System Dashboard", style={'textAlign': 'center', 'color': '#2c3e50'}), - - html.Div([ - html.Div([ - html.H3("System Status", style={'color': '#27ae60'}), - html.P(id='system-status', children="System: RUNNING", style={'fontSize': '18px'}), - html.P(id='current-time', children=f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"), - ], style={'width': '48%', 'display': 'inline-block', 'padding': '20px'}), - - html.Div([ - html.H3("Trading Stats", style={'color': '#3498db'}), - html.P("Total Trades: 0"), - html.P("Success Rate: 0%"), - html.P("Current PnL: $0.00"), - ], style={'width': '48%', 'display': 'inline-block', 'padding': '20px'}), - ]), - - html.Div([ - dcc.Graph(id='price-chart'), - ], style={'padding': '20px'}), - - html.Div([ - dcc.Graph(id='performance-chart'), - ], style={'padding': '20px'}), - - # Auto-refresh component - dcc.Interval( - id='interval-component', - interval=5000, # Update every 5 seconds - n_intervals=0 - ) - ]) - - # Callback for updating time - @app.callback( - Output('current-time', 'children'), - Input('interval-component', 'n_intervals') - ) - def update_time(n): - return f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - - # Callback for price chart - @app.callback( - Output('price-chart', 'figure'), - Input('interval-component', 'n_intervals') - ) - def update_price_chart(n): - # Generate sample data - dates = pd.date_range(start=datetime.now() - timedelta(hours=24), - end=datetime.now(), freq='1H') - prices = 3000 + np.cumsum(np.random.randn(len(dates)) * 10) - - fig = go.Figure() - fig.add_trace(go.Scatter( - x=dates, - y=prices, - mode='lines', - name='ETH/USDT', - line=dict(color='#3498db', width=2) - )) - - fig.update_layout( - title='ETH/USDT Price Chart (24H)', - xaxis_title='Time', - yaxis_title='Price (USD)', - template='plotly_white', - height=400 - ) - - return fig - - # Callback for performance chart - @app.callback( - Output('performance-chart', 'figure'), - Input('interval-component', 'n_intervals') - ) - def update_performance_chart(n): - # Generate sample performance data - dates = pd.date_range(start=datetime.now() - timedelta(days=7), - end=datetime.now(), freq='1D') - performance = np.cumsum(np.random.randn(len(dates)) * 0.02) * 100 - - fig = go.Figure() - fig.add_trace(go.Scatter( - x=dates, - y=performance, - mode='lines+markers', - name='Portfolio Performance', - line=dict(color='#27ae60', width=3), - marker=dict(size=6) - )) - - fig.update_layout( - title='Portfolio Performance (7 Days)', - xaxis_title='Date', - yaxis_title='Performance (%)', - template='plotly_white', - height=400 - ) - - return fig - - return app - - except Exception as e: - logger.error(f"Error creating dashboard: {e}") - import traceback - logger.error(traceback.format_exc()) - return None - -def test_data_provider(): - """Test data provider in background""" - try: - from core.data_provider import DataProvider - from core.api_rate_limiter import get_rate_limiter - - logger.info("Testing data provider...") - - # Create data provider - data_provider = DataProvider( - symbols=['ETH/USDT'], - timeframes=['1m', '5m'] - ) - - # Test getting data - df = data_provider.get_historical_data('ETH/USDT', '1m', limit=10) - if df is not None and len(df) > 0: - logger.info(f"✓ Data provider working: {len(df)} candles retrieved") - else: - logger.warning("⚠ Data provider returned no data (rate limiting)") - - # Test rate limiter status - rate_limiter = get_rate_limiter() - status = rate_limiter.get_all_endpoint_status() - logger.info(f"Rate limiter status: {status}") - - except Exception as e: - logger.error(f"Data provider test error: {e}") - -def main(): - """Main function""" - logger.info("=" * 60) - logger.info("SIMPLE DASHBOARD RUNNER - TESTING SYSTEM") - logger.info("=" * 60) - - # Test data provider in background - data_thread = threading.Thread(target=test_data_provider, daemon=True) - data_thread.start() - - # Create and run dashboard - app = create_simple_dashboard() - if app is None: - logger.error("Failed to create dashboard") - return - - try: - logger.info("Starting dashboard server...") - logger.info("Dashboard URL: http://127.0.0.1:8050") - logger.info("Press Ctrl+C to stop") - - # Run the dashboard - app.run(debug=False, host='127.0.0.1', port=8050, use_reloader=False) - - except KeyboardInterrupt: - logger.info("Dashboard stopped by user") - except Exception as e: - logger.error(f"Dashboard error: {e}") - import traceback - logger.error(traceback.format_exc()) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/run_stable_dashboard.py b/run_stable_dashboard.py deleted file mode 100644 index a98e2d8..0000000 --- a/run_stable_dashboard.py +++ /dev/null @@ -1,275 +0,0 @@ -#!/usr/bin/env python3 -""" -Stable Dashboard Runner - Prioritizes System Stability - -This runner focuses on: -1. System stability and reliability -2. Core trading functionality -3. Minimal resource usage -4. Robust error handling -5. Graceful degradation - -Deferred features (until stability is achieved): -- TensorBoard integration -- Complex training loops -- Advanced visualizations -""" - -import os -import sys -import time -import logging -import threading -import signal -from pathlib import Path - -# Fix environment issues before imports -os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' -os.environ['OMP_NUM_THREADS'] = '2' # Reduced from 4 for stability - -# Fix matplotlib backend -import matplotlib -matplotlib.use('Agg') - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import setup_logging, get_config -from system_stability_audit import SystemStabilityAuditor - -# Setup logging with reduced verbosity for stability -setup_logging() -logger = logging.getLogger(__name__) - -# Reduce logging noise from other modules -logging.getLogger('werkzeug').setLevel(logging.ERROR) -logging.getLogger('dash').setLevel(logging.ERROR) -logging.getLogger('matplotlib').setLevel(logging.ERROR) - -class StableDashboardRunner: - """ - Stable dashboard runner with focus on reliability - """ - - def __init__(self): - """Initialize stable dashboard runner""" - self.config = get_config() - self.running = False - self.dashboard = None - self.stability_auditor = None - - # Core components - self.data_provider = None - self.orchestrator = None - self.trading_executor = None - - # Stability monitoring - self.last_health_check = time.time() - self.health_check_interval = 30 # Check every 30 seconds - - logger.info("Stable Dashboard Runner initialized") - - def initialize_components(self): - """Initialize core components with error handling""" - try: - logger.info("Initializing core components...") - - # Initialize data provider - from core.data_provider import DataProvider - self.data_provider = DataProvider() - logger.info("✓ Data provider initialized") - - # Initialize trading executor - from core.trading_executor import TradingExecutor - self.trading_executor = TradingExecutor() - logger.info("✓ Trading executor initialized") - - # Initialize orchestrator with minimal features for stability - from core.orchestrator import TradingOrchestrator - self.orchestrator = TradingOrchestrator( - data_provider=self.data_provider, - enhanced_rl_training=False # Disabled for stability - ) - logger.info("✓ Orchestrator initialized (training disabled for stability)") - - # Initialize dashboard - from web.clean_dashboard import CleanTradingDashboard - self.dashboard = CleanTradingDashboard( - data_provider=self.data_provider, - orchestrator=self.orchestrator, - trading_executor=self.trading_executor - ) - logger.info("✓ Dashboard initialized") - - return True - - except Exception as e: - logger.error(f"Error initializing components: {e}") - return False - - def start_stability_monitoring(self): - """Start system stability monitoring""" - try: - self.stability_auditor = SystemStabilityAuditor() - self.stability_auditor.start_monitoring() - logger.info("✓ Stability monitoring started") - except Exception as e: - logger.error(f"Error starting stability monitoring: {e}") - - def health_check(self): - """Perform system health check""" - try: - current_time = time.time() - if current_time - self.last_health_check < self.health_check_interval: - return - - self.last_health_check = current_time - - # Check stability score - if self.stability_auditor: - report = self.stability_auditor.get_stability_report() - stability_score = report.get('stability_score', 0) - - if stability_score < 50: - logger.warning(f"Low stability score: {stability_score:.1f}/100") - # Attempt to fix issues - self.stability_auditor.fix_common_issues() - elif stability_score < 80: - logger.info(f"Moderate stability: {stability_score:.1f}/100") - else: - logger.debug(f"Good stability: {stability_score:.1f}/100") - - # Check component health - if self.dashboard and hasattr(self.dashboard, 'app'): - logger.debug("✓ Dashboard responsive") - - if self.data_provider: - logger.debug("✓ Data provider active") - - if self.orchestrator: - logger.debug("✓ Orchestrator active") - - except Exception as e: - logger.error(f"Error in health check: {e}") - - def run(self): - """Run the stable dashboard""" - try: - logger.info("=" * 60) - logger.info("STABLE TRADING DASHBOARD") - logger.info("=" * 60) - logger.info("Priority: System Stability & Core Functionality") - logger.info("Training: Disabled (will be enabled after stability)") - logger.info("TensorBoard: Deferred (documented in design)") - logger.info("Focus: Dashboard, Data, Basic Trading") - logger.info("=" * 60) - - # Initialize components - if not self.initialize_components(): - logger.error("Failed to initialize components") - return False - - # Start stability monitoring - self.start_stability_monitoring() - - # Start health check thread - health_thread = threading.Thread(target=self._health_check_loop, daemon=True) - health_thread.start() - - # Get dashboard port - port = int(os.environ.get('DASHBOARD_PORT', '8051')) - - logger.info(f"Starting dashboard on http://127.0.0.1:{port}") - logger.info("Press Ctrl+C to stop") - - self.running = True - - # Start dashboard (this blocks) - if self.dashboard and hasattr(self.dashboard, 'app'): - self.dashboard.app.run_server( - host='127.0.0.1', - port=port, - debug=False, - use_reloader=False, # Disable reloader for stability - threaded=True - ) - else: - logger.error("Dashboard not properly initialized") - return False - - except KeyboardInterrupt: - logger.info("Dashboard stopped by user") - self.shutdown() - return True - except Exception as e: - logger.error(f"Error running dashboard: {e}") - import traceback - logger.error(traceback.format_exc()) - return False - - def _health_check_loop(self): - """Health check loop running in background""" - while self.running: - try: - self.health_check() - time.sleep(self.health_check_interval) - except Exception as e: - logger.error(f"Error in health check loop: {e}") - time.sleep(60) # Wait longer on error - - def shutdown(self): - """Graceful shutdown""" - try: - logger.info("Shutting down stable dashboard...") - self.running = False - - # Stop stability monitoring - if self.stability_auditor: - self.stability_auditor.stop_monitoring() - logger.info("✓ Stability monitoring stopped") - - # Stop components - if self.orchestrator and hasattr(self.orchestrator, 'stop'): - self.orchestrator.stop() - logger.info("✓ Orchestrator stopped") - - if self.data_provider and hasattr(self.data_provider, 'stop'): - self.data_provider.stop() - logger.info("✓ Data provider stopped") - - logger.info("Stable dashboard shutdown complete") - - except Exception as e: - logger.error(f"Error during shutdown: {e}") - -def signal_handler(signum, frame): - """Handle shutdown signals""" - logger.info("Received shutdown signal") - sys.exit(0) - -def main(): - """Main function""" - # Setup signal handlers - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - try: - runner = StableDashboardRunner() - success = runner.run() - - if success: - logger.info("Dashboard completed successfully") - sys.exit(0) - else: - logger.error("Dashboard failed") - sys.exit(1) - - except Exception as e: - logger.error(f"Fatal error: {e}") - import traceback - logger.error(traceback.format_exc()) - sys.exit(1) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/run_templated_dashboard.py b/run_templated_dashboard.py deleted file mode 100644 index 3d6d390..0000000 --- a/run_templated_dashboard.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python3 -""" -Run Templated Trading Dashboard -Demonstrates the new MVC template-based architecture -""" -import logging -import sys -import os - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from web.templated_dashboard import create_templated_dashboard -from web.dashboard_model import create_sample_dashboard_data -from web.template_renderer import DashboardTemplateRenderer - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - - -def main(): - """Main function to run the templated dashboard""" - try: - logger.info("=== TEMPLATED DASHBOARD DEMO ===") - - # Test the template system first - logger.info("Testing template system...") - - # Create sample data - sample_data = create_sample_dashboard_data() - logger.info(f"Created sample data with {len(sample_data.metrics)} metrics") - - # Test template renderer - renderer = DashboardTemplateRenderer() - logger.info("Template renderer initialized") - - # Create templated dashboard - logger.info("Creating templated dashboard...") - dashboard = create_templated_dashboard() - - logger.info("Dashboard created successfully!") - logger.info("Template-based MVC architecture features:") - logger.info(" ✓ HTML templates separated from Python code") - logger.info(" ✓ Data models for structured data") - logger.info(" ✓ Template renderer for clean separation") - logger.info(" ✓ Easy to modify HTML without touching Python") - logger.info(" ✓ Reusable components and templates") - - # Run the dashboard - logger.info("Starting templated dashboard server...") - dashboard.run_server(host='127.0.0.1', port=8052, debug=False) - - except Exception as e: - logger.error(f"Error running templated dashboard: {e}") - import traceback - traceback.print_exc() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/run_tensorboard.py b/run_tensorboard.py deleted file mode 100644 index 13abf27..0000000 --- a/run_tensorboard.py +++ /dev/null @@ -1,155 +0,0 @@ -#!/usr/bin/env python3 -""" -TensorBoard Launch Script - -Starts TensorBoard server for monitoring training progress. -Visualizes training metrics, rewards, state information, and model performance. - -This script can be run standalone or integrated with the dashboard. -""" - -import subprocess -import sys -import os -import time -import webbrowser -import argparse -from pathlib import Path -import logging - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def start_tensorboard(logdir="runs", port=6006, open_browser=True): - """ - Start TensorBoard server programmatically - - Args: - logdir: Directory containing TensorBoard logs - port: Port to run TensorBoard on - open_browser: Whether to open browser automatically - - Returns: - subprocess.Popen: TensorBoard process - """ - # Set log directory - runs_dir = Path(logdir) - if not runs_dir.exists(): - logger.warning(f"No '{logdir}' directory found. Creating it.") - runs_dir.mkdir(parents=True, exist_ok=True) - - # Check if there are any log directories - log_dirs = list(runs_dir.glob("*")) - if not log_dirs: - logger.warning(f"No training logs found in '{logdir}' directory.") - else: - logger.info(f"Found {len(log_dirs)} training sessions") - - # List available sessions - logger.info("Available training sessions:") - for i, log_dir in enumerate(sorted(log_dirs), 1): - logger.info(f" {i}. {log_dir.name}") - - try: - logger.info(f"Starting TensorBoard on port {port}...") - - # Try to open browser automatically if requested - if open_browser: - try: - webbrowser.open(f"http://localhost:{port}") - logger.info("Browser opened automatically") - except Exception as e: - logger.warning(f"Could not open browser automatically: {e}") - - # Start TensorBoard process with enhanced options - cmd = [ - sys.executable, - "-m", - "tensorboard.main", - "--logdir", str(runs_dir), - "--port", str(port), - "--samples_per_plugin", "images=100,audio=100,text=100", - "--reload_interval", "5", # Reload data every 5 seconds - "--reload_multifile", "true" # Better handling of multiple log files - ] - - logger.info("TensorBoard is running with enhanced training visualization!") - logger.info(f"View training metrics at: http://localhost:{port}") - logger.info("Available dashboards:") - logger.info(" - SCALARS: Training metrics, rewards, and losses") - logger.info(" - HISTOGRAMS: Feature distributions and model weights") - logger.info(" - TIME SERIES: Training progress over time") - - # Start TensorBoard process - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True - ) - - # Return process for management - return process - - except FileNotFoundError: - logger.error("TensorBoard not found. Install with: pip install tensorboard") - return None - except Exception as e: - logger.error(f"Error starting TensorBoard: {e}") - return None - -def main(): - """Launch TensorBoard with enhanced visualization options""" - - # Parse command line arguments - parser = argparse.ArgumentParser(description="Launch TensorBoard for training visualization") - parser.add_argument("--port", type=int, default=6006, help="Port to run TensorBoard on") - parser.add_argument("--logdir", type=str, default="runs", help="Directory containing TensorBoard logs") - parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically") - parser.add_argument("--dashboard-integration", action="store_true", help="Run in dashboard integration mode") - args = parser.parse_args() - - # Start TensorBoard - process = start_tensorboard( - logdir=args.logdir, - port=args.port, - open_browser=not args.no_browser - ) - - if process is None: - return 1 - - # If running in dashboard integration mode, return immediately - if args.dashboard_integration: - return 0 - - # Otherwise, wait for process to complete - try: - print("\n" + "="*70) - print("🔥 TensorBoard is running with enhanced training visualization!") - print(f"📈 View training metrics at: http://localhost:{args.port}") - print("⏹️ Press Ctrl+C to stop TensorBoard") - print("="*70 + "\n") - - # Wait for process to complete or user interrupt - process.wait() - return 0 - - except KeyboardInterrupt: - print("\n🛑 TensorBoard stopped") - process.terminate() - try: - process.wait(timeout=5) - except subprocess.TimeoutExpired: - process.kill() - return 0 - except Exception as e: - print(f"❌ Error: {e}") - return 1 - -if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file diff --git a/run_tests.py b/run_tests.py deleted file mode 100644 index cdb0737..0000000 --- a/run_tests.py +++ /dev/null @@ -1,182 +0,0 @@ -#!/usr/bin/env python3 -""" -Unified Test Runner for Trading System - -This script provides a unified interface to run all tests in the system: -- Essential functionality tests -- Model persistence tests -- Training integration tests -- Indicators and signals tests -- Remaining individual test files - -Usage: - python run_tests.py # Run all tests - python run_tests.py essential # Run essential tests only - python run_tests.py persistence # Run model persistence tests only - python run_tests.py training # Run training integration tests only - python run_tests.py indicators # Run indicators and signals tests only - python run_tests.py individual # Run individual test files only -""" - -import sys -import os -import subprocess -import logging -from pathlib import Path -from safe_logging import setup_safe_logging - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import setup_logging - -logger = logging.getLogger(__name__) - -def run_test_module(module_path, test_type="all"): - """Run a specific test module""" - try: - cmd = [sys.executable, str(module_path)] - if test_type != "all": - cmd.append(test_type) - - logger.info(f"Running: {' '.join(cmd)}") - result = subprocess.run(cmd, capture_output=True, text=True, cwd=project_root) - - if result.returncode == 0: - logger.info(f"✅ {module_path.name} passed") - if result.stdout: - logger.info(result.stdout) - return True - else: - logger.error(f"❌ {module_path.name} failed") - if result.stderr: - logger.error(result.stderr) - if result.stdout: - logger.error(result.stdout) - return False - - except Exception as e: - logger.error(f"❌ Error running {module_path}: {e}") - return False - -def run_essential_tests(): - """Run essential functionality tests""" - logger.info("=== Running Essential Tests ===") - return run_test_module(project_root / "tests" / "test_essential.py") - -def run_persistence_tests(): - """Run model persistence tests""" - logger.info("=== Running Model Persistence Tests ===") - return run_test_module(project_root / "tests" / "test_model_persistence.py") - -def run_training_tests(): - """Run training integration tests""" - logger.info("=== Running Training Integration Tests ===") - return run_test_module(project_root / "tests" / "test_training_integration.py") - -def run_indicators_tests(): - """Run indicators and signals tests""" - logger.info("=== Running Indicators and Signals Tests ===") - return run_test_module(project_root / "tests" / "test_indicators_and_signals.py") - -def run_individual_tests(): - """Run remaining individual test files""" - logger.info("=== Running Individual Test Files ===") - - individual_tests = [ - "test_positions.py", - "test_tick_cache.py", - "test_timestamps.py" - ] - - results = [] - for test_file in individual_tests: - test_path = project_root / test_file - if test_path.exists(): - logger.info(f"Running {test_file}...") - result = run_test_module(test_path) - results.append(result) - else: - logger.warning(f"Test file not found: {test_file}") - results.append(False) - - return all(results) - -def run_all_tests(): - """Run all test suites""" - logger.info("🧪 Running All Trading System Tests") - logger.info("=" * 60) - - test_suites = [ - ("Essential Tests", run_essential_tests), - ("Model Persistence Tests", run_persistence_tests), - ("Training Integration Tests", run_training_tests), - ("Indicators and Signals Tests", run_indicators_tests), - ("Individual Tests", run_individual_tests), - ] - - results = [] - for suite_name, suite_func in test_suites: - logger.info(f"\n📋 {suite_name}") - logger.info("-" * 40) - try: - result = suite_func() - results.append((suite_name, result)) - except Exception as e: - logger.error(f"❌ {suite_name} crashed: {e}") - results.append((suite_name, False)) - - # Print summary - logger.info("\n" + "=" * 60) - logger.info("📊 TEST RESULTS SUMMARY") - logger.info("=" * 60) - - passed = 0 - for suite_name, result in results: - status = "✅ PASS" if result else "❌ FAIL" - logger.info(f"{status}: {suite_name}") - if result: - passed += 1 - - logger.info(f"\nPassed: {passed}/{len(results)} test suites") - - if passed == len(results): - logger.info("🎉 All tests passed! Trading system is working correctly.") - return True - else: - logger.warning(f"⚠️ {len(results) - passed} test suite(s) failed. Please check the issues above.") - return False - -def main(): - """Main test runner""" - setup_safe_logging() - - # Parse command line arguments - if len(sys.argv) > 1: - test_type = sys.argv[1].lower() - - if test_type == "essential": - success = run_essential_tests() - elif test_type == "persistence": - success = run_persistence_tests() - elif test_type == "training": - success = run_training_tests() - elif test_type == "indicators": - success = run_indicators_tests() - elif test_type == "individual": - success = run_individual_tests() - elif test_type in ["help", "-h", "--help"]: - print(__doc__) - return 0 - else: - logger.error(f"Unknown test type: {test_type}") - print(__doc__) - return 1 - else: - success = run_all_tests() - - return 0 if success else 1 - -if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file diff --git a/scripts/kill_stale_processes.py b/scripts/kill_stale_processes.py deleted file mode 100644 index 92ec71c..0000000 --- a/scripts/kill_stale_processes.py +++ /dev/null @@ -1,197 +0,0 @@ -#!/usr/bin/env python3 -""" -Kill Stale Processes Script - -Safely terminates stale Python processes related to the trading dashboard -with proper error handling and graceful termination. -""" - -import os -import sys -import time -import signal -from pathlib import Path -import threading - -# Global timeout flag -timeout_reached = False - -def timeout_handler(): - """Handler for overall script timeout""" - global timeout_reached - timeout_reached = True - print("\n⚠️ WARNING: Script timeout reached (10s) - forcing exit") - os._exit(0) # Force exit - -def kill_stale_processes(): - """Kill stale trading dashboard processes safely""" - global timeout_reached - - # Set up overall timeout (10 seconds) - timer = threading.Timer(10.0, timeout_handler) - timer.daemon = True - timer.start() - - try: - import psutil - except ImportError: - print("psutil not available - using fallback method") - return kill_stale_fallback() - - current_pid = os.getpid() - killed_processes = [] - failed_processes = [] - - # Keywords to identify trading dashboard processes - target_keywords = [ - 'dashboard', 'scalping', 'trading', 'tensorboard', - 'run_clean', 'run_main', 'gogo2', 'mexc' - ] - - try: - print("Scanning for stale processes...") - - # Get all Python processes with timeout - python_processes = [] - scan_start = time.time() - - for proc in psutil.process_iter(['pid', 'name', 'cmdline']): - if timeout_reached or (time.time() - scan_start) > 3.0: # 3s max for scanning - print("Process scanning timeout - proceeding with found processes") - break - - try: - if proc.info['pid'] == current_pid: - continue - - name = proc.info['name'].lower() - if 'python' in name or 'tensorboard' in name: - cmdline_str = ' '.join(proc.info['cmdline']) if proc.info['cmdline'] else '' - - # Check if this is a target process - if any(keyword in cmdline_str.lower() for keyword in target_keywords): - python_processes.append({ - 'proc': proc, - 'pid': proc.info['pid'], - 'name': proc.info['name'], - 'cmdline': cmdline_str - }) - - except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): - continue - - if not python_processes: - print("No stale processes found") - timer.cancel() # Cancel the timeout - return True - - print(f"Found {len(python_processes)} target processes to terminate:") - for p in python_processes[:5]: # Show max 5 to save time - print(f" - PID {p['pid']}: {p['name']} - {p['cmdline'][:80]}...") - if len(python_processes) > 5: - print(f" ... and {len(python_processes) - 5} more") - - # Graceful termination first (with reduced wait time) - print("\nAttempting graceful termination...") - termination_start = time.time() - - for p in python_processes: - if timeout_reached or (time.time() - termination_start) > 2.0: - print("Termination timeout - moving to force kill") - break - - try: - proc = p['proc'] - if proc.is_running(): - proc.terminate() - print(f" Sent SIGTERM to PID {p['pid']}") - except Exception as e: - failed_processes.append(f"Failed to terminate PID {p['pid']}: {e}") - - # Wait for graceful shutdown (reduced from 2.0 to 1.0) - time.sleep(1.0) - - # Force kill remaining processes - print("\nChecking for remaining processes...") - kill_start = time.time() - - for p in python_processes: - if timeout_reached or (time.time() - kill_start) > 2.0: - print("Force kill timeout - exiting") - break - - try: - proc = p['proc'] - if proc.is_running(): - print(f" Force killing PID {p['pid']} ({p['name']})") - proc.kill() - killed_processes.append(f"Force killed PID {p['pid']} ({p['name']})") - else: - killed_processes.append(f"Gracefully terminated PID {p['pid']} ({p['name']})") - except (psutil.NoSuchProcess, psutil.AccessDenied): - killed_processes.append(f"Process PID {p['pid']} already terminated") - except Exception as e: - failed_processes.append(f"Failed to kill PID {p['pid']}: {e}") - - # Results (quick summary) - print(f"\n=== Quick Results ===") - print(f"✓ Cleaned up {len(killed_processes)} processes") - if failed_processes: - print(f"✗ Failed: {len(failed_processes)} processes") - - timer.cancel() # Cancel the timeout if we finished early - return len(failed_processes) == 0 - - except Exception as e: - print(f"Error during process cleanup: {e}") - timer.cancel() - return False - -def kill_stale_fallback(): - """Fallback method using basic OS commands""" - print("Using fallback process killing method...") - - try: - if os.name == 'nt': # Windows - import subprocess - # Kill Python processes with dashboard keywords (with timeout) - result = subprocess.run([ - 'taskkill', '/f', '/im', 'python.exe' - ], capture_output=True, text=True, timeout=5.0) - - if result.returncode == 0: - print("Windows: Killed all Python processes") - else: - print("Windows: No Python processes to kill or access denied") - - else: # Unix/Linux - import subprocess - # More targeted approach for Unix (with timeouts) - subprocess.run(['pkill', '-f', 'dashboard'], capture_output=True, timeout=2.0) - subprocess.run(['pkill', '-f', 'scalping'], capture_output=True, timeout=2.0) - subprocess.run(['pkill', '-f', 'tensorboard'], capture_output=True, timeout=2.0) - print("Unix: Killed dashboard-related processes") - - return True - - except subprocess.TimeoutExpired: - print("Fallback method timed out") - return False - except Exception as e: - print(f"Fallback method failed: {e}") - return False - -if __name__ == "__main__": - print("=" * 50) - print("STALE PROCESS CLEANUP (10s timeout)") - print("=" * 50) - - start_time = time.time() - success = kill_stale_processes() - elapsed = time.time() - start_time - - exit_code = 0 if success else 1 - - print(f"Completed in {elapsed:.1f}s") - print("=" * 50) - sys.exit(exit_code) \ No newline at end of file diff --git a/scripts/restart_dashboard_with_learning.py b/scripts/restart_dashboard_with_learning.py deleted file mode 100644 index a272719..0000000 --- a/scripts/restart_dashboard_with_learning.py +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env python3 -""" -Restart Dashboard with Forced Learning Enabled -Simple script to start dashboard with all learning features enabled -""" - -import sys -import logging -from datetime import datetime - -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def main(): - """Start dashboard with forced learning""" - logger.info("🚀 Starting Dashboard with FORCED LEARNING ENABLED") - logger.info("=" * 60) - logger.info(f"Timestamp: {datetime.now()}") - logger.info("Fixes Applied:") - logger.info("✅ Enhanced RL: FORCED ENABLED") - logger.info("✅ CNN Training: FORCED ENABLED") - logger.info("✅ Williams Pivots: CNN INTEGRATED") - logger.info("✅ Learning Pipeline: ACTIVE") - logger.info("=" * 60) - - try: - # Import and run main - from main_clean import run_web_dashboard - logger.info("Starting web dashboard...") - run_web_dashboard() - - except KeyboardInterrupt: - logger.info("Dashboard stopped by user") - except Exception as e: - logger.error(f"Dashboard failed: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/scripts/restart_main_overnight.py b/scripts/restart_main_overnight.py deleted file mode 100644 index a80c6b3..0000000 --- a/scripts/restart_main_overnight.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python3 -""" -Overnight Training Restart Script -Keeps main.py running continuously, restarting it if it crashes. -Designed for overnight training sessions with unstable code. - -Usage: - python restart_main_overnight.py - -Press Ctrl+C to stop the restart loop. -""" - -import subprocess -import sys -import time -import logging -from datetime import datetime -from pathlib import Path -import signal -import os - -# Setup logging for the restart script -def setup_restart_logging(): - """Setup logging for restart events""" - log_dir = Path("logs") - log_dir.mkdir(exist_ok=True) - - # Create restart log file with timestamp - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - log_file = log_dir / f"restart_main_{timestamp}.log" - - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler(log_file, encoding='utf-8'), - logging.StreamHandler(sys.stdout) - ] - ) - - logger = logging.getLogger(__name__) - logger.info(f"Restart script logging to: {log_file}") - return logger - -def kill_existing_processes(logger): - """Kill any existing main.py processes to avoid conflicts""" - try: - if os.name == 'nt': # Windows - # Kill any existing Python processes running main.py - subprocess.run(['taskkill', '/f', '/im', 'python.exe'], - capture_output=True, check=False) - subprocess.run(['taskkill', '/f', '/im', 'pythonw.exe'], - capture_output=True, check=False) - time.sleep(2) - except Exception as e: - logger.warning(f"Could not kill existing processes: {e}") - -def run_main_with_restart(logger): - """Main restart loop""" - restart_count = 0 - consecutive_fast_exits = 0 - start_time = datetime.now() - - logger.info("=" * 60) - logger.info("OVERNIGHT TRAINING RESTART SCRIPT STARTED") - logger.info("=" * 60) - logger.info("Press Ctrl+C to stop the restart loop") - logger.info("Main script: main.py") - logger.info("Restart delay on crash: 10 seconds") - logger.info("Fast exit protection: Enabled") - logger.info("=" * 60) - - # Kill any existing processes - kill_existing_processes(logger) - - while True: - try: - restart_count += 1 - run_start_time = datetime.now() - - logger.info(f"[RESTART #{restart_count}] Starting main.py at {run_start_time.strftime('%H:%M:%S')}") - - # Start main.py as subprocess - process = subprocess.Popen([ - sys.executable, "main.py" - ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - universal_newlines=True, bufsize=1) - - logger.info(f"[PROCESS] main.py started with PID: {process.pid}") - - # Stream output from main.py - try: - if process.stdout: - while True: - output = process.stdout.readline() - if output == '' and process.poll() is not None: - break - if output: - # Forward output from main.py (remove extra newlines) - print(f"[MAIN] {output.rstrip()}") - else: - # If no stdout, just wait for process to complete - process.wait() - except KeyboardInterrupt: - logger.info("[INTERRUPT] Ctrl+C received, stopping main.py...") - process.terminate() - try: - process.wait(timeout=10) - except subprocess.TimeoutExpired: - logger.warning("[FORCE KILL] Process didn't terminate, force killing...") - process.kill() - raise - - # Process has exited - exit_code = process.poll() - run_end_time = datetime.now() - run_duration = (run_end_time - run_start_time).total_seconds() - - logger.info(f"[EXIT] main.py exited with code {exit_code}") - logger.info(f"[DURATION] Process ran for {run_duration:.1f} seconds") - - # Check for fast exits (potential configuration issues) - if run_duration < 30: # Less than 30 seconds - consecutive_fast_exits += 1 - logger.warning(f"[FAST EXIT] Process exited quickly ({consecutive_fast_exits} consecutive)") - - if consecutive_fast_exits >= 5: - logger.error("[ABORT] Too many consecutive fast exits (5+)") - logger.error("This indicates a configuration or startup problem") - logger.error("Please check the main.py script manually") - break - - # Longer delay for fast exits - delay = min(60, 10 * consecutive_fast_exits) - logger.info(f"[DELAY] Waiting {delay} seconds before restart due to fast exit...") - time.sleep(delay) - else: - consecutive_fast_exits = 0 # Reset counter - logger.info("[DELAY] Waiting 10 seconds before restart...") - time.sleep(10) - - # Log session statistics every 10 restarts - if restart_count % 10 == 0: - total_duration = (datetime.now() - start_time).total_seconds() - logger.info(f"[STATS] Session: {restart_count} restarts in {total_duration/3600:.1f} hours") - - except KeyboardInterrupt: - logger.info("[SHUTDOWN] Restart loop interrupted by user") - break - except Exception as e: - logger.error(f"[ERROR] Unexpected error in restart loop: {e}") - logger.error("Continuing restart loop after 30 second delay...") - time.sleep(30) - - total_duration = (datetime.now() - start_time).total_seconds() - logger.info("=" * 60) - logger.info("OVERNIGHT TRAINING SESSION COMPLETE") - logger.info(f"Total restarts: {restart_count}") - logger.info(f"Total session time: {total_duration/3600:.1f} hours") - logger.info("=" * 60) - -def main(): - """Main entry point""" - # Setup signal handlers for clean shutdown - def signal_handler(signum, frame): - logger.info(f"[SIGNAL] Received signal {signum}, shutting down...") - sys.exit(0) - - signal.signal(signal.SIGINT, signal_handler) - if hasattr(signal, 'SIGTERM'): - signal.signal(signal.SIGTERM, signal_handler) - - # Setup logging - global logger - logger = setup_restart_logging() - - try: - run_main_with_restart(logger) - except Exception as e: - logger.error(f"[FATAL] Fatal error in restart script: {e}") - import traceback - logger.error(traceback.format_exc()) - return 1 - - return 0 - -if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file diff --git a/setup_mexc_browser.py b/setup_mexc_browser.py deleted file mode 100644 index dc9b697..0000000 --- a/setup_mexc_browser.py +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/env python3 -""" -MEXC Browser Setup & Runner - -This script automatically installs dependencies and runs the MEXC browser automation. -""" - -import subprocess -import sys -import os -import importlib - -def check_and_install_requirements(): - """Check and install required packages""" - required_packages = [ - 'selenium', - 'webdriver-manager', - 'requests' - ] - - print("🔍 Checking required packages...") - - missing_packages = [] - for package in required_packages: - try: - importlib.import_module(package.replace('-', '_')) - print(f"✅ {package} - already installed") - except ImportError: - missing_packages.append(package) - print(f"❌ {package} - missing") - - if missing_packages: - print(f"\n📦 Installing missing packages: {', '.join(missing_packages)}") - - for package in missing_packages: - try: - subprocess.check_call([sys.executable, '-m', 'pip', 'install', package]) - print(f"✅ Successfully installed {package}") - except subprocess.CalledProcessError as e: - print(f"❌ Failed to install {package}: {e}") - return False - - print("✅ All requirements satisfied!") - return True - -def run_browser_automation(): - """Run the MEXC browser automation""" - try: - # Import and run the auto browser - from core.mexc_webclient.auto_browser import main as auto_browser_main - auto_browser_main() - except ImportError: - print("❌ Could not import auto browser module") - print("Make sure core/mexc_webclient/auto_browser.py exists") - except Exception as e: - print(f"❌ Error running browser automation: {e}") - -def main(): - """Main setup and run function""" - print("🚀 MEXC Browser Automation Setup") - print("=" * 40) - - # Check Python version - if sys.version_info < (3, 7): - print("❌ Python 3.7+ required") - return - - print(f"✅ Python {sys.version.split()[0]} detected") - - # Install requirements - if not check_and_install_requirements(): - print("❌ Failed to install requirements") - return - - print("\n🌐 Starting browser automation...") - print("This will:") - print("• Download ChromeDriver automatically") - print("• Open MEXC futures page") - print("• Capture all trading requests") - print("• Extract session cookies") - - input("\nPress Enter to continue...") - - # Run the automation - run_browser_automation() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/start_monitoring.py b/start_monitoring.py deleted file mode 100644 index 36d8c28..0000000 --- a/start_monitoring.py +++ /dev/null @@ -1,160 +0,0 @@ -#!/usr/bin/env python3 -""" -Helper script to start monitoring services for RL training -""" - -import subprocess -import sys -import time -import requests -import os -import json -from pathlib import Path - -# Available ports to try for TensorBoard -TENSORBOARD_PORTS = [6006, 6007, 6008, 6009, 6010, 6011, 6012] - -def check_port(port, service_name): - """Check if a service is running on the specified port""" - try: - response = requests.get(f"http://localhost:{port}", timeout=3) - print(f"✅ {service_name} is running on port {port}") - return True - except requests.exceptions.RequestException: - return False - -def is_port_in_use(port): - """Check if a port is already in use""" - import socket - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.bind(('localhost', port)) - return False - except OSError: - return True - -def find_available_port(ports_list, service_name): - """Find an available port from the list""" - for port in ports_list: - if not is_port_in_use(port): - print(f"🔍 Found available port {port} for {service_name}") - return port - else: - print(f"⚠️ Port {port} is already in use") - return None - -def save_port_config(tensorboard_port): - """Save the port configuration to a file""" - config = { - "tensorboard_port": tensorboard_port, - "web_dashboard_port": 8051 - } - with open("monitoring_ports.json", "w") as f: - json.dump(config, f, indent=2) - print(f"💾 Port configuration saved to monitoring_ports.json") - -def start_tensorboard(): - """Start TensorBoard in background on an available port""" - try: - # First check if TensorBoard is already running on any of our ports - for port in TENSORBOARD_PORTS: - if check_port(port, "TensorBoard"): - print(f"✅ TensorBoard already running on port {port}") - save_port_config(port) - return port - - # Find an available port - port = find_available_port(TENSORBOARD_PORTS, "TensorBoard") - if port is None: - print(f"❌ No available ports found in range {TENSORBOARD_PORTS}") - return None - - print(f"🚀 Starting TensorBoard on port {port}...") - - # Create runs directory if it doesn't exist - Path("runs").mkdir(exist_ok=True) - - # Start TensorBoard - if os.name == 'nt': # Windows - subprocess.Popen([ - sys.executable, "-m", "tensorboard", - "--logdir=runs", f"--port={port}", "--reload_interval=1" - ], creationflags=subprocess.CREATE_NEW_CONSOLE) - else: # Linux/Mac - subprocess.Popen([ - sys.executable, "-m", "tensorboard", - "--logdir=runs", f"--port={port}", "--reload_interval=1" - ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - - # Wait for TensorBoard to start - print(f"⏳ Waiting for TensorBoard to start on port {port}...") - for i in range(15): - time.sleep(2) - if check_port(port, "TensorBoard"): - save_port_config(port) - return port - - print(f"⚠️ TensorBoard failed to start on port {port} within 30 seconds") - return None - - except Exception as e: - print(f"❌ Error starting TensorBoard: {e}") - return None - -def check_web_dashboard_port(): - """Check if web dashboard port is available""" - port = 8051 - if is_port_in_use(port): - print(f"⚠️ Web dashboard port {port} is in use") - # Try alternative ports - for alt_port in [8052, 8053, 8054, 8055]: - if not is_port_in_use(alt_port): - print(f"🔍 Alternative port {alt_port} available for web dashboard") - return alt_port - print("❌ No alternative ports found for web dashboard") - return port - else: - print(f"✅ Web dashboard port {port} is available") - return port - -def main(): - """Main function""" - print("=" * 60) - print("🎯 RL TRAINING MONITORING SETUP") - print("=" * 60) - - # Check web dashboard port - web_port = check_web_dashboard_port() - - # Start TensorBoard - tensorboard_port = start_tensorboard() - - print("\n" + "=" * 60) - print("📊 MONITORING STATUS") - print("=" * 60) - - if tensorboard_port: - print(f"✅ TensorBoard: http://localhost:{tensorboard_port}") - # Update port config - save_port_config(tensorboard_port) - else: - print("❌ TensorBoard: Failed to start") - print(" Manual start: python -m tensorboard --logdir=runs --port=6007") - - if web_port: - print(f"✅ Web Dashboard: Ready on port {web_port}") - - print(f"\n🎯 Ready to start RL training!") - if tensorboard_port and web_port != 8051: - print(f"Run: python train_realtime_with_tensorboard.py --episodes 10 --web-port {web_port}") - else: - print("Run: python train_realtime_with_tensorboard.py --episodes 10") - - print(f"\n📋 Available URLs:") - if tensorboard_port: - print(f" 📊 TensorBoard: http://localhost:{tensorboard_port}") - if web_port: - print(f" 🌐 Web Dashboard: http://localhost:{web_port} (starts with training)") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/start_overnight_training.py b/start_overnight_training.py deleted file mode 100644 index 2479e85..0000000 --- a/start_overnight_training.py +++ /dev/null @@ -1,179 +0,0 @@ -#!/usr/bin/env python3 -""" -Start Overnight Training Session - -This script starts a comprehensive overnight training session that: -1. Ensures CNN and COB RL training processes are implemented and running -2. Executes training passes on each signal when predictions change -3. Calculates PnL and records trades in SIM mode -4. Tracks model performance statistics -5. Converts signals to actual trades for performance tracking -""" - -import os -import sys -import time -import logging -from datetime import datetime -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler(f'overnight_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'), - logging.StreamHandler() - ] -) - -logger = logging.getLogger(__name__) - -def main(): - """Start the overnight training session""" - try: - logger.info("🌙 STARTING OVERNIGHT TRAINING SESSION") - logger.info("=" * 80) - - # Import required components - from core.config import get_config, setup_logging - from core.data_provider import DataProvider - from core.orchestrator import TradingOrchestrator - from core.trading_executor import TradingExecutor - from web.clean_dashboard import CleanTradingDashboard - - # Setup logging - setup_logging() - - # Initialize components - logger.info("Initializing components...") - - # Create data provider - data_provider = DataProvider() - logger.info("✅ Data Provider initialized") - - # Create trading executor in simulation mode - trading_executor = TradingExecutor() - trading_executor.simulation_mode = True # Ensure we're in simulation mode - logger.info("✅ Trading Executor initialized (SIMULATION MODE)") - - # Create orchestrator with enhanced training - orchestrator = TradingOrchestrator( - data_provider=data_provider, - enhanced_rl_training=True - ) - logger.info("✅ Trading Orchestrator initialized") - - # Connect trading executor to orchestrator - if hasattr(orchestrator, 'set_trading_executor'): - orchestrator.set_trading_executor(trading_executor) - logger.info("✅ Trading Executor connected to Orchestrator") - - # Create dashboard (this initializes the overnight training coordinator) - dashboard = CleanTradingDashboard( - data_provider=data_provider, - orchestrator=orchestrator, - trading_executor=trading_executor - ) - logger.info("✅ Dashboard initialized with Overnight Training Coordinator") - - # Start the overnight training session - logger.info("Starting overnight training session...") - success = dashboard.start_overnight_training() - - if success: - logger.info("🌙 OVERNIGHT TRAINING SESSION STARTED SUCCESSFULLY") - logger.info("=" * 80) - logger.info("Training Features Active:") - logger.info("✅ CNN training on signal changes") - logger.info("✅ COB RL training on market microstructure") - logger.info("✅ DQN training on trading decisions") - logger.info("✅ Trade execution and recording (SIMULATION)") - logger.info("✅ Performance tracking and statistics") - logger.info("✅ Model checkpointing every 50 trades") - logger.info("✅ Signal-to-trade conversion with PnL calculation") - logger.info("=" * 80) - - # Monitor training progress - logger.info("Monitoring training progress...") - logger.info("Press Ctrl+C to stop the training session") - - # Keep the session running and periodically report progress - start_time = datetime.now() - last_report_time = start_time - - while True: - try: - time.sleep(60) # Check every minute - - current_time = datetime.now() - elapsed_time = current_time - start_time - - # Get performance summary every 10 minutes - if (current_time - last_report_time).total_seconds() >= 600: # 10 minutes - performance = dashboard.get_training_performance_summary() - - logger.info("=" * 60) - logger.info(f"🌙 TRAINING PROGRESS REPORT - {elapsed_time}") - logger.info("=" * 60) - logger.info(f"Total Signals: {performance.get('total_signals', 0)}") - logger.info(f"Total Trades: {performance.get('total_trades', 0)}") - logger.info(f"Successful Trades: {performance.get('successful_trades', 0)}") - logger.info(f"Success Rate: {performance.get('success_rate', 0):.1%}") - logger.info(f"Total P&L: ${performance.get('total_pnl', 0):.2f}") - logger.info(f"Models Trained: {', '.join(performance.get('models_trained', []))}") - logger.info(f"Training Status: {'ACTIVE' if performance.get('is_running', False) else 'INACTIVE'}") - logger.info("=" * 60) - - last_report_time = current_time - - except KeyboardInterrupt: - logger.info("\n🛑 Training session interrupted by user") - break - except Exception as e: - logger.error(f"Error during training monitoring: {e}") - time.sleep(30) # Wait 30 seconds before retrying - - # Stop the training session - logger.info("Stopping overnight training session...") - dashboard.stop_overnight_training() - - # Final report - final_performance = dashboard.get_training_performance_summary() - total_time = datetime.now() - start_time - - logger.info("=" * 80) - logger.info("🌅 OVERNIGHT TRAINING SESSION COMPLETED") - logger.info("=" * 80) - logger.info(f"Total Duration: {total_time}") - logger.info(f"Final Statistics:") - logger.info(f" Total Signals: {final_performance.get('total_signals', 0)}") - logger.info(f" Total Trades: {final_performance.get('total_trades', 0)}") - logger.info(f" Successful Trades: {final_performance.get('successful_trades', 0)}") - logger.info(f" Success Rate: {final_performance.get('success_rate', 0):.1%}") - logger.info(f" Total P&L: ${final_performance.get('total_pnl', 0):.2f}") - logger.info(f" Models Trained: {', '.join(final_performance.get('models_trained', []))}") - logger.info("=" * 80) - - else: - logger.error("❌ Failed to start overnight training session") - return 1 - - except KeyboardInterrupt: - logger.info("\n🛑 Training session interrupted by user") - return 0 - except Exception as e: - logger.error(f"❌ Error in overnight training session: {e}") - import traceback - traceback.print_exc() - return 1 - - return 0 - -if __name__ == "__main__": - exit_code = main() - sys.exit(exit_code) \ No newline at end of file diff --git a/system_stability_audit.py b/system_stability_audit.py deleted file mode 100644 index 26fbe94..0000000 --- a/system_stability_audit.py +++ /dev/null @@ -1,426 +0,0 @@ -#!/usr/bin/env python3 -""" -System Stability Audit and Monitoring - -This script performs a comprehensive audit of the trading system to identify -and fix stability issues, memory leaks, and performance bottlenecks. -""" - -import os -import sys -import psutil -import logging -import time -import threading -import gc -from datetime import datetime, timedelta -from pathlib import Path -from typing import Dict, List, Optional, Any -import traceback - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import setup_logging, get_config - -# Setup logging -setup_logging() -logger = logging.getLogger(__name__) - -class SystemStabilityAuditor: - """ - Comprehensive system stability auditor and monitor - - Monitors: - - Memory usage and leaks - - CPU usage and performance - - Thread health and deadlocks - - Model performance and stability - - Dashboard responsiveness - - Data provider health - """ - - def __init__(self): - """Initialize the stability auditor""" - self.config = get_config() - self.monitoring_active = False - self.monitoring_thread = None - - # Performance baselines - self.baseline_memory = psutil.virtual_memory().used - self.baseline_cpu = psutil.cpu_percent() - - # Monitoring data - self.memory_history = [] - self.cpu_history = [] - self.thread_history = [] - self.error_history = [] - - # Stability metrics - self.stability_score = 100.0 - self.critical_issues = [] - self.warnings = [] - - logger.info("System Stability Auditor initialized") - - def start_monitoring(self): - """Start continuous system monitoring""" - if self.monitoring_active: - logger.warning("Monitoring already active") - return - - self.monitoring_active = True - self.monitoring_thread = threading.Thread(target=self._monitoring_loop, daemon=True) - self.monitoring_thread.start() - - logger.info("System stability monitoring started") - - def stop_monitoring(self): - """Stop system monitoring""" - self.monitoring_active = False - if self.monitoring_thread: - self.monitoring_thread.join(timeout=5) - - logger.info("System stability monitoring stopped") - - def _monitoring_loop(self): - """Main monitoring loop""" - while self.monitoring_active: - try: - # Collect system metrics - self._collect_system_metrics() - - # Check for memory leaks - self._check_memory_leaks() - - # Check CPU usage - self._check_cpu_usage() - - # Check thread health - self._check_thread_health() - - # Check for deadlocks - self._check_for_deadlocks() - - # Update stability score - self._update_stability_score() - - # Log status every 60 seconds - if len(self.memory_history) % 12 == 0: # Every 12 * 5s = 60s - self._log_stability_status() - - time.sleep(5) # Check every 5 seconds - - except Exception as e: - logger.error(f"Error in monitoring loop: {e}") - self.error_history.append({ - 'timestamp': datetime.now(), - 'error': str(e), - 'traceback': traceback.format_exc() - }) - time.sleep(10) # Wait longer on error - - def _collect_system_metrics(self): - """Collect system performance metrics""" - try: - # Memory metrics - memory = psutil.virtual_memory() - memory_data = { - 'timestamp': datetime.now(), - 'used_gb': memory.used / (1024**3), - 'available_gb': memory.available / (1024**3), - 'percent': memory.percent - } - self.memory_history.append(memory_data) - - # Keep only last 720 entries (1 hour at 5s intervals) - if len(self.memory_history) > 720: - self.memory_history = self.memory_history[-720:] - - # CPU metrics - cpu_percent = psutil.cpu_percent(interval=1) - cpu_data = { - 'timestamp': datetime.now(), - 'percent': cpu_percent, - 'cores': psutil.cpu_count() - } - self.cpu_history.append(cpu_data) - - # Keep only last 720 entries - if len(self.cpu_history) > 720: - self.cpu_history = self.cpu_history[-720:] - - # Thread metrics - thread_count = threading.active_count() - thread_data = { - 'timestamp': datetime.now(), - 'count': thread_count, - 'threads': [t.name for t in threading.enumerate()] - } - self.thread_history.append(thread_data) - - # Keep only last 720 entries - if len(self.thread_history) > 720: - self.thread_history = self.thread_history[-720:] - - except Exception as e: - logger.error(f"Error collecting system metrics: {e}") - - def _check_memory_leaks(self): - """Check for memory leaks""" - try: - if len(self.memory_history) < 10: - return - - # Check if memory usage is consistently increasing - recent_memory = [m['used_gb'] for m in self.memory_history[-10:]] - memory_trend = sum(recent_memory[-5:]) / 5 - sum(recent_memory[:5]) / 5 - - # If memory increased by more than 100MB in last 10 checks - if memory_trend > 0.1: - warning = f"Potential memory leak detected: +{memory_trend:.2f}GB in last 50s" - if warning not in self.warnings: - self.warnings.append(warning) - logger.warning(warning) - - # Force garbage collection - gc.collect() - logger.info("Forced garbage collection to free memory") - - # Check for excessive memory usage - current_memory = self.memory_history[-1]['percent'] - if current_memory > 85: - critical = f"High memory usage: {current_memory:.1f}%" - if critical not in self.critical_issues: - self.critical_issues.append(critical) - logger.error(critical) - - except Exception as e: - logger.error(f"Error checking memory leaks: {e}") - - def _check_cpu_usage(self): - """Check CPU usage patterns""" - try: - if len(self.cpu_history) < 10: - return - - # Check for sustained high CPU usage - recent_cpu = [c['percent'] for c in self.cpu_history[-10:]] - avg_cpu = sum(recent_cpu) / len(recent_cpu) - - if avg_cpu > 90: - critical = f"Sustained high CPU usage: {avg_cpu:.1f}%" - if critical not in self.critical_issues: - self.critical_issues.append(critical) - logger.error(critical) - elif avg_cpu > 75: - warning = f"High CPU usage: {avg_cpu:.1f}%" - if warning not in self.warnings: - self.warnings.append(warning) - logger.warning(warning) - - except Exception as e: - logger.error(f"Error checking CPU usage: {e}") - - def _check_thread_health(self): - """Check thread health and detect issues""" - try: - if len(self.thread_history) < 5: - return - - current_threads = self.thread_history[-1]['count'] - - # Check for thread explosion - if current_threads > 50: - critical = f"Thread explosion detected: {current_threads} active threads" - if critical not in self.critical_issues: - self.critical_issues.append(critical) - logger.error(critical) - - # Log thread names for debugging - thread_names = self.thread_history[-1]['threads'] - logger.error(f"Active threads: {thread_names}") - - # Check for thread leaks (gradually increasing thread count) - if len(self.thread_history) >= 10: - thread_counts = [t['count'] for t in self.thread_history[-10:]] - thread_trend = sum(thread_counts[-5:]) / 5 - sum(thread_counts[:5]) / 5 - - if thread_trend > 2: # More than 2 threads increase on average - warning = f"Potential thread leak: +{thread_trend:.1f} threads in last 50s" - if warning not in self.warnings: - self.warnings.append(warning) - logger.warning(warning) - - except Exception as e: - logger.error(f"Error checking thread health: {e}") - - def _check_for_deadlocks(self): - """Check for potential deadlocks""" - try: - # Simple deadlock detection based on thread states - all_threads = threading.enumerate() - blocked_threads = [] - - for thread in all_threads: - if hasattr(thread, '_is_stopped') and not thread._is_stopped: - # Thread is running but might be blocked - # This is a simplified check - real deadlock detection is complex - pass - - # For now, just check if we have threads that haven't been active - # More sophisticated deadlock detection would require thread state analysis - - except Exception as e: - logger.error(f"Error checking for deadlocks: {e}") - - def _update_stability_score(self): - """Update overall system stability score""" - try: - score = 100.0 - - # Deduct points for critical issues - score -= len(self.critical_issues) * 20 - - # Deduct points for warnings - score -= len(self.warnings) * 5 - - # Deduct points for recent errors - recent_errors = [e for e in self.error_history - if e['timestamp'] > datetime.now() - timedelta(minutes=10)] - score -= len(recent_errors) * 10 - - # Deduct points for high resource usage - if self.memory_history: - current_memory = self.memory_history[-1]['percent'] - if current_memory > 80: - score -= (current_memory - 80) * 2 - - if self.cpu_history: - current_cpu = self.cpu_history[-1]['percent'] - if current_cpu > 80: - score -= (current_cpu - 80) * 1 - - self.stability_score = max(0, score) - - except Exception as e: - logger.error(f"Error updating stability score: {e}") - - def _log_stability_status(self): - """Log current stability status""" - try: - logger.info("=" * 50) - logger.info("SYSTEM STABILITY STATUS") - logger.info("=" * 50) - logger.info(f"Stability Score: {self.stability_score:.1f}/100") - - if self.memory_history: - mem = self.memory_history[-1] - logger.info(f"Memory: {mem['used_gb']:.1f}GB used ({mem['percent']:.1f}%)") - - if self.cpu_history: - cpu = self.cpu_history[-1] - logger.info(f"CPU: {cpu['percent']:.1f}%") - - if self.thread_history: - threads = self.thread_history[-1] - logger.info(f"Threads: {threads['count']} active") - - if self.critical_issues: - logger.error(f"Critical Issues ({len(self.critical_issues)}):") - for issue in self.critical_issues[-5:]: # Show last 5 - logger.error(f" - {issue}") - - if self.warnings: - logger.warning(f"Warnings ({len(self.warnings)}):") - for warning in self.warnings[-5:]: # Show last 5 - logger.warning(f" - {warning}") - - logger.info("=" * 50) - - except Exception as e: - logger.error(f"Error logging stability status: {e}") - - def get_stability_report(self) -> Dict[str, Any]: - """Get comprehensive stability report""" - try: - return { - 'stability_score': self.stability_score, - 'critical_issues': self.critical_issues, - 'warnings': self.warnings, - 'memory_usage': self.memory_history[-1] if self.memory_history else None, - 'cpu_usage': self.cpu_history[-1] if self.cpu_history else None, - 'thread_count': self.thread_history[-1]['count'] if self.thread_history else 0, - 'recent_errors': len([e for e in self.error_history - if e['timestamp'] > datetime.now() - timedelta(minutes=10)]), - 'monitoring_active': self.monitoring_active - } - except Exception as e: - logger.error(f"Error generating stability report: {e}") - return {'error': str(e)} - - def fix_common_issues(self): - """Attempt to fix common stability issues""" - try: - logger.info("Attempting to fix common stability issues...") - - # Force garbage collection - gc.collect() - logger.info("✓ Forced garbage collection") - - # Clear old history to free memory - if len(self.memory_history) > 360: # Keep only 30 minutes - self.memory_history = self.memory_history[-360:] - if len(self.cpu_history) > 360: - self.cpu_history = self.cpu_history[-360:] - if len(self.thread_history) > 360: - self.thread_history = self.thread_history[-360:] - - logger.info("✓ Cleared old monitoring history") - - # Clear old errors - cutoff_time = datetime.now() - timedelta(hours=1) - self.error_history = [e for e in self.error_history if e['timestamp'] > cutoff_time] - logger.info("✓ Cleared old error history") - - # Reset warnings and critical issues that might be stale - self.warnings = [] - self.critical_issues = [] - logger.info("✓ Reset stale warnings and critical issues") - - logger.info("Common stability fixes applied") - - except Exception as e: - logger.error(f"Error fixing common issues: {e}") - -def main(): - """Main function for standalone execution""" - try: - logger.info("Starting System Stability Audit") - - auditor = SystemStabilityAuditor() - auditor.start_monitoring() - - # Run for 5 minutes then generate report - time.sleep(300) - - report = auditor.get_stability_report() - logger.info("FINAL STABILITY REPORT:") - logger.info(f"Stability Score: {report['stability_score']:.1f}/100") - logger.info(f"Critical Issues: {len(report['critical_issues'])}") - logger.info(f"Warnings: {len(report['warnings'])}") - - # Attempt fixes if needed - if report['stability_score'] < 80: - auditor.fix_common_issues() - - auditor.stop_monitoring() - - except KeyboardInterrupt: - logger.info("Audit interrupted by user") - except Exception as e: - logger.error(f"Error in stability audit: {e}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_build_base_data_performance.py b/test_build_base_data_performance.py deleted file mode 100644 index 112c4f7..0000000 --- a/test_build_base_data_performance.py +++ /dev/null @@ -1,191 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Build Base Data Performance - -This script tests the performance of build_base_data_input to ensure it's instantaneous. -""" - -import sys -import os -import time -import logging -from datetime import datetime - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.orchestrator import TradingOrchestrator -from core.config import get_config - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_build_base_data_performance(): - """Test the performance of build_base_data_input""" - - logger.info("=== Testing Build Base Data Performance ===") - - try: - # Initialize orchestrator - config = get_config() - orchestrator = TradingOrchestrator( - symbol="ETH/USDT", - config=config - ) - - # Start the orchestrator to initialize data - orchestrator.start() - logger.info("✅ Orchestrator started") - - # Wait a bit for data to be populated - time.sleep(2) - - # Test performance of build_base_data_input - symbol = "ETH/USDT" - num_tests = 10 - total_time = 0 - - logger.info(f"Running {num_tests} performance tests...") - - for i in range(num_tests): - start_time = time.time() - - base_data = orchestrator.build_base_data_input(symbol) - - end_time = time.time() - duration = (end_time - start_time) * 1000 # Convert to milliseconds - total_time += duration - - if base_data: - logger.info(f"Test {i+1}: {duration:.2f}ms - ✅ Success") - else: - logger.warning(f"Test {i+1}: {duration:.2f}ms - ❌ Failed (no data)") - - avg_time = total_time / num_tests - - logger.info(f"=== Performance Results ===") - logger.info(f"Average time: {avg_time:.2f}ms") - logger.info(f"Total time: {total_time:.2f}ms") - - # Performance thresholds - if avg_time < 10: # Less than 10ms is excellent - logger.info("🎉 EXCELLENT: Build time is under 10ms") - elif avg_time < 50: # Less than 50ms is good - logger.info("✅ GOOD: Build time is under 50ms") - elif avg_time < 100: # Less than 100ms is acceptable - logger.info("⚠️ ACCEPTABLE: Build time is under 100ms") - else: - logger.error("❌ SLOW: Build time is over 100ms - needs optimization") - - # Test with multiple symbols - logger.info("Testing with multiple symbols...") - symbols = ["ETH/USDT", "BTC/USDT"] - - for symbol in symbols: - start_time = time.time() - base_data = orchestrator.build_base_data_input(symbol) - end_time = time.time() - duration = (end_time - start_time) * 1000 - - logger.info(f"{symbol}: {duration:.2f}ms") - - # Stop orchestrator - orchestrator.stop() - logger.info("✅ Orchestrator stopped") - - return avg_time < 100 # Return True if performance is acceptable - - except Exception as e: - logger.error(f"❌ Performance test failed: {e}") - import traceback - traceback.print_exc() - return False - -def test_cache_effectiveness(): - """Test that caching is working effectively""" - - logger.info("=== Testing Cache Effectiveness ===") - - try: - # Initialize orchestrator - config = get_config() - orchestrator = TradingOrchestrator( - symbol="ETH/USDT", - config=config - ) - - orchestrator.start() - time.sleep(2) # Let data populate - - symbol = "ETH/USDT" - - # First call (should build cache) - start_time = time.time() - base_data1 = orchestrator.build_base_data_input(symbol) - first_call_time = (time.time() - start_time) * 1000 - - # Second call (should use cache) - start_time = time.time() - base_data2 = orchestrator.build_base_data_input(symbol) - second_call_time = (time.time() - start_time) * 1000 - - # Third call (should still use cache) - start_time = time.time() - base_data3 = orchestrator.build_base_data_input(symbol) - third_call_time = (time.time() - start_time) * 1000 - - logger.info(f"First call (build cache): {first_call_time:.2f}ms") - logger.info(f"Second call (use cache): {second_call_time:.2f}ms") - logger.info(f"Third call (use cache): {third_call_time:.2f}ms") - - # Cache should make subsequent calls faster - if second_call_time < first_call_time * 0.5: - logger.info("✅ Cache is working effectively") - cache_effective = True - else: - logger.warning("⚠️ Cache may not be working as expected") - cache_effective = False - - # Verify data consistency - if base_data1 and base_data2 and base_data3: - # Check that we get consistent data structure - if (len(base_data1.ohlcv_1s) == len(base_data2.ohlcv_1s) == len(base_data3.ohlcv_1s)): - logger.info("✅ Data consistency maintained") - else: - logger.warning("⚠️ Data consistency issues detected") - - orchestrator.stop() - - return cache_effective - - except Exception as e: - logger.error(f"❌ Cache effectiveness test failed: {e}") - return False - -def main(): - """Run all performance tests""" - - logger.info("Starting Build Base Data Performance Tests") - - # Test 1: Basic performance - test1_passed = test_build_base_data_performance() - - # Test 2: Cache effectiveness - test2_passed = test_cache_effectiveness() - - # Summary - logger.info("=== Test Summary ===") - logger.info(f"Performance Test: {'✅ PASSED' if test1_passed else '❌ FAILED'}") - logger.info(f"Cache Effectiveness: {'✅ PASSED' if test2_passed else '❌ FAILED'}") - - if test1_passed and test2_passed: - logger.info("🎉 All tests passed! build_base_data_input is optimized.") - logger.info("The system now:") - logger.info(" - Builds BaseDataInput in under 100ms") - logger.info(" - Uses effective caching for repeated calls") - logger.info(" - Maintains data consistency") - else: - logger.error("❌ Some tests failed. Performance optimization needed.") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_bybit_eth_futures.py b/test_bybit_eth_futures.py deleted file mode 100644 index 8a8f085..0000000 --- a/test_bybit_eth_futures.py +++ /dev/null @@ -1,348 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for Bybit ETH futures position opening/closing -""" - -import os -import sys -import time -import logging -from datetime import datetime - -# Add the project root to the path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -# Load environment variables from .env file -try: - from dotenv import load_dotenv - load_dotenv() -except ImportError: - # If dotenv is not available, try to load .env manually - if os.path.exists('.env'): - with open('.env', 'r') as f: - for line in f: - if line.strip() and not line.startswith('#'): - key, value = line.strip().split('=', 1) - os.environ[key] = value - -from core.exchanges.bybit_interface import BybitInterface - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -class BybitEthFuturesTest: - """Test class for Bybit ETH futures trading""" - - def __init__(self, test_mode=True): - self.test_mode = test_mode - self.bybit = BybitInterface(test_mode=test_mode) - self.test_symbol = 'ETHUSDT' - self.test_quantity = 0.01 # Small test amount - - def run_tests(self): - """Run all tests""" - print("=" * 60) - print("BYBIT ETH FUTURES POSITION TESTING") - print("=" * 60) - print(f"Test mode: {'TESTNET' if self.test_mode else 'LIVE'}") - print(f"Symbol: {self.test_symbol}") - print(f"Test quantity: {self.test_quantity} ETH") - print("=" * 60) - - # Test 1: Connection - if not self.test_connection(): - print("❌ Connection failed - stopping tests") - return False - - # Test 2: Check balance - if not self.test_balance(): - print("❌ Balance check failed - stopping tests") - return False - - # Test 3: Check current positions - self.test_current_positions() - - # Test 4: Get ticker - if not self.test_ticker(): - print("❌ Ticker test failed - stopping tests") - return False - - # Test 5: Open a long position - long_order = self.test_open_long_position() - if not long_order: - print("❌ Open long position failed") - return False - - # Test 6: Check position after opening - time.sleep(2) # Wait for position to be reflected - if not self.test_position_after_open(): - print("❌ Position check after opening failed") - return False - - # Test 7: Close the position - if not self.test_close_position(): - print("❌ Close position failed") - return False - - # Test 8: Check position after closing - time.sleep(2) # Wait for position to be reflected - self.test_position_after_close() - - print("\n" + "=" * 60) - print("✅ ALL TESTS COMPLETED SUCCESSFULLY") - print("=" * 60) - return True - - def test_connection(self): - """Test connection to Bybit""" - print("\n📡 Testing connection to Bybit...") - - # First test simple connectivity without auth - print("Testing basic API connectivity...") - try: - from core.exchanges.bybit_rest_client import BybitRestClient - client = BybitRestClient( - api_key="dummy", - api_secret="dummy", - testnet=True - ) - - # Test public endpoint (server time) - server_time = client.get_server_time() - print(f"✅ Public API working - Server time: {server_time.get('result', {}).get('timeSecond')}") - - except Exception as e: - print(f"❌ Public API failed: {e}") - return False - - # Now test with actual credentials - print("Testing with API credentials...") - try: - connected = self.bybit.connect() - if connected: - print("✅ Successfully connected to Bybit with credentials") - return True - else: - print("❌ Failed to connect to Bybit with credentials") - print("This might be due to:") - print("- Invalid API credentials") - print("- Credentials not enabled for testnet") - print("- Missing required permissions") - return False - except Exception as e: - print(f"❌ Connection error: {e}") - return False - - def test_balance(self): - """Test getting account balance""" - print("\n💰 Testing account balance...") - - try: - # Get USDT balance (for margin) - usdt_balance = self.bybit.get_balance('USDT') - print(f"USDT Balance: {usdt_balance}") - - # Get all balances - all_balances = self.bybit.get_all_balances() - print("All balances:") - for asset, balance in all_balances.items(): - if balance['total'] > 0: - print(f" {asset}: Free={balance['free']}, Locked={balance['locked']}, Total={balance['total']}") - - if usdt_balance > 10: # Need at least $10 for testing - print("✅ Sufficient balance for testing") - return True - else: - print("❌ Insufficient USDT balance for testing (need at least $10)") - return False - - except Exception as e: - print(f"❌ Balance check error: {e}") - return False - - def test_current_positions(self): - """Test getting current positions""" - print("\n📊 Checking current positions...") - - try: - positions = self.bybit.get_positions() - if positions: - print(f"Found {len(positions)} open positions:") - for pos in positions: - print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}") - print(f" PnL: ${pos['unrealized_pnl']:.2f} ({pos['percentage']:.2f}%)") - else: - print("No open positions found") - - except Exception as e: - print(f"❌ Position check error: {e}") - - def test_ticker(self): - """Test getting ticker information""" - print(f"\n📈 Testing ticker for {self.test_symbol}...") - - try: - ticker = self.bybit.get_ticker(self.test_symbol) - if ticker: - print(f"✅ Ticker data received:") - print(f" Last Price: ${ticker['last_price']:.2f}") - print(f" Bid: ${ticker['bid_price']:.2f}") - print(f" Ask: ${ticker['ask_price']:.2f}") - print(f" 24h Volume: {ticker['volume_24h']:.2f}") - print(f" 24h Change: {ticker['change_24h']:.4f}%") - return True - else: - print("❌ Failed to get ticker data") - return False - - except Exception as e: - print(f"❌ Ticker error: {e}") - return False - - def test_open_long_position(self): - """Test opening a long position""" - print(f"\n🚀 Opening long position for {self.test_quantity} {self.test_symbol}...") - - try: - # Place market buy order - order = self.bybit.place_order( - symbol=self.test_symbol, - side='buy', - order_type='market', - quantity=self.test_quantity - ) - - if 'error' in order: - print(f"❌ Order failed: {order['error']}") - return None - - print("✅ Long position opened successfully:") - print(f" Order ID: {order['order_id']}") - print(f" Symbol: {order['symbol']}") - print(f" Side: {order['side']}") - print(f" Quantity: {order['quantity']}") - print(f" Status: {order['status']}") - - return order - - except Exception as e: - print(f"❌ Open position error: {e}") - return None - - def test_position_after_open(self): - """Test checking position after opening""" - print(f"\n📊 Checking position after opening...") - - try: - positions = self.bybit.get_positions(self.test_symbol) - if positions: - position = positions[0] - print("✅ Position found:") - print(f" Symbol: {position['symbol']}") - print(f" Side: {position['side']}") - print(f" Size: {position['size']}") - print(f" Entry Price: ${position['entry_price']:.2f}") - print(f" Mark Price: ${position['mark_price']:.2f}") - print(f" Unrealized PnL: ${position['unrealized_pnl']:.2f}") - print(f" Percentage: {position['percentage']:.2f}%") - print(f" Leverage: {position['leverage']}x") - return True - else: - print("❌ No position found after opening") - return False - - except Exception as e: - print(f"❌ Position check error: {e}") - return False - - def test_close_position(self): - """Test closing the position""" - print(f"\n🔄 Closing position for {self.test_symbol}...") - - try: - # Close the position - close_order = self.bybit.close_position(self.test_symbol) - - if 'error' in close_order: - print(f"❌ Close order failed: {close_order['error']}") - return False - - print("✅ Position closed successfully:") - print(f" Order ID: {close_order['order_id']}") - print(f" Symbol: {close_order['symbol']}") - print(f" Side: {close_order['side']}") - print(f" Quantity: {close_order['quantity']}") - print(f" Status: {close_order['status']}") - - return True - - except Exception as e: - print(f"❌ Close position error: {e}") - return False - - def test_position_after_close(self): - """Test checking position after closing""" - print(f"\n📊 Checking position after closing...") - - try: - positions = self.bybit.get_positions(self.test_symbol) - if positions: - position = positions[0] - print("⚠️ Position still exists (may be partially closed):") - print(f" Symbol: {position['symbol']}") - print(f" Side: {position['side']}") - print(f" Size: {position['size']}") - print(f" Entry Price: ${position['entry_price']:.2f}") - print(f" Unrealized PnL: ${position['unrealized_pnl']:.2f}") - else: - print("✅ Position successfully closed - no open positions") - - except Exception as e: - print(f"❌ Position check error: {e}") - - def test_order_history(self): - """Test getting order history""" - print(f"\n📋 Checking recent orders...") - - try: - # Get open orders - open_orders = self.bybit.get_open_orders(self.test_symbol) - print(f"Open orders: {len(open_orders)}") - for order in open_orders: - print(f" {order['order_id']}: {order['side']} {order['quantity']} @ ${order['price']:.2f} - {order['status']}") - - except Exception as e: - print(f"❌ Order history error: {e}") - -def main(): - """Main function""" - print("Starting Bybit ETH Futures Test...") - - # Check if API credentials are set - api_key = os.getenv('BYBIT_API_KEY') - api_secret = os.getenv('BYBIT_API_SECRET') - - if not api_key or not api_secret: - print("❌ Please set BYBIT_API_KEY and BYBIT_API_SECRET environment variables") - return False - - # Create test instance - test = BybitEthFuturesTest(test_mode=True) # Always use testnet for safety - - # Run tests - success = test.run_tests() - - if success: - print("\n🎉 All tests passed!") - else: - print("\n💥 Some tests failed!") - - return success - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) diff --git a/test_bybit_eth_futures_fixed.py b/test_bybit_eth_futures_fixed.py deleted file mode 100644 index ec07c56..0000000 --- a/test_bybit_eth_futures_fixed.py +++ /dev/null @@ -1,304 +0,0 @@ -#!/usr/bin/env python3 -""" -Fixed Bybit ETH futures trading test with proper minimum order size handling -""" - -import os -import sys -import time -import logging -import json - -# Add the project root to the path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -# Load environment variables -try: - from dotenv import load_dotenv - load_dotenv() -except ImportError: - if os.path.exists('.env'): - with open('.env', 'r') as f: - for line in f: - if line.strip() and not line.startswith('#'): - key, value = line.strip().split('=', 1) - os.environ[key] = value - -from core.exchanges.bybit_interface import BybitInterface - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def get_instrument_info(bybit: BybitInterface, symbol: str) -> dict: - """Get instrument information including minimum order size""" - try: - instruments = bybit.get_instruments("linear") - for instrument in instruments: - if instrument.get('symbol') == symbol: - return instrument - return {} - except Exception as e: - logger.error(f"Error getting instrument info: {e}") - return {} - -def test_eth_futures_trading(): - """Test ETH futures trading with proper minimum order size""" - print("🚀 Starting Fixed Bybit ETH Futures Live Trading Test...") - print("=" * 60) - print("BYBIT ETH FUTURES LIVE TRADING TEST (FIXED)") - print("=" * 60) - print("⚠️ This uses LIVE environment with real money!") - print("⚠️ Will check minimum order size first") - print("=" * 60) - - # Check if API credentials are set - api_key = os.getenv('BYBIT_API_KEY') - api_secret = os.getenv('BYBIT_API_SECRET') - - if not api_key or not api_secret: - print("❌ API credentials not found in environment") - return False - - # Create Bybit interface with live environment - bybit = BybitInterface( - api_key=api_key, - api_secret=api_secret, - test_mode=False # Use live environment - ) - - symbol = 'ETHUSDT' - - # Test 1: Connection - print(f"\n📡 Testing connection to Bybit live environment...") - try: - if not bybit.connect(): - print("❌ Failed to connect to Bybit") - return False - print("✅ Successfully connected to Bybit live environment") - except Exception as e: - print(f"❌ Connection error: {e}") - return False - - # Test 2: Get instrument information to check minimum order size - print(f"\n📋 Getting instrument information for {symbol}...") - try: - instrument_info = get_instrument_info(bybit, symbol) - if not instrument_info: - print(f"❌ Failed to get instrument info for {symbol}") - return False - - print("✅ Instrument information retrieved:") - print(f" Symbol: {instrument_info.get('symbol')}") - print(f" Status: {instrument_info.get('status')}") - print(f" Base Coin: {instrument_info.get('baseCoin')}") - print(f" Quote Coin: {instrument_info.get('quoteCoin')}") - - # Extract minimum order size - lot_size_filter = instrument_info.get('lotSizeFilter', {}) - min_order_qty = float(lot_size_filter.get('minOrderQty', 0.01)) - max_order_qty = float(lot_size_filter.get('maxOrderQty', 10000)) - qty_step = float(lot_size_filter.get('qtyStep', 0.01)) - - print(f" Minimum Order Qty: {min_order_qty}") - print(f" Maximum Order Qty: {max_order_qty}") - print(f" Quantity Step: {qty_step}") - - # Use minimum order size for testing - test_quantity = min_order_qty - print(f" Using test quantity: {test_quantity} ETH") - - except Exception as e: - print(f"❌ Instrument info error: {e}") - return False - - # Test 3: Get account balance - print(f"\n💰 Checking account balance...") - try: - usdt_balance = bybit.get_balance('USDT') - print(f"USDT Balance: ${usdt_balance:.2f}") - - # Calculate required balance (with some buffer) - current_price_data = bybit.get_ticker(symbol) - if not current_price_data: - print("❌ Failed to get current ETH price") - return False - - current_price = current_price_data['last_price'] - required_balance = current_price * test_quantity * 1.1 # 10% buffer - - print(f"Current ETH price: ${current_price:.2f}") - print(f"Required balance: ${required_balance:.2f}") - - if usdt_balance < required_balance: - print(f"❌ Insufficient USDT balance for testing (need at least ${required_balance:.2f})") - return False - - print("✅ Sufficient balance for testing") - - except Exception as e: - print(f"❌ Balance check error: {e}") - return False - - # Test 4: Check existing positions - print(f"\n📊 Checking existing positions...") - try: - positions = bybit.get_positions(symbol) - if positions: - print(f"Found {len(positions)} existing positions:") - for pos in positions: - print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}") - print(f" PnL: ${pos['unrealized_pnl']:.2f}") - else: - print("No existing positions found") - except Exception as e: - print(f"❌ Position check error: {e}") - return False - - # Test 5: Ask user confirmation before trading - print(f"\n⚠️ TRADING CONFIRMATION") - print(f" Symbol: {symbol}") - print(f" Quantity: {test_quantity} ETH") - print(f" Estimated cost: ${current_price * test_quantity:.2f}") - print(f" Environment: LIVE (real money)") - print(f" Minimum order size confirmed: {min_order_qty}") - - response = input("\nDo you want to proceed with the live trading test? (y/N): ").lower() - if response != 'y' and response != 'yes': - print("❌ Trading test cancelled by user") - return False - - # Test 6: Open a small long position - print(f"\n🚀 Opening small long position...") - try: - order = bybit.place_order( - symbol=symbol, - side='buy', - order_type='market', - quantity=test_quantity - ) - - if 'error' in order: - print(f"❌ Order failed: {order['error']}") - return False - - print("✅ Long position opened successfully:") - print(f" Order ID: {order['order_id']}") - print(f" Symbol: {order['symbol']}") - print(f" Side: {order['side']}") - print(f" Quantity: {order['quantity']}") - print(f" Status: {order['status']}") - - order_id = order['order_id'] - - except Exception as e: - print(f"❌ Order placement error: {e}") - return False - - # Test 7: Wait a moment and check position - print(f"\n⏳ Waiting 5 seconds for position to be reflected...") - time.sleep(5) - - try: - positions = bybit.get_positions(symbol) - if positions: - position = positions[0] - print("✅ Position confirmed:") - print(f" Symbol: {position['symbol']}") - print(f" Side: {position['side']}") - print(f" Size: {position['size']}") - print(f" Entry Price: ${position['entry_price']:.2f}") - print(f" Current PnL: ${position['unrealized_pnl']:.2f}") - print(f" Leverage: {position['leverage']}x") - else: - print("⚠️ No position found (may already be closed)") - - except Exception as e: - print(f"❌ Position check error: {e}") - - # Test 8: Close the position - print(f"\n🔄 Closing the position...") - try: - close_order = bybit.close_position(symbol) - - if 'error' in close_order: - print(f"❌ Close order failed: {close_order['error']}") - # Don't return False here, as the position might still exist - print("⚠️ You may need to manually close the position") - else: - print("✅ Position closed successfully:") - print(f" Order ID: {close_order['order_id']}") - print(f" Symbol: {close_order['symbol']}") - print(f" Side: {close_order['side']}") - print(f" Quantity: {close_order['quantity']}") - print(f" Status: {close_order['status']}") - - except Exception as e: - print(f"❌ Close position error: {e}") - print("⚠️ You may need to manually close the position") - - # Test 9: Final position check - print(f"\n📊 Final position check...") - time.sleep(3) - - try: - positions = bybit.get_positions(symbol) - if positions: - position = positions[0] - print("⚠️ Position still exists:") - print(f" Size: {position['size']}") - print(f" PnL: ${position['unrealized_pnl']:.2f}") - print("💡 You may want to manually close this position") - else: - print("✅ No open positions - trading test completed successfully") - - except Exception as e: - print(f"❌ Final position check error: {e}") - - # Test 10: Final balance check - print(f"\n💰 Final balance check...") - try: - final_balance = bybit.get_balance('USDT') - print(f"Final USDT Balance: ${final_balance:.2f}") - - balance_change = final_balance - usdt_balance - if balance_change > 0: - print(f"💰 Profit: +${balance_change:.2f}") - elif balance_change < 0: - print(f"📉 Loss: ${balance_change:.2f}") - else: - print(f"🔄 No change: ${balance_change:.2f}") - - except Exception as e: - print(f"❌ Final balance check error: {e}") - - return True - -def main(): - """Main function""" - print("🚀 Starting Fixed Bybit ETH Futures Live Trading Test...") - - success = test_eth_futures_trading() - - if success: - print("\n" + "=" * 60) - print("✅ BYBIT ETH FUTURES TRADING TEST COMPLETED") - print("=" * 60) - print("🎯 Your Bybit integration is fully functional!") - print("🔄 Position opening and closing works correctly") - print("💰 Account balance integration works") - print("📊 All trading functions are operational") - print("📏 Minimum order size handling works") - print("=" * 60) - else: - print("\n💥 Trading test failed!") - print("🔍 Check the error messages above for details") - - return success - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/test_bybit_eth_live.py b/test_bybit_eth_live.py deleted file mode 100644 index 9554a91..0000000 --- a/test_bybit_eth_live.py +++ /dev/null @@ -1,249 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Bybit ETH futures trading with live environment -""" - -import os -import sys -import time -import logging - -# Add the project root to the path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -# Load environment variables -try: - from dotenv import load_dotenv - load_dotenv() -except ImportError: - if os.path.exists('.env'): - with open('.env', 'r') as f: - for line in f: - if line.strip() and not line.startswith('#'): - key, value = line.strip().split('=', 1) - os.environ[key] = value - -from core.exchanges.bybit_interface import BybitInterface - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_eth_futures_trading(): - """Test ETH futures trading with live environment""" - print("=" * 60) - print("BYBIT ETH FUTURES LIVE TRADING TEST") - print("=" * 60) - print("⚠️ This uses LIVE environment with real money!") - print("⚠️ Test amount: 0.001 ETH (very small)") - print("=" * 60) - - # Check if API credentials are set - api_key = os.getenv('BYBIT_API_KEY') - api_secret = os.getenv('BYBIT_API_SECRET') - - if not api_key or not api_secret: - print("❌ API credentials not found in environment") - return False - - # Create Bybit interface with live environment - bybit = BybitInterface( - api_key=api_key, - api_secret=api_secret, - test_mode=False # Use live environment - ) - - symbol = 'ETHUSDT' - test_quantity = 0.01 # Minimum order size for ETH futures - - # Test 1: Connection - print(f"\n📡 Testing connection to Bybit live environment...") - try: - if not bybit.connect(): - print("❌ Failed to connect to Bybit") - return False - print("✅ Successfully connected to Bybit live environment") - except Exception as e: - print(f"❌ Connection error: {e}") - return False - - # Test 2: Get account balance - print(f"\n💰 Checking account balance...") - try: - usdt_balance = bybit.get_balance('USDT') - print(f"USDT Balance: ${usdt_balance:.2f}") - - if usdt_balance < 5: - print("❌ Insufficient USDT balance for testing (need at least $5)") - return False - - print("✅ Sufficient balance for testing") - except Exception as e: - print(f"❌ Balance check error: {e}") - return False - - # Test 3: Get current ETH price - print(f"\n📈 Getting current ETH price...") - try: - ticker = bybit.get_ticker(symbol) - if not ticker: - print("❌ Failed to get ticker") - return False - - current_price = ticker['last_price'] - print(f"Current ETH price: ${current_price:.2f}") - print(f"Test order value: ${current_price * test_quantity:.2f}") - - except Exception as e: - print(f"❌ Ticker error: {e}") - return False - - # Test 4: Check existing positions - print(f"\n📊 Checking existing positions...") - try: - positions = bybit.get_positions(symbol) - if positions: - print(f"Found {len(positions)} existing positions:") - for pos in positions: - print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}") - print(f" PnL: ${pos['unrealized_pnl']:.2f}") - else: - print("No existing positions found") - except Exception as e: - print(f"❌ Position check error: {e}") - return False - - # Test 5: Ask user confirmation before trading - print(f"\n⚠️ TRADING CONFIRMATION") - print(f" Symbol: {symbol}") - print(f" Quantity: {test_quantity} ETH") - print(f" Estimated cost: ${current_price * test_quantity:.2f}") - print(f" Environment: LIVE (real money)") - - response = input("\nDo you want to proceed with the live trading test? (y/N): ").lower() - if response != 'y' and response != 'yes': - print("❌ Trading test cancelled by user") - return False - - # Test 6: Open a small long position - print(f"\n🚀 Opening small long position...") - try: - order = bybit.place_order( - symbol=symbol, - side='buy', - order_type='market', - quantity=test_quantity - ) - - if 'error' in order: - print(f"❌ Order failed: {order['error']}") - return False - - print("✅ Long position opened successfully:") - print(f" Order ID: {order['order_id']}") - print(f" Symbol: {order['symbol']}") - print(f" Side: {order['side']}") - print(f" Quantity: {order['quantity']}") - print(f" Status: {order['status']}") - - order_id = order['order_id'] - - except Exception as e: - print(f"❌ Order placement error: {e}") - return False - - # Test 7: Wait a moment and check position - print(f"\n⏳ Waiting 3 seconds for position to be reflected...") - time.sleep(3) - - try: - positions = bybit.get_positions(symbol) - if positions: - position = positions[0] - print("✅ Position confirmed:") - print(f" Symbol: {position['symbol']}") - print(f" Side: {position['side']}") - print(f" Size: {position['size']}") - print(f" Entry Price: ${position['entry_price']:.2f}") - print(f" Current PnL: ${position['unrealized_pnl']:.2f}") - print(f" Leverage: {position['leverage']}x") - else: - print("⚠️ No position found (may already be closed)") - - except Exception as e: - print(f"❌ Position check error: {e}") - - # Test 8: Close the position - print(f"\n🔄 Closing the position...") - try: - close_order = bybit.close_position(symbol) - - if 'error' in close_order: - print(f"❌ Close order failed: {close_order['error']}") - return False - - print("✅ Position closed successfully:") - print(f" Order ID: {close_order['order_id']}") - print(f" Symbol: {close_order['symbol']}") - print(f" Side: {close_order['side']}") - print(f" Quantity: {close_order['quantity']}") - print(f" Status: {close_order['status']}") - - except Exception as e: - print(f"❌ Close position error: {e}") - return False - - # Test 9: Final position check - print(f"\n📊 Final position check...") - time.sleep(2) - - try: - positions = bybit.get_positions(symbol) - if positions: - position = positions[0] - print("⚠️ Position still exists:") - print(f" Size: {position['size']}") - print(f" PnL: ${position['unrealized_pnl']:.2f}") - else: - print("✅ No open positions - trading test completed successfully") - - except Exception as e: - print(f"❌ Final position check error: {e}") - - # Test 10: Final balance check - print(f"\n💰 Final balance check...") - try: - final_balance = bybit.get_balance('USDT') - print(f"Final USDT Balance: ${final_balance:.2f}") - - except Exception as e: - print(f"❌ Final balance check error: {e}") - - return True - -def main(): - """Main function""" - print("🚀 Starting Bybit ETH Futures Live Trading Test...") - - success = test_eth_futures_trading() - - if success: - print("\n" + "=" * 60) - print("✅ BYBIT ETH FUTURES TRADING TEST COMPLETED") - print("=" * 60) - print("🎯 Your Bybit integration is fully functional!") - print("🔄 Position opening and closing works correctly") - print("💰 Account balance integration works") - print("📊 All trading functions are operational") - print("=" * 60) - else: - print("\n💥 Trading test failed!") - - return success - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) diff --git a/test_bybit_public_api.py b/test_bybit_public_api.py deleted file mode 100644 index de58964..0000000 --- a/test_bybit_public_api.py +++ /dev/null @@ -1,220 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Bybit public API functionality (no authentication required) -""" - -import os -import sys -import time -import logging - -# Add the project root to the path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.exchanges.bybit_rest_client import BybitRestClient - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_public_api(): - """Test public API endpoints""" - print("=" * 60) - print("BYBIT PUBLIC API TEST") - print("=" * 60) - - # Test both testnet and live for public endpoints - for testnet in [True, False]: - env_name = "TESTNET" if testnet else "LIVE" - print(f"\n🔄 Testing {env_name} environment...") - - client = BybitRestClient( - api_key="dummy", - api_secret="dummy", - testnet=testnet - ) - - # Test 1: Server time - try: - server_time = client.get_server_time() - time_second = server_time.get('result', {}).get('timeSecond') - print(f"✅ Server time: {time_second}") - except Exception as e: - print(f"❌ Server time failed: {e}") - continue - - # Test 2: Get ticker for ETHUSDT - try: - ticker = client.get_ticker('ETHUSDT', 'linear') - ticker_data = ticker.get('result', {}).get('list', []) - if ticker_data: - data = ticker_data[0] - print(f"✅ ETH/USDT ticker:") - print(f" Last Price: ${float(data.get('lastPrice', 0)):.2f}") - print(f" 24h Volume: {float(data.get('volume24h', 0)):.2f}") - print(f" 24h Change: {float(data.get('price24hPcnt', 0)) * 100:.2f}%") - else: - print("❌ No ticker data received") - except Exception as e: - print(f"❌ Ticker failed: {e}") - - # Test 3: Get instruments info - try: - instruments = client.get_instruments_info('linear') - instruments_list = instruments.get('result', {}).get('list', []) - eth_instruments = [i for i in instruments_list if 'ETH' in i.get('symbol', '')] - print(f"✅ Found {len(eth_instruments)} ETH instruments") - for instr in eth_instruments[:3]: # Show first 3 - print(f" {instr.get('symbol')} - Status: {instr.get('status')}") - except Exception as e: - print(f"❌ Instruments failed: {e}") - - # Test 4: Get orderbook - try: - orderbook = client.get_orderbook('ETHUSDT', 'linear', 5) - ob_data = orderbook.get('result', {}) - bids = ob_data.get('b', []) - asks = ob_data.get('a', []) - - if bids and asks: - print(f"✅ Orderbook (top 3):") - print(f" Best bid: ${float(bids[0][0]):.2f} (qty: {float(bids[0][1]):.4f})") - print(f" Best ask: ${float(asks[0][0]):.2f} (qty: {float(asks[0][1]):.4f})") - spread = float(asks[0][0]) - float(bids[0][0]) - print(f" Spread: ${spread:.2f}") - else: - print("❌ No orderbook data received") - except Exception as e: - print(f"❌ Orderbook failed: {e}") - - print(f"📊 {env_name} environment test completed") - -def test_live_authentication(): - """Test live authentication (if user wants to test with live credentials)""" - print("\n" + "=" * 60) - print("BYBIT LIVE AUTHENTICATION TEST") - print("=" * 60) - print("⚠️ This will test with LIVE credentials (not testnet)") - - # Load environment variables - try: - from dotenv import load_dotenv - load_dotenv() - except ImportError: - # If dotenv is not available, try to load .env manually - if os.path.exists('.env'): - with open('.env', 'r') as f: - for line in f: - if line.strip() and not line.startswith('#'): - key, value = line.strip().split('=', 1) - os.environ[key] = value - - api_key = os.getenv('BYBIT_API_KEY') - api_secret = os.getenv('BYBIT_API_SECRET') - - if not api_key or not api_secret: - print("❌ No API credentials found in environment") - return - - print(f"🔑 Using API key: {api_key[:8]}...") - - # Test with live environment (testnet=False) - client = BybitRestClient( - api_key=api_key, - api_secret=api_secret, - testnet=False # Use live environment - ) - - # Test connectivity - try: - if client.test_connectivity(): - print("✅ Basic connectivity OK") - else: - print("❌ Basic connectivity failed") - return - except Exception as e: - print(f"❌ Connectivity error: {e}") - return - - # Test authentication - try: - if client.test_authentication(): - print("✅ Authentication successful!") - - # Get account info - account_info = client.get_account_info() - accounts = account_info.get('result', {}).get('list', []) - - if accounts: - print("📊 Account information:") - for account in accounts: - account_type = account.get('accountType', 'Unknown') - print(f" Account Type: {account_type}") - - coins = account.get('coin', []) - usdt_balance = None - for coin in coins: - if coin.get('coin') == 'USDT': - usdt_balance = float(coin.get('walletBalance', 0)) - break - - if usdt_balance: - print(f" USDT Balance: ${usdt_balance:.2f}") - - # Show positions if any - try: - positions = client.get_positions('linear') - pos_list = positions.get('result', {}).get('list', []) - active_positions = [p for p in pos_list if float(p.get('size', 0)) != 0] - - if active_positions: - print(f" Active Positions: {len(active_positions)}") - for pos in active_positions: - symbol = pos.get('symbol') - side = pos.get('side') - size = float(pos.get('size', 0)) - pnl = float(pos.get('unrealisedPnl', 0)) - print(f" {symbol}: {side} {size} (PnL: ${pnl:.2f})") - else: - print(" No active positions") - except Exception as e: - print(f" ⚠️ Could not get positions: {e}") - - return True - else: - print("❌ Authentication failed") - return False - - except Exception as e: - print(f"❌ Authentication error: {e}") - return False - -def main(): - """Main function""" - print("🚀 Starting Bybit API Tests...") - - # Test public API - test_public_api() - - # Ask user if they want to test live authentication - print("\n" + "=" * 60) - response = input("Do you want to test live authentication? (y/N): ").lower() - - if response == 'y' or response == 'yes': - success = test_live_authentication() - if success: - print("\n✅ Live authentication test passed!") - print("🎯 Your Bybit integration is working!") - else: - print("\n❌ Live authentication test failed") - else: - print("\n📋 Skipping live authentication test") - - print("\n🎉 Public API tests completed successfully!") - print("📈 Bybit integration is functional for market data") - -if __name__ == "__main__": - main() diff --git a/test_cache_fix.py b/test_cache_fix.py deleted file mode 100644 index 078af14..0000000 --- a/test_cache_fix.py +++ /dev/null @@ -1,137 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Cache Fix - -Creates a corrupted Parquet file to test the fix mechanism -""" - -import os -import pandas as pd -from pathlib import Path -from utils.cache_manager import get_cache_manager - -def create_test_data(): - """Create test cache files including a corrupted one""" - print("Creating test cache files...") - - # Ensure cache directory exists - cache_dir = Path("data/cache") - cache_dir.mkdir(parents=True, exist_ok=True) - - # Create a valid Parquet file - valid_data = pd.DataFrame({ - 'timestamp': pd.date_range('2025-01-01', periods=100, freq='1min'), - 'open': [100.0 + i for i in range(100)], - 'high': [101.0 + i for i in range(100)], - 'low': [99.0 + i for i in range(100)], - 'close': [100.5 + i for i in range(100)], - 'volume': [1000 + i*10 for i in range(100)] - }) - - valid_file = cache_dir / "ETHUSDT_1m.parquet" - valid_data.to_parquet(valid_file, index=False) - print(f"Created valid file: {valid_file}") - - # Create a corrupted Parquet file by writing invalid data - corrupted_file = cache_dir / "BTCUSDT_1m.parquet" - with open(corrupted_file, 'wb') as f: - f.write(b"This is not a valid Parquet file - corrupted data") - print(f"Created corrupted file: {corrupted_file}") - - # Create an empty file - empty_file = cache_dir / "SOLUSDT_1m.parquet" - empty_file.touch() - print(f"Created empty file: {empty_file}") - -def test_cache_manager(): - """Test the cache manager's ability to detect and fix issues""" - print("\n=== Testing Cache Manager ===") - - cache_manager = get_cache_manager() - - # Scan health - print("1. Scanning cache health...") - health_summary = cache_manager.get_cache_summary() - - print(f"Total files: {health_summary['total_files']}") - print(f"Valid files: {health_summary['valid_files']}") - print(f"Corrupted files: {health_summary['corrupted_files']}") - print(f"Health percentage: {health_summary['health_percentage']:.1f}%") - - # Show corrupted files - for cache_dir, report in health_summary['directories'].items(): - if report['corrupted_files'] > 0: - print(f"\nCorrupted files in {cache_dir}:") - for corrupted in report['corrupted_files_list']: - print(f" - {corrupted['file']}: {corrupted['error']}") - - # Test cleanup - print("\n2. Testing cleanup...") - deleted_files = cache_manager.cleanup_corrupted_files(dry_run=False) - - deleted_count = 0 - for cache_dir, files in deleted_files.items(): - for file_info in files: - if "DELETED:" in file_info: - deleted_count += 1 - print(f" {file_info}") - - print(f"Deleted {deleted_count} corrupted files") - - # Verify cleanup - print("\n3. Verifying cleanup...") - health_summary_after = cache_manager.get_cache_summary() - print(f"Corrupted files after cleanup: {health_summary_after['corrupted_files']}") - -def test_data_provider_integration(): - """Test that the data provider can handle corrupted cache gracefully""" - print("\n=== Testing Data Provider Integration ===") - - # Create another corrupted file - cache_dir = Path("data/cache") - corrupted_file = cache_dir / "ETHUSDT_5m.parquet" - with open(corrupted_file, 'wb') as f: - f.write(b"PAR1\x00\x00corrupted thrift data that will cause deserialization error") - print(f"Created corrupted file with thrift error: {corrupted_file}") - - # Try to import and use data provider - try: - from core.data_provider import DataProvider - - # Create data provider (should auto-fix corrupted cache) - data_provider = DataProvider() - print("Data provider created successfully - auto-fix worked!") - - # Try to load data (should handle corruption gracefully) - try: - data = data_provider._load_from_cache("ETH/USDT", "5m") - if data is None: - print("Cache loading returned None (expected for corrupted file)") - else: - print(f"Loaded {len(data)} rows from cache") - except Exception as e: - print(f"Cache loading failed: {e}") - - except Exception as e: - print(f"Data provider test failed: {e}") - -def main(): - """Run all tests""" - print("=== Cache Fix Test Suite ===") - - # Clean up any existing test files - cache_dir = Path("data/cache") - if cache_dir.exists(): - for file in cache_dir.glob("*.parquet"): - file.unlink() - - # Run tests - create_test_data() - test_cache_manager() - test_data_provider_integration() - - print("\n=== Test Complete ===") - print("The cache fix system is working correctly!") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_cnn_integration.py b/test_cnn_integration.py deleted file mode 100644 index 46671dc..0000000 --- a/test_cnn_integration.py +++ /dev/null @@ -1,175 +0,0 @@ -#!/usr/bin/env python3 -""" -Test CNN Integration - -This script tests if the CNN adapter is working properly and identifies issues. -""" - -import logging -import sys -import os -from datetime import datetime - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def test_cnn_adapter(): - """Test CNN adapter initialization and basic functionality""" - try: - logger.info("Testing CNN adapter initialization...") - - # Test 1: Import CNN adapter - from core.enhanced_cnn_adapter import EnhancedCNNAdapter - logger.info("✅ CNN adapter import successful") - - # Test 2: Initialize CNN adapter - cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") - logger.info("✅ CNN adapter initialization successful") - - # Test 3: Check adapter attributes - logger.info(f"CNN adapter model: {cnn_adapter.model}") - logger.info(f"CNN adapter device: {cnn_adapter.device}") - logger.info(f"CNN adapter model_name: {cnn_adapter.model_name}") - - # Test 4: Check metrics tracking - logger.info(f"Inference count: {cnn_adapter.inference_count}") - logger.info(f"Training count: {cnn_adapter.training_count}") - logger.info(f"Training data length: {len(cnn_adapter.training_data)}") - - # Test 5: Test simple training sample addition - cnn_adapter.add_training_sample("ETH/USDT", "BUY", 0.1) - logger.info(f"✅ Training sample added, new length: {len(cnn_adapter.training_data)}") - - # Test 6: Test training if we have enough samples - if len(cnn_adapter.training_data) >= 2: - # Add another sample to have minimum for training - cnn_adapter.add_training_sample("ETH/USDT", "SELL", -0.05) - - # Try training - training_result = cnn_adapter.train(epochs=1) - logger.info(f"✅ Training successful: {training_result}") - - # Check if metrics were updated - logger.info(f"Last training time: {cnn_adapter.last_training_time}") - logger.info(f"Last training loss: {cnn_adapter.last_training_loss}") - logger.info(f"Training count: {cnn_adapter.training_count}") - else: - logger.info("⚠️ Not enough training samples for training test") - - return True - - except Exception as e: - logger.error(f"❌ CNN adapter test failed: {e}") - import traceback - traceback.print_exc() - return False - -def test_base_data_input(): - """Test BaseDataInput creation""" - try: - logger.info("Testing BaseDataInput creation...") - - # Test 1: Import BaseDataInput - from core.data_models import BaseDataInput, OHLCVBar, COBData - logger.info("✅ BaseDataInput import successful") - - # Test 2: Create sample OHLCV bars - sample_bars = [] - for i in range(10): # Create 10 sample bars - bar = OHLCVBar( - symbol="ETH/USDT", - timestamp=datetime.now(), - open=3500.0 + i, - high=3510.0 + i, - low=3490.0 + i, - close=3505.0 + i, - volume=1000.0, - timeframe="1s" - ) - sample_bars.append(bar) - - logger.info(f"✅ Created {len(sample_bars)} sample OHLCV bars") - - # Test 3: Create BaseDataInput - base_data = BaseDataInput( - symbol="ETH/USDT", - timestamp=datetime.now(), - ohlcv_1s=sample_bars, - ohlcv_1m=sample_bars, - ohlcv_1h=sample_bars, - ohlcv_1d=sample_bars, - btc_ohlcv_1s=sample_bars - ) - - logger.info("✅ BaseDataInput created successfully") - - # Test 4: Validate BaseDataInput - is_valid = base_data.validate() - logger.info(f"BaseDataInput validation: {is_valid}") - - # Test 5: Get feature vector - feature_vector = base_data.get_feature_vector() - logger.info(f"✅ Feature vector created, shape: {feature_vector.shape}") - - return base_data - - except Exception as e: - logger.error(f"❌ BaseDataInput test failed: {e}") - import traceback - traceback.print_exc() - return None - -def test_cnn_prediction(): - """Test CNN prediction with BaseDataInput""" - try: - logger.info("Testing CNN prediction...") - - # Get CNN adapter and base data - from core.enhanced_cnn_adapter import EnhancedCNNAdapter - cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") - - base_data = test_base_data_input() - if not base_data: - logger.error("❌ Cannot test prediction without valid BaseDataInput") - return False - - # Test prediction - model_output = cnn_adapter.predict(base_data) - logger.info(f"✅ Prediction successful: {model_output.predictions['action']} ({model_output.confidence:.3f})") - - # Check if metrics were updated - logger.info(f"Inference count after prediction: {cnn_adapter.inference_count}") - logger.info(f"Last inference time: {cnn_adapter.last_inference_time}") - logger.info(f"Last prediction output: {cnn_adapter.last_prediction_output}") - - return True - - except Exception as e: - logger.error(f"❌ CNN prediction test failed: {e}") - import traceback - traceback.print_exc() - return False - -def main(): - """Run all tests""" - logger.info("🧪 Starting CNN Integration Tests") - - # Test 1: CNN Adapter - if not test_cnn_adapter(): - logger.error("❌ CNN adapter test failed - stopping") - return False - - # Test 2: CNN Prediction - if not test_cnn_prediction(): - logger.error("❌ CNN prediction test failed - stopping") - return False - - logger.info("✅ All CNN integration tests passed!") - logger.info("🎯 The CNN adapter should now work properly in the dashboard") - - return True - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/test_cob_dashboard.py b/test_cob_dashboard.py deleted file mode 100644 index 87f1363..0000000 --- a/test_cob_dashboard.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python3 -""" -Test COB Dashboard with Enhanced WebSocket -""" - -import asyncio -import logging -from web.cob_realtime_dashboard import COBDashboardServer - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) - -async def main(): - """Test the COB dashboard""" - dashboard = COBDashboardServer(host='localhost', port=8053) - await dashboard.start() - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/test_cob_data_quality.py b/test_cob_data_quality.py deleted file mode 100644 index 948434c..0000000 --- a/test_cob_data_quality.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for COB data quality and imbalance indicators -""" - -import time -import logging -from core.data_provider import DataProvider - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_cob_data_quality(): - """Test COB data quality and imbalance indicators""" - logger.info("Testing COB data quality and imbalance indicators...") - - # Initialize data provider - dp = DataProvider() - - # Wait for initial data load and COB connection - logger.info("Waiting for initial data load and COB connection...") - time.sleep(20) - - # Test 1: Check cached data summary - logger.info("\n=== Test 1: Cached Data Summary ===") - cache_summary = dp.get_cached_data_summary() - for symbol in cache_summary['cached_data']: - logger.info(f"\n{symbol}:") - for timeframe, info in cache_summary['cached_data'][symbol].items(): - if 'candle_count' in info and info['candle_count'] > 0: - logger.info(f" {timeframe}: {info['candle_count']} candles, latest: ${info['latest_price']}") - else: - logger.info(f" {timeframe}: {info.get('status', 'no data')}") - - # Test 2: Check COB data quality - logger.info("\n=== Test 2: COB Data Quality ===") - cob_quality = dp.get_cob_data_quality() - - for symbol in cob_quality['symbols']: - logger.info(f"\n{symbol} COB Data:") - - # Raw ticks - raw_info = cob_quality['raw_ticks'].get(symbol, {}) - logger.info(f" Raw ticks: {raw_info.get('count', 0)} ticks") - if raw_info.get('age_seconds') is not None: - logger.info(f" Raw data age: {raw_info['age_seconds']:.1f} seconds") - - # Aggregated 1s data - agg_info = cob_quality['aggregated_1s'].get(symbol, {}) - logger.info(f" Aggregated 1s: {agg_info.get('count', 0)} records") - if agg_info.get('age_seconds') is not None: - logger.info(f" Aggregated data age: {agg_info['age_seconds']:.1f} seconds") - - # Imbalance indicators - imbalance_info = cob_quality['imbalance_indicators'].get(symbol, {}) - if imbalance_info: - logger.info(f" Imbalance 1s: {imbalance_info.get('imbalance_1s', 0):.4f}") - logger.info(f" Imbalance 5s: {imbalance_info.get('imbalance_5s', 0):.4f}") - logger.info(f" Imbalance 15s: {imbalance_info.get('imbalance_15s', 0):.4f}") - logger.info(f" Imbalance 60s: {imbalance_info.get('imbalance_60s', 0):.4f}") - logger.info(f" Total volume: {imbalance_info.get('total_volume', 0):.2f}") - logger.info(f" Price buckets: {imbalance_info.get('bucket_count', 0)}") - - # Data freshness - freshness = cob_quality['data_freshness'].get(symbol, 'unknown') - logger.info(f" Data freshness: {freshness}") - - # Test 3: Get recent COB aggregated data - logger.info("\n=== Test 3: Recent COB Aggregated Data ===") - for symbol in ['ETH/USDT', 'BTC/USDT']: - recent_cob = dp.get_cob_1s_aggregated(symbol, count=5) - logger.info(f"\n{symbol} - Last 5 aggregated records:") - - for i, record in enumerate(recent_cob[-5:]): - timestamp = record.get('timestamp', 0) - imbalance_1s = record.get('imbalance_1s', 0) - imbalance_5s = record.get('imbalance_5s', 0) - total_volume = record.get('total_volume', 0) - bucket_count = len(record.get('bid_buckets', {})) + len(record.get('ask_buckets', {})) - - logger.info(f" [{i+1}] Time: {timestamp}, Imb1s: {imbalance_1s:.4f}, " - f"Imb5s: {imbalance_5s:.4f}, Vol: {total_volume:.2f}, Buckets: {bucket_count}") - - # Test 4: Monitor real-time updates - logger.info("\n=== Test 4: Real-time Updates (30 seconds) ===") - logger.info("Monitoring COB data updates...") - - initial_quality = dp.get_cob_data_quality() - time.sleep(30) - updated_quality = dp.get_cob_data_quality() - - for symbol in ['ETH/USDT', 'BTC/USDT']: - initial_count = initial_quality['raw_ticks'].get(symbol, {}).get('count', 0) - updated_count = updated_quality['raw_ticks'].get(symbol, {}).get('count', 0) - new_ticks = updated_count - initial_count - - initial_agg = initial_quality['aggregated_1s'].get(symbol, {}).get('count', 0) - updated_agg = updated_quality['aggregated_1s'].get(symbol, {}).get('count', 0) - new_agg = updated_agg - initial_agg - - logger.info(f"{symbol}: +{new_ticks} raw ticks, +{new_agg} aggregated records") - - # Show latest imbalances - latest_imbalances = updated_quality['imbalance_indicators'].get(symbol, {}) - if latest_imbalances: - logger.info(f" Latest imbalances: 1s={latest_imbalances.get('imbalance_1s', 0):.4f}, " - f"5s={latest_imbalances.get('imbalance_5s', 0):.4f}, " - f"15s={latest_imbalances.get('imbalance_15s', 0):.4f}, " - f"60s={latest_imbalances.get('imbalance_60s', 0):.4f}") - - # Clean shutdown - logger.info("\n=== Shutting Down ===") - dp.stop_automatic_data_maintenance() - logger.info("COB data quality test completed") - -if __name__ == "__main__": - test_cob_data_quality() \ No newline at end of file diff --git a/test_cob_websocket_only.py b/test_cob_websocket_only.py deleted file mode 100644 index 73859eb..0000000 --- a/test_cob_websocket_only.py +++ /dev/null @@ -1,131 +0,0 @@ -#!/usr/bin/env python3 -""" -Test COB WebSocket Only Integration - -This script tests that COB integration works with Enhanced WebSocket only, -without falling back to REST API calls. -""" - -import asyncio -import time -from datetime import datetime -from typing import Dict -from core.cob_integration import COBIntegration - -async def test_cob_websocket_only(): - """Test COB integration with WebSocket only""" - print("=== Testing COB WebSocket Only Integration ===") - - # Initialize COB integration - print("1. Initializing COB integration...") - symbols = ['ETH/USDT', 'BTC/USDT'] - cob_integration = COBIntegration(symbols=symbols) - - # Track updates - update_count = 0 - last_update_time = None - - def dashboard_callback(symbol: str, data: Dict): - nonlocal update_count, last_update_time - update_count += 1 - last_update_time = datetime.now() - - if update_count <= 5: # Show first 5 updates - data_type = data.get('type', 'unknown') - if data_type == 'cob_update': - stats = data.get('data', {}).get('stats', {}) - mid_price = stats.get('mid_price', 0) - spread_bps = stats.get('spread_bps', 0) - source = stats.get('source', 'unknown') - print(f" Update #{update_count}: {symbol} - Price: ${mid_price:.2f}, Spread: {spread_bps:.1f}bps, Source: {source}") - elif data_type == 'websocket_status': - status_data = data.get('data', {}) - status = status_data.get('status', 'unknown') - print(f" Status #{update_count}: {symbol} - WebSocket: {status}") - - # Add dashboard callback - cob_integration.add_dashboard_callback(dashboard_callback) - - # Start COB integration - print("2. Starting COB integration...") - try: - # Start in background - start_task = asyncio.create_task(cob_integration.start()) - - # Wait for initialization - await asyncio.sleep(3) - - # Check if COB provider is disabled - print("3. Checking COB provider status:") - if cob_integration.cob_provider is None: - print(" ✅ COB provider is disabled (using Enhanced WebSocket only)") - else: - print(" ❌ COB provider is still active (may cause REST API fallback)") - - # Check Enhanced WebSocket status - print("4. Checking Enhanced WebSocket status:") - if cob_integration.enhanced_websocket: - print(" ✅ Enhanced WebSocket is initialized") - - # Check WebSocket status for each symbol - websocket_status = cob_integration.get_websocket_status() - for symbol, status in websocket_status.items(): - print(f" {symbol}: {status}") - else: - print(" ❌ Enhanced WebSocket is not initialized") - - # Monitor updates for a few seconds - print("5. Monitoring COB updates...") - initial_count = update_count - monitor_start = time.time() - - # Wait for updates - await asyncio.sleep(5) - - monitor_duration = time.time() - monitor_start - updates_received = update_count - initial_count - update_rate = updates_received / monitor_duration - - print(f" Received {updates_received} updates in {monitor_duration:.1f}s") - print(f" Update rate: {update_rate:.1f} updates/second") - - if update_rate >= 8: # Should be around 10 updates/second - print(" ✅ Update rate is excellent (8+ updates/second)") - elif update_rate >= 5: - print(" ✅ Update rate is good (5+ updates/second)") - elif update_rate >= 1: - print(" ⚠️ Update rate is low (1+ updates/second)") - else: - print(" ❌ Update rate is too low (<1 update/second)") - - # Check data quality - print("6. Data quality check:") - if last_update_time: - time_since_last = (datetime.now() - last_update_time).total_seconds() - if time_since_last < 1: - print(f" ✅ Recent data (last update {time_since_last:.1f}s ago)") - else: - print(f" ⚠️ Stale data (last update {time_since_last:.1f}s ago)") - else: - print(" ❌ No updates received") - - # Stop the integration - print("7. Stopping COB integration...") - await cob_integration.stop() - - # Cancel the start task - start_task.cancel() - try: - await start_task - except asyncio.CancelledError: - pass - - except Exception as e: - print(f" ❌ Error during COB integration test: {e}") - - print(f"\n✅ COB WebSocket only test completed!") - print(f"Total updates received: {update_count}") - print("Enhanced WebSocket is now the sole data source (no REST API fallback)") - -if __name__ == "__main__": - asyncio.run(test_cob_websocket_only()) \ No newline at end of file diff --git a/test_continuous_cnn_training.py b/test_continuous_cnn_training.py deleted file mode 100644 index 0a6778b..0000000 --- a/test_continuous_cnn_training.py +++ /dev/null @@ -1,155 +0,0 @@ -""" -Test Continuous CNN Training - -This script demonstrates how the CNN model can be trained with each new inference result -using collected data, implementing a continuous learning loop. -""" - -import logging -import time -from datetime import datetime -import random -import os - -from core.standardized_data_provider import StandardizedDataProvider -from core.enhanced_cnn_adapter import EnhancedCNNAdapter -from core.data_models import create_model_output - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def simulate_market_feedback(action, symbol): - """ - Simulate market feedback for a given action - - In a real system, this would be replaced with actual market performance data - - Args: - action: Trading action ('BUY', 'SELL', 'HOLD') - symbol: Trading symbol - - Returns: - tuple: (actual_action, reward) - """ - # Simulate market movement (random for demonstration) - market_direction = random.choice(['up', 'down', 'sideways']) - - # Determine actual best action based on market direction - if market_direction == 'up': - best_action = 'BUY' - elif market_direction == 'down': - best_action = 'SELL' - else: - best_action = 'HOLD' - - # Calculate reward based on whether the action matched the best action - if action == best_action: - reward = random.uniform(0.01, 0.1) # Positive reward for correct action - else: - reward = random.uniform(-0.1, -0.01) # Negative reward for incorrect action - - logger.info(f"Market went {market_direction}, best action was {best_action}, model chose {action}, reward: {reward:.4f}") - - return best_action, reward - -def test_continuous_training(): - """Test continuous training of the CNN model with new inference results""" - try: - # Initialize data provider - symbols = ['ETH/USDT', 'BTC/USDT'] - timeframes = ['1s', '1m', '1h', '1d'] - data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes) - - # Initialize CNN adapter - checkpoint_dir = "models/enhanced_cnn" - os.makedirs(checkpoint_dir, exist_ok=True) - cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir) - - # Load best checkpoint if available - cnn_adapter.load_best_checkpoint() - - # Continuous learning loop - num_iterations = 10 - training_frequency = 3 # Train every N iterations - samples_collected = 0 - - logger.info(f"Starting continuous learning loop with {num_iterations} iterations") - - for i in range(num_iterations): - logger.info(f"\nIteration {i+1}/{num_iterations}") - - # Get standardized input data - symbol = random.choice(symbols) - logger.info(f"Getting data for {symbol}...") - base_data = data_provider.get_base_data_input(symbol) - - if base_data is None: - logger.warning(f"Failed to get base data input for {symbol}, skipping iteration") - continue - - # Make prediction - logger.info(f"Making prediction for {symbol}...") - model_output = cnn_adapter.predict(base_data) - - # Log prediction - action = model_output.predictions['action'] - confidence = model_output.confidence - logger.info(f"Prediction: {action} with confidence {confidence:.4f}") - - # Store model output - data_provider.store_model_output(model_output) - - # Simulate market feedback - best_action, reward = simulate_market_feedback(action, symbol) - - # Add training sample - logger.info(f"Adding training sample: action={best_action}, reward={reward:.4f}") - cnn_adapter.add_training_sample(base_data, best_action, reward) - samples_collected += 1 - - # Train model periodically - if (i + 1) % training_frequency == 0 and samples_collected >= 3: - logger.info(f"Training model with {samples_collected} samples...") - metrics = cnn_adapter.train(epochs=1) - - # Log training metrics - logger.info(f"Training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}") - - # Simulate time passing - time.sleep(1) - - logger.info("\nContinuous learning loop completed") - - # Final evaluation - logger.info("Performing final evaluation...") - - # Get data for evaluation - symbol = 'ETH/USDT' - base_data = data_provider.get_base_data_input(symbol) - - if base_data is not None: - # Make prediction - model_output = cnn_adapter.predict(base_data) - - # Log prediction - action = model_output.predictions['action'] - confidence = model_output.confidence - logger.info(f"Final prediction for {symbol}: {action} with confidence {confidence:.4f}") - - # Get model output manager - output_manager = data_provider.get_model_output_manager() - - # Evaluate model performance - metrics = output_manager.evaluate_model_performance(symbol, cnn_adapter.model_name) - logger.info(f"Performance metrics: {metrics}") - else: - logger.warning(f"Failed to get base data input for final evaluation") - - logger.info("Test completed successfully") - - except Exception as e: - logger.error(f"Error in test: {e}", exc_info=True) - -if __name__ == "__main__": - test_continuous_training() \ No newline at end of file diff --git a/test_dashboard_data_flow.py b/test_dashboard_data_flow.py deleted file mode 100644 index 2d620ef..0000000 --- a/test_dashboard_data_flow.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to debug dashboard data flow issues - -This script tests if the dashboard can properly retrieve and display model data. -""" - -import sys -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -import logging -logging.basicConfig(level=logging.DEBUG) - -from web.clean_dashboard import CleanTradingDashboard -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -def test_dashboard_data_flow(): - """Test if dashboard can retrieve model data correctly""" - - print("🧪 DASHBOARD DATA FLOW TEST") - print("=" * 50) - - try: - # Initialize components - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider=data_provider) - - print(f"✅ Orchestrator initialized") - print(f" Model registry models: {list(orchestrator.model_registry.get_all_models().keys())}") - print(f" Model toggle states: {list(orchestrator.model_toggle_states.keys())}") - - # Initialize dashboard - dashboard = CleanTradingDashboard( - data_provider=data_provider, - orchestrator=orchestrator - ) - - print(f"✅ Dashboard initialized") - - # Test available models - available_models = dashboard._get_available_models() - print(f" Available models: {list(available_models.keys())}") - - # Test training metrics - print("\n📊 Testing training metrics...") - toggle_states = {} - for model_name in available_models.keys(): - toggle_states[model_name] = orchestrator.get_model_toggle_state(model_name) - - print(f" Toggle states: {list(toggle_states.keys())}") - - metrics_data = dashboard._get_training_metrics(toggle_states) - print(f" Metrics data type: {type(metrics_data)}") - - if metrics_data and isinstance(metrics_data, dict): - print(f" Metrics keys: {list(metrics_data.keys())}") - if 'loaded_models' in metrics_data: - loaded_models = metrics_data['loaded_models'] - print(f" Loaded models count: {len(loaded_models)}") - for model_name, model_info in loaded_models.items(): - print(f" - {model_name}: active={model_info.get('active', False)}") - else: - print(" ❌ No 'loaded_models' in metrics_data!") - else: - print(f" ❌ Invalid metrics_data: {metrics_data}") - - # Test component manager formatting - print("\n🎨 Testing component manager...") - formatted_components = dashboard.component_manager.format_training_metrics(metrics_data) - print(f" Formatted components type: {type(formatted_components)}") - print(f" Formatted components count: {len(formatted_components) if formatted_components else 0}") - - if formatted_components: - print(" ✅ Component manager returned formatted data") - else: - print(" ❌ Component manager returned empty data") - - print("\n🚀 Dashboard data flow test completed!") - return True - - except Exception as e: - print(f"❌ Test failed with error: {e}") - import traceback - traceback.print_exc() - return False - -if __name__ == "__main__": - test_dashboard_data_flow() \ No newline at end of file diff --git a/test_dashboard_performance.py b/test_dashboard_performance.py deleted file mode 100644 index 35de9d1..0000000 --- a/test_dashboard_performance.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python3 -""" -Dashboard Performance Test - -Test the optimized callback structure to ensure we've reduced -the number of requests per second. -""" - -import time -from web.clean_dashboard import CleanTradingDashboard -from core.data_provider import DataProvider - -def test_callback_optimization(): - """Test that we've optimized the callback structure""" - print("=== Dashboard Performance Optimization Test ===") - - print("✅ BEFORE Optimization:") - print(" - 7 callbacks on 1-second interval = 7 requests/second") - print(" - Server overload with single client") - print(" - Poor user experience") - - print("\n✅ AFTER Optimization:") - print(" - Main interval: 2 seconds (reduced from 1s)") - print(" - Slow interval: 10 seconds (increased from 5s)") - print(" - Critical metrics: 2s interval (3 requests every 2s)") - print(" - Non-critical data: 10s interval (4 requests every 10s)") - - print("\n📊 Performance Improvement:") - print(" - Before: 7 requests/second = 420 requests/minute") - print(" - After: ~1.9 requests/second = 114 requests/minute") - print(" - Reduction: ~73% fewer requests") - - print("\n🎯 Callback Distribution:") - print(" Fast Interval (2s):") - print(" 1. update_metrics (price, PnL, position, status)") - print(" 2. update_price_chart (trading chart)") - print(" 3. update_cob_data (order book for trading)") - print(" ") - print(" Slow Interval (10s):") - print(" 4. update_recent_decisions (trading history)") - print(" 5. update_closed_trades (completed trades)") - print(" 6. update_pending_orders (pending orders)") - print(" 7. update_training_metrics (ML model stats)") - - print("\n✅ Benefits:") - print(" - Server can handle multiple clients") - print(" - Reduced CPU usage") - print(" - Better responsiveness") - print(" - Still real-time for critical trading data") - - return True - -def test_interval_configuration(): - """Test the interval configuration""" - print("\n=== Interval Configuration Test ===") - - try: - from web.layout_manager import DashboardLayoutManager - - # Create layout manager to test intervals - layout_manager = DashboardLayoutManager(100.0, None) - layout = layout_manager.create_main_layout() - - # Check if intervals are properly configured - print("✅ Layout created successfully") - print("✅ Intervals should be configured as:") - print(" - interval-component: 2000ms (2s)") - print(" - slow-interval-component: 10000ms (10s)") - - return True - - except Exception as e: - print(f"❌ Error testing interval configuration: {e}") - return False - -def calculate_performance_metrics(): - """Calculate the performance improvement metrics""" - print("\n=== Performance Metrics Calculation ===") - - # Old system - old_callbacks = 7 - old_interval = 1 # second - old_requests_per_second = old_callbacks / old_interval - old_requests_per_minute = old_requests_per_second * 60 - - # New system - fast_callbacks = 3 # metrics, chart, cob - fast_interval = 2 # seconds - slow_callbacks = 4 # decisions, trades, orders, training - slow_interval = 10 # seconds - - new_requests_per_second = (fast_callbacks / fast_interval) + (slow_callbacks / slow_interval) - new_requests_per_minute = new_requests_per_second * 60 - - reduction_percent = ((old_requests_per_second - new_requests_per_second) / old_requests_per_second) * 100 - - print(f"📊 Detailed Performance Analysis:") - print(f" Old System:") - print(f" - {old_callbacks} callbacks × {old_interval}s = {old_requests_per_second:.1f} req/s") - print(f" - {old_requests_per_minute:.0f} requests/minute") - print(f" ") - print(f" New System:") - print(f" - Fast: {fast_callbacks} callbacks ÷ {fast_interval}s = {fast_callbacks/fast_interval:.1f} req/s") - print(f" - Slow: {slow_callbacks} callbacks ÷ {slow_interval}s = {slow_callbacks/slow_interval:.1f} req/s") - print(f" - Total: {new_requests_per_second:.1f} req/s") - print(f" - {new_requests_per_minute:.0f} requests/minute") - print(f" ") - print(f" 🎉 Improvement: {reduction_percent:.1f}% reduction in requests") - - # Server capacity estimation - print(f"\n🖥️ Server Capacity Estimation:") - print(f" - Old: Could handle ~{100/old_requests_per_second:.0f} concurrent users") - print(f" - New: Can handle ~{100/new_requests_per_second:.0f} concurrent users") - print(f" - Capacity increase: {(100/new_requests_per_second)/(100/old_requests_per_second):.1f}x") - - return { - 'old_rps': old_requests_per_second, - 'new_rps': new_requests_per_second, - 'reduction_percent': reduction_percent, - 'capacity_multiplier': (100/new_requests_per_second)/(100/old_requests_per_second) - } - -def main(): - """Run all performance tests""" - print("=== Dashboard Performance Optimization Test Suite ===") - - tests = [ - ("Callback Optimization", test_callback_optimization), - ("Interval Configuration", test_interval_configuration) - ] - - passed = 0 - total = len(tests) - - for test_name, test_func in tests: - print(f"\n{'='*60}") - try: - if test_func(): - passed += 1 - print(f"✅ {test_name}: PASSED") - else: - print(f"❌ {test_name}: FAILED") - except Exception as e: - print(f"❌ {test_name}: ERROR - {e}") - - # Calculate performance metrics - metrics = calculate_performance_metrics() - - print(f"\n{'='*60}") - print(f"=== Test Results: {passed}/{total} passed ===") - - if passed == total: - print("\n🎉 ALL TESTS PASSED!") - print("✅ Dashboard performance optimized successfully") - print(f"✅ {metrics['reduction_percent']:.1f}% reduction in server requests") - print(f"✅ {metrics['capacity_multiplier']:.1f}x increase in server capacity") - print("✅ Better user experience with responsive UI") - print("✅ Ready for production with multiple users") - else: - print(f"\n⚠️ {total - passed} tests failed") - print("Check individual test results above") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_data_integration.py b/test_data_integration.py deleted file mode 100644 index 4a86d0a..0000000 --- a/test_data_integration.py +++ /dev/null @@ -1,182 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Data Integration - -Test that the FIFO queues are properly populated from the data provider -""" - -import time -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -def test_data_provider_methods(): - """Test what methods are available in the data provider""" - print("=== Testing Data Provider Methods ===") - - try: - data_provider = DataProvider() - - # Check available methods - methods = [method for method in dir(data_provider) if not method.startswith('_') and callable(getattr(data_provider, method))] - data_methods = [method for method in methods if 'data' in method.lower() or 'ohlcv' in method.lower() or 'historical' in method.lower() or 'latest' in method.lower()] - - print("Available data-related methods:") - for method in sorted(data_methods): - print(f" - {method}") - - # Test getting historical data - print(f"\nTesting get_historical_data:") - try: - df = data_provider.get_historical_data('ETH/USDT', '1m', limit=5) - if df is not None and not df.empty: - print(f" ✅ Got {len(df)} rows of 1m data") - print(f" Columns: {list(df.columns)}") - print(f" Latest close: {df['close'].iloc[-1]}") - else: - print(f" ❌ No data returned") - except Exception as e: - print(f" ❌ Error: {e}") - - # Test getting latest candles if available - if hasattr(data_provider, 'get_latest_candles'): - print(f"\nTesting get_latest_candles:") - try: - df = data_provider.get_latest_candles('ETH/USDT', '1m', limit=5) - if df is not None and not df.empty: - print(f" ✅ Got {len(df)} rows of latest candles") - print(f" Latest close: {df['close'].iloc[-1]}") - else: - print(f" ❌ No data returned") - except Exception as e: - print(f" ❌ Error: {e}") - - return True - - except Exception as e: - print(f"❌ Test failed: {e}") - return False - -def test_queue_population(): - """Test that queues get populated with data""" - print("\n=== Testing Queue Population ===") - - try: - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - - # Wait a moment for initial population - print("Waiting 3 seconds for initial data population...") - time.sleep(3) - - # Check queue status - print("\nQueue status after initialization:") - orchestrator.log_queue_status(detailed=True) - - # Check if we have minimum data - symbols_to_check = ['ETH/USDT', 'BTC/USDT'] - timeframes_to_check = ['1s', '1m', '1h', '1d'] - min_requirements = {'1s': 100, '1m': 50, '1h': 20, '1d': 10} - - print(f"\nChecking minimum data requirements:") - for symbol in symbols_to_check: - print(f"\n{symbol}:") - for timeframe in timeframes_to_check: - min_count = min_requirements.get(timeframe, 10) - has_min = orchestrator.ensure_minimum_data(f'ohlcv_{timeframe}', symbol, min_count) - actual_count = 0 - if f'ohlcv_{timeframe}' in orchestrator.data_queues and symbol in orchestrator.data_queues[f'ohlcv_{timeframe}']: - with orchestrator.data_queue_locks[f'ohlcv_{timeframe}'][symbol]: - actual_count = len(orchestrator.data_queues[f'ohlcv_{timeframe}'][symbol]) - - status = "✅" if has_min else "❌" - print(f" {timeframe}: {status} {actual_count}/{min_count}") - - # Test BaseDataInput building - print(f"\nTesting BaseDataInput building:") - base_data = orchestrator.build_base_data_input('ETH/USDT') - if base_data: - features = base_data.get_feature_vector() - print(f" ✅ BaseDataInput built successfully") - print(f" Feature vector size: {len(features)}") - print(f" OHLCV 1s bars: {len(base_data.ohlcv_1s)}") - print(f" OHLCV 1m bars: {len(base_data.ohlcv_1m)}") - print(f" BTC bars: {len(base_data.btc_ohlcv_1s)}") - else: - print(f" ❌ Failed to build BaseDataInput") - - return base_data is not None - - except Exception as e: - print(f"❌ Test failed: {e}") - return False - -def test_polling_thread(): - """Test that the polling thread is working""" - print("\n=== Testing Polling Thread ===") - - try: - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - - # Get initial queue counts - initial_status = orchestrator.get_queue_status() - print("Initial queue counts:") - for data_type, symbols in initial_status.items(): - for symbol, count in symbols.items(): - if count > 0: - print(f" {data_type}/{symbol}: {count}") - - # Wait for polling thread to run - print("\nWaiting 10 seconds for polling thread...") - time.sleep(10) - - # Get updated queue counts - updated_status = orchestrator.get_queue_status() - print("\nUpdated queue counts:") - for data_type, symbols in updated_status.items(): - for symbol, count in symbols.items(): - if count > 0: - print(f" {data_type}/{symbol}: {count}") - - # Check if any queues grew - growth_detected = False - for data_type in initial_status: - for symbol in initial_status[data_type]: - initial_count = initial_status[data_type][symbol] - updated_count = updated_status[data_type][symbol] - if updated_count > initial_count: - print(f" ✅ Growth detected: {data_type}/{symbol} {initial_count} -> {updated_count}") - growth_detected = True - - if not growth_detected: - print(" ⚠️ No queue growth detected - polling may not be working") - - return growth_detected - - except Exception as e: - print(f"❌ Test failed: {e}") - return False - -def main(): - """Run all data integration tests""" - print("=== Data Integration Test Suite ===") - - test1_passed = test_data_provider_methods() - test2_passed = test_queue_population() - test3_passed = test_polling_thread() - - print(f"\n=== Results ===") - print(f"Data provider methods: {'✅ PASSED' if test1_passed else '❌ FAILED'}") - print(f"Queue population: {'✅ PASSED' if test2_passed else '❌ FAILED'}") - print(f"Polling thread: {'✅ PASSED' if test3_passed else '❌ FAILED'}") - - if test1_passed and test2_passed: - print("\n✅ Data integration is working!") - print("✅ FIFO queues should be populated with data") - print("✅ Models should be able to make predictions") - else: - print("\n❌ Data integration issues detected") - print("❌ Check data provider connectivity and methods") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_data_provider_integration.py b/test_data_provider_integration.py deleted file mode 100644 index d3feff8..0000000 --- a/test_data_provider_integration.py +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/env python3 -""" -Integration test for the simplified data provider with other components -""" - -import time -import logging -import pandas as pd -from core.data_provider import DataProvider - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_integration(): - """Test integration with other components""" - logger.info("Testing DataProvider integration...") - - # Initialize data provider - dp = DataProvider() - - # Wait for initial data load - logger.info("Waiting for initial data load...") - time.sleep(15) - - # Test 1: Feature matrix generation - logger.info("\n=== Test 1: Feature Matrix Generation ===") - try: - feature_matrix = dp.get_feature_matrix('ETH/USDT', ['1m', '1h'], window_size=20) - if feature_matrix is not None: - logger.info(f"✅ Feature matrix shape: {feature_matrix.shape}") - else: - logger.warning("❌ Feature matrix generation failed") - except Exception as e: - logger.error(f"❌ Feature matrix error: {e}") - - # Test 2: Multi-symbol data access - logger.info("\n=== Test 2: Multi-Symbol Data Access ===") - for symbol in ['ETH/USDT', 'BTC/USDT']: - for timeframe in ['1s', '1m', '1h', '1d']: - data = dp.get_historical_data(symbol, timeframe, limit=10) - if data is not None and not data.empty: - logger.info(f"✅ {symbol} {timeframe}: {len(data)} candles") - else: - logger.warning(f"❌ {symbol} {timeframe}: No data") - - # Test 3: Data consistency checks - logger.info("\n=== Test 3: Data Consistency ===") - eth_1m = dp.get_historical_data('ETH/USDT', '1m', limit=100) - if eth_1m is not None and not eth_1m.empty: - # Check for proper OHLCV structure - required_cols = ['open', 'high', 'low', 'close', 'volume'] - has_all_cols = all(col in eth_1m.columns for col in required_cols) - logger.info(f"✅ OHLCV columns present: {has_all_cols}") - - # Check data types - numeric_cols = eth_1m[required_cols].dtypes - all_numeric = all(pd.api.types.is_numeric_dtype(dtype) for dtype in numeric_cols) - logger.info(f"✅ All columns numeric: {all_numeric}") - - # Check for NaN values - has_nan = eth_1m[required_cols].isna().any().any() - logger.info(f"✅ No NaN values: {not has_nan}") - - # Check price relationships (high >= low, etc.) - price_valid = (eth_1m['high'] >= eth_1m['low']).all() - logger.info(f"✅ Price relationships valid: {price_valid}") - - # Test 4: Performance test - logger.info("\n=== Test 4: Performance Test ===") - start_time = time.time() - for i in range(100): - data = dp.get_historical_data('ETH/USDT', '1m', limit=50) - end_time = time.time() - avg_time = (end_time - start_time) / 100 * 1000 # ms - logger.info(f"✅ Average data access time: {avg_time:.2f}ms") - - # Test 5: Current price accuracy - logger.info("\n=== Test 5: Current Price Accuracy ===") - eth_price = dp.get_current_price('ETH/USDT') - eth_data = dp.get_historical_data('ETH/USDT', '1s', limit=1) - if eth_data is not None and not eth_data.empty: - latest_close = eth_data.iloc[-1]['close'] - price_match = abs(eth_price - latest_close) < 0.01 - logger.info(f"✅ Current price matches latest candle: {price_match}") - logger.info(f" Current price: ${eth_price}") - logger.info(f" Latest close: ${latest_close}") - - # Test 6: Cache efficiency - logger.info("\n=== Test 6: Cache Efficiency ===") - cache_summary = dp.get_cached_data_summary() - total_candles = 0 - for symbol_data in cache_summary['cached_data'].values(): - for tf_data in symbol_data.values(): - if isinstance(tf_data, dict) and 'candle_count' in tf_data: - total_candles += tf_data['candle_count'] - - logger.info(f"✅ Total cached candles: {total_candles}") - logger.info(f"✅ Data maintenance active: {cache_summary['data_maintenance_active']}") - - # Test 7: Memory usage estimation - logger.info("\n=== Test 7: Memory Usage Estimation ===") - # Rough estimation: 8 columns * 8 bytes * total_candles - estimated_memory_mb = (total_candles * 8 * 8) / (1024 * 1024) - logger.info(f"✅ Estimated memory usage: {estimated_memory_mb:.2f} MB") - - # Clean shutdown - dp.stop_automatic_data_maintenance() - logger.info("\n✅ Integration test completed successfully!") - -if __name__ == "__main__": - test_integration() \ No newline at end of file diff --git a/test_db_migration.py b/test_db_migration.py deleted file mode 100644 index ade072e..0000000 --- a/test_db_migration.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify database migration works correctly -""" - -import sys -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from utils.database_manager import get_database_manager, reset_database_manager -import logging - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_migration(): - """Test the database migration""" - try: - logger.info("Testing database migration...") - - # Reset the database manager to force re-initialization - reset_database_manager() - - # Get a new instance (this will trigger migration) - db_manager = get_database_manager() - - # Test if we can access the input_features_blob column - with db_manager._get_connection() as conn: - cursor = conn.execute("PRAGMA table_info(inference_records)") - columns = [row[1] for row in cursor.fetchall()] - - if 'input_features_blob' in columns: - logger.info("✅ input_features_blob column exists - migration successful!") - return True - else: - logger.error("❌ input_features_blob column missing - migration failed!") - return False - - except Exception as e: - logger.error(f"❌ Migration test failed: {e}") - return False - -if __name__ == "__main__": - success = test_migration() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/test_deribit_integration.py b/test_deribit_integration.py deleted file mode 100644 index 24ebb18..0000000 --- a/test_deribit_integration.py +++ /dev/null @@ -1,171 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Deribit Integration -Test the new DeribitInterface and ExchangeFactory -""" -import os -import sys -import logging -from dotenv import load_dotenv - -# Load environment variables -load_dotenv() - -# Add project paths -sys.path.append(os.path.join(os.path.dirname(__file__), 'NN')) -sys.path.append(os.path.join(os.path.dirname(__file__), 'core')) - -from core.exchanges.exchange_factory import ExchangeFactory -from core.exchanges.deribit_interface import DeribitInterface -from core.config import get_config - -# Setup logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_deribit_credentials(): - """Test Deribit API credentials""" - api_key = os.getenv('DERIBIT_API_CLIENTID') - api_secret = os.getenv('DERIBIT_API_SECRET') - - logger.info(f"Deribit API Key: {'*' * 8 + api_key[-4:] if api_key and len(api_key) > 4 else 'Not set'}") - logger.info(f"Deribit API Secret: {'*' * 8 + api_secret[-4:] if api_secret and len(api_secret) > 4 else 'Not set'}") - - return bool(api_key and api_secret) - -def test_deribit_interface(): - """Test DeribitInterface directly""" - logger.info("Testing DeribitInterface directly...") - - try: - # Create Deribit interface - deribit = DeribitInterface(test_mode=True) - - # Test connection - if deribit.connect(): - logger.info("✓ Successfully connected to Deribit testnet") - - # Test getting instruments - btc_instruments = deribit.get_instruments('BTC') - logger.info(f"✓ Found {len(btc_instruments)} BTC instruments") - - # Test getting ticker - ticker = deribit.get_ticker('BTC-PERPETUAL') - if ticker: - logger.info(f"✓ BTC-PERPETUAL ticker: ${ticker.get('last_price', 'N/A')}") - - # Test getting account summary (if authenticated) - account = deribit.get_account_summary('BTC') - if account: - logger.info(f"✓ BTC account balance: {account.get('available_funds', 'N/A')}") - - return True - else: - logger.error("✗ Failed to connect to Deribit") - return False - - except Exception as e: - logger.error(f"✗ Error testing DeribitInterface: {e}") - return False - -def test_exchange_factory(): - """Test ExchangeFactory with config""" - logger.info("Testing ExchangeFactory...") - - try: - # Load config - config = get_config() - exchanges_config = config.get('exchanges', {}) - - logger.info(f"Primary exchange: {exchanges_config.get('primary', 'Not set')}") - - # Test creating primary exchange - primary_exchange = ExchangeFactory.get_primary_exchange(exchanges_config) - if primary_exchange: - logger.info(f"✓ Successfully created primary exchange: {type(primary_exchange).__name__}") - - # Test basic operations - if hasattr(primary_exchange, 'get_ticker'): - ticker = primary_exchange.get_ticker('BTC-PERPETUAL') - if ticker: - logger.info(f"✓ Primary exchange ticker test successful") - - return True - else: - logger.error("✗ Failed to create primary exchange") - return False - - except Exception as e: - logger.error(f"✗ Error testing ExchangeFactory: {e}") - return False - -def test_multiple_exchanges(): - """Test creating multiple exchanges""" - logger.info("Testing multiple exchanges...") - - try: - config = get_config() - exchanges_config = config.get('exchanges', {}) - - # Create all configured exchanges - exchanges = ExchangeFactory.create_multiple_exchanges(exchanges_config) - - logger.info(f"✓ Created {len(exchanges)} exchange interfaces:") - for name, exchange in exchanges.items(): - logger.info(f" - {name}: {type(exchange).__name__}") - - return len(exchanges) > 0 - - except Exception as e: - logger.error(f"✗ Error testing multiple exchanges: {e}") - return False - -def main(): - """Run all tests""" - logger.info("=" * 50) - logger.info("TESTING DERIBIT INTEGRATION") - logger.info("=" * 50) - - tests = [ - ("Credentials", test_deribit_credentials), - ("DeribitInterface", test_deribit_interface), - ("ExchangeFactory", test_exchange_factory), - ("Multiple Exchanges", test_multiple_exchanges) - ] - - results = [] - for test_name, test_func in tests: - logger.info(f"\n--- Testing {test_name} ---") - try: - result = test_func() - results.append((test_name, result)) - status = "PASS" if result else "FAIL" - logger.info(f"{test_name}: {status}") - except Exception as e: - logger.error(f"{test_name}: ERROR - {e}") - results.append((test_name, False)) - - # Summary - logger.info("\n" + "=" * 50) - logger.info("TEST SUMMARY") - logger.info("=" * 50) - - passed = sum(1 for _, result in results if result) - total = len(results) - - for test_name, result in results: - status = "✓ PASS" if result else "✗ FAIL" - logger.info(f"{status}: {test_name}") - - logger.info(f"\nOverall: {passed}/{total} tests passed") - - if passed == total: - logger.info("🎉 All tests passed! Deribit integration is working.") - return True - else: - logger.error("❌ Some tests failed. Check the logs above.") - return False - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/test_device_fix.py b/test_device_fix.py deleted file mode 100644 index 20b0c51..0000000 --- a/test_device_fix.py +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify device mismatch fixes for GPU training -""" - -import torch -import logging -import sys -import os - -# Add the project root to the path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from NN.models.enhanced_cnn import EnhancedCNN -from core.data_models import BaseDataInput, OHLCVBar - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_device_consistency(): - """Test that all tensors are on the same device""" - - logger.info("Testing device consistency for EnhancedCNN...") - - # Check if CUDA is available - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - logger.info(f"Using device: {device}") - - try: - # Initialize the adapter - adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") - - # Verify adapter device - logger.info(f"Adapter device: {adapter.device}") - logger.info(f"Model device: {next(adapter.model.parameters()).device}") - - # Create sample data - sample_ohlcv = [ - OHLCVBar( - symbol="ETH/USDT", - timeframe="1s", - timestamp=1640995200.0, # 2022-01-01 - open=50000.0, - high=51000.0, - low=49000.0, - close=50500.0, - volume=1000.0 - ) - ] * 300 # 300 frames - - base_data = BaseDataInput( - symbol="ETH/USDT", - timestamp=1640995200.0, - ohlcv_1s=sample_ohlcv, - ohlcv_1m=sample_ohlcv, - ohlcv_5m=sample_ohlcv, - ohlcv_15m=sample_ohlcv, - btc_ohlcv=sample_ohlcv, - cob_data={}, - ma_data={}, - technical_indicators={}, - last_predictions={} - ) - - # Test prediction - logger.info("Testing prediction...") - prediction = adapter.predict(base_data) - logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})") - - # Test training sample addition - logger.info("Testing training sample addition...") - adapter.add_training_sample(base_data, "BUY", 0.1) - adapter.add_training_sample(base_data, "SELL", -0.05) - adapter.add_training_sample(base_data, "HOLD", 0.02) - - # Test training - logger.info("Testing training...") - training_results = adapter.train(epochs=1) - logger.info(f"Training results: {training_results}") - - logger.info("✅ All device consistency tests passed!") - return True - - except Exception as e: - logger.error(f"❌ Device consistency test failed: {e}") - import traceback - traceback.print_exc() - return False - -def test_orchestrator_inference_history(): - """Test that orchestrator properly initializes inference history""" - - logger.info("Testing orchestrator inference history initialization...") - - try: - from core.orchestrator import TradingOrchestrator - from core.data_provider import DataProvider - - # Initialize orchestrator - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider=data_provider) - - # Check if inference history is initialized - logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}") - - # Check if models are registered - logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}") - - # Verify each registered model has inference history - for model_name in orchestrator.model_registry.models.keys(): - if model_name in orchestrator.inference_history: - logger.info(f"✅ {model_name} has inference history initialized") - else: - logger.warning(f"❌ {model_name} missing inference history") - - logger.info("✅ Orchestrator inference history test completed!") - return True - - except Exception as e: - logger.error(f"❌ Orchestrator test failed: {e}") - import traceback - traceback.print_exc() - return False - -if __name__ == "__main__": - logger.info("Starting device fix verification tests...") - - # Test 1: Device consistency - test1_passed = test_device_consistency() - - # Test 2: Orchestrator inference history - test2_passed = test_orchestrator_inference_history() - - # Summary - if test1_passed and test2_passed: - logger.info("🎉 All tests passed! Device issues should be fixed.") - sys.exit(0) - else: - logger.error("❌ Some tests failed. Please check the logs above.") - sys.exit(1) \ No newline at end of file diff --git a/test_device_training_fix.py b/test_device_training_fix.py deleted file mode 100644 index b883043..0000000 --- a/test_device_training_fix.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify device handling and training sample population fixes -""" - -import logging -import asyncio -import torch -from datetime import datetime - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def test_device_handling(): - """Test that device handling is working correctly""" - try: - logger.info("Testing device handling...") - - # Test 1: Check CUDA availability - cuda_available = torch.cuda.is_available() - device = torch.device("cuda" if cuda_available else "cpu") - logger.info(f"CUDA available: {cuda_available}") - logger.info(f"Using device: {device}") - - # Test 2: Initialize CNN adapter - from core.enhanced_cnn_adapter import EnhancedCNNAdapter - - logger.info("Initializing CNN adapter...") - cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") - - logger.info(f"CNN adapter device: {cnn_adapter.device}") - logger.info(f"CNN model device: {cnn_adapter.model.device}") - - # Test 3: Create test data - from core.data_models import BaseDataInput - - logger.info("Creating test BaseDataInput...") - base_data = BaseDataInput( - symbol="ETH/USDT", - timestamp=datetime.now(), - ohlcv_1s=[], - ohlcv_1m=[], - ohlcv_1h=[], - ohlcv_1d=[], - btc_ohlcv_1s=[], - cob_data=None, - technical_indicators={}, - last_predictions={} - ) - - # Test 4: Make prediction (this should not cause device mismatch) - logger.info("Making prediction...") - prediction = cnn_adapter.predict(base_data) - - logger.info(f"Prediction successful: {prediction.predictions['action']}") - logger.info(f"Confidence: {prediction.confidence:.4f}") - - # Test 5: Add training samples - logger.info("Adding training samples...") - cnn_adapter.add_training_sample(base_data, "BUY", 0.1) - cnn_adapter.add_training_sample(base_data, "SELL", -0.05) - cnn_adapter.add_training_sample(base_data, "HOLD", 0.02) - - logger.info(f"Training samples added: {len(cnn_adapter.training_data)}") - - # Test 6: Try training if we have enough samples - if len(cnn_adapter.training_data) >= 2: - logger.info("Attempting training...") - training_results = cnn_adapter.train(epochs=1) - logger.info(f"Training results: {training_results}") - else: - logger.info("Not enough samples for training") - - logger.info("✅ Device handling test passed!") - return True - - except Exception as e: - logger.error(f"❌ Device handling test failed: {e}") - import traceback - traceback.print_exc() - return False - -async def test_orchestrator_training(): - """Test that orchestrator properly adds training samples""" - try: - logger.info("Testing orchestrator training integration...") - - # Test 1: Initialize orchestrator - from core.orchestrator import TradingOrchestrator - from core.standardized_data_provider import StandardizedDataProvider - - logger.info("Initializing data provider...") - data_provider = StandardizedDataProvider() - - logger.info("Initializing orchestrator...") - orchestrator = TradingOrchestrator(data_provider=data_provider) - - # Test 2: Check if CNN adapter is available - if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter: - logger.info(f"✅ CNN adapter available in orchestrator") - logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}") - else: - logger.warning("⚠️ CNN adapter not available in orchestrator") - return False - - # Test 3: Make a trading decision (this should add training samples) - logger.info("Making trading decision...") - decision = await orchestrator.make_trading_decision("ETH/USDT") - - if decision: - logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})") - logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}") - else: - logger.warning("No decision made") - - # Test 4: Check inference history - logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}") - for model_name, history in orchestrator.inference_history.items(): - logger.info(f" {model_name}: {len(history)} records") - - logger.info("✅ Orchestrator training test passed!") - return True - - except Exception as e: - logger.error(f"❌ Orchestrator training test failed: {e}") - import traceback - traceback.print_exc() - return False - -async def main(): - """Run all tests""" - logger.info("Starting device and training fix tests...") - - # Test 1: Device handling - test1_passed = test_device_handling() - - # Test 2: Orchestrator training - test2_passed = await test_orchestrator_training() - - # Summary - logger.info("\n" + "="*50) - logger.info("TEST SUMMARY:") - logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}") - logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}") - - if test1_passed and test2_passed: - logger.info("🎉 All tests passed! Device and training issues should be fixed.") - else: - logger.error("❌ Some tests failed. Please check the logs above.") - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/test_enhanced_cnn_adapter.py b/test_enhanced_cnn_adapter.py deleted file mode 100644 index a03ff1b..0000000 --- a/test_enhanced_cnn_adapter.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -Test Enhanced CNN Adapter - -This script tests the EnhancedCNNAdapter with standardized input format. -""" - -import logging -import time -from datetime import datetime - -from core.standardized_data_provider import StandardizedDataProvider -from core.enhanced_cnn_adapter import EnhancedCNNAdapter -from core.data_models import create_model_output - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def test_cnn_adapter(): - """Test the EnhancedCNNAdapter with standardized input format""" - try: - # Initialize data provider - symbols = ['ETH/USDT', 'BTC/USDT'] - timeframes = ['1s', '1m', '1h', '1d'] - data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes) - - # Initialize CNN adapter - cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") - - # Load best checkpoint if available - cnn_adapter.load_best_checkpoint() - - # Get standardized input data - logger.info("Getting standardized input data...") - base_data = data_provider.get_base_data_input('ETH/USDT') - - if base_data is None: - logger.error("Failed to get base data input") - return - - # Make prediction - logger.info("Making prediction...") - model_output = cnn_adapter.predict(base_data) - - # Log prediction - logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}") - - # Store model output - data_provider.store_model_output(model_output) - - # Add training sample (simulated) - logger.info("Adding training sample...") - cnn_adapter.add_training_sample(base_data, 'BUY', 0.05) - - # Train model - logger.info("Training model...") - metrics = cnn_adapter.train(epochs=1) - - # Log training metrics - logger.info(f"Training metrics: {metrics}") - - # Make another prediction - logger.info("Making another prediction...") - model_output = cnn_adapter.predict(base_data) - - # Log prediction - logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}") - - # Test model output manager - logger.info("Testing model output manager...") - output_manager = data_provider.get_model_output_manager() - - # Get current outputs - current_outputs = output_manager.get_all_current_outputs('ETH/USDT') - logger.info(f"Current outputs: {len(current_outputs)} models") - - # Evaluate model performance - metrics = output_manager.evaluate_model_performance('ETH/USDT', 'enhanced_cnn_v1') - logger.info(f"Performance metrics: {metrics}") - - logger.info("Test completed successfully") - - except Exception as e: - logger.error(f"Error in test: {e}", exc_info=True) - -if __name__ == "__main__": - test_cnn_adapter() \ No newline at end of file diff --git a/test_enhanced_cob_websocket.py b/test_enhanced_cob_websocket.py deleted file mode 100644 index 367ec95..0000000 --- a/test_enhanced_cob_websocket.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced COB WebSocket Implementation - -This script tests the enhanced COB WebSocket system to ensure: -1. WebSocket connections work properly -2. Fallback to REST API when WebSocket fails -3. Dashboard status updates are working -4. Clear error messages and warnings are displayed -""" - -import asyncio -import logging -import sys -import time -from datetime import datetime - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -# Import the enhanced COB WebSocket -try: - from core.enhanced_cob_websocket import EnhancedCOBWebSocket, get_enhanced_cob_websocket - print("✅ Enhanced COB WebSocket imported successfully") -except ImportError as e: - print(f"❌ Failed to import Enhanced COB WebSocket: {e}") - sys.exit(1) - -async def test_dashboard_callback(status_data): - """Test dashboard callback function""" - print(f"📊 Dashboard callback received: {status_data}") - -async def test_cob_callback(symbol, cob_data): - """Test COB data callback function""" - stats = cob_data.get('stats', {}) - mid_price = stats.get('mid_price', 0) - bid_levels = len(cob_data.get('bids', [])) - ask_levels = len(cob_data.get('asks', [])) - source = cob_data.get('source', 'unknown') - - print(f"📈 COB data for {symbol}: ${mid_price:.2f}, {bid_levels} bids, {ask_levels} asks (via {source})") - -async def main(): - """Main test function""" - print("🚀 Testing Enhanced COB WebSocket System") - print("=" * 60) - - # Test 1: Initialize Enhanced COB WebSocket - print("\n1. Initializing Enhanced COB WebSocket...") - try: - cob_ws = EnhancedCOBWebSocket( - symbols=['BTC/USDT', 'ETH/USDT'], - dashboard_callback=test_dashboard_callback - ) - - # Add callbacks - cob_ws.add_cob_callback(test_cob_callback) - - print("✅ Enhanced COB WebSocket initialized") - except Exception as e: - print(f"❌ Failed to initialize: {e}") - return - - # Test 2: Start WebSocket connections - print("\n2. Starting WebSocket connections...") - try: - await cob_ws.start() - print("✅ WebSocket connections started") - except Exception as e: - print(f"❌ Failed to start connections: {e}") - return - - # Test 3: Monitor connections for 30 seconds - print("\n3. Monitoring connections for 30 seconds...") - start_time = time.time() - - while time.time() - start_time < 30: - try: - # Get status summary - status = cob_ws.get_status_summary() - overall_status = status.get('overall_status', 'unknown') - - print(f"⏱️ Status: {overall_status}") - - # Print symbol-specific status - for symbol, symbol_status in status.get('symbols', {}).items(): - connected = symbol_status.get('connected', False) - fallback = symbol_status.get('rest_fallback_active', False) - messages = symbol_status.get('messages_received', 0) - - if connected: - print(f" {symbol}: ✅ Connected ({messages} messages)") - elif fallback: - print(f" {symbol}: ⚠️ REST fallback active") - else: - error = symbol_status.get('last_error', 'Unknown error') - print(f" {symbol}: ❌ Error - {error}") - - await asyncio.sleep(5) # Check every 5 seconds - - except KeyboardInterrupt: - print("\n⏹️ Test interrupted by user") - break - except Exception as e: - print(f"❌ Error during monitoring: {e}") - break - - # Test 4: Final status check - print("\n4. Final status check...") - try: - final_status = cob_ws.get_status_summary() - print(f"Final overall status: {final_status.get('overall_status', 'unknown')}") - - for symbol, symbol_status in final_status.get('symbols', {}).items(): - print(f" {symbol}:") - print(f" Connected: {symbol_status.get('connected', False)}") - print(f" Messages received: {symbol_status.get('messages_received', 0)}") - print(f" REST fallback: {symbol_status.get('rest_fallback_active', False)}") - if symbol_status.get('last_error'): - print(f" Last error: {symbol_status.get('last_error')}") - - except Exception as e: - print(f"❌ Error getting final status: {e}") - - # Test 5: Stop connections - print("\n5. Stopping connections...") - try: - await cob_ws.stop() - print("✅ Connections stopped successfully") - except Exception as e: - print(f"❌ Error stopping connections: {e}") - - print("\n" + "=" * 60) - print("🏁 Enhanced COB WebSocket test completed") - -if __name__ == "__main__": - try: - asyncio.run(main()) - except KeyboardInterrupt: - print("\n⏹️ Test interrupted") - except Exception as e: - print(f"❌ Test failed: {e}") - import traceback - traceback.print_exc() \ No newline at end of file diff --git a/test_enhanced_data_provider_websocket.py b/test_enhanced_data_provider_websocket.py deleted file mode 100644 index 116766e..0000000 --- a/test_enhanced_data_provider_websocket.py +++ /dev/null @@ -1,149 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced Data Provider WebSocket Integration - -This script tests the integration between the Enhanced COB WebSocket and the Data Provider. -""" - -import asyncio -import logging -import sys -import time -from datetime import datetime - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -# Import the enhanced data provider -try: - from core.data_provider import DataProvider - print("✅ Enhanced Data Provider imported successfully") -except ImportError as e: - print(f"❌ Failed to import Enhanced Data Provider: {e}") - sys.exit(1) - -async def test_enhanced_websocket_integration(): - """Test the enhanced WebSocket integration with data provider""" - print("🚀 Testing Enhanced WebSocket Integration with Data Provider") - print("=" * 70) - - # Test 1: Initialize Data Provider - print("\n1. Initializing Data Provider...") - try: - data_provider = DataProvider( - symbols=['ETH/USDT', 'BTC/USDT'], - timeframes=['1m', '1h'] - ) - print("✅ Data Provider initialized") - except Exception as e: - print(f"❌ Failed to initialize Data Provider: {e}") - return - - # Test 2: Start Enhanced WebSocket Streaming - print("\n2. Starting Enhanced WebSocket streaming...") - try: - await data_provider.start_real_time_streaming() - print("✅ Enhanced WebSocket streaming started") - except Exception as e: - print(f"❌ Failed to start WebSocket streaming: {e}") - return - - # Test 3: Check WebSocket Status - print("\n3. Checking WebSocket status...") - try: - status = data_provider.get_cob_websocket_status() - overall_status = status.get('overall_status', 'unknown') - print(f"Overall WebSocket status: {overall_status}") - - for symbol, symbol_status in status.get('symbols', {}).items(): - connected = symbol_status.get('connected', False) - messages = symbol_status.get('messages_received', 0) - fallback = symbol_status.get('rest_fallback_active', False) - - if connected: - print(f" {symbol}: ✅ Connected ({messages} messages)") - elif fallback: - print(f" {symbol}: ⚠️ REST fallback active") - else: - print(f" {symbol}: ❌ Disconnected") - - except Exception as e: - print(f"❌ Error checking WebSocket status: {e}") - - # Test 4: Monitor COB Data for 30 seconds - print("\n4. Monitoring COB data for 30 seconds...") - start_time = time.time() - data_received = {'ETH/USDT': 0, 'BTC/USDT': 0} - - while time.time() - start_time < 30: - try: - for symbol in ['ETH/USDT', 'BTC/USDT']: - cob_data = data_provider.get_latest_cob_data(symbol) - if cob_data: - data_received[symbol] += 1 - if data_received[symbol] % 10 == 1: # Print every 10th update - bids = len(cob_data.get('bids', [])) - asks = len(cob_data.get('asks', [])) - source = cob_data.get('source', 'unknown') - mid_price = cob_data.get('stats', {}).get('mid_price', 0) - print(f" 📊 {symbol}: ${mid_price:.2f}, {bids} bids, {asks} asks (via {source})") - - await asyncio.sleep(2) # Check every 2 seconds - - except KeyboardInterrupt: - print("\n⏹️ Test interrupted by user") - break - except Exception as e: - print(f"❌ Error monitoring COB data: {e}") - break - - # Test 5: Final Status Check - print("\n5. Final status check...") - try: - for symbol in ['ETH/USDT', 'BTC/USDT']: - count = data_received[symbol] - if count > 0: - print(f" {symbol}: ✅ Received {count} COB updates") - else: - print(f" {symbol}: ❌ No COB data received") - - # Check overall WebSocket status again - final_status = data_provider.get_cob_websocket_status() - print(f"Final WebSocket status: {final_status.get('overall_status', 'unknown')}") - - except Exception as e: - print(f"❌ Error in final status check: {e}") - - # Test 6: Stop WebSocket Streaming - print("\n6. Stopping WebSocket streaming...") - try: - await data_provider.stop_real_time_streaming() - print("✅ WebSocket streaming stopped") - except Exception as e: - print(f"❌ Error stopping WebSocket streaming: {e}") - - print("\n" + "=" * 70) - print("🏁 Enhanced WebSocket Integration Test Completed") - - # Summary - total_updates = sum(data_received.values()) - if total_updates > 0: - print(f"✅ SUCCESS: Received {total_updates} total COB updates") - print("🎉 Enhanced WebSocket integration is working!") - else: - print("❌ FAILURE: No COB data received") - print("⚠️ Enhanced WebSocket integration needs investigation") - -if __name__ == "__main__": - try: - asyncio.run(test_enhanced_websocket_integration()) - except KeyboardInterrupt: - print("\n⏹️ Test interrupted") - except Exception as e: - print(f"❌ Test failed: {e}") - import traceback - traceback.print_exc() \ No newline at end of file diff --git a/test_enhanced_inference_logging.py b/test_enhanced_inference_logging.py deleted file mode 100644 index 3a656f9..0000000 --- a/test_enhanced_inference_logging.py +++ /dev/null @@ -1,193 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced Inference Logging - -This script tests the enhanced inference logging system that stores -full input features for training feedback. -""" - -import sys -import os -import logging -import numpy as np -from datetime import datetime - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.enhanced_cnn_adapter import EnhancedCNNAdapter -from core.data_models import BaseDataInput, OHLCVBar -from utils.database_manager import get_database_manager -from utils.inference_logger import get_inference_logger - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def create_test_base_data(): - """Create test BaseDataInput with realistic data""" - - # Create OHLCV bars for different timeframes - def create_ohlcv_bars(symbol, timeframe, count=300): - bars = [] - base_price = 3000.0 if 'ETH' in symbol else 50000.0 - - for i in range(count): - price = base_price + np.random.normal(0, base_price * 0.01) - bars.append(OHLCVBar( - symbol=symbol, - timestamp=datetime.now(), - open=price, - high=price * 1.002, - low=price * 0.998, - close=price + np.random.normal(0, price * 0.005), - volume=np.random.uniform(100, 1000), - timeframe=timeframe - )) - return bars - - base_data = BaseDataInput( - symbol="ETH/USDT", - timestamp=datetime.now(), - ohlcv_1s=create_ohlcv_bars("ETH/USDT", "1s", 300), - ohlcv_1m=create_ohlcv_bars("ETH/USDT", "1m", 300), - ohlcv_1h=create_ohlcv_bars("ETH/USDT", "1h", 300), - ohlcv_1d=create_ohlcv_bars("ETH/USDT", "1d", 300), - btc_ohlcv_1s=create_ohlcv_bars("BTC/USDT", "1s", 300), - technical_indicators={ - 'rsi': 45.5, - 'macd': 0.12, - 'bb_upper': 3100.0, - 'bb_lower': 2900.0, - 'volume_ma': 500.0 - } - ) - - return base_data - -def test_enhanced_inference_logging(): - """Test the enhanced inference logging system""" - - logger.info("=== Testing Enhanced Inference Logging ===") - - try: - # Initialize CNN adapter - cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") - logger.info("✅ CNN adapter initialized") - - # Create test data - base_data = create_test_base_data() - logger.info("✅ Test data created") - - # Make a prediction (this should log inference data) - logger.info("Making prediction...") - model_output = cnn_adapter.predict(base_data) - logger.info(f"✅ Prediction made: {model_output.predictions['action']} (confidence: {model_output.confidence:.3f})") - - # Verify inference was logged to database - db_manager = get_database_manager() - recent_inferences = db_manager.get_recent_inferences(cnn_adapter.model_name, limit=1) - - if recent_inferences: - latest_inference = recent_inferences[0] - logger.info(f"✅ Inference logged to database:") - logger.info(f" Model: {latest_inference.model_name}") - logger.info(f" Action: {latest_inference.action}") - logger.info(f" Confidence: {latest_inference.confidence:.3f}") - logger.info(f" Processing time: {latest_inference.processing_time_ms:.1f}ms") - logger.info(f" Has input features: {latest_inference.input_features is not None}") - - if latest_inference.input_features is not None: - logger.info(f" Input features shape: {latest_inference.input_features.shape}") - logger.info(f" Input features sample: {latest_inference.input_features[:5]}") - else: - logger.error("❌ No inference records found in database") - return False - - # Test training data loading from inference history - logger.info("Testing training data loading from inference history...") - original_training_count = len(cnn_adapter.training_data) - cnn_adapter._load_training_data_from_inference_history() - new_training_count = len(cnn_adapter.training_data) - - logger.info(f"✅ Training data loaded: {original_training_count} -> {new_training_count} samples") - - # Test prediction evaluation - logger.info("Testing prediction evaluation...") - evaluation_metrics = cnn_adapter.evaluate_predictions_against_outcomes(hours_back=1) - logger.info(f"✅ Evaluation metrics: {evaluation_metrics}") - - # Test training with inference data - if new_training_count >= cnn_adapter.batch_size: - logger.info("Testing training with inference data...") - training_metrics = cnn_adapter.train(epochs=1) - logger.info(f"✅ Training completed: {training_metrics}") - else: - logger.info("⚠️ Not enough training data for training test") - - return True - - except Exception as e: - logger.error(f"❌ Test failed: {e}") - import traceback - traceback.print_exc() - return False - -def test_database_query_methods(): - """Test the new database query methods""" - - logger.info("=== Testing Database Query Methods ===") - - try: - db_manager = get_database_manager() - - # Test getting inference records for training - training_records = db_manager.get_inference_records_for_training( - model_name="enhanced_cnn", - hours_back=24, - limit=10 - ) - - logger.info(f"✅ Found {len(training_records)} training records") - - for i, record in enumerate(training_records[:3]): # Show first 3 - logger.info(f" Record {i+1}:") - logger.info(f" Action: {record.action}") - logger.info(f" Confidence: {record.confidence:.3f}") - logger.info(f" Has features: {record.input_features is not None}") - if record.input_features is not None: - logger.info(f" Features shape: {record.input_features.shape}") - - return True - - except Exception as e: - logger.error(f"❌ Database query test failed: {e}") - return False - -def main(): - """Run all tests""" - - logger.info("Starting Enhanced Inference Logging Tests") - - # Test 1: Enhanced inference logging - test1_passed = test_enhanced_inference_logging() - - # Test 2: Database query methods - test2_passed = test_database_query_methods() - - # Summary - logger.info("=== Test Summary ===") - logger.info(f"Enhanced Inference Logging: {'✅ PASSED' if test1_passed else '❌ FAILED'}") - logger.info(f"Database Query Methods: {'✅ PASSED' if test2_passed else '❌ FAILED'}") - - if test1_passed and test2_passed: - logger.info("🎉 All tests passed! Enhanced inference logging is working correctly.") - logger.info("The system now:") - logger.info(" - Stores full input features with each inference") - logger.info(" - Can retrieve inference data for training feedback") - logger.info(" - Supports continuous learning from inference history") - logger.info(" - Evaluates prediction accuracy over time") - else: - logger.error("❌ Some tests failed. Please check the implementation.") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_enhanced_training_integration.py b/test_enhanced_training_integration.py deleted file mode 100644 index 3568fff..0000000 --- a/test_enhanced_training_integration.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced Training Integration - -This script tests the integration of EnhancedRealtimeTrainingSystem -into the TradingOrchestrator to ensure it works correctly. -""" - -import sys -import os -import logging -import asyncio -from datetime import datetime - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -async def test_enhanced_training_integration(): - """Test the enhanced training system integration""" - try: - logger.info("=" * 60) - logger.info("TESTING ENHANCED TRAINING INTEGRATION") - logger.info("=" * 60) - - # 1. Initialize orchestrator with enhanced training - logger.info("1. Initializing orchestrator with enhanced training...") - data_provider = DataProvider() - orchestrator = TradingOrchestrator( - data_provider=data_provider, - enhanced_rl_training=True - ) - - # 2. Check if training system is available - logger.info("2. Checking training system availability...") - training_available = hasattr(orchestrator, 'enhanced_training_system') - training_enabled = getattr(orchestrator, 'training_enabled', False) - - logger.info(f" - Training system attribute: {'✅ Available' if training_available else '❌ Missing'}") - logger.info(f" - Training enabled: {'✅ Yes' if training_enabled else '❌ No'}") - - # 3. Test training system initialization - if training_available and orchestrator.enhanced_training_system: - logger.info("3. Testing training system methods...") - - # Test getting training statistics - stats = orchestrator.get_enhanced_training_stats() - logger.info(f" - Training stats retrieved: {len(stats)} fields") - logger.info(f" - Training enabled in stats: {stats.get('training_enabled', False)}") - logger.info(f" - System available: {stats.get('system_available', False)}") - - # Test starting training - start_result = orchestrator.start_enhanced_training() - logger.info(f" - Start training result: {'✅ Success' if start_result else '❌ Failed'}") - - if start_result: - # Let it run for a few seconds - logger.info(" - Letting training run for 5 seconds...") - await asyncio.sleep(5) - - # Get updated stats - updated_stats = orchestrator.get_enhanced_training_stats() - logger.info(f" - Updated stats: {updated_stats.get('is_training', False)}") - - # Stop training - stop_result = orchestrator.stop_enhanced_training() - logger.info(f" - Stop training result: {'✅ Success' if stop_result else '❌ Failed'}") - - else: - logger.warning("3. Training system not available - checking fallback behavior...") - - # Test methods when training system is not available - stats = orchestrator.get_enhanced_training_stats() - logger.info(f" - Fallback stats: {stats}") - - start_result = orchestrator.start_enhanced_training() - logger.info(f" - Fallback start result: {start_result}") - - # 4. Test dashboard connection method - logger.info("4. Testing dashboard connection method...") - try: - orchestrator.set_training_dashboard(None) # Test with None - logger.info(" - Dashboard connection method: ✅ Available") - except Exception as e: - logger.error(f" - Dashboard connection method error: {e}") - - # 5. Summary - logger.info("=" * 60) - logger.info("INTEGRATION TEST SUMMARY") - logger.info("=" * 60) - - if training_available and training_enabled: - logger.info("✅ ENHANCED TRAINING INTEGRATION SUCCESSFUL") - logger.info(" - Training system properly integrated") - logger.info(" - All methods available and functional") - logger.info(" - Ready for real-time training") - elif training_available: - logger.info("⚠️ ENHANCED TRAINING PARTIALLY INTEGRATED") - logger.info(" - Training system available but not enabled") - logger.info(" - Check EnhancedRealtimeTrainingSystem import") - else: - logger.info("❌ ENHANCED TRAINING INTEGRATION FAILED") - logger.info(" - Training system not properly integrated") - logger.info(" - Methods missing or non-functional") - - return training_available and training_enabled - - except Exception as e: - logger.error(f"Error in integration test: {e}") - import traceback - logger.error(traceback.format_exc()) - return False - -async def main(): - """Main test function""" - try: - success = await test_enhanced_training_integration() - - if success: - logger.info("🎉 All tests passed! Enhanced training integration is working.") - return 0 - else: - logger.warning("⚠️ Some tests failed. Check the integration.") - return 1 - - except KeyboardInterrupt: - logger.info("Test interrupted by user") - return 0 - except Exception as e: - logger.error(f"Fatal error in test: {e}") - return 1 - -if __name__ == "__main__": - exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file diff --git a/test_enhanced_training_simple.py b/test_enhanced_training_simple.py deleted file mode 100644 index f3f600c..0000000 --- a/test_enhanced_training_simple.py +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple Enhanced Training Test - -Quick test to verify enhanced training system can be enabled and controlled. -""" - -import sys -import os -import logging - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_enhanced_training(): - """Test enhanced training system""" - try: - logger.info("Testing Enhanced Training System...") - - # 1. Create data provider - data_provider = DataProvider() - - # 2. Create orchestrator with enhanced training ENABLED - logger.info("Creating orchestrator with enhanced_rl_training=True...") - orchestrator = TradingOrchestrator( - data_provider=data_provider, - enhanced_rl_training=True # 🔥 THIS ENABLES IT - ) - - # 3. Check if training system is available - logger.info(f"Training system available: {orchestrator.enhanced_training_system is not None}") - logger.info(f"Training enabled: {orchestrator.training_enabled}") - - # 4. Get training stats - stats = orchestrator.get_enhanced_training_stats() - logger.info(f"Training stats: {stats}") - - # 5. Test start/stop - if orchestrator.enhanced_training_system: - logger.info("Testing start/stop functionality...") - - # Start training - start_result = orchestrator.start_enhanced_training() - logger.info(f"Start result: {start_result}") - - # Get updated stats - updated_stats = orchestrator.get_enhanced_training_stats() - logger.info(f"Updated stats: {updated_stats}") - - # Stop training - stop_result = orchestrator.stop_enhanced_training() - logger.info(f"Stop result: {stop_result}") - - logger.info("✅ Enhanced training system is working!") - return True - else: - logger.warning("❌ Enhanced training system not available") - return False - - except Exception as e: - logger.error(f"Error testing enhanced training: {e}") - return False - -if __name__ == "__main__": - success = test_enhanced_training() - if success: - print("\n🎉 Enhanced training system is ready to use!") - print("To enable it in your main system, use:") - print(" enhanced_rl_training=True when creating TradingOrchestrator") - else: - print("\n⚠️ Enhanced training system has issues. Check the logs above.") \ No newline at end of file diff --git a/test_fifo_queues.py b/test_fifo_queues.py deleted file mode 100644 index ea905c9..0000000 --- a/test_fifo_queues.py +++ /dev/null @@ -1,285 +0,0 @@ -#!/usr/bin/env python3 -""" -Test FIFO Queue System - -Verify that the orchestrator's FIFO queue system works correctly -""" - -import time -from datetime import datetime -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider -from core.data_models import OHLCVBar - -def test_fifo_queue_operations(): - """Test basic FIFO queue operations""" - print("=== Testing FIFO Queue Operations ===") - - try: - # Create orchestrator - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - - # Test queue status - status = orchestrator.get_queue_status() - print(f"Initial queue status: {status}") - - # Test adding data to queues - test_bar = OHLCVBar( - symbol="ETH/USDT", - timestamp=datetime.now(), - open=2500.0, - high=2510.0, - low=2490.0, - close=2505.0, - volume=1000.0, - timeframe="1s" - ) - - # Add test data - success = orchestrator.update_data_queue('ohlcv_1s', 'ETH/USDT', test_bar) - print(f"Added OHLCV data: {success}") - - # Check queue status after adding data - status = orchestrator.get_queue_status() - print(f"Queue status after adding data: {status}") - - # Test retrieving data - latest_data = orchestrator.get_latest_data('ohlcv_1s', 'ETH/USDT', 1) - print(f"Retrieved latest data: {len(latest_data)} items") - - if latest_data: - bar = latest_data[0] - print(f" Bar: {bar.symbol} {bar.close} @ {bar.timestamp}") - - # Test minimum data check - has_min_data = orchestrator.ensure_minimum_data('ohlcv_1s', 'ETH/USDT', 1) - print(f"Has minimum data (1): {has_min_data}") - - has_min_data_100 = orchestrator.ensure_minimum_data('ohlcv_1s', 'ETH/USDT', 100) - print(f"Has minimum data (100): {has_min_data_100}") - - return True - - except Exception as e: - print(f"❌ FIFO queue operations test failed: {e}") - return False - -def test_data_queue_filling(): - """Test filling queues with multiple data points""" - print("\n=== Testing Data Queue Filling ===") - - try: - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - - # Add multiple OHLCV bars - for i in range(150): # Add 150 bars - test_bar = OHLCVBar( - symbol="ETH/USDT", - timestamp=datetime.now(), - open=2500.0 + i, - high=2510.0 + i, - low=2490.0 + i, - close=2505.0 + i, - volume=1000.0 + i, - timeframe="1s" - ) - orchestrator.update_data_queue('ohlcv_1s', 'ETH/USDT', test_bar) - - # Check queue status - status = orchestrator.get_queue_status() - print(f"Queue status after adding 150 bars: {status}") - - # Test minimum data requirements - has_min_data = orchestrator.ensure_minimum_data('ohlcv_1s', 'ETH/USDT', 100) - print(f"Has minimum data (100): {has_min_data}") - - # Get all data - all_data = orchestrator.get_queue_data('ohlcv_1s', 'ETH/USDT') - print(f"Total data in queue: {len(all_data)} items") - - # Test max_items parameter - limited_data = orchestrator.get_queue_data('ohlcv_1s', 'ETH/USDT', max_items=50) - print(f"Limited data (50): {len(limited_data)} items") - - return True - - except Exception as e: - print(f"❌ Data queue filling test failed: {e}") - return False - -def test_base_data_input_building(): - """Test building BaseDataInput from FIFO queues""" - print("\n=== Testing BaseDataInput Building ===") - - try: - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - - # Fill queues with sufficient data - timeframes = ['1s', '1m', '1h', '1d'] - min_counts = [100, 50, 20, 10] - - for timeframe, min_count in zip(timeframes, min_counts): - for i in range(min_count + 10): # Add a bit more than minimum - test_bar = OHLCVBar( - symbol="ETH/USDT", - timestamp=datetime.now(), - open=2500.0 + i, - high=2510.0 + i, - low=2490.0 + i, - close=2505.0 + i, - volume=1000.0 + i, - timeframe=timeframe - ) - orchestrator.update_data_queue(f'ohlcv_{timeframe}', 'ETH/USDT', test_bar) - - # Add BTC data - for i in range(110): - btc_bar = OHLCVBar( - symbol="BTC/USDT", - timestamp=datetime.now(), - open=50000.0 + i, - high=50100.0 + i, - low=49900.0 + i, - close=50050.0 + i, - volume=100.0 + i, - timeframe="1s" - ) - orchestrator.update_data_queue('ohlcv_1s', 'BTC/USDT', btc_bar) - - # Add technical indicators - test_indicators = {'rsi': 50.0, 'macd': 0.1, 'bb_upper': 2520.0, 'bb_lower': 2480.0} - orchestrator.update_data_queue('technical_indicators', 'ETH/USDT', test_indicators) - - # Try to build BaseDataInput - base_data = orchestrator.build_base_data_input('ETH/USDT') - - if base_data: - print("✅ BaseDataInput built successfully") - - # Test feature vector - features = base_data.get_feature_vector() - print(f" Feature vector size: {len(features)}") - print(f" Symbol: {base_data.symbol}") - print(f" OHLCV 1s data: {len(base_data.ohlcv_1s)} bars") - print(f" OHLCV 1m data: {len(base_data.ohlcv_1m)} bars") - print(f" BTC data: {len(base_data.btc_ohlcv_1s)} bars") - print(f" Technical indicators: {len(base_data.technical_indicators)}") - - # Validate - is_valid = base_data.validate() - print(f" Validation: {is_valid}") - - return True - else: - print("❌ Failed to build BaseDataInput") - return False - - except Exception as e: - print(f"❌ BaseDataInput building test failed: {e}") - return False - -def test_consistent_feature_size(): - """Test that feature vectors are always the same size""" - print("\n=== Testing Consistent Feature Size ===") - - try: - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - - # Fill with minimal data first - for timeframe, min_count in [('1s', 100), ('1m', 50), ('1h', 20), ('1d', 10)]: - for i in range(min_count): - test_bar = OHLCVBar( - symbol="ETH/USDT", - timestamp=datetime.now(), - open=2500.0 + i, - high=2510.0 + i, - low=2490.0 + i, - close=2505.0 + i, - volume=1000.0 + i, - timeframe=timeframe - ) - orchestrator.update_data_queue(f'ohlcv_{timeframe}', 'ETH/USDT', test_bar) - - # Add BTC data - for i in range(100): - btc_bar = OHLCVBar( - symbol="BTC/USDT", - timestamp=datetime.now(), - open=50000.0 + i, - high=50100.0 + i, - low=49900.0 + i, - close=50050.0 + i, - volume=100.0 + i, - timeframe="1s" - ) - orchestrator.update_data_queue('ohlcv_1s', 'BTC/USDT', btc_bar) - - feature_sizes = [] - - # Test multiple scenarios - scenarios = [ - ("Minimal data", {}), - ("With indicators", {'rsi': 50.0, 'macd': 0.1}), - ("More indicators", {'rsi': 45.0, 'macd': 0.2, 'bb_upper': 2520.0, 'bb_lower': 2480.0, 'ema_20': 2500.0}) - ] - - for name, indicators in scenarios: - if indicators: - orchestrator.update_data_queue('technical_indicators', 'ETH/USDT', indicators) - - base_data = orchestrator.build_base_data_input('ETH/USDT') - if base_data: - features = base_data.get_feature_vector() - feature_sizes.append(len(features)) - print(f"{name}: {len(features)} features") - else: - print(f"{name}: Failed to build BaseDataInput") - return False - - # Check consistency - if len(set(feature_sizes)) == 1: - print(f"✅ All feature vectors have consistent size: {feature_sizes[0]}") - return True - else: - print(f"❌ Inconsistent feature sizes: {feature_sizes}") - return False - - except Exception as e: - print(f"❌ Consistent feature size test failed: {e}") - return False - -def main(): - """Run all FIFO queue tests""" - print("=== FIFO Queue System Test Suite ===\n") - - tests = [ - test_fifo_queue_operations, - test_data_queue_filling, - test_base_data_input_building, - test_consistent_feature_size - ] - - passed = 0 - total = len(tests) - - for test in tests: - if test(): - passed += 1 - print() - - print(f"=== Test Results: {passed}/{total} passed ===") - - if passed == total: - print("✅ ALL TESTS PASSED!") - print("✅ FIFO queue system is working correctly") - print("✅ Consistent data flow ensured") - print("✅ No more network rebuilding issues") - else: - print("❌ Some tests failed") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_hold_position_fix.py b/test_hold_position_fix.py deleted file mode 100644 index e69de29..0000000 diff --git a/test_imbalance_calculation.py b/test_imbalance_calculation.py deleted file mode 100644 index 4372985..0000000 --- a/test_imbalance_calculation.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for imbalance calculation logic -""" - -import time -import logging -from datetime import datetime -from core.data_provider import DataProvider - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_imbalance_calculation(): - """Test the imbalance calculation logic with mock data""" - logger.info("Testing imbalance calculation logic...") - - # Initialize data provider - dp = DataProvider() - - # Create mock COB tick data - mock_ticks = [] - current_time = time.time() - - # Create 10 mock ticks with different imbalances - for i in range(10): - tick = { - 'symbol': 'ETH/USDT', - 'timestamp': current_time - (10 - i), # 10 seconds ago to now - 'bids': [ - [3800 + i, 100 + i * 10], # Price, Volume - [3799 + i, 50 + i * 5], - [3798 + i, 25 + i * 2] - ], - 'asks': [ - [3801 + i, 80 + i * 8], # Price, Volume - [3802 + i, 40 + i * 4], - [3803 + i, 20 + i * 2] - ], - 'stats': { - 'mid_price': 3800.5 + i, - 'spread_bps': 2.5 + i * 0.1, - 'imbalance': (i - 5) * 0.1 # Varying imbalance from -0.5 to +0.4 - }, - 'source': 'mock' - } - mock_ticks.append(tick) - - # Add mock ticks to the data provider - for tick in mock_ticks: - dp.cob_raw_ticks['ETH/USDT'].append(tick) - - logger.info(f"Added {len(mock_ticks)} mock COB ticks") - - # Test the aggregation function - logger.info("\n=== Testing COB Aggregation ===") - target_second = int(current_time - 5) # 5 seconds ago - - # Manually call the aggregation function - dp._aggregate_cob_1s('ETH/USDT', target_second) - - # Check the results - aggregated_data = list(dp.cob_1s_aggregated['ETH/USDT']) - if aggregated_data: - latest = aggregated_data[-1] - logger.info(f"Aggregated data created:") - logger.info(f" Timestamp: {latest.get('timestamp')}") - logger.info(f" Tick count: {latest.get('tick_count')}") - logger.info(f" Current imbalance: {latest.get('imbalance', 0):.4f}") - logger.info(f" Total volume: {latest.get('total_volume', 0):.2f}") - logger.info(f" Bid buckets: {len(latest.get('bid_buckets', {}))}") - logger.info(f" Ask buckets: {len(latest.get('ask_buckets', {}))}") - - # Check multi-timeframe imbalances - logger.info(f" Imbalance 1s: {latest.get('imbalance_1s', 0):.4f}") - logger.info(f" Imbalance 5s: {latest.get('imbalance_5s', 0):.4f}") - logger.info(f" Imbalance 15s: {latest.get('imbalance_15s', 0):.4f}") - logger.info(f" Imbalance 60s: {latest.get('imbalance_60s', 0):.4f}") - else: - logger.warning("No aggregated data created") - - # Test multiple aggregations to build history - logger.info("\n=== Testing Multi-timeframe Imbalances ===") - for i in range(1, 6): - target_second = int(current_time - 5 + i) - dp._aggregate_cob_1s('ETH/USDT', target_second) - - # Check the final results - final_data = list(dp.cob_1s_aggregated['ETH/USDT']) - logger.info(f"Created {len(final_data)} aggregated records") - - if final_data: - latest = final_data[-1] - logger.info(f"Final imbalance indicators:") - logger.info(f" 1s: {latest.get('imbalance_1s', 0):.4f}") - logger.info(f" 5s: {latest.get('imbalance_5s', 0):.4f}") - logger.info(f" 15s: {latest.get('imbalance_15s', 0):.4f}") - logger.info(f" 60s: {latest.get('imbalance_60s', 0):.4f}") - - # Test the COB data quality function - logger.info("\n=== Testing COB Data Quality Function ===") - quality = dp.get_cob_data_quality() - - eth_quality = quality.get('imbalance_indicators', {}).get('ETH/USDT', {}) - if eth_quality: - logger.info("COB quality indicators:") - logger.info(f" Imbalance 1s: {eth_quality.get('imbalance_1s', 0):.4f}") - logger.info(f" Imbalance 5s: {eth_quality.get('imbalance_5s', 0):.4f}") - logger.info(f" Imbalance 15s: {eth_quality.get('imbalance_15s', 0):.4f}") - logger.info(f" Imbalance 60s: {eth_quality.get('imbalance_60s', 0):.4f}") - logger.info(f" Total volume: {eth_quality.get('total_volume', 0):.2f}") - logger.info(f" Bucket count: {eth_quality.get('bucket_count', 0)}") - - logger.info("\n✅ Imbalance calculation test completed successfully!") - -if __name__ == "__main__": - test_imbalance_calculation() \ No newline at end of file diff --git a/test_improved_data_integration.py b/test_improved_data_integration.py deleted file mode 100644 index c10bc32..0000000 --- a/test_improved_data_integration.py +++ /dev/null @@ -1,222 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Improved Data Integration - -Test the enhanced data integration with fallback strategies -""" - -import time -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -def test_enhanced_data_population(): - """Test enhanced data population with fallback strategies""" - print("=== Testing Enhanced Data Population ===") - - try: - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - - # Wait for initial population - print("Waiting 5 seconds for enhanced data population...") - time.sleep(5) - - # Check detailed queue status - print("\nDetailed queue status after enhanced population:") - orchestrator.log_queue_status(detailed=True) - - # Check minimum data requirements - symbols_to_check = ['ETH/USDT', 'BTC/USDT'] - timeframes_to_check = ['1s', '1m', '1h', '1d'] - min_requirements = {'1s': 100, '1m': 50, '1h': 20, '1d': 10} - - print(f"\nChecking minimum data requirements with fallback:") - all_requirements_met = True - - for symbol in symbols_to_check: - print(f"\n{symbol}:") - symbol_requirements_met = True - - for timeframe in timeframes_to_check: - min_count = min_requirements.get(timeframe, 10) - has_min = orchestrator.ensure_minimum_data(f'ohlcv_{timeframe}', symbol, min_count) - actual_count = 0 - if f'ohlcv_{timeframe}' in orchestrator.data_queues and symbol in orchestrator.data_queues[f'ohlcv_{timeframe}']: - with orchestrator.data_queue_locks[f'ohlcv_{timeframe}'][symbol]: - actual_count = len(orchestrator.data_queues[f'ohlcv_{timeframe}'][symbol]) - - status = "✅" if has_min else "❌" - print(f" {timeframe}: {status} {actual_count}/{min_count}") - - if not has_min: - symbol_requirements_met = False - all_requirements_met = False - - # Check technical indicators - indicators_count = 0 - if 'technical_indicators' in orchestrator.data_queues and symbol in orchestrator.data_queues['technical_indicators']: - with orchestrator.data_queue_locks['technical_indicators'][symbol]: - indicators_data = list(orchestrator.data_queues['technical_indicators'][symbol]) - if indicators_data: - indicators_count = len(indicators_data[-1]) # Latest indicators dict - - indicators_status = "✅" if indicators_count > 0 else "❌" - print(f" indicators: {indicators_status} {indicators_count} calculated") - - # Test BaseDataInput building - print(f"\nTesting BaseDataInput building with fallback:") - for symbol in ['ETH/USDT', 'BTC/USDT']: - base_data = orchestrator.build_base_data_input(symbol) - if base_data: - features = base_data.get_feature_vector() - print(f" ✅ {symbol}: BaseDataInput built successfully") - print(f" Feature vector size: {len(features)}") - print(f" OHLCV 1s bars: {len(base_data.ohlcv_1s)}") - print(f" OHLCV 1m bars: {len(base_data.ohlcv_1m)}") - print(f" OHLCV 1h bars: {len(base_data.ohlcv_1h)}") - print(f" OHLCV 1d bars: {len(base_data.ohlcv_1d)}") - print(f" BTC bars: {len(base_data.btc_ohlcv_1s)}") - print(f" Technical indicators: {len(base_data.technical_indicators)}") - - # Validate feature vector - if len(features) == 7850: - print(f" ✅ Feature vector has correct size (7850)") - else: - print(f" ❌ Feature vector size mismatch: {len(features)} != 7850") - else: - print(f" ❌ {symbol}: Failed to build BaseDataInput") - - return all_requirements_met - - except Exception as e: - print(f"❌ Test failed: {e}") - return False - -def test_fallback_strategies(): - """Test specific fallback strategies""" - print("\n=== Testing Fallback Strategies ===") - - try: - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - - # Wait for initial population - time.sleep(3) - - # Check if fallback strategies were used - print("Checking fallback strategy usage:") - - # Check ETH/USDT 1s data (likely to need fallback) - eth_1s_count = 0 - if 'ohlcv_1s' in orchestrator.data_queues and 'ETH/USDT' in orchestrator.data_queues['ohlcv_1s']: - with orchestrator.data_queue_locks['ohlcv_1s']['ETH/USDT']: - eth_1s_count = len(orchestrator.data_queues['ohlcv_1s']['ETH/USDT']) - - if eth_1s_count >= 100: - print(f" ✅ ETH/USDT 1s data: {eth_1s_count} bars (fallback likely used)") - else: - print(f" ❌ ETH/USDT 1s data: {eth_1s_count} bars (fallback may have failed)") - - # Check ETH/USDT 1h data (likely to need fallback) - eth_1h_count = 0 - if 'ohlcv_1h' in orchestrator.data_queues and 'ETH/USDT' in orchestrator.data_queues['ohlcv_1h']: - with orchestrator.data_queue_locks['ohlcv_1h']['ETH/USDT']: - eth_1h_count = len(orchestrator.data_queues['ohlcv_1h']['ETH/USDT']) - - if eth_1h_count >= 20: - print(f" ✅ ETH/USDT 1h data: {eth_1h_count} bars (fallback likely used)") - else: - print(f" ❌ ETH/USDT 1h data: {eth_1h_count} bars (fallback may have failed)") - - # Test manual fallback strategy - print(f"\nTesting manual fallback strategy:") - missing_data = [('ohlcv_1s', 0, 100), ('ohlcv_1h', 0, 20)] - fallback_success = orchestrator._try_fallback_data_strategy('ETH/USDT', missing_data) - print(f" Manual fallback result: {'✅ SUCCESS' if fallback_success else '❌ FAILED'}") - - return eth_1s_count >= 100 and eth_1h_count >= 20 - - except Exception as e: - print(f"❌ Test failed: {e}") - return False - -def test_model_predictions(): - """Test that models can now make predictions with the improved data""" - print("\n=== Testing Model Predictions ===") - - try: - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - - # Wait for data population - time.sleep(5) - - # Try to make predictions - print("Testing model prediction capability:") - - # Test CNN prediction - try: - base_data = orchestrator.build_base_data_input('ETH/USDT') - if base_data: - print(" ✅ BaseDataInput available for CNN") - - # Test feature vector - features = base_data.get_feature_vector() - if len(features) == 7850: - print(" ✅ Feature vector has correct size for CNN") - print(" ✅ CNN should be able to make predictions without rebuilding") - else: - print(f" ❌ Feature vector size issue: {len(features)} != 7850") - else: - print(" ❌ BaseDataInput not available for CNN") - except Exception as e: - print(f" ❌ CNN prediction test failed: {e}") - - # Test RL prediction - try: - base_data = orchestrator.build_base_data_input('ETH/USDT') - if base_data: - print(" ✅ BaseDataInput available for RL") - - # Test state features - state_features = base_data.get_feature_vector() - if len(state_features) == 7850: - print(" ✅ State features have correct size for RL") - else: - print(f" ❌ State features size issue: {len(state_features)} != 7850") - else: - print(" ❌ BaseDataInput not available for RL") - except Exception as e: - print(f" ❌ RL prediction test failed: {e}") - - return base_data is not None - - except Exception as e: - print(f"❌ Test failed: {e}") - return False - -def main(): - """Run all enhanced data integration tests""" - print("=== Enhanced Data Integration Test Suite ===") - - test1_passed = test_enhanced_data_population() - test2_passed = test_fallback_strategies() - test3_passed = test_model_predictions() - - print(f"\n=== Results ===") - print(f"Enhanced data population: {'✅ PASSED' if test1_passed else '❌ FAILED'}") - print(f"Fallback strategies: {'✅ PASSED' if test2_passed else '❌ FAILED'}") - print(f"Model predictions: {'✅ PASSED' if test3_passed else '❌ FAILED'}") - - if test1_passed and test2_passed and test3_passed: - print("\n✅ ALL TESTS PASSED!") - print("✅ Enhanced data integration is working!") - print("✅ Fallback strategies provide missing data") - print("✅ Models should be able to make predictions") - print("✅ No more 'Insufficient data' errors expected") - else: - print("\n⚠️ Some tests failed, but system may still work") - print("⚠️ Check specific failures above") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_integrated_standardized_provider.py b/test_integrated_standardized_provider.py deleted file mode 100644 index 87164e4..0000000 --- a/test_integrated_standardized_provider.py +++ /dev/null @@ -1,181 +0,0 @@ -""" -Test script for integrated StandardizedDataProvider with ModelOutputManager - -This script tests the complete standardized data provider with extensible model output storage -""" - -import sys -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -import logging -from datetime import datetime -from core.standardized_data_provider import StandardizedDataProvider -from core.data_models import create_model_output - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_integrated_standardized_provider(): - """Test the integrated StandardizedDataProvider with ModelOutputManager""" - - print("Testing Integrated StandardizedDataProvider with ModelOutputManager...") - - # Initialize the provider - symbols = ['ETH/USDT', 'BTC/USDT'] - timeframes = ['1s', '1m', '1h', '1d'] - - provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes) - - print("✅ StandardizedDataProvider initialized with ModelOutputManager") - - # Test 1: Store model outputs from different types - print("\n1. Testing model output storage integration...") - - # Create and store outputs from different model types - model_outputs = [ - create_model_output('cnn', 'enhanced_cnn_v1', 'ETH/USDT', 'BUY', 0.85), - create_model_output('rl', 'dqn_agent_v2', 'ETH/USDT', 'SELL', 0.72), - create_model_output('transformer', 'transformer_v1', 'ETH/USDT', 'BUY', 0.91), - create_model_output('orchestrator', 'main_orchestrator', 'ETH/USDT', 'BUY', 0.78) - ] - - for output in model_outputs: - provider.store_model_output(output) - print(f"✅ Stored {output.model_type} output: {output.predictions['action']} ({output.confidence})") - - # Test 2: Retrieve model outputs - print("\n2. Testing model output retrieval...") - - all_outputs = provider.get_model_outputs('ETH/USDT') - print(f"✅ Retrieved {len(all_outputs)} model outputs for ETH/USDT") - - for model_name, output in all_outputs.items(): - print(f" {model_name} ({output.model_type}): {output.predictions['action']} - {output.confidence}") - - # Test 3: Test BaseDataInput with cross-model feeding - print("\n3. Testing BaseDataInput with cross-model predictions...") - - # Use real current prices only - no mock data - - base_input = provider.get_base_data_input('ETH/USDT') - - if base_input: - print("✅ BaseDataInput created with cross-model predictions!") - print(f" Symbol: {base_input.symbol}") - print(f" OHLCV frames: 1s={len(base_input.ohlcv_1s)}, 1m={len(base_input.ohlcv_1m)}, 1h={len(base_input.ohlcv_1h)}, 1d={len(base_input.ohlcv_1d)}") - print(f" BTC frames: {len(base_input.btc_ohlcv_1s)}") - print(f" COB data: {'Available' if base_input.cob_data else 'Not available'}") - print(f" Last predictions: {len(base_input.last_predictions)} models") - - # Show cross-model predictions - for model_name, prediction in base_input.last_predictions.items(): - print(f" {model_name}: {prediction.predictions['action']} ({prediction.confidence})") - - # Test feature vector creation - try: - feature_vector = base_input.get_feature_vector() - print(f"✅ Feature vector created: shape {feature_vector.shape}") - except Exception as e: - print(f"❌ Feature vector creation failed: {e}") - else: - print("⚠️ BaseDataInput creation failed - this may be due to insufficient historical data") - - # Test 4: Advanced ModelOutputManager features - print("\n4. Testing advanced model output manager features...") - - output_manager = provider.get_model_output_manager() - - # Test consensus prediction - consensus = output_manager.get_consensus_prediction('ETH/USDT', confidence_threshold=0.7) - if consensus: - print(f"✅ Consensus prediction: {consensus['action']} (confidence: {consensus['confidence']:.3f})") - print(f" Votes: {consensus['votes']}") - print(f" Contributing models: {consensus['model_types']}") - else: - print("⚠️ No consensus reached") - - # Test cross-model states - cross_states = output_manager.get_cross_model_states('ETH/USDT', 'dqn_agent_v2') - print(f"✅ Cross-model states available for RL model: {len(cross_states)} models") - - # Test performance summary - performance = output_manager.get_performance_summary('ETH/USDT') - print(f"✅ Performance summary: {performance['active_models']} active models") - - # Test 5: Custom model type support - print("\n5. Testing custom model type extensibility...") - - # Add a custom model type - output_manager.add_custom_model_type('hybrid_lstm_transformer') - - # Create and store custom model output - custom_output = create_model_output( - model_type='hybrid_lstm_transformer', - model_name='hybrid_model_v1', - symbol='ETH/USDT', - action='BUY', - confidence=0.89, - metadata={'hybrid_components': ['lstm', 'transformer'], 'ensemble_weight': 0.6} - ) - - provider.store_model_output(custom_output) - print("✅ Custom model type 'hybrid_lstm_transformer' stored successfully") - - # Verify it's included in BaseDataInput - updated_base_input = provider.get_base_data_input('ETH/USDT') - if updated_base_input and 'hybrid_model_v1' in updated_base_input.last_predictions: - print("✅ Custom model output included in BaseDataInput cross-model feeding") - - print(f" Total supported model types: {len(output_manager.get_supported_model_types())}") - - # Test 6: Historical tracking - print("\n6. Testing historical output tracking...") - - # Store a few more outputs to build history - for i in range(3): - historical_output = create_model_output( - model_type='cnn', - model_name='enhanced_cnn_v1', - symbol='ETH/USDT', - action='HOLD', - confidence=0.6 + i * 0.05 - ) - provider.store_model_output(historical_output) - - history = output_manager.get_output_history('ETH/USDT', 'enhanced_cnn_v1', count=5) - print(f"✅ Historical tracking: {len(history)} outputs for enhanced_cnn_v1") - - # Test 7: Real-time data integration readiness - print("\n7. Testing real-time integration readiness...") - - print("✅ Real-time processing methods available:") - print(" - start_real_time_processing()") - print(" - stop_real_time_processing()") - print(" - COB provider integration ready") - print(" - Model output persistence enabled") - - print("\n✅ Integrated StandardizedDataProvider test completed successfully!") - print("\n🎯 Key achievements:") - print("✓ Standardized BaseDataInput format for all models") - print("✓ Extensible ModelOutput storage (CNN, RL, LSTM, Transformer, Custom)") - print("✓ Cross-model feeding with last predictions") - print("✓ COB data integration with moving averages") - print("✓ Consensus prediction calculation") - print("✓ Historical output tracking") - print("✓ Performance analytics") - print("✓ Thread-safe operations") - print("✓ Persistent storage capabilities") - - print("\n🚀 Ready for model integration:") - print("1. CNN models can use BaseDataInput and store ModelOutput") - print("2. RL models can access CNN hidden states via cross-model feeding") - print("3. Orchestrator can calculate consensus from all models") - print("4. New model types can be added without code changes") - print("5. All models receive identical standardized input format") - - return provider - -if __name__ == "__main__": - test_integrated_standardized_provider() \ No newline at end of file diff --git a/test_leverage_fix.py b/test_leverage_fix.py deleted file mode 100644 index 93c708c..0000000 --- a/test_leverage_fix.py +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Leverage Fix - -This script tests if the leverage is now being applied correctly to trade P&L calculations. -""" - -import sys -import os -from datetime import datetime - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.trading_executor import TradingExecutor, Position - -def test_leverage_fix(): - """Test that leverage is now being applied correctly""" - print("🧪 Testing Leverage Fix") - print("=" * 50) - - # Create trading executor - executor = TradingExecutor() - - # Check current leverage setting - current_leverage = executor.get_leverage() - print(f"Current leverage setting: x{current_leverage}") - - # Test leverage in P&L calculation - position = Position( - symbol="ETH/USDT", - side="SHORT", - quantity=0.005, # 0.005 ETH - entry_price=3755.33, - entry_time=datetime.now(), - order_id="test_123" - ) - - # Test P&L calculation with current price - current_price = 3740.51 # Price went down, should be profitable for SHORT - - # Calculate P&L with leverage - pnl_with_leverage = position.calculate_pnl(current_price, leverage=current_leverage) - pnl_without_leverage = position.calculate_pnl(current_price, leverage=1.0) - - print(f"\nPosition: SHORT 0.005 ETH @ $3755.33") - print(f"Current price: $3740.51") - print(f"Price difference: ${3755.33 - 3740.51:.2f} (favorable for SHORT)") - - print(f"\nP&L without leverage (x1): ${pnl_without_leverage:.2f}") - print(f"P&L with leverage (x{current_leverage}): ${pnl_with_leverage:.2f}") - print(f"Leverage multiplier effect: {pnl_with_leverage / pnl_without_leverage:.1f}x") - - # Expected calculation - position_value = 0.005 * 3755.33 # ~$18.78 - price_diff = 3755.33 - 3740.51 # $14.82 favorable - raw_pnl = price_diff * 0.005 # ~$0.074 - leveraged_pnl = raw_pnl * current_leverage # ~$3.70 - - print(f"\nExpected calculation:") - print(f"Position value: ${position_value:.2f}") - print(f"Raw P&L: ${raw_pnl:.3f}") - print(f"Leveraged P&L (before fees): ${leveraged_pnl:.2f}") - - # Check if the calculation is correct - if abs(pnl_with_leverage - leveraged_pnl) < 0.1: # Allow for small fee differences - print("✅ Leverage calculation appears correct!") - else: - print("❌ Leverage calculation may have issues") - - print("\n" + "=" * 50) - print("Test completed. Check if new trades show leveraged P&L in dashboard.") - -if __name__ == "__main__": - test_leverage_fix() \ No newline at end of file diff --git a/test_massive_dqn.py b/test_massive_dqn.py deleted file mode 100644 index 3d03c69..0000000 --- a/test_massive_dqn.py +++ /dev/null @@ -1,232 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for the massive 50M parameter DQN agent -Tests: -1. Model initialization and parameter count -2. Forward pass functionality -3. Gradient flow verification -4. Training step simulation -""" - -import sys -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -import torch -import numpy as np -from NN.models.dqn_agent import DQNAgent, DQNNetwork -import logging - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_dqn_architecture(): - """Test the new massive DQN architecture""" - print("🔥 Testing Massive DQN Architecture (Target: 50M parameters)") - - # Test the network directly first - input_dim = 7850 # BaseDataInput feature size - n_actions = 3 # BUY, SELL, HOLD - - print(f"\n1. Creating DQN Network with input_dim={input_dim}, n_actions={n_actions}") - network = DQNNetwork(input_dim, n_actions) - - # Count parameters - total_params = sum(p.numel() for p in network.parameters()) - print(f" ✅ Total parameters: {total_params:,}") - print(f" 🎯 Target achieved: {total_params >= 50_000_000}") - - # Test forward pass - print(f"\n2. Testing forward pass...") - batch_size = 4 - test_input = torch.randn(batch_size, input_dim) - - with torch.no_grad(): - output = network(test_input) - - if isinstance(output, tuple): - q_values, regime_pred, price_pred, volatility_pred, features = output - print(f" ✅ Q-values shape: {q_values.shape}") - print(f" ✅ Regime prediction shape: {regime_pred.shape}") - print(f" ✅ Price prediction shape: {price_pred.shape}") - print(f" ✅ Volatility prediction shape: {volatility_pred.shape}") - print(f" ✅ Features shape: {features.shape}") - else: - print(f" ✅ Output shape: {output.shape}") - - return network - -def test_gradient_flow(): - """Test that gradients flow properly through the network""" - print(f"\n🧪 Testing Gradient Flow...") - - # Create agent - state_shape = (7850,) - agent = DQNAgent( - state_shape=state_shape, - n_actions=3, - learning_rate=0.001, - batch_size=16, - buffer_size=1000 - ) - - # Force disable mixed precision - agent.use_mixed_precision = False - print(f" ✅ Mixed precision disabled: {not agent.use_mixed_precision}") - - # Ensure model is in training mode - agent.policy_net.train() - print(f" ✅ Model in training mode: {agent.policy_net.training}") - - # Create test batch - batch_size = 8 - state_dim = 7850 - - states = torch.randn(batch_size, state_dim, requires_grad=True) - actions = torch.randint(0, 3, (batch_size,)) - rewards = torch.randn(batch_size) - next_states = torch.randn(batch_size, state_dim) - dones = torch.zeros(batch_size) - - print(f" 📊 Test batch created - states: {states.shape}, actions: {actions.shape}") - - # Test forward pass and check gradients - agent.optimizer.zero_grad() - - # Forward pass - output = agent.policy_net(states) - if isinstance(output, tuple): - q_values = output[0] - else: - q_values = output - - print(f" ✅ Forward pass successful - Q-values: {q_values.shape}") - print(f" ✅ Q-values require grad: {q_values.requires_grad}") - - # Gather Q-values for actions - current_q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1) - print(f" ✅ Gathered Q-values require grad: {current_q_values.requires_grad}") - - # Compute simple loss - target_q_values = rewards # Simplified target - loss = torch.nn.MSELoss()(current_q_values, target_q_values) - print(f" ✅ Loss computed: {loss.item():.6f}") - print(f" ✅ Loss requires grad: {loss.requires_grad}") - - # Backward pass - loss.backward() - - # Check if gradients exist and are finite - grad_norms = [] - params_with_grad = 0 - total_params = 0 - - for name, param in agent.policy_net.named_parameters(): - total_params += 1 - if param.grad is not None: - params_with_grad += 1 - grad_norm = param.grad.norm().item() - grad_norms.append(grad_norm) - if not torch.isfinite(param.grad).all(): - print(f" ❌ Non-finite gradients in {name}") - return False - - print(f" ✅ Parameters with gradients: {params_with_grad}/{total_params}") - print(f" ✅ Average gradient norm: {np.mean(grad_norms):.6f}") - print(f" ✅ Max gradient norm: {max(grad_norms):.6f}") - - # Test optimizer step - agent.optimizer.step() - print(f" ✅ Optimizer step completed successfully") - - return True - -def test_training_step(): - """Test a complete training step""" - print(f"\n🏋️ Testing Complete Training Step...") - - # Create agent - state_shape = (7850,) - agent = DQNAgent( - state_shape=state_shape, - n_actions=3, - learning_rate=0.001, - batch_size=8, - buffer_size=1000 - ) - - # Force disable mixed precision - agent.use_mixed_precision = False - - # Add some experiences - for i in range(20): - state = np.random.randn(7850).astype(np.float32) - action = np.random.randint(0, 3) - reward = np.random.randn() * 0.1 - next_state = np.random.randn(7850).astype(np.float32) - done = np.random.random() < 0.1 - - agent.remember(state, action, reward, next_state, done) - - print(f" ✅ Added {len(agent.memory)} experiences to memory") - - # Test replay training - if len(agent.memory) >= agent.batch_size: - loss = agent.replay() - print(f" ✅ Training completed with loss: {loss:.6f}") - - if loss > 0: - print(f" ✅ Training successful - non-zero loss indicates learning") - return True - else: - print(f" ❌ Training failed - zero loss indicates gradient issues") - return False - else: - print(f" ⚠️ Not enough experiences for training") - return True - -def main(): - """Run all tests""" - print("🚀 MASSIVE DQN AGENT TESTING SUITE") - print("=" * 50) - - # Test 1: Architecture - try: - network = test_dqn_architecture() - print(" ✅ Architecture test PASSED") - except Exception as e: - print(f" ❌ Architecture test FAILED: {e}") - return False - - # Test 2: Gradient flow - try: - gradient_success = test_gradient_flow() - if gradient_success: - print(" ✅ Gradient flow test PASSED") - else: - print(" ❌ Gradient flow test FAILED") - return False - except Exception as e: - print(f" ❌ Gradient flow test FAILED: {e}") - return False - - # Test 3: Training step - try: - training_success = test_training_step() - if training_success: - print(" ✅ Training step test PASSED") - else: - print(" ❌ Training step test FAILED") - return False - except Exception as e: - print(f" ❌ Training step test FAILED: {e}") - return False - - print("\n🎉 ALL TESTS PASSED!") - print("✅ Massive DQN agent is ready for 50M parameter learning!") - return True - -if __name__ == "__main__": - success = main() - exit(0 if success else 1) \ No newline at end of file diff --git a/test_mexc_order_fix.py b/test_mexc_order_fix.py deleted file mode 100644 index 6c3f1b6..0000000 --- a/test_mexc_order_fix.py +++ /dev/null @@ -1,174 +0,0 @@ -#!/usr/bin/env python3 -""" -Test MEXC Order Fix - -Tests the fixed MEXC interface to ensure order execution works correctly -""" - -import os -import sys -import logging -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) - -logger = logging.getLogger(__name__) - -def test_mexc_order_fix(): - """Test the fixed MEXC interface""" - print("Testing Fixed MEXC Interface") - print("=" * 50) - - # Import after path setup - try: - from core.exchanges.mexc_interface import MEXCInterface - except ImportError as e: - print(f"❌ Import error: {e}") - return False - - # Get API credentials - api_key = os.getenv('MEXC_API_KEY', '') - api_secret = os.getenv('MEXC_SECRET_KEY', '') - - if not api_key or not api_secret: - print("❌ No MEXC API credentials found") - print("Set MEXC_API_KEY and MEXC_SECRET_KEY environment variables") - return False - - # Initialize MEXC interface - mexc = MEXCInterface( - api_key=api_key, - api_secret=api_secret, - test_mode=False, # Use live API (MEXC doesn't have testnet) - trading_mode='live' - ) - - # Test 1: Connection - print("\n1. Testing connection...") - if mexc.connect(): - print("✅ Connection successful") - else: - print("❌ Connection failed") - return False - - # Test 2: Account info - print("\n2. Testing account info...") - account_info = mexc.get_account_info() - if account_info: - print("✅ Account info retrieved") - print(f"Account type: {account_info.get('accountType', 'N/A')}") - else: - print("❌ Failed to get account info") - return False - - # Test 3: Balance check - print("\n3. Testing balance retrieval...") - usdc_balance = mexc.get_balance('USDC') - usdt_balance = mexc.get_balance('USDT') - print(f"USDC balance: {usdc_balance}") - print(f"USDT balance: {usdt_balance}") - - if usdc_balance <= 0 and usdt_balance <= 0: - print("❌ No USDC or USDT balance for testing") - return False - - # Test 4: Symbol support check - print("\n4. Testing symbol support...") - symbol = 'ETH/USDT' # Will be converted to ETHUSDC internally - formatted_symbol = mexc._format_spot_symbol(symbol) - print(f"Symbol {symbol} formatted to: {formatted_symbol}") - - if mexc.is_symbol_supported(symbol): - print(f"✅ Symbol {formatted_symbol} is supported") - else: - print(f"❌ Symbol {formatted_symbol} is not supported") - print("Checking supported symbols...") - supported = mexc.get_api_symbols() - print(f"Found {len(supported)} supported symbols") - if 'ETHUSDC' in supported: - print("✅ ETHUSDC is in supported list") - else: - print("❌ ETHUSDC not in supported list") - - # Test 5: Get ticker - print("\n5. Testing ticker retrieval...") - ticker = mexc.get_ticker(symbol) - if ticker: - print(f"✅ Ticker retrieved for {symbol}") - print(f"Last price: ${ticker['last']:.2f}") - print(f"Bid: ${ticker['bid']:.2f}, Ask: ${ticker['ask']:.2f}") - else: - print(f"❌ Failed to get ticker for {symbol}") - return False - - # Test 6: Small test order (only if balance available) - print("\n6. Testing small order placement...") - if usdc_balance >= 10.0: # Need at least $10 for minimum order - try: - # Calculate small test quantity - test_price = ticker['last'] * 1.01 # 1% above market for quick execution - test_quantity = round(10.0 / test_price, 5) # $10 worth - - print(f"Attempting to place test order:") - print(f"- Symbol: {symbol} -> {formatted_symbol}") - print(f"- Side: BUY") - print(f"- Type: LIMIT") - print(f"- Quantity: {test_quantity}") - print(f"- Price: ${test_price:.2f}") - - # Note: This is a real order that will use real funds! - confirm = input("⚠️ This will place a REAL order with REAL funds! Continue? (yes/no): ") - if confirm.lower() != 'yes': - print("❌ Order test skipped by user") - return True - - order_result = mexc.place_order( - symbol=symbol, - side='BUY', - order_type='LIMIT', - quantity=test_quantity, - price=test_price - ) - - if order_result: - print("✅ Order placed successfully!") - print(f"Order ID: {order_result.get('orderId')}") - print(f"Order result: {order_result}") - - # Try to cancel the order immediately - order_id = order_result.get('orderId') - if order_id: - print(f"\n7. Testing order cancellation...") - cancel_result = mexc.cancel_order(symbol, str(order_id)) - if cancel_result: - print("✅ Order cancelled successfully") - else: - print("❌ Failed to cancel order") - print("⚠️ You may have an open order to manually cancel") - else: - print("❌ Order placement failed") - return False - - except Exception as e: - print(f"❌ Order test failed with exception: {e}") - return False - else: - print(f"⚠️ Insufficient balance for order test (need $10+, have ${usdc_balance:.2f} USDC)") - print("✅ All other tests passed - order API should work when balance is sufficient") - - print("\n" + "=" * 50) - print("✅ MEXC Interface Test Completed Successfully!") - print("✅ Order execution should now work correctly") - return True - -if __name__ == "__main__": - success = test_mexc_order_fix() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/test_model_output_manager.py b/test_model_output_manager.py deleted file mode 100644 index 85f6048..0000000 --- a/test_model_output_manager.py +++ /dev/null @@ -1,182 +0,0 @@ -""" -Test script for ModelOutputManager - -This script tests the extensible model output storage functionality -""" - -import sys -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -import logging -from datetime import datetime -from core.model_output_manager import ModelOutputManager -from core.data_models import create_model_output, ModelOutput - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_model_output_manager(): - """Test the ModelOutputManager functionality""" - - print("Testing ModelOutputManager...") - - # Initialize the manager - manager = ModelOutputManager(cache_dir="test_cache/model_outputs", max_history=100) - - print(f"✅ ModelOutputManager initialized") - print(f" Supported model types: {manager.get_supported_model_types()}") - - # Test 1: Store outputs from different model types - print("\n1. Testing model output storage...") - - # Create outputs from different model types - models_to_test = [ - ('cnn', 'enhanced_cnn_v1', 'BUY', 0.85), - ('rl', 'dqn_agent_v2', 'SELL', 0.72), - ('lstm', 'lstm_predictor', 'HOLD', 0.65), - ('transformer', 'transformer_v1', 'BUY', 0.91), - ('orchestrator', 'main_orchestrator', 'BUY', 0.78) - ] - - symbol = 'ETH/USDT' - stored_outputs = [] - - for model_type, model_name, action, confidence in models_to_test: - # Create model output with hidden states for cross-model feeding - hidden_states = { - 'layer_1': [0.1, 0.2, 0.3], - 'layer_2': [0.4, 0.5, 0.6], - 'attention_weights': [0.7, 0.8, 0.9] - } if model_type in ['cnn', 'transformer'] else None - - metadata = { - 'model_version': '1.0', - 'training_iterations': 1000, - 'last_updated': datetime.now().isoformat() - } - - model_output = create_model_output( - model_type=model_type, - model_name=model_name, - symbol=symbol, - action=action, - confidence=confidence, - hidden_states=hidden_states, - metadata=metadata - ) - - # Store the output - success = manager.store_output(model_output) - if success: - print(f"✅ Stored {model_type} output: {action} ({confidence})") - stored_outputs.append(model_output) - else: - print(f"❌ Failed to store {model_type} output") - - # Test 2: Retrieve current outputs - print("\n2. Testing output retrieval...") - - all_current = manager.get_all_current_outputs(symbol) - print(f"✅ Retrieved {len(all_current)} current outputs for {symbol}") - - for model_name, output in all_current.items(): - print(f" {model_name} ({output.model_type}): {output.predictions['action']} - {output.confidence}") - - # Test 3: Cross-model feeding - print("\n3. Testing cross-model feeding...") - - cross_model_states = manager.get_cross_model_states(symbol, 'dqn_agent_v2') - print(f"✅ Retrieved cross-model states for RL model: {len(cross_model_states)} models") - - for model_name, states in cross_model_states.items(): - if states: - print(f" {model_name}: {len(states)} hidden state layers") - - # Test 4: Consensus prediction - print("\n4. Testing consensus prediction...") - - consensus = manager.get_consensus_prediction(symbol, confidence_threshold=0.7) - if consensus: - print(f"✅ Consensus prediction: {consensus['action']} (confidence: {consensus['confidence']:.3f})") - print(f" Votes: {consensus['votes']}") - print(f" Models: {consensus['model_types']}") - else: - print("⚠️ No consensus reached (insufficient high-confidence predictions)") - - # Test 5: Performance summary - print("\n5. Testing performance tracking...") - - performance = manager.get_performance_summary(symbol) - print(f"✅ Performance summary for {symbol}:") - print(f" Active models: {performance['active_models']}") - - for model_name, stats in performance['model_stats'].items(): - print(f" {model_name} ({stats['model_type']}): {stats['predictions']} predictions, " - f"avg confidence: {stats['avg_confidence']}") - - # Test 6: Custom model type support - print("\n6. Testing custom model type support...") - - # Add a custom model type - manager.add_custom_model_type('hybrid_ensemble') - - # Create output with custom model type - custom_output = create_model_output( - model_type='hybrid_ensemble', - model_name='custom_ensemble_v1', - symbol=symbol, - action='BUY', - confidence=0.88, - metadata={'ensemble_size': 5, 'voting_method': 'weighted'} - ) - - success = manager.store_output(custom_output) - if success: - print("✅ Custom model type 'hybrid_ensemble' stored successfully") - else: - print("❌ Failed to store custom model type") - - print(f" Updated supported types: {len(manager.get_supported_model_types())} types") - - # Test 7: Historical outputs - print("\n7. Testing historical output tracking...") - - # Store a few more outputs to build history - for i in range(3): - historical_output = create_model_output( - model_type='cnn', - model_name='enhanced_cnn_v1', - symbol=symbol, - action='HOLD', - confidence=0.6 + i * 0.1 - ) - manager.store_output(historical_output) - - history = manager.get_output_history(symbol, 'enhanced_cnn_v1', count=5) - print(f"✅ Retrieved {len(history)} historical outputs for enhanced_cnn_v1") - - for i, output in enumerate(history): - print(f" {i+1}. {output.predictions['action']} ({output.confidence}) at {output.timestamp}") - - # Test 8: Active model types - print("\n8. Testing active model type detection...") - - active_types = manager.get_model_types_active(symbol) - print(f"✅ Active model types for {symbol}: {active_types}") - - print("\n✅ ModelOutputManager test completed successfully!") - print("\nKey features verified:") - print("✓ Extensible model type support (CNN, RL, LSTM, Transformer, Custom)") - print("✓ Cross-model feeding with hidden states") - print("✓ Historical output tracking") - print("✓ Performance analytics") - print("✓ Consensus prediction calculation") - print("✓ Metadata management") - print("✓ Thread-safe storage operations") - - return manager - -if __name__ == "__main__": - test_model_output_manager() \ No newline at end of file diff --git a/test_model_registry.py b/test_model_registry.py deleted file mode 100644 index b3236bd..0000000 --- a/test_model_registry.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -import logging -import sys -import os - -# Add the project root to the path -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def test_model_registry(): - """Test the model registry state""" - try: - from core.orchestrator import TradingOrchestrator - from core.data_provider import DataProvider - - logger.info("Testing model registry...") - - # Initialize data provider - data_provider = DataProvider() - - # Initialize orchestrator - orchestrator = TradingOrchestrator(data_provider=data_provider) - - # Check model registry state - logger.info(f"Model registry models: {len(orchestrator.model_registry.models)}") - logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}") - - # Check if models were created - logger.info(f"RL Agent: {orchestrator.rl_agent is not None}") - logger.info(f"CNN Model: {orchestrator.cnn_model is not None}") - logger.info(f"CNN Adapter: {orchestrator.cnn_adapter is not None}") - - # Check model weights - logger.info(f"Model weights: {orchestrator.model_weights}") - - return True - - except Exception as e: - logger.error(f"Error testing model registry: {e}") - import traceback - traceback.print_exc() - return False - -if __name__ == "__main__": - success = test_model_registry() - if success: - logger.info("✅ Model registry test completed successfully") - else: - logger.error("❌ Model registry test failed") \ No newline at end of file diff --git a/test_model_statistics.py b/test_model_statistics.py deleted file mode 100644 index 2480c6c..0000000 --- a/test_model_statistics.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Model Statistics Implementation - -This script tests the new model statistics tracking functionality. -""" - -import asyncio -import time -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -async def test_model_statistics(): - """Test the model statistics tracking""" - print("=== Testing Model Statistics ===") - - # Initialize orchestrator - print("1. Initializing orchestrator...") - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider=data_provider) - - # Wait for initialization - await asyncio.sleep(2) - - # Test initial statistics - print("\n2. Initial model statistics:") - orchestrator.log_model_statistics() - - # Run some predictions to generate statistics - print("\n3. Running predictions to generate statistics...") - for i in range(5): - print(f" Running prediction batch {i+1}/5...") - predictions = await orchestrator._get_all_predictions('ETH/USDT') - print(f" Got {len(predictions)} predictions") - await asyncio.sleep(1) # Small delay between batches - - # Show updated statistics - print("\n4. Updated model statistics:") - orchestrator.log_model_statistics(detailed=True) - - # Test statistics summary - print("\n5. Statistics summary (JSON format):") - summary = orchestrator.get_model_statistics_summary() - for model_name, stats in summary.items(): - print(f" {model_name}: {stats}") - - # Test individual model statistics - print("\n6. Individual model statistics:") - for model_name in orchestrator.model_statistics.keys(): - stats = orchestrator.get_model_statistics(model_name) - if stats: - print(f" {model_name}: {stats.total_inferences} inferences, " - f"rate={stats.inference_rate_per_minute:.1f}/min") - - print("\n✅ Model statistics test completed successfully!") - -if __name__ == "__main__": - asyncio.run(test_model_statistics()) \ No newline at end of file diff --git a/test_model_stats.py b/test_model_stats.py deleted file mode 100644 index d9e1941..0000000 --- a/test_model_stats.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify model stats functionality -""" - -import sys -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -import logging -from core.orchestrator import TradingOrchestrator - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_model_stats(): - """Test the model stats functionality""" - try: - logger.info("Testing model stats functionality...") - - # Create orchestrator instance (this will initialize model states) - orchestrator = TradingOrchestrator() - - # Sync with dashboard values - orchestrator.sync_model_states_with_dashboard() - - # Get current model stats - stats = orchestrator.get_model_training_stats() - - logger.info("Current model training stats:") - for model_name, model_stats in stats.items(): - if model_stats['current_loss'] is not None: - logger.info(f" {model_name.upper()}: {model_stats['current_loss']:.4f} loss, {model_stats['improvement_pct']:.1f}% improvement") - else: - logger.info(f" {model_name.upper()}: No training data yet") - - # Test updating a model loss - orchestrator.update_model_loss('cnn', 0.0001) - logger.info("Updated CNN loss to 0.0001") - - # Get updated stats - updated_stats = orchestrator.get_model_training_stats() - cnn_stats = updated_stats['cnn'] - logger.info(f"CNN updated: {cnn_stats['current_loss']:.4f} loss, {cnn_stats['improvement_pct']:.1f}% improvement") - - return True - - except Exception as e: - logger.error(f"❌ Model stats test failed: {e}") - return False - -if __name__ == "__main__": - success = test_model_stats() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/test_model_training.py b/test_model_training.py deleted file mode 100644 index f0488e0..0000000 --- a/test_model_training.py +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Model Training Implementation - -This script tests the improved model training functionality. -""" - -import asyncio -import time -import numpy as np -from datetime import datetime -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -async def test_model_training(): - """Test the improved model training system""" - print("=== Testing Model Training System ===") - - # Initialize orchestrator - print("1. Initializing orchestrator...") - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider=data_provider) - - # Wait for initialization - await asyncio.sleep(3) - - # Show initial model statistics - print("\n2. Initial model statistics:") - orchestrator.log_model_statistics() - - # Run predictions to generate training data - print("\n3. Running predictions to generate training data...") - predictions_data = [] - - for i in range(3): - print(f" Running prediction batch {i+1}/3...") - predictions = await orchestrator._get_all_predictions('ETH/USDT') - print(f" Got {len(predictions)} predictions") - - # Store prediction data for training simulation - for pred in predictions: - predictions_data.append({ - 'model_name': pred.model_name, - 'prediction': { - 'action': pred.action, - 'confidence': pred.confidence - }, - 'timestamp': pred.timestamp, - 'symbol': 'ETH/USDT' - }) - - await asyncio.sleep(1) - - print(f"\n4. Collected {len(predictions_data)} predictions for training") - - # Simulate training with different outcomes - print("\n5. Testing training with simulated outcomes...") - - for i, pred_data in enumerate(predictions_data[:6]): # Test first 6 predictions - # Simulate market outcome - was_correct = i % 2 == 0 # Alternate between correct and incorrect - price_change_pct = 0.5 if was_correct else -0.3 - sophisticated_reward = 1.0 if was_correct else -0.5 - - # Create training record - training_record = { - 'model_name': pred_data['model_name'], - 'model_input': np.random.randn(7850), # Simulate model input - 'prediction': pred_data['prediction'], - 'symbol': pred_data['symbol'], - 'timestamp': pred_data['timestamp'] - } - - print(f" Training {pred_data['model_name']}: " - f"action={pred_data['prediction']['action']}, " - f"correct={was_correct}, reward={sophisticated_reward}") - - # Test the training method - try: - await orchestrator._train_model_on_outcome( - training_record, was_correct, price_change_pct, sophisticated_reward - ) - print(f" ✅ Training completed for {pred_data['model_name']}") - except Exception as e: - print(f" ❌ Training failed for {pred_data['model_name']}: {e}") - - # Show updated statistics - print("\n6. Updated model statistics after training:") - orchestrator.log_model_statistics(detailed=True) - - # Test specific model training methods - print("\n7. Testing specific model training methods...") - - # Test DQN training - if 'dqn_agent' in orchestrator.model_statistics: - print(" Testing DQN agent training...") - dqn_record = { - 'model_name': 'dqn_agent', - 'model_input': np.random.randn(7850), - 'prediction': {'action': 'BUY', 'confidence': 0.8}, - 'symbol': 'ETH/USDT', - 'timestamp': datetime.now() - } - try: - await orchestrator._train_model_on_outcome(dqn_record, True, 0.5, 1.0) - print(" ✅ DQN training test passed") - except Exception as e: - print(f" ❌ DQN training test failed: {e}") - - # Test CNN training - if 'enhanced_cnn' in orchestrator.model_statistics: - print(" Testing CNN model training...") - cnn_record = { - 'model_name': 'enhanced_cnn', - 'model_input': np.random.randn(7850), - 'prediction': {'action': 'SELL', 'confidence': 0.6}, - 'symbol': 'ETH/USDT', - 'timestamp': datetime.now() - } - try: - await orchestrator._train_model_on_outcome(cnn_record, False, -0.3, -0.5) - print(" ✅ CNN training test passed") - except Exception as e: - print(f" ❌ CNN training test failed: {e}") - - # Show final statistics - print("\n8. Final model statistics:") - summary = orchestrator.get_model_statistics_summary() - for model_name, stats in summary.items(): - print(f" {model_name}:") - print(f" Inferences: {stats['total_inferences']}") - print(f" Rate: {stats['inference_rate_per_minute']:.1f}/min") - print(f" Current loss: {stats['current_loss']}") - print(f" Last prediction: {stats['last_prediction']} ({stats['last_confidence']})") - - print("\n✅ Model training test completed!") - -if __name__ == "__main__": - asyncio.run(test_model_training()) \ No newline at end of file diff --git a/test_orchestrator_fix.py b/test_orchestrator_fix.py deleted file mode 100644 index 6b23b82..0000000 --- a/test_orchestrator_fix.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify orchestrator fix -""" - -import logging -import os - -# Fix OpenMP library conflicts -os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' -os.environ['OMP_NUM_THREADS'] = '4' - -# Fix matplotlib backend -import matplotlib -matplotlib.use('Agg') - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_orchestrator(): - """Test orchestrator initialization""" - try: - logger.info("Testing orchestrator initialization...") - - # Import required modules - from core.standardized_data_provider import StandardizedDataProvider - from core.orchestrator import TradingOrchestrator - - logger.info("Imports successful") - - # Create data provider - data_provider = StandardizedDataProvider() - logger.info("StandardizedDataProvider created") - - # Create orchestrator - orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True) - logger.info("TradingOrchestrator created successfully!") - - # Test basic functionality - status = orchestrator.get_queue_status() - logger.info(f"Queue status: {status}") - - logger.info("✅ Orchestrator test completed successfully!") - - except Exception as e: - logger.error(f"❌ Orchestrator test failed: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - test_orchestrator() \ No newline at end of file diff --git a/test_order_sync_and_fees.py b/test_order_sync_and_fees.py deleted file mode 100644 index 0323b3c..0000000 --- a/test_order_sync_and_fees.py +++ /dev/null @@ -1,122 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Open Order Sync and Fee Calculation -Verify that open orders are properly synchronized and fees are correctly calculated in PnL -""" - -import os -import sys -import logging - -# Add the project root to the path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -# Load environment variables -try: - from dotenv import load_dotenv - load_dotenv() -except ImportError: - if os.path.exists('.env'): - with open('.env', 'r') as f: - for line in f: - if line.strip() and not line.startswith('#'): - key, value = line.strip().split('=', 1) - os.environ[key] = value - -from core.trading_executor import TradingExecutor - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_open_order_sync_and_fees(): - """Test open order synchronization and fee calculation""" - print("🧪 Testing Open Order Sync and Fee Calculation...") - print("=" * 70) - - try: - # Create trading executor - executor = TradingExecutor() - - print(f"📊 Current State Analysis:") - print(f" Open orders count: {executor._get_open_orders_count()}") - print(f" Max open orders: {executor.max_open_orders}") - print(f" Can place new order: {executor._can_place_new_order()}") - - # Test open order synchronization - print(f"\n🔍 Open Order Sync Analysis:") - print(f" - Current sync method: _get_open_orders_count()") - print(f" - Counts orders across all symbols") - print(f" - Real-time API queries") - print(f" - Handles API errors gracefully") - - # Check if there's a dedicated sync method - if hasattr(executor, 'sync_open_orders'): - print(f" ✅ Dedicated sync method exists") - else: - print(f" ⚠️ No dedicated sync method - using count method") - - # Test fee calculation in PnL - print(f"\n💰 Fee Calculation Analysis:") - - # Check fee calculation methods - if hasattr(executor, '_calculate_trading_fee'): - print(f" ✅ Fee calculation method exists") - else: - print(f" ❌ No dedicated fee calculation method") - - # Check if fees are included in PnL - print(f"\n📈 PnL Fee Integration:") - print(f" - TradeRecord includes fees field") - print(f" - PnL calculation: pnl = gross_pnl - fees") - print(f" - Fee rates from config: taker_fee, maker_fee") - - # Check fee sync - print(f"\n🔄 Fee Synchronization:") - if hasattr(executor, 'sync_fees_with_api'): - print(f" ✅ Fee sync method exists") - else: - print(f" ❌ No fee sync method") - - # Check config sync - if hasattr(executor, 'config_sync'): - print(f" ✅ Config synchronizer exists") - else: - print(f" ❌ No config synchronizer") - - print(f"\n📋 Issues Found:") - - # Issue 1: No dedicated open order sync method - if not hasattr(executor, 'sync_open_orders'): - print(f" ❌ Missing: Dedicated open order synchronization method") - print(f" Current: Only counts orders, doesn't sync state") - - # Issue 2: Fee calculation may not be comprehensive - print(f" ⚠️ Potential: Fee calculation uses simulated rates") - print(f" Should: Use actual API fees when available") - - # Issue 3: Check if fees are properly tracked - print(f" ✅ Good: Fees are tracked in TradeRecord") - print(f" ✅ Good: PnL includes fee deduction") - - print(f"\n🔧 Recommended Fixes:") - print(f" 1. Add dedicated open order sync method") - print(f" 2. Enhance fee calculation with real API data") - print(f" 3. Add periodic order state synchronization") - print(f" 4. Improve fee tracking accuracy") - - return True - - except Exception as e: - print(f"❌ Error testing order sync and fees: {e}") - return False - -if __name__ == "__main__": - success = test_open_order_sync_and_fees() - if success: - print(f"\n🎉 Order sync and fee test completed!") - else: - print(f"\n💥 Order sync and fee test failed!") \ No newline at end of file diff --git a/test_position_based_rewards.py b/test_position_based_rewards.py deleted file mode 100644 index 316da43..0000000 --- a/test_position_based_rewards.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for position-based reward system - -This script tests the enhanced reward calculations that incentivize: -1. Holding profitable positions (let winners run) -2. Closing losing positions (cut losses) -3. Taking action when appropriate based on P&L -""" - -import sys -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.orchestrator import TradingOrchestrator -from NN.models.enhanced_cnn import EnhancedCNN -import numpy as np - -def test_position_reward_scenarios(): - """Test various position-based reward scenarios""" - - print("🧪 POSITION-BASED REWARD SYSTEM TEST") - print("=" * 50) - - # Initialize orchestrator - orchestrator = TradingOrchestrator() - - # Test scenarios - scenarios = [ - # (action, position_pnl, has_position, price_change_pct, description) - ("HOLD", 50.0, True, 0.5, "Hold profitable position with continued gains"), - ("HOLD", 50.0, True, -0.3, "Hold profitable position with small pullback"), - ("HOLD", -30.0, True, 0.8, "Hold losing position that recovers"), - ("HOLD", -30.0, True, -0.5, "Hold losing position that continues down"), - ("SELL", 50.0, True, 0.0, "Close profitable position"), - ("SELL", -30.0, True, 0.0, "Close losing position (good)"), - ("BUY", 0.0, False, 1.0, "New buy position with immediate gain"), - ("HOLD", 0.0, False, 0.1, "Hold with no position (stable market)"), - ] - - print("\n📊 SOPHISTICATED REWARD CALCULATION TESTS:") - print("-" * 80) - - for i, (action, position_pnl, has_position, price_change_pct, description) in enumerate(scenarios, 1): - # Test sophisticated reward calculation - reward, was_correct = orchestrator._calculate_sophisticated_reward( - predicted_action=action, - prediction_confidence=0.8, - price_change_pct=price_change_pct, - time_diff_minutes=5.0, - has_price_prediction=False, - symbol="ETH/USDT", - has_position=has_position, - current_position_pnl=position_pnl - ) - - print(f"{i:2d}. {description}") - print(f" Action: {action}, P&L: ${position_pnl:+.1f}, Price Change: {price_change_pct:+.1f}%") - print(f" Reward: {reward:+.3f}, Correct: {was_correct}") - print() - - print("\n🧠 CNN POSITION-ENHANCED REWARD TESTS:") - print("-" * 80) - - # Initialize CNN model - cnn_model = EnhancedCNN(input_shape=100, n_actions=3) - - for i, (action, position_pnl, has_position, _, description) in enumerate(scenarios, 1): - base_reward = 0.5 # Moderate base reward - enhanced_reward = cnn_model._calculate_position_enhanced_reward( - base_reward=base_reward, - action=action, - position_pnl=position_pnl, - has_position=has_position - ) - - enhancement = enhanced_reward - base_reward - print(f"{i:2d}. {description}") - print(f" Action: {action}, P&L: ${position_pnl:+.1f}") - print(f" Base Reward: {base_reward:+.3f} → Enhanced: {enhanced_reward:+.3f} (Δ{enhancement:+.3f})") - print() - - print("\n🤖 DQN POSITION-ENHANCED REWARD TESTS:") - print("-" * 80) - - for i, (action, position_pnl, has_position, _, description) in enumerate(scenarios, 1): - base_reward = 0.5 # Moderate base reward - enhanced_reward = orchestrator._calculate_position_enhanced_reward_for_dqn( - base_reward=base_reward, - action=action, - position_pnl=position_pnl, - has_position=has_position - ) - - enhancement = enhanced_reward - base_reward - print(f"{i:2d}. {description}") - print(f" Action: {action}, P&L: ${position_pnl:+.1f}") - print(f" Base Reward: {base_reward:+.3f} → Enhanced: {enhanced_reward:+.3f} (Δ{enhancement:+.3f})") - print() - -def test_reward_incentives(): - """Test that rewards properly incentivize desired behaviors""" - - print("\n🎯 REWARD INCENTIVE VALIDATION:") - print("-" * 50) - - orchestrator = TradingOrchestrator() - cnn_model = EnhancedCNN(input_shape=100, n_actions=3) - - # Test 1: Holding winners vs holding losers - print("1. HOLD action comparison:") - - hold_winner_reward = cnn_model._calculate_position_enhanced_reward(0.5, "HOLD", 100.0, True) - hold_loser_reward = cnn_model._calculate_position_enhanced_reward(0.5, "HOLD", -100.0, True) - - print(f" Hold profitable position (+$100): {hold_winner_reward:+.3f}") - print(f" Hold losing position (-$100): {hold_loser_reward:+.3f}") - print(f" ✅ Incentive correct: {hold_winner_reward > hold_loser_reward}") - - # Test 2: Closing losers vs closing winners - print("\n2. SELL action comparison:") - - sell_winner_reward = cnn_model._calculate_position_enhanced_reward(0.5, "SELL", 100.0, True) - sell_loser_reward = cnn_model._calculate_position_enhanced_reward(0.5, "SELL", -100.0, True) - - print(f" Sell profitable position (+$100): {sell_winner_reward:+.3f}") - print(f" Sell losing position (-$100): {sell_loser_reward:+.3f}") - print(f" ✅ Incentive correct: {sell_loser_reward > sell_winner_reward}") - - # Test 3: DQN reward scaling - print("\n3. DQN vs CNN reward scaling:") - - dqn_reward = orchestrator._calculate_position_enhanced_reward_for_dqn(0.5, "HOLD", -100.0, True) - cnn_reward = cnn_model._calculate_position_enhanced_reward(0.5, "HOLD", -100.0, True) - - print(f" DQN penalty for holding loser: {dqn_reward:+.3f}") - print(f" CNN penalty for holding loser: {cnn_reward:+.3f}") - print(f" ✅ DQN more sensitive: {abs(dqn_reward) > abs(cnn_reward)}") - -def main(): - """Run all position-based reward tests""" - try: - test_position_reward_scenarios() - test_reward_incentives() - - print("\n🚀 POSITION-BASED REWARD SYSTEM VALIDATION COMPLETE!") - print("✅ System properly incentivizes:") - print(" • Holding profitable positions (let winners run)") - print(" • Closing losing positions (cut losses)") - print(" • Taking appropriate action based on P&L") - print(" • Different reward scaling for CNN vs DQN models") - - except Exception as e: - print(f"❌ Test failed with error: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_profitability_reward_system.py b/test_profitability_reward_system.py deleted file mode 100644 index 06b53c8..0000000 --- a/test_profitability_reward_system.py +++ /dev/null @@ -1,294 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for the dynamic profitability reward system - -This script tests: -1. Fee reversion to normal 0.1% (0.001) -2. Dynamic profitability reward multiplier adjustment -3. Success rate calculation -4. Integration with dashboard display -""" - -import sys -import os -import time -from datetime import datetime, timedelta - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.trading_executor import TradingExecutor, TradeRecord -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -def test_fee_configuration(): - """Test that fees are reverted to normal 0.1%""" - print("=" * 60) - print("🧪 TESTING FEE CONFIGURATION") - print("=" * 60) - - executor = TradingExecutor() - - # Check fee configuration - expected_open_fee = 0.001 # 0.1% - expected_close_fee = 0.001 # 0.1% - expected_total_fee = 0.002 # 0.2% - - actual_open_fee = executor.trading_fees['open_fee_percent'] - actual_close_fee = executor.trading_fees['close_fee_percent'] - actual_total_fee = executor.trading_fees['total_round_trip_fee'] - - print(f"Expected Open Fee: {expected_open_fee} (0.1%)") - print(f"Actual Open Fee: {actual_open_fee} (0.1%)") - print(f"✅ Open Fee: {'PASS' if actual_open_fee == expected_open_fee else 'FAIL'}") - print() - - print(f"Expected Close Fee: {expected_close_fee} (0.1%)") - print(f"Actual Close Fee: {actual_close_fee} (0.1%)") - print(f"✅ Close Fee: {'PASS' if actual_close_fee == expected_close_fee else 'FAIL'}") - print() - - print(f"Expected Total Fee: {expected_total_fee} (0.2%)") - print(f"Actual Total Fee: {actual_total_fee} (0.2%)") - print(f"✅ Total Fee: {'PASS' if actual_total_fee == expected_total_fee else 'FAIL'}") - print() - - return actual_open_fee == expected_open_fee and actual_close_fee == expected_close_fee - -def test_profitability_multiplier_initialization(): - """Test profitability multiplier initialization""" - print("=" * 60) - print("🧪 TESTING PROFITABILITY MULTIPLIER INITIALIZATION") - print("=" * 60) - - executor = TradingExecutor() - - # Check initial values - initial_multiplier = executor.profitability_reward_multiplier - min_multiplier = executor.min_profitability_multiplier - max_multiplier = executor.max_profitability_multiplier - adjustment_step = executor.profitability_adjustment_step - - print(f"Initial Multiplier: {initial_multiplier} (should be 0.0)") - print(f"Min Multiplier: {min_multiplier} (should be 0.0)") - print(f"Max Multiplier: {max_multiplier} (should be 2.0)") - print(f"Adjustment Step: {adjustment_step} (should be 0.1)") - print() - - # Check thresholds - increase_threshold = executor.success_rate_increase_threshold - decrease_threshold = executor.success_rate_decrease_threshold - trades_window = executor.recent_trades_window - - print(f"Increase Threshold: {increase_threshold:.1%} (should be 60%)") - print(f"Decrease Threshold: {decrease_threshold:.1%} (should be 51%)") - print(f"Trades Window: {trades_window} (should be 20)") - print() - - # Test getter method - multiplier_from_getter = executor.get_profitability_reward_multiplier() - print(f"Multiplier via getter: {multiplier_from_getter}") - print(f"✅ Getter method: {'PASS' if multiplier_from_getter == initial_multiplier else 'FAIL'}") - - return (initial_multiplier == 0.0 and - min_multiplier == 0.0 and - max_multiplier == 2.0 and - adjustment_step == 0.1) - -def simulate_trades_and_test_adjustment(executor, winning_trades, total_trades): - """Simulate trades and test multiplier adjustment""" - print(f"📊 Simulating {winning_trades}/{total_trades} winning trades ({winning_trades/total_trades:.1%} success rate)") - - # Clear existing trade records - executor.trade_records = [] - - # Create simulated trade records - base_time = datetime.now() - timedelta(hours=1) - - for i in range(total_trades): - # Create winning or losing trade based on ratio - is_winning = i < winning_trades - pnl = 10.0 if is_winning else -5.0 # $10 profit or $5 loss - - trade_record = TradeRecord( - symbol="ETH/USDT", - side="LONG", - quantity=0.01, - entry_price=3000.0, - exit_price=3010.0 if is_winning else 2995.0, - entry_time=base_time + timedelta(minutes=i*2), - exit_time=base_time + timedelta(minutes=i*2+1), - pnl=pnl, - fees=2.0, - confidence=0.8, - net_pnl=pnl - 2.0 # After fees - ) - - executor.trade_records.append(trade_record) - - # Force adjustment by setting last adjustment time to past - executor.last_profitability_adjustment = datetime.now() - timedelta(minutes=10) - - # Get initial multiplier - initial_multiplier = executor.get_profitability_reward_multiplier() - - # Calculate success rate - success_rate = executor._calculate_recent_success_rate() - print(f"Calculated success rate: {success_rate:.1%}") - - # Trigger adjustment - executor._adjust_profitability_reward_multiplier() - - # Get new multiplier - new_multiplier = executor.get_profitability_reward_multiplier() - - print(f"Initial multiplier: {initial_multiplier:.1f}") - print(f"New multiplier: {new_multiplier:.1f}") - - # Determine expected change - if success_rate > executor.success_rate_increase_threshold: - expected_change = "increase" - expected_new = min(executor.max_profitability_multiplier, initial_multiplier + executor.profitability_adjustment_step) - elif success_rate < executor.success_rate_decrease_threshold: - expected_change = "decrease" - expected_new = max(executor.min_profitability_multiplier, initial_multiplier - executor.profitability_adjustment_step) - else: - expected_change = "no change" - expected_new = initial_multiplier - - print(f"Expected change: {expected_change}") - print(f"Expected new value: {expected_new:.1f}") - - success = abs(new_multiplier - expected_new) < 0.01 - print(f"✅ Adjustment: {'PASS' if success else 'FAIL'}") - print() - - return success - -def test_orchestrator_integration(): - """Test orchestrator integration with profitability multiplier""" - print("=" * 60) - print("🧪 TESTING ORCHESTRATOR INTEGRATION") - print("=" * 60) - - # Create components - data_provider = DataProvider() - executor = TradingExecutor() - orchestrator = TradingOrchestrator(data_provider=data_provider) - - # Connect executor to orchestrator - orchestrator.set_trading_executor(executor) - - # Set a test multiplier - executor.profitability_reward_multiplier = 1.5 - - # Test getting multiplier through orchestrator - multiplier = orchestrator.get_profitability_reward_multiplier() - print(f"Multiplier via orchestrator: {multiplier}") - print(f"✅ Orchestrator getter: {'PASS' if multiplier == 1.5 else 'FAIL'}") - - # Test enhanced reward calculation - base_pnl = 100.0 # $100 profit - confidence = 0.8 - - enhanced_reward = orchestrator.calculate_enhanced_reward(base_pnl, confidence) - expected_enhanced = base_pnl * (1.0 + 1.5) # 100 * 2.5 = 250 - - print(f"Base P&L: ${base_pnl:.2f}") - print(f"Enhanced reward: ${enhanced_reward:.2f}") - print(f"Expected: ${expected_enhanced:.2f}") - print(f"✅ Enhanced reward: {'PASS' if abs(enhanced_reward - expected_enhanced) < 0.01 else 'FAIL'}") - - # Test with losing trade (should not be enhanced) - losing_pnl = -50.0 - enhanced_losing = orchestrator.calculate_enhanced_reward(losing_pnl, confidence) - print(f"Losing P&L: ${losing_pnl:.2f}") - print(f"Enhanced losing: ${enhanced_losing:.2f}") - print(f"✅ No enhancement for losses: {'PASS' if enhanced_losing == losing_pnl else 'FAIL'}") - - return multiplier == 1.5 and abs(enhanced_reward - expected_enhanced) < 0.01 - -def main(): - """Run all tests""" - print("🚀 DYNAMIC PROFITABILITY REWARD SYSTEM TEST") - print("Testing fee reversion and dynamic reward adjustment") - print() - - all_tests_passed = True - - # Test 1: Fee configuration - try: - fee_test_passed = test_fee_configuration() - all_tests_passed = all_tests_passed and fee_test_passed - except Exception as e: - print(f"❌ Fee configuration test failed: {e}") - all_tests_passed = False - - # Test 2: Profitability multiplier initialization - try: - init_test_passed = test_profitability_multiplier_initialization() - all_tests_passed = all_tests_passed and init_test_passed - except Exception as e: - print(f"❌ Initialization test failed: {e}") - all_tests_passed = False - - # Test 3: Multiplier adjustment scenarios - print("=" * 60) - print("🧪 TESTING MULTIPLIER ADJUSTMENT SCENARIOS") - print("=" * 60) - - executor = TradingExecutor() - - try: - # Scenario 1: High success rate (should increase multiplier) - print("Scenario 1: High success rate (65% - should increase)") - high_success_test = simulate_trades_and_test_adjustment(executor, 13, 20) # 65% - all_tests_passed = all_tests_passed and high_success_test - - # Scenario 2: Low success rate (should decrease multiplier) - print("Scenario 2: Low success rate (45% - should decrease)") - low_success_test = simulate_trades_and_test_adjustment(executor, 9, 20) # 45% - all_tests_passed = all_tests_passed and low_success_test - - # Scenario 3: Medium success rate (should not change) - print("Scenario 3: Medium success rate (55% - should not change)") - medium_success_test = simulate_trades_and_test_adjustment(executor, 11, 20) # 55% - all_tests_passed = all_tests_passed and medium_success_test - - except Exception as e: - print(f"❌ Adjustment scenario tests failed: {e}") - all_tests_passed = False - - # Test 4: Orchestrator integration - try: - orchestrator_test_passed = test_orchestrator_integration() - all_tests_passed = all_tests_passed and orchestrator_test_passed - except Exception as e: - print(f"❌ Orchestrator integration test failed: {e}") - all_tests_passed = False - - # Final results - print("=" * 60) - print("📋 TEST RESULTS SUMMARY") - print("=" * 60) - - if all_tests_passed: - print("🎉 ALL TESTS PASSED!") - print("✅ Fees reverted to normal 0.1%") - print("✅ Dynamic profitability multiplier working") - print("✅ Success rate calculation accurate") - print("✅ Orchestrator integration functional") - print() - print("🚀 System ready for trading with dynamic profitability rewards!") - print("📈 The model will learn to prioritize more profitable trades over time") - print("🎯 Success rate >60% → increase reward multiplier") - print("⚠️ Success rate <51% → decrease reward multiplier") - else: - print("❌ SOME TESTS FAILED!") - print("Please check the error messages above and fix issues before trading.") - - return all_tests_passed - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/test_training_data_collection.py b/test_training_data_collection.py deleted file mode 100644 index c73c1a2..0000000 --- a/test_training_data_collection.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Training Data Collection and Checkpoint Storage - -This script tests if the training system is working correctly and storing checkpoints. -""" - -import os -import sys -import logging -import asyncio -from pathlib import Path -from datetime import datetime - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import get_config, setup_logging -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider -from utils.checkpoint_manager import get_checkpoint_manager - -# Setup logging -setup_logging() -logger = logging.getLogger(__name__) - -async def test_training_system(): - """Test if the training system is working and storing checkpoints""" - logger.info("Testing training system and checkpoint storage...") - - # Initialize components - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True) - - # Get checkpoint manager - checkpoint_manager = get_checkpoint_manager() - - # Check if checkpoint directory exists - checkpoint_dir = Path("models/saved") - if not checkpoint_dir.exists(): - logger.warning(f"Checkpoint directory {checkpoint_dir} does not exist. Creating...") - checkpoint_dir.mkdir(parents=True, exist_ok=True) - - # Check for existing checkpoints - checkpoint_stats = checkpoint_manager.get_checkpoint_stats() - logger.info(f"Found {checkpoint_stats['total_checkpoints']} existing checkpoints.") - logger.info(f"Total checkpoint size: {checkpoint_stats['total_size_mb']:.2f} MB") - - # List checkpoint files - checkpoint_files = list(checkpoint_dir.glob("*.pt")) - if checkpoint_files: - logger.info("Recent checkpoint files:") - for i, file in enumerate(sorted(checkpoint_files, key=lambda f: f.stat().st_mtime, reverse=True)[:5]): - file_size = file.stat().st_size / (1024 * 1024) # Convert to MB - modified_time = datetime.fromtimestamp(file.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S") - logger.info(f" {i+1}. {file.name} ({file_size:.2f} MB, modified: {modified_time})") - else: - logger.warning("No checkpoint files found.") - - # Test training by making trading decisions - logger.info("\nTesting training by making trading decisions...") - symbols = orchestrator.symbols - - for symbol in symbols: - logger.info(f"Making trading decision for {symbol}...") - decision = await orchestrator.make_trading_decision(symbol) - - if decision: - logger.info(f"Decision for {symbol}: {decision.action} (confidence: {decision.confidence:.3f})") - else: - logger.warning(f"No decision made for {symbol}.") - - # Check if new checkpoints were created - new_checkpoint_stats = checkpoint_manager.get_checkpoint_stats() - new_checkpoints = new_checkpoint_stats['total_checkpoints'] - checkpoint_stats['total_checkpoints'] - - if new_checkpoints > 0: - logger.info(f"\nSuccess! {new_checkpoints} new checkpoints were created.") - logger.info("Training system is working correctly.") - else: - logger.warning("\nNo new checkpoints were created.") - logger.warning("This could be normal if the training threshold wasn't met.") - logger.warning("Check the orchestrator's checkpoint saving logic.") - - # Check model states - model_states = orchestrator.get_model_states() - logger.info("\nModel states:") - for model_name, state in model_states.items(): - checkpoint_loaded = state.get('checkpoint_loaded', False) - checkpoint_filename = state.get('checkpoint_filename', 'none') - current_loss = state.get('current_loss', None) - - status = "LOADED" if checkpoint_loaded else "FRESH" - loss_str = f"{current_loss:.4f}" if current_loss is not None else "N/A" - - logger.info(f" {model_name}: {status}, Loss: {loss_str}, Checkpoint: {checkpoint_filename}") - - return new_checkpoints > 0 - -async def main(): - """Main function""" - logger.info("=" * 70) - logger.info("TRAINING SYSTEM TEST") - logger.info("=" * 70) - - success = await test_training_system() - - if success: - logger.info("\nTraining system test passed!") - return 0 - else: - logger.warning("\nTraining system test completed with warnings.") - logger.info("Check the logs for details.") - return 1 - -if __name__ == "__main__": - sys.exit(asyncio.run(main())) \ No newline at end of file diff --git a/test_training_fixes.py b/test_training_fixes.py deleted file mode 100644 index b3484e9..0000000 --- a/test_training_fixes.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Training Fixes - -This script tests the fixes for CNN adapter and DQN training issues. -""" - -import asyncio -import time -import numpy as np -from datetime import datetime -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -async def test_training_fixes(): - """Test the training fixes""" - print("=== Testing Training Fixes ===") - - # Initialize orchestrator - print("1. Initializing orchestrator...") - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider=data_provider) - - # Wait for initialization - await asyncio.sleep(3) - - # Check CNN adapter initialization - print("\n2. Checking CNN adapter initialization:") - if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter: - print(" ✅ CNN adapter is properly initialized") - print(f" CNN adapter type: {type(orchestrator.cnn_adapter)}") - else: - print(" ❌ CNN adapter is None or missing") - - # Check DQN agent initialization - print("\n3. Checking DQN agent initialization:") - if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent: - print(" ✅ DQN agent is properly initialized") - print(f" DQN agent type: {type(orchestrator.rl_agent)}") - if hasattr(orchestrator.rl_agent, 'policy_net'): - print(" ✅ DQN policy network is available") - else: - print(" ❌ DQN policy network is missing") - else: - print(" ❌ DQN agent is None or missing") - - # Test CNN predictions - print("\n4. Testing CNN predictions:") - try: - predictions = await orchestrator._get_all_predictions('ETH/USDT') - cnn_predictions = [p for p in predictions if 'cnn' in p.model_name.lower()] - if cnn_predictions: - print(f" ✅ Got {len(cnn_predictions)} CNN predictions") - for pred in cnn_predictions: - print(f" CNN prediction: {pred.action} (confidence: {pred.confidence:.3f})") - else: - print(" ❌ No CNN predictions received") - except Exception as e: - print(f" ❌ CNN prediction failed: {e}") - - # Test training with validation - print("\n5. Testing training with validation:") - for i in range(3): - print(f" Training iteration {i+1}/3...") - - # Create training records for different models - training_records = [ - { - 'model_name': 'enhanced_cnn', - 'model_input': np.random.randn(7850), - 'prediction': {'action': 'BUY', 'confidence': 0.7}, - 'symbol': 'ETH/USDT', - 'timestamp': datetime.now() - }, - { - 'model_name': 'dqn_agent', - 'model_input': np.random.randn(7850), - 'prediction': {'action': 'SELL', 'confidence': 0.8}, - 'symbol': 'ETH/USDT', - 'timestamp': datetime.now() - } - ] - - for record in training_records: - try: - success = await orchestrator._train_model_on_outcome( - record, True, 0.5, 1.0 - ) - if success: - print(f" ✅ Training succeeded for {record['model_name']}") - else: - print(f" ⚠️ Training failed for {record['model_name']}") - except Exception as e: - print(f" ❌ Training error for {record['model_name']}: {e}") - - await asyncio.sleep(1) - - # Show final statistics - print("\n6. Final model statistics:") - orchestrator.log_model_statistics(detailed=True) - - # Check for overfitting warnings - print("\n7. Checking for training quality:") - summary = orchestrator.get_model_statistics_summary() - for model_name, stats in summary.items(): - if stats['total_trainings'] > 0: - print(f" {model_name}: {stats['total_trainings']} trainings, " - f"avg time: {stats['average_training_time_ms']:.1f}ms") - if stats['current_loss'] is not None: - if stats['current_loss'] < 0.001: - print(f" ⚠️ {model_name} has very low loss ({stats['current_loss']:.6f}) - check for overfitting") - else: - print(f" ✅ {model_name} has reasonable loss ({stats['current_loss']:.6f})") - - print("\n✅ Training fixes test completed!") - -if __name__ == "__main__": - asyncio.run(test_training_fixes()) \ No newline at end of file diff --git a/test_websocket_cob_data.py b/test_websocket_cob_data.py deleted file mode 100644 index bb51691..0000000 --- a/test_websocket_cob_data.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to check if we're getting real COB data from WebSocket -""" - -import time -import logging -from core.data_provider import DataProvider - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_websocket_cob_data(): - """Test if we're getting real COB data from WebSocket""" - logger.info("Testing WebSocket COB data reception...") - - # Initialize data provider - dp = DataProvider() - - # Wait for WebSocket connections - logger.info("Waiting for WebSocket connections...") - time.sleep(15) - - # Check WebSocket status - logger.info("\n=== WebSocket Status ===") - try: - if hasattr(dp, 'enhanced_cob_websocket') and dp.enhanced_cob_websocket: - status = dp.enhanced_cob_websocket.get_status_summary() - logger.info(f"WebSocket status: {status}") - else: - logger.warning("Enhanced COB WebSocket not available") - except Exception as e: - logger.error(f"Error getting WebSocket status: {e}") - - # Check if we have any COB WebSocket data - logger.info("\n=== COB WebSocket Data Check ===") - if hasattr(dp, 'cob_websocket_data'): - for symbol in ['ETH/USDT', 'BTC/USDT']: - if symbol in dp.cob_websocket_data: - data = dp.cob_websocket_data[symbol] - logger.info(f"{symbol}: {type(data)} - {len(str(data))} chars") - if isinstance(data, dict): - logger.info(f" Keys: {list(data.keys())}") - if 'bids' in data: - logger.info(f" Bids: {len(data['bids'])} levels") - if 'asks' in data: - logger.info(f" Asks: {len(data['asks'])} levels") - else: - logger.info(f"{symbol}: No WebSocket data") - else: - logger.warning("No cob_websocket_data attribute found") - - # Check raw COB ticks - logger.info("\n=== Raw COB Ticks ===") - for symbol in ['ETH/USDT', 'BTC/USDT']: - if hasattr(dp, 'cob_raw_ticks') and symbol in dp.cob_raw_ticks: - raw_ticks = list(dp.cob_raw_ticks[symbol]) - logger.info(f"{symbol}: {len(raw_ticks)} raw ticks") - if raw_ticks: - latest = raw_ticks[-1] - logger.info(f" Latest tick keys: {list(latest.keys())}") - if 'timestamp' in latest: - logger.info(f" Latest timestamp: {latest['timestamp']}") - else: - logger.info(f"{symbol}: No raw ticks") - - # Monitor for 30 seconds to see if data comes in - logger.info("\n=== Monitoring for 30 seconds ===") - initial_counts = {} - for symbol in ['ETH/USDT', 'BTC/USDT']: - if hasattr(dp, 'cob_raw_ticks') and symbol in dp.cob_raw_ticks: - initial_counts[symbol] = len(dp.cob_raw_ticks[symbol]) - else: - initial_counts[symbol] = 0 - - time.sleep(30) - - logger.info("After 30 seconds:") - for symbol in ['ETH/USDT', 'BTC/USDT']: - if hasattr(dp, 'cob_raw_ticks') and symbol in dp.cob_raw_ticks: - current_count = len(dp.cob_raw_ticks[symbol]) - new_ticks = current_count - initial_counts[symbol] - logger.info(f"{symbol}: +{new_ticks} new ticks (total: {current_count})") - else: - logger.info(f"{symbol}: No raw ticks available") - - # Check if Enhanced WebSocket has latest data - logger.info("\n=== Enhanced WebSocket Latest Data ===") - try: - if hasattr(dp, 'enhanced_cob_websocket') and dp.enhanced_cob_websocket: - for symbol in ['ETH/USDT', 'BTC/USDT']: - if hasattr(dp.enhanced_cob_websocket, 'latest_cob_data'): - latest_data = dp.enhanced_cob_websocket.latest_cob_data.get(symbol) - if latest_data: - logger.info(f"{symbol}: Latest WebSocket data available") - logger.info(f" Keys: {list(latest_data.keys())}") - if 'bids' in latest_data and 'asks' in latest_data: - logger.info(f" Bids: {len(latest_data['bids'])}, Asks: {len(latest_data['asks'])}") - else: - logger.info(f"{symbol}: No latest WebSocket data") - except Exception as e: - logger.error(f"Error checking Enhanced WebSocket data: {e}") - - # Clean shutdown - logger.info("\n=== Shutting Down ===") - dp.stop_automatic_data_maintenance() - logger.info("WebSocket COB data test completed") - -if __name__ == "__main__": - test_websocket_cob_data() \ No newline at end of file diff --git a/tests/cob/test_cob_comparison.py b/tests/cob/test_cob_comparison.py deleted file mode 100644 index 1215dd1..0000000 --- a/tests/cob/test_cob_comparison.py +++ /dev/null @@ -1,276 +0,0 @@ -#!/usr/bin/env python3 -""" -Compare COB data quality between DataProvider and COBIntegration - -This test compares: -1. DataProvider COB collection (used in our test) -2. COBIntegration direct access (used in cob_realtime_dashboard.py) - -To understand why cob_realtime_dashboard.py gets more stable data. -""" - -import asyncio -import logging -import time -from collections import deque -from datetime import datetime, timedelta - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from core.data_provider import DataProvider, MarketTick -from core.config import get_config - -# Try to import COBIntegration like cob_realtime_dashboard does -try: - from core.cob_integration import COBIntegration - COB_INTEGRATION_AVAILABLE = True -except ImportError: - COB_INTEGRATION_AVAILABLE = False - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - - -class COBComparisonTester: - def __init__(self, symbol='ETH/USDT', duration_seconds=15): - self.symbol = symbol - self.duration = timedelta(seconds=duration_seconds) - - # Data storage for both methods - self.dp_ticks = deque() # DataProvider ticks - self.cob_data = deque() # COBIntegration data - - # Initialize DataProvider (method 1) - logger.info("Initializing DataProvider...") - self.data_provider = DataProvider() - self.dp_cob_received = 0 - - # Initialize COBIntegration (method 2) - self.cob_integration = None - self.cob_received = 0 - if COB_INTEGRATION_AVAILABLE: - logger.info("Initializing COBIntegration...") - self.cob_integration = COBIntegration(symbols=[self.symbol]) - else: - logger.warning("COBIntegration not available - will only test DataProvider") - - self.start_time = None - self.subscriber_id = None - - def _dp_cob_callback(self, symbol: str, cob_data: dict): - """Callback for DataProvider COB data""" - self.dp_cob_received += 1 - - if 'stats' in cob_data and 'mid_price' in cob_data['stats']: - mid_price = cob_data['stats']['mid_price'] - if mid_price > 0: - synthetic_tick = MarketTick( - symbol=symbol, - timestamp=cob_data.get('timestamp', datetime.now()), - price=mid_price, - volume=cob_data.get('stats', {}).get('total_volume', 0), - quantity=0, - side='dp_cob', - trade_id=f"dp_{self.dp_cob_received}", - is_buyer_maker=False, - raw_data=cob_data - ) - self.dp_ticks.append(synthetic_tick) - - if self.dp_cob_received % 20 == 0: - logger.info(f"[DataProvider] Update #{self.dp_cob_received}: {symbol} @ ${mid_price:.2f}") - - def _cob_integration_callback(self, symbol: str, data: dict): - """Callback for COBIntegration data""" - self.cob_received += 1 - - # Store COBIntegration data directly - cob_record = { - 'symbol': symbol, - 'timestamp': datetime.now(), - 'data': data, - 'source': 'cob_integration' - } - self.cob_data.append(cob_record) - - if self.cob_received % 20 == 0: - stats = data.get('stats', {}) - mid_price = stats.get('mid_price', 0) - logger.info(f"[COBIntegration] Update #{self.cob_received}: {symbol} @ ${mid_price:.2f}") - - async def run_comparison_test(self): - """Run the comparison test""" - logger.info(f"Starting COB comparison test for {self.symbol} for {self.duration.total_seconds()} seconds...") - - # Start DataProvider COB collection - try: - logger.info("Starting DataProvider COB collection...") - self.data_provider.start_cob_collection() - self.data_provider.subscribe_to_cob(self._dp_cob_callback) - await self.data_provider.start_real_time_streaming() - logger.info("DataProvider streaming started") - except Exception as e: - logger.error(f"Failed to start DataProvider: {e}") - - # Start COBIntegration if available - if self.cob_integration: - try: - logger.info("Starting COBIntegration...") - self.cob_integration.add_dashboard_callback(self._cob_integration_callback) - await self.cob_integration.start() - logger.info("COBIntegration started") - except Exception as e: - logger.error(f"Failed to start COBIntegration: {e}") - - # Collect data for specified duration - self.start_time = datetime.now() - while datetime.now() - self.start_time < self.duration: - await asyncio.sleep(1) - logger.info(f"DataProvider: {len(self.dp_ticks)} ticks | COBIntegration: {len(self.cob_data)} updates") - - # Stop data collection - try: - await self.data_provider.stop_real_time_streaming() - if self.cob_integration: - await self.cob_integration.stop() - except Exception as e: - logger.error(f"Error stopping data collection: {e}") - - logger.info(f"Comparison complete:") - logger.info(f" DataProvider: {len(self.dp_ticks)} ticks received") - logger.info(f" COBIntegration: {len(self.cob_data)} updates received") - - # Analyze and plot the differences - self.analyze_differences() - self.create_comparison_plots() - - def analyze_differences(self): - """Analyze the differences between the two data sources""" - logger.info("Analyzing data quality differences...") - - # Analyze DataProvider data - dp_order_book_count = 0 - dp_mid_prices = [] - - for tick in self.dp_ticks: - if hasattr(tick, 'raw_data') and tick.raw_data: - if 'bids' in tick.raw_data and 'asks' in tick.raw_data: - dp_order_book_count += 1 - if 'stats' in tick.raw_data and 'mid_price' in tick.raw_data['stats']: - dp_mid_prices.append(tick.raw_data['stats']['mid_price']) - - # Analyze COBIntegration data - cob_order_book_count = 0 - cob_mid_prices = [] - - for record in self.cob_data: - data = record['data'] - if 'bids' in data and 'asks' in data: - cob_order_book_count += 1 - if 'stats' in data and 'mid_price' in data['stats']: - cob_mid_prices.append(data['stats']['mid_price']) - - logger.info("Data Quality Analysis:") - logger.info(f" DataProvider:") - logger.info(f" Total updates: {len(self.dp_ticks)}") - logger.info(f" With order book data: {dp_order_book_count}") - logger.info(f" Mid prices collected: {len(dp_mid_prices)}") - if dp_mid_prices: - logger.info(f" Price range: ${min(dp_mid_prices):.2f} - ${max(dp_mid_prices):.2f}") - - logger.info(f" COBIntegration:") - logger.info(f" Total updates: {len(self.cob_data)}") - logger.info(f" With order book data: {cob_order_book_count}") - logger.info(f" Mid prices collected: {len(cob_mid_prices)}") - if cob_mid_prices: - logger.info(f" Price range: ${min(cob_mid_prices):.2f} - ${max(cob_mid_prices):.2f}") - - def create_comparison_plots(self): - """Create comparison plots showing the difference""" - logger.info("Creating comparison plots...") - - fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12)) - - # Plot 1: Price comparison - dp_times = [] - dp_prices = [] - for tick in self.dp_ticks: - if tick.price > 0: - dp_times.append(tick.timestamp) - dp_prices.append(tick.price) - - cob_times = [] - cob_prices = [] - for record in self.cob_data: - data = record['data'] - if 'stats' in data and 'mid_price' in data['stats']: - cob_times.append(record['timestamp']) - cob_prices.append(data['stats']['mid_price']) - - if dp_times: - ax1.plot(pd.to_datetime(dp_times), dp_prices, 'b-', alpha=0.7, label='DataProvider COB', linewidth=1) - if cob_times: - ax1.plot(pd.to_datetime(cob_times), cob_prices, 'r-', alpha=0.7, label='COBIntegration', linewidth=1) - - ax1.set_title('Price Comparison: DataProvider vs COBIntegration') - ax1.set_ylabel('Price (USDT)') - ax1.legend() - ax1.grid(True, alpha=0.3) - - # Plot 2: Data quality comparison (order book depth) - dp_bid_counts = [] - dp_ask_counts = [] - dp_ob_times = [] - - for tick in self.dp_ticks: - if hasattr(tick, 'raw_data') and tick.raw_data: - if 'bids' in tick.raw_data and 'asks' in tick.raw_data: - dp_bid_counts.append(len(tick.raw_data['bids'])) - dp_ask_counts.append(len(tick.raw_data['asks'])) - dp_ob_times.append(tick.timestamp) - - cob_bid_counts = [] - cob_ask_counts = [] - cob_ob_times = [] - - for record in self.cob_data: - data = record['data'] - if 'bids' in data and 'asks' in data: - cob_bid_counts.append(len(data['bids'])) - cob_ask_counts.append(len(data['asks'])) - cob_ob_times.append(record['timestamp']) - - if dp_ob_times: - ax2.plot(pd.to_datetime(dp_ob_times), dp_bid_counts, 'b--', alpha=0.7, label='DP Bid Levels') - ax2.plot(pd.to_datetime(dp_ob_times), dp_ask_counts, 'b:', alpha=0.7, label='DP Ask Levels') - if cob_ob_times: - ax2.plot(pd.to_datetime(cob_ob_times), cob_bid_counts, 'r--', alpha=0.7, label='COB Bid Levels') - ax2.plot(pd.to_datetime(cob_ob_times), cob_ask_counts, 'r:', alpha=0.7, label='COB Ask Levels') - - ax2.set_title('Order Book Depth Comparison') - ax2.set_ylabel('Number of Levels') - ax2.set_xlabel('Time') - ax2.legend() - ax2.grid(True, alpha=0.3) - - plt.tight_layout() - - plot_filename = f"cob_comparison_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png" - plt.savefig(plot_filename, dpi=150) - logger.info(f"Comparison plot saved to {plot_filename}") - plt.show() - - -async def main(): - tester = COBComparisonTester() - await tester.run_comparison_test() - - -if __name__ == "__main__": - try: - asyncio.run(main()) - except KeyboardInterrupt: - logger.info("Test interrupted by user.") diff --git a/tests/cob/test_cob_data_stability.py b/tests/cob/test_cob_data_stability.py deleted file mode 100644 index a736146..0000000 --- a/tests/cob/test_cob_data_stability.py +++ /dev/null @@ -1,502 +0,0 @@ -import asyncio -import logging -import time -from collections import deque -from datetime import datetime, timedelta - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -from matplotlib.colors import LogNorm - -from core.data_provider import DataProvider, MarketTick -from core.config import get_config - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - - -class COBStabilityTester: - def __init__(self, symbol='ETHUSDT', duration_seconds=10): - self.symbol = symbol - self.duration = timedelta(seconds=duration_seconds) - self.ticks = deque() - - # Set granularity (buckets) based on symbol - if 'ETH' in symbol.upper(): - self.price_granularity = 1.0 # 1 USD for ETH - elif 'BTC' in symbol.upper(): - self.price_granularity = 10.0 # 10 USD for BTC - else: - self.price_granularity = 1.0 # Default 1 USD - - logger.info(f"Using price granularity: ${self.price_granularity} for {symbol}") - - # Initialize DataProvider the same way as clean_dashboard - logger.info("Initializing DataProvider like in clean_dashboard...") - self.data_provider = DataProvider() # Use default constructor like clean_dashboard - - # Initialize COB data collection like clean_dashboard does - self.cob_data_received = 0 - self.latest_cob_data = {} - - # Store all COB snapshots for heatmap generation - self.cob_snapshots = deque() - self.price_data = [] # For price line chart - - self.start_time = None - self.subscriber_id = None - self.last_log_time = None - - def _tick_callback(self, tick: MarketTick): - """Callback function to receive ticks from the DataProvider.""" - if self.start_time is None: - self.start_time = datetime.now() - logger.info(f"Started collecting ticks at {self.start_time}") - - # Store all ticks - self.ticks.append(tick) - - def _cob_data_callback(self, symbol: str, cob_data: dict): - """Callback function to receive COB data from the DataProvider.""" - # Debug: Log first few callbacks to see what symbols we're getting - if self.cob_data_received < 5: - logger.info(f"DEBUG: Received COB data for symbol '{symbol}' (target: '{self.symbol}')") - - # Filter to only our requested symbol - handle both formats (ETH/USDT and ETHUSDT) - normalized_symbol = symbol.replace('/', '') - normalized_target = self.symbol.replace('/', '') - if normalized_symbol != normalized_target: - if self.cob_data_received < 5: - logger.info(f"DEBUG: Skipping symbol '{symbol}' (normalized: '{normalized_symbol}' vs target: '{normalized_target}')") - return - - self.cob_data_received += 1 - self.latest_cob_data[symbol] = cob_data - - # Store the complete COB snapshot for heatmap generation - if 'bids' in cob_data and 'asks' in cob_data: - # Debug: Log structure of first few COB snapshots - if len(self.cob_snapshots) < 3: - logger.info(f"DEBUG: COB data structure - bids: {len(cob_data['bids'])} items, asks: {len(cob_data['asks'])} items") - if cob_data['bids']: - logger.info(f"DEBUG: First bid: {cob_data['bids'][0]}") - if cob_data['asks']: - logger.info(f"DEBUG: First ask: {cob_data['asks'][0]}") - - # Use current time for timestamp consistency - current_time = datetime.now() - snapshot = { - 'timestamp': current_time, - 'bids': cob_data['bids'], - 'asks': cob_data['asks'], - 'stats': cob_data.get('stats', {}) - } - self.cob_snapshots.append(snapshot) - - # Log bucketed COB data every second - now = datetime.now() - if self.last_log_time is None or (now - self.last_log_time).total_seconds() >= 1.0: - self.last_log_time = now - self._log_bucketed_cob_data(cob_data) - - # Convert COB data to tick-like format for analysis - if 'stats' in cob_data and 'mid_price' in cob_data['stats']: - mid_price = cob_data['stats']['mid_price'] - if mid_price > 0: - # Filter out extreme price movements (±10% of recent average) - if len(self.price_data) > 5: - recent_prices = [p['price'] for p in self.price_data[-5:]] - avg_recent_price = sum(recent_prices) / len(recent_prices) - price_deviation = abs(mid_price - avg_recent_price) / avg_recent_price - - if price_deviation > 0.10: # More than 10% deviation - logger.warning(f"Filtering out extreme price: ${mid_price:.2f} (deviation: {price_deviation:.1%} from avg ${avg_recent_price:.2f})") - return # Skip this data point - - # Store price data for line chart with consistent timestamp - current_time = datetime.now() - self.price_data.append({ - 'timestamp': current_time, - 'price': mid_price - }) - - # Create a synthetic tick from COB data with consistent timestamp - current_time = datetime.now() - synthetic_tick = MarketTick( - symbol=symbol, - timestamp=current_time, - price=mid_price, - volume=cob_data.get('stats', {}).get('total_volume', 0), - quantity=0, # Not available in COB data - side='unknown', # COB data doesn't have side info - trade_id=f"cob_{self.cob_data_received}", - is_buyer_maker=False, - raw_data=cob_data - ) - self.ticks.append(synthetic_tick) - - if self.cob_data_received % 10 == 0: # Log every 10th update - logger.info(f"COB update #{self.cob_data_received}: {symbol} @ ${mid_price:.2f}") - - def _log_bucketed_cob_data(self, cob_data: dict): - """Log bucketed COB data every second""" - try: - if 'bids' not in cob_data or 'asks' not in cob_data: - logger.info("COB-1s: No order book data available") - return - - if 'stats' not in cob_data or 'mid_price' not in cob_data['stats']: - logger.info("COB-1s: No mid price available") - return - - mid_price = cob_data['stats']['mid_price'] - if mid_price <= 0: - return - - # Bucket the order book data - bid_buckets = {} - ask_buckets = {} - - # Process bids (top 10) - for bid in cob_data['bids'][:10]: - try: - if isinstance(bid, dict): - price = float(bid['price']) - size = float(bid['size']) - elif isinstance(bid, (list, tuple)) and len(bid) >= 2: - price = float(bid[0]) - size = float(bid[1]) - else: - continue - - bucketed_price = round(price / self.price_granularity) * self.price_granularity - bid_buckets[bucketed_price] = bid_buckets.get(bucketed_price, 0) + size - except (ValueError, TypeError, IndexError): - continue - - # Process asks (top 10) - for ask in cob_data['asks'][:10]: - try: - if isinstance(ask, dict): - price = float(ask['price']) - size = float(ask['size']) - elif isinstance(ask, (list, tuple)) and len(ask) >= 2: - price = float(ask[0]) - size = float(ask[1]) - else: - continue - - bucketed_price = round(price / self.price_granularity) * self.price_granularity - ask_buckets[bucketed_price] = ask_buckets.get(bucketed_price, 0) + size - except (ValueError, TypeError, IndexError): - continue - - # Format for log output - bid_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(bid_buckets.items(), reverse=True)]) - ask_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(ask_buckets.items())]) - - logger.info(f"COB-1s @ ${mid_price:.2f} | BIDS: {bid_str} | ASKS: {ask_str}") - - except Exception as e: - logger.warning(f"Error logging bucketed COB data: {e}") - - async def run_test(self): - """Run the data collection and plotting test.""" - logger.info(f"Starting COB stability test for {self.symbol} for {self.duration.total_seconds()} seconds...") - - # Initialize COB collection like clean_dashboard does - try: - logger.info("Starting COB collection in data provider...") - self.data_provider.start_cob_collection() - logger.info("Started COB collection in data provider") - - # Subscribe to COB updates - logger.info("Subscribing to COB data updates...") - self.data_provider.subscribe_to_cob(self._cob_data_callback) - logger.info("Subscribed to COB data updates from data provider") - except Exception as e: - logger.error(f"Failed to start COB collection or subscribe: {e}") - - # Subscribe to ticks as fallback - try: - self.subscriber_id = self.data_provider.subscribe_to_ticks(self._tick_callback, symbols=[self.symbol]) - logger.info("Subscribed to tick data as fallback") - except Exception as e: - logger.warning(f"Failed to subscribe to ticks: {e}") - - # Start the data provider's real-time streaming - try: - await self.data_provider.start_real_time_streaming() - logger.info("Started real-time streaming") - except Exception as e: - logger.error(f"Failed to start real-time streaming: {e}") - - # Collect data for the specified duration - self.start_time = datetime.now() - while datetime.now() - self.start_time < self.duration: - await asyncio.sleep(1) - logger.info(f"Collected {len(self.ticks)} ticks so far...") - - # Stop streaming and unsubscribe - await self.data_provider.stop_real_time_streaming() - self.data_provider.unsubscribe_from_ticks(self.subscriber_id) - - logger.info(f"Finished collecting data. Total ticks: {len(self.ticks)}") - - # Plot the results - if self.price_data and self.cob_snapshots: - self.create_price_heatmap_chart() - elif self.ticks: - self._create_simple_price_chart() - else: - logger.warning("No data was collected. Cannot generate plot.") - - def create_price_heatmap_chart(self): - """Create a visualization with price chart and order book scatter plot.""" - if not self.price_data or not self.cob_snapshots: - logger.warning("Insufficient data to plot.") - return - - logger.info(f"Creating price and order book chart...") - logger.info(f"Data summary: {len(self.price_data)} price points, {len(self.cob_snapshots)} COB snapshots") - - # Prepare price data - price_df = pd.DataFrame(self.price_data) - price_df['timestamp'] = pd.to_datetime(price_df['timestamp']) - - logger.info(f"Price data time range: {price_df['timestamp'].min()} to {price_df['timestamp'].max()}") - logger.info(f"Price range: ${price_df['price'].min():.2f} to ${price_df['price'].max():.2f}") - - # Create figure with subplots - fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 12), height_ratios=[3, 2]) - - # Top plot: Price chart with order book levels - ax1.plot(price_df['timestamp'], price_df['price'], 'yellow', linewidth=2, label='Mid Price', zorder=10) - - # Plot order book levels as scatter points - bid_times, bid_prices, bid_sizes = [], [], [] - ask_times, ask_prices, ask_sizes = [], [], [] - - # Calculate average price for filtering - avg_price = price_df['price'].mean() if not price_df.empty else 3500 # Fallback price - price_lower = avg_price * 0.9 # -10% - price_upper = avg_price * 1.1 # +10% - - logger.info(f"Filtering order book data to price range: ${price_lower:.2f} - ${price_upper:.2f} (±10% of ${avg_price:.2f})") - - for snapshot in list(self.cob_snapshots)[-50:]: # Use last 50 snapshots for clarity - timestamp = pd.to_datetime(snapshot['timestamp']) - - # Process bids (top 10) - for order in snapshot.get('bids', [])[:10]: - try: - if isinstance(order, dict): - price = float(order['price']) - size = float(order['size']) - elif isinstance(order, (list, tuple)) and len(order) >= 2: - price = float(order[0]) - size = float(order[1]) - else: - continue - - # Filter out prices outside ±10% range - if price < price_lower or price > price_upper: - continue - - bid_times.append(timestamp) - bid_prices.append(price) - bid_sizes.append(size) - except (ValueError, TypeError, IndexError): - continue - - # Process asks (top 10) - for order in snapshot.get('asks', [])[:10]: - try: - if isinstance(order, dict): - price = float(order['price']) - size = float(order['size']) - elif isinstance(order, (list, tuple)) and len(order) >= 2: - price = float(order[0]) - size = float(order[1]) - else: - continue - - # Filter out prices outside ±10% range - if price < price_lower or price > price_upper: - continue - - ask_times.append(timestamp) - ask_prices.append(price) - ask_sizes.append(size) - except (ValueError, TypeError, IndexError): - continue - - # Plot order book data as scatter with size indicating volume - if bid_times: - bid_sizes_normalized = np.array(bid_sizes) * 3 # Scale for visibility - ax1.scatter(bid_times, bid_prices, s=bid_sizes_normalized, c='green', alpha=0.3, label='Bids') - logger.info(f"Plotted {len(bid_times)} bid levels") - - if ask_times: - ask_sizes_normalized = np.array(ask_sizes) * 3 # Scale for visibility - ax1.scatter(ask_times, ask_prices, s=ask_sizes_normalized, c='red', alpha=0.3, label='Asks') - logger.info(f"Plotted {len(ask_times)} ask levels") - - ax1.set_title(f'Real-time Price and Order Book - {self.symbol}\nGranularity: ${self.price_granularity} | Duration: {self.duration.total_seconds()}s') - ax1.set_ylabel('Price (USDT)') - ax1.legend() - ax1.grid(True, alpha=0.3) - - # Set proper time range (X-axis) - use actual data collection period - time_min = price_df['timestamp'].min() - time_max = price_df['timestamp'].max() - actual_duration = (time_max - time_min).total_seconds() - logger.info(f"Actual data collection duration: {actual_duration:.1f} seconds") - - ax1.set_xlim(time_min, time_max) - - # Set tight price range (Y-axis) - use ±2% of price range for better visibility - price_min = price_df['price'].min() - price_max = price_df['price'].max() - price_center = (price_min + price_max) / 2 - price_range = price_max - price_min - - # If price range is very small, use a minimum range of $5 - if price_range < 5: - price_range = 5 - - # Add 20% padding to the price range for better visualization - y_padding = price_range * 0.2 - y_min = price_min - y_padding - y_max = price_max + y_padding - - ax1.set_ylim(y_min, y_max) - logger.info(f"Chart Y-axis range: ${y_min:.2f} - ${y_max:.2f} (center: ${price_center:.2f}, range: ${price_range:.2f})") - - # Bottom plot: Order book depth over time (aggregated) - time_buckets = [] - bid_depths = [] - ask_depths = [] - - # Create time buckets (every few snapshots) - snapshots_list = list(self.cob_snapshots) - bucket_size = max(1, len(snapshots_list) // 20) # ~20 buckets - for i in range(0, len(snapshots_list), bucket_size): - bucket_snapshots = snapshots_list[i:i+bucket_size] - if not bucket_snapshots: - continue - - # Use middle timestamp of bucket - mid_snapshot = bucket_snapshots[len(bucket_snapshots)//2] - time_buckets.append(pd.to_datetime(mid_snapshot['timestamp'])) - - # Calculate average depths - total_bid_depth = 0 - total_ask_depth = 0 - snapshot_count = 0 - - for snapshot in bucket_snapshots: - bid_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0)) - for order in snapshot.get('bids', [])[:10]]) - ask_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0)) - for order in snapshot.get('asks', [])[:10]]) - total_bid_depth += bid_depth - total_ask_depth += ask_depth - snapshot_count += 1 - - if snapshot_count > 0: - bid_depths.append(total_bid_depth / snapshot_count) - ask_depths.append(total_ask_depth / snapshot_count) - else: - bid_depths.append(0) - ask_depths.append(0) - - if time_buckets: - ax2.plot(time_buckets, bid_depths, 'green', linewidth=2, label='Bid Depth', alpha=0.7) - ax2.plot(time_buckets, ask_depths, 'red', linewidth=2, label='Ask Depth', alpha=0.7) - ax2.fill_between(time_buckets, bid_depths, alpha=0.3, color='green') - ax2.fill_between(time_buckets, ask_depths, alpha=0.3, color='red') - - ax2.set_title('Order Book Depth Over Time') - ax2.set_xlabel('Time') - ax2.set_ylabel('Depth (Volume)') - ax2.legend() - ax2.grid(True, alpha=0.3) - - # Set same time range for bottom chart - ax2.set_xlim(time_min, time_max) - - # Format time axes - fig.autofmt_xdate() - plt.tight_layout() - - plot_filename = f"price_heatmap_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png" - plt.savefig(plot_filename, dpi=150, bbox_inches='tight') - logger.info(f"Price and order book chart saved to {plot_filename}") - plt.show() - - def _create_simple_price_chart(self): - """Create a simple price chart as fallback""" - logger.info("Creating simple price chart as fallback...") - - prices = [] - times = [] - - for tick in self.ticks: - if tick.price > 0: - prices.append(tick.price) - times.append(tick.timestamp) - - if not prices: - logger.warning("No price data to plot") - return - - fig, ax = plt.subplots(figsize=(15, 8)) - ax.plot(pd.to_datetime(times), prices, 'cyan', linewidth=1) - ax.set_title(f'Price Chart - {self.symbol}') - ax.set_xlabel('Time') - ax.set_ylabel('Price (USDT)') - fig.autofmt_xdate() - - plot_filename = f"cob_price_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png" - plt.savefig(plot_filename) - logger.info(f"Price chart saved to {plot_filename}") - plt.show() - - -async def main(symbol='ETHUSDT', duration_seconds=10): - """Main function to run the COB test with configurable parameters. - - Args: - symbol: Trading symbol (default: ETHUSDT) - duration_seconds: Test duration in seconds (default: 10) - """ - logger.info(f"Starting COB test with symbol={symbol}, duration={duration_seconds}s") - tester = COBStabilityTester(symbol=symbol, duration_seconds=duration_seconds) - await tester.run_test() - - -if __name__ == "__main__": - import sys - - # Parse command line arguments - symbol = 'ETHUSDT' # Default - duration = 10 # Default - - if len(sys.argv) > 1: - symbol = sys.argv[1] - if len(sys.argv) > 2: - try: - duration = int(sys.argv[2]) - except ValueError: - logger.warning(f"Invalid duration '{sys.argv[2]}', using default 10 seconds") - - logger.info(f"Configuration: Symbol={symbol}, Duration={duration}s") - logger.info(f"Granularity: {'1 USD for ETH' if 'ETH' in symbol.upper() else '10 USD for BTC' if 'BTC' in symbol.upper() else '1 USD default'}") - - try: - asyncio.run(main(symbol, duration)) - except KeyboardInterrupt: - logger.info("Test interrupted by user.") diff --git a/tests/test_training.py b/tests/test_training.py deleted file mode 100644 index c3bb012..0000000 --- a/tests/test_training.py +++ /dev/null @@ -1,310 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Training Script for AI Trading Models - -This script tests the training functionality of our CNN and RL models -and demonstrates the learning capabilities. -""" - -import logging -import sys -import asyncio -from pathlib import Path -from datetime import datetime, timedelta -from safe_logging import setup_safe_logging - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import setup_logging -from core.data_provider import DataProvider -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from models import get_model_registry, CNNModelWrapper, RLAgentWrapper - -# Setup logging -setup_logging() -logger = logging.getLogger(__name__) - -def test_model_loading(): - """Test that models load correctly""" - logger.info("=== TESTING MODEL LOADING ===") - - try: - # Get model registry - registry = get_model_registry() - - # Check loaded models - logger.info(f"Loaded models: {list(registry.models.keys())}") - - # Test each model - for name, model in registry.models.items(): - logger.info(f"Testing {name} model...") - - # Test prediction - import numpy as np - test_features = np.random.random((20, 5)) # 20 timesteps, 5 features - - try: - predictions, confidence = model.predict(test_features) - logger.info(f" ✅ {name} prediction: {predictions} (confidence: {confidence:.3f})") - except Exception as e: - logger.error(f" ❌ {name} prediction failed: {e}") - - # Memory stats - stats = registry.get_memory_stats() - logger.info(f"Memory usage: {stats['total_used_mb']:.1f}MB / {stats['total_limit_mb']:.1f}MB") - - return True - - except Exception as e: - logger.error(f"Model loading test failed: {e}") - return False - -async def test_orchestrator_integration(): - """Test orchestrator integration with models""" - logger.info("=== TESTING ORCHESTRATOR INTEGRATION ===") - - try: - # Initialize components - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator(data_provider) - - # Test coordinated decisions - logger.info("Testing coordinated decision making...") - decisions = await orchestrator.make_coordinated_decisions() - - if decisions: - for symbol, decision in decisions.items(): - if decision: - logger.info(f" ✅ {symbol}: {decision.action} (confidence: {decision.confidence:.3f})") - else: - logger.info(f" ⏸️ {symbol}: No decision (waiting)") - else: - logger.warning(" ❌ No decisions made") - - # Test RL evaluation - logger.info("Testing RL evaluation...") - await orchestrator.evaluate_actions_with_rl() - - return True - - except Exception as e: - logger.error(f"Orchestrator integration test failed: {e}") - return False - -def test_rl_learning(): - """Test RL learning functionality""" - logger.info("=== TESTING RL LEARNING ===") - - try: - registry = get_model_registry() - rl_agent = registry.get_model('RL') - - if not rl_agent: - logger.error("RL agent not found") - return False - - # Simulate some experiences - import numpy as np - - logger.info("Simulating trading experiences...") - for i in range(50): - state = np.random.random(10) - action = np.random.randint(0, 3) - reward = np.random.uniform(-0.1, 0.1) # Random P&L - next_state = np.random.random(10) - done = False - - # Store experience - rl_agent.remember(state, action, reward, next_state, done) - - logger.info(f"Stored {len(rl_agent.experience_buffer)} experiences") - - # Test replay training - logger.info("Testing replay training...") - loss = rl_agent.replay() - - if loss is not None: - logger.info(f" ✅ Training loss: {loss:.4f}") - else: - logger.info(" ⏸️ Not enough experiences for training") - - return True - - except Exception as e: - logger.error(f"RL learning test failed: {e}") - return False - -def test_cnn_training(): - """Test CNN training functionality""" - logger.info("=== TESTING CNN TRAINING ===") - - try: - registry = get_model_registry() - cnn_model = registry.get_model('CNN') - - if not cnn_model: - logger.error("CNN model not found") - return False - - # Test training with mock perfect moves - training_data = { - 'perfect_moves': [], - 'market_data': {}, - 'symbols': ['ETH/USDT', 'BTC/USDT'], - 'timeframes': ['1m', '1h'] - } - - # Mock some perfect moves - for i in range(10): - perfect_move = { - 'symbol': 'ETH/USDT', - 'timeframe': '1m', - 'timestamp': datetime.now() - timedelta(hours=i), - 'optimal_action': 'BUY' if i % 2 == 0 else 'SELL', - 'confidence_should_have_been': 0.8 + i * 0.01, - 'actual_outcome': 0.02 if i % 2 == 0 else -0.015 - } - training_data['perfect_moves'].append(perfect_move) - - logger.info(f"Testing training with {len(training_data['perfect_moves'])} perfect moves...") - - # Test training - result = cnn_model.train(training_data) - - if result and result.get('status') == 'training_simulated': - logger.info(f" ✅ Training completed: {result}") - else: - logger.warning(f" ⚠️ Training result: {result}") - - return True - - except Exception as e: - logger.error(f"CNN training test failed: {e}") - return False - -def test_prediction_tracking(): - """Test prediction tracking and learning feedback""" - logger.info("=== TESTING PREDICTION TRACKING ===") - - try: - # Initialize components - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator(data_provider) - - # Get some market data for testing - test_data = data_provider.get_historical_data('ETH/USDT', '1m', limit=100) - - if test_data is None or test_data.empty: - logger.warning("No market data available for testing") - return True - - logger.info(f"Testing with {len(test_data)} candles of ETH/USDT 1m data") - - # Simulate some predictions and outcomes - correct_predictions = 0 - total_predictions = 0 - - for i in range(min(10, len(test_data) - 5)): - # Get a slice of data - current_data = test_data.iloc[i:i+20] - future_data = test_data.iloc[i+20:i+25] - - if len(current_data) < 20 or len(future_data) < 5: - continue - - # Make prediction - current_price = current_data['close'].iloc[-1] - future_price = future_data['close'].iloc[-1] - actual_change = (future_price - current_price) / current_price - - # Simulate model prediction - predicted_action = 'BUY' if actual_change > 0.001 else 'SELL' if actual_change < -0.001 else 'HOLD' - - # Check if prediction was correct - if predicted_action == 'BUY' and actual_change > 0: - correct_predictions += 1 - logger.info(f" ✅ Correct BUY prediction: {actual_change:.4f}") - elif predicted_action == 'SELL' and actual_change < 0: - correct_predictions += 1 - logger.info(f" ✅ Correct SELL prediction: {actual_change:.4f}") - elif predicted_action == 'HOLD' and abs(actual_change) < 0.001: - correct_predictions += 1 - logger.info(f" ✅ Correct HOLD prediction: {actual_change:.4f}") - else: - logger.info(f" ❌ Wrong {predicted_action} prediction: {actual_change:.4f}") - - total_predictions += 1 - - if total_predictions > 0: - accuracy = correct_predictions / total_predictions - logger.info(f"Prediction accuracy: {accuracy:.1%} ({correct_predictions}/{total_predictions})") - - return True - - except Exception as e: - logger.error(f"Prediction tracking test failed: {e}") - return False - -async def main(): - """Main test function""" - logger.info("🧪 STARTING AI TRADING MODEL TESTS") - logger.info("Testing model loading, training, and learning capabilities") - - tests = [ - ("Model Loading", test_model_loading), - ("Orchestrator Integration", test_orchestrator_integration), - ("RL Learning", test_rl_learning), - ("CNN Training", test_cnn_training), - ("Prediction Tracking", test_prediction_tracking) - ] - - results = {} - - for test_name, test_func in tests: - logger.info(f"\n{'='*50}") - logger.info(f"Running: {test_name}") - logger.info(f"{'='*50}") - - try: - if asyncio.iscoroutinefunction(test_func): - result = await test_func() - else: - result = test_func() - - results[test_name] = result - - if result: - logger.info(f"✅ {test_name}: PASSED") - else: - logger.error(f"❌ {test_name}: FAILED") - - except Exception as e: - logger.error(f"❌ {test_name}: ERROR - {e}") - results[test_name] = False - - # Summary - logger.info(f"\n{'='*50}") - logger.info("TEST SUMMARY") - logger.info(f"{'='*50}") - - passed = sum(1 for result in results.values() if result) - total = len(results) - - for test_name, result in results.items(): - status = "✅ PASSED" if result else "❌ FAILED" - logger.info(f"{test_name}: {status}") - - logger.info(f"\nOverall: {passed}/{total} tests passed ({passed/total:.1%})") - - if passed == total: - logger.info("🎉 All tests passed! The AI trading system is working correctly.") - else: - logger.warning(f"⚠️ {total-passed} tests failed. Please check the logs above.") - - return 0 if passed == total else 1 - -if __name__ == "__main__": - exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file diff --git a/tests/test_training_integration.py b/tests/test_training_integration.py deleted file mode 100644 index 04c854f..0000000 --- a/tests/test_training_integration.py +++ /dev/null @@ -1,204 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Training Integration with Dashboard - -This script tests the enhanced dashboard's ability to: -1. Stream training data to CNN and DQN models -2. Display real-time training metrics and progress -3. Show model learning curves and performance -4. Integrate with the continuous training system -""" - -import sys -import logging -import time -import asyncio -from datetime import datetime, timedelta -from pathlib import Path - -# Setup logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_training_integration(): - """Test the training integration functionality""" - try: - print("="*60) - print("TESTING TRAINING INTEGRATION WITH DASHBOARD") - print("="*60) - - # Import dashboard - from web.clean_dashboard import CleanTradingDashboard as TradingDashboard - from core.data_provider import DataProvider - from core.orchestrator import TradingOrchestrator - - # Create components - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - dashboard = TradingDashboard(data_provider, orchestrator) - - print(f"✓ Dashboard created with training integration") - print(f"✓ Continuous training active: {getattr(dashboard, 'training_active', False)}") - - # Test 1: Simulate tick data for training - print("\n📊 TEST 1: Simulating Tick Data") - print("-" * 40) - - # Add simulated tick data to cache - base_price = 3500.0 - for i in range(1000): - tick_data = { - 'timestamp': datetime.now() - timedelta(seconds=1000-i), - 'price': base_price + (i % 100) * 0.1, - 'volume': 100 + (i % 50), - 'side': 'buy' if i % 2 == 0 else 'sell' - } - dashboard.tick_cache.append(tick_data) - - print(f"✓ Added {len(dashboard.tick_cache)} ticks to cache") - - # Test 2: Prepare training data - print("\n🔄 TEST 2: Preparing Training Data") - print("-" * 40) - - training_data = dashboard._prepare_training_data() - if training_data: - print(f"✓ Training data prepared successfully") - print(f" - OHLCV bars: {len(training_data['ohlcv'])}") - print(f" - Features: {training_data['features']}") - print(f" - Symbol: {training_data['symbol']}") - else: - print("❌ Failed to prepare training data") - - # Test 3: Format data for CNN - print("\n🧠 TEST 3: CNN Data Formatting") - print("-" * 40) - - if training_data: - cnn_data = dashboard._format_data_for_cnn(training_data) - if cnn_data and 'sequences' in cnn_data: - print(f"✓ CNN data formatted successfully") - print(f" - Sequences shape: {cnn_data['sequences'].shape}") - print(f" - Targets shape: {cnn_data['targets'].shape}") - print(f" - Sequence length: {cnn_data['sequence_length']}") - else: - print("❌ Failed to format CNN data") - - # Test 4: Format data for RL - print("\n🤖 TEST 4: RL Data Formatting") - print("-" * 40) - - if training_data: - rl_experiences = dashboard._format_data_for_rl(training_data) - if rl_experiences: - print(f"✓ RL experiences formatted successfully") - print(f" - Number of experiences: {len(rl_experiences)}") - print(f" - Experience format: (state, action, reward, next_state, done)") - print(f" - Sample experience shapes: {[len(exp) for exp in rl_experiences[:3]]}") - else: - print("❌ Failed to format RL experiences") - - # Test 5: Send training data to models - print("\n📤 TEST 5: Sending Training Data to Models") - print("-" * 40) - - success = dashboard.send_training_data_to_models() - print(f"✓ Training data sent: {success}") - - if hasattr(dashboard, 'training_stats'): - stats = dashboard.training_stats - print(f" - Total training sessions: {stats.get('total_training_sessions', 0)}") - print(f" - CNN training count: {stats.get('cnn_training_count', 0)}") - print(f" - RL training count: {stats.get('rl_training_count', 0)}") - print(f" - Training data points: {stats.get('training_data_points', 0)}") - - # Test 6: Training metrics display - print("\n📈 TEST 6: Training Metrics Display") - print("-" * 40) - - training_metrics = dashboard._create_training_metrics() - print(f"✓ Training metrics created: {len(training_metrics)} components") - - # Test 7: Model training status - print("\n🔍 TEST 7: Model Training Status") - print("-" * 40) - - training_status = dashboard._get_model_training_status() - print(f"✓ Training status retrieved") - print(f" - CNN status: {training_status['cnn']['status']}") - print(f" - CNN accuracy: {training_status['cnn']['accuracy']:.1%}") - print(f" - RL status: {training_status['rl']['status']}") - print(f" - RL win rate: {training_status['rl']['win_rate']:.1%}") - - # Test 8: Training events log - print("\n📝 TEST 8: Training Events Log") - print("-" * 40) - - training_events = dashboard._get_recent_training_events() - print(f"✓ Training events retrieved: {len(training_events)} events") - - # Test 9: Mini training chart - print("\n📊 TEST 9: Mini Training Chart") - print("-" * 40) - - try: - training_chart = dashboard._create_mini_training_chart(training_status) - print(f"✓ Mini training chart created") - print(f" - Chart type: {type(training_chart)}") - except Exception as e: - print(f"❌ Error creating training chart: {e}") - - # Test 10: Continuous training loop - print("\n🔄 TEST 10: Continuous Training Loop") - print("-" * 40) - - print(f"✓ Continuous training active: {getattr(dashboard, 'training_active', False)}") - if hasattr(dashboard, 'training_thread'): - print(f"✓ Training thread alive: {dashboard.training_thread.is_alive()}") - - # Test 11: Integration with existing continuous training system - print("\n🔗 TEST 11: Integration with Continuous Training System") - print("-" * 40) - - try: - # Check if we can get tick cache for external training - tick_cache = dashboard.get_tick_cache_for_training() - print(f"✓ Tick cache accessible: {len(tick_cache)} ticks") - - # Check if we can get 1-second bars - one_second_bars = dashboard.get_one_second_bars() - print(f"✓ 1-second bars accessible: {len(one_second_bars)} bars") - - except Exception as e: - print(f"❌ Error accessing training data: {e}") - - print("\n" + "="*60) - print("TRAINING INTEGRATION TEST COMPLETED") - print("="*60) - - # Summary - print("\n📋 SUMMARY:") - print(f"✓ Dashboard with training integration: WORKING") - print(f"✓ Training data preparation: WORKING") - print(f"✓ CNN data formatting: WORKING") - print(f"✓ RL data formatting: WORKING") - print(f"✓ Training metrics display: WORKING") - print(f"✓ Continuous training: ACTIVE") - print(f"✓ Model status tracking: WORKING") - print(f"✓ Training events logging: WORKING") - - return True - - except Exception as e: - logger.error(f"Training integration test failed: {e}") - import traceback - traceback.print_exc() - return False - -if __name__ == "__main__": - success = test_training_integration() - if success: - print("\n🎉 All training integration tests passed!") - else: - print("\n❌ Some training integration tests failed!") - sys.exit(1) \ No newline at end of file diff --git a/tests/test_training_status.py b/tests/test_training_status.py deleted file mode 100644 index 49836d9..0000000 --- a/tests/test_training_status.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to check training status functionality -""" - -import logging -logging.basicConfig(level=logging.INFO) - -print("Testing training status functionality...") - -try: - from web.old_archived.scalping_dashboard import create_scalping_dashboard - from core.data_provider import DataProvider - from core.enhanced_orchestrator import EnhancedTradingOrchestrator - - print("✅ Imports successful") - - # Create components - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator(data_provider) - dashboard = create_scalping_dashboard(data_provider, orchestrator) - - print("✅ Dashboard created successfully") - - # Test training status - training_status = dashboard._get_model_training_status() - print("\n📊 Training Status:") - print(f"CNN Status: {training_status['cnn']['status']}") - print(f"CNN Accuracy: {training_status['cnn']['accuracy']:.1%}") - print(f"CNN Loss: {training_status['cnn']['loss']:.4f}") - print(f"CNN Epochs: {training_status['cnn']['epochs']}") - - print(f"RL Status: {training_status['rl']['status']}") - print(f"RL Win Rate: {training_status['rl']['win_rate']:.1%}") - print(f"RL Episodes: {training_status['rl']['episodes']}") - print(f"RL Memory: {training_status['rl']['memory_size']}") - - # Test extrema stats - if hasattr(orchestrator, 'get_extrema_stats'): - extrema_stats = orchestrator.get_extrema_stats() - print(f"\n🎯 Extrema Stats:") - print(f"Total extrema detected: {extrema_stats.get('total_extrema_detected', 0)}") - print(f"Training queue size: {extrema_stats.get('training_queue_size', 0)}") - print("✅ Extrema stats available") - else: - print("❌ Extrema stats not available") - - # Test tick cache - print(f"\n📈 Training Data:") - print(f"Tick cache size: {len(dashboard.tick_cache)}") - print(f"1s bars cache size: {len(dashboard.one_second_bars)}") - print(f"Streaming status: {dashboard.is_streaming}") - - print("\n✅ All tests completed successfully!") - -except Exception as e: - print(f"❌ Error: {e}") - import traceback - traceback.print_exc() \ No newline at end of file diff --git a/tests/test_universal_data_format.py b/tests/test_universal_data_format.py deleted file mode 100644 index e30558a..0000000 --- a/tests/test_universal_data_format.py +++ /dev/null @@ -1,262 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Universal Data Format Compliance - -This script verifies that our enhanced trading system properly feeds -the 5 required timeseries streams to all models: -- ETH/USDT: ticks (1s), 1m, 1h, 1d -- BTC/USDT: ticks (1s) as reference - -This is our universal trading system input format. -""" - -import asyncio -import logging -import sys -from pathlib import Path -import numpy as np - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import get_config -from core.data_provider import DataProvider -from core.universal_data_adapter import UniversalDataAdapter, UniversalDataStream -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from training.enhanced_cnn_trainer import EnhancedCNNTrainer -from training.enhanced_rl_trainer import EnhancedRLTrainer - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -async def test_universal_data_format(): - """Test that all components properly use the universal 5-timeseries format""" - logger.info("="*80) - logger.info("🧪 TESTING UNIVERSAL DATA FORMAT COMPLIANCE") - logger.info("="*80) - - try: - # Initialize components - config = get_config() - data_provider = DataProvider(config) - - # Test 1: Universal Data Adapter - logger.info("\n📊 TEST 1: Universal Data Adapter") - logger.info("-" * 40) - - adapter = UniversalDataAdapter(data_provider) - universal_stream = adapter.get_universal_data_stream() - - if universal_stream is None: - logger.error("❌ Failed to get universal data stream") - return False - - # Validate format - is_valid, issues = adapter.validate_universal_format(universal_stream) - if not is_valid: - logger.error(f"❌ Universal format validation failed: {issues}") - return False - - logger.info("✅ Universal Data Adapter: PASSED") - logger.info(f" ETH ticks: {len(universal_stream.eth_ticks)} samples") - logger.info(f" ETH 1m: {len(universal_stream.eth_1m)} candles") - logger.info(f" ETH 1h: {len(universal_stream.eth_1h)} candles") - logger.info(f" ETH 1d: {len(universal_stream.eth_1d)} candles") - logger.info(f" BTC reference: {len(universal_stream.btc_ticks)} samples") - logger.info(f" Data quality: {universal_stream.metadata['data_quality']['overall_score']:.2f}") - - # Test 2: Enhanced Orchestrator - logger.info("\n🎯 TEST 2: Enhanced Orchestrator") - logger.info("-" * 40) - - orchestrator = EnhancedTradingOrchestrator(data_provider) - - # Test that orchestrator uses universal adapter - if not hasattr(orchestrator, 'universal_adapter'): - logger.error("❌ Orchestrator missing universal_adapter") - return False - - # Test coordinated decisions - decisions = await orchestrator.make_coordinated_decisions() - - logger.info("✅ Enhanced Orchestrator: PASSED") - logger.info(f" Generated {len(decisions)} decisions") - logger.info(f" Universal adapter: {type(orchestrator.universal_adapter).__name__}") - - for symbol, decision in decisions.items(): - if decision: - logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.2f})") - - # Test 3: CNN Model Data Format - logger.info("\n🧠 TEST 3: CNN Model Data Format") - logger.info("-" * 40) - - # Format data for CNN - cnn_data = adapter.format_for_model(universal_stream, 'cnn') - - required_cnn_keys = ['eth_ticks', 'eth_1m', 'eth_1h', 'eth_1d', 'btc_ticks'] - missing_keys = [key for key in required_cnn_keys if key not in cnn_data] - - if missing_keys: - logger.error(f"❌ CNN data missing keys: {missing_keys}") - return False - - logger.info("✅ CNN Model Data Format: PASSED") - for key, data in cnn_data.items(): - if isinstance(data, np.ndarray): - logger.info(f" {key}: shape {data.shape}") - else: - logger.info(f" {key}: {type(data)}") - - # Test 4: RL Model Data Format - logger.info("\n🤖 TEST 4: RL Model Data Format") - logger.info("-" * 40) - - # Format data for RL - rl_data = adapter.format_for_model(universal_stream, 'rl') - - if 'state_vector' not in rl_data: - logger.error("❌ RL data missing state_vector") - return False - - state_vector = rl_data['state_vector'] - if not isinstance(state_vector, np.ndarray): - logger.error("❌ RL state_vector is not numpy array") - return False - - logger.info("✅ RL Model Data Format: PASSED") - logger.info(f" State vector shape: {state_vector.shape}") - logger.info(f" State vector size: {len(state_vector)} features") - - # Test 5: CNN Trainer Integration - logger.info("\n🎓 TEST 5: CNN Trainer Integration") - logger.info("-" * 40) - - try: - cnn_trainer = EnhancedCNNTrainer(config, orchestrator) - logger.info("✅ CNN Trainer Integration: PASSED") - logger.info(f" Model timeframes: {cnn_trainer.model.timeframes}") - logger.info(f" Model device: {cnn_trainer.model.device}") - except Exception as e: - logger.error(f"❌ CNN Trainer Integration failed: {e}") - return False - - # Test 6: RL Trainer Integration - logger.info("\n🎮 TEST 6: RL Trainer Integration") - logger.info("-" * 40) - - try: - rl_trainer = EnhancedRLTrainer(config, orchestrator) - logger.info("✅ RL Trainer Integration: PASSED") - logger.info(f" RL agents: {len(rl_trainer.agents)}") - for symbol, agent in rl_trainer.agents.items(): - logger.info(f" {symbol} agent: {type(agent).__name__}") - except Exception as e: - logger.error(f"❌ RL Trainer Integration failed: {e}") - return False - - # Test 7: Data Flow Verification - logger.info("\n🔄 TEST 7: Data Flow Verification") - logger.info("-" * 40) - - # Verify that models receive the correct data format - test_predictions = await orchestrator._get_enhanced_predictions_universal( - 'ETH/USDT', - list(orchestrator.market_states['ETH/USDT'])[-1] if orchestrator.market_states['ETH/USDT'] else None, - universal_stream - ) - - if test_predictions: - logger.info("✅ Data Flow Verification: PASSED") - for pred in test_predictions: - logger.info(f" Model: {pred.model_name}") - logger.info(f" Action: {pred.overall_action}") - logger.info(f" Confidence: {pred.overall_confidence:.2f}") - logger.info(f" Timeframes: {len(pred.timeframe_predictions)}") - else: - logger.warning("⚠️ No predictions generated (may be normal if no models loaded)") - - # Test 8: Configuration Compliance - logger.info("\n⚙️ TEST 8: Configuration Compliance") - logger.info("-" * 40) - - # Check that config matches universal format - expected_symbols = ['ETH/USDT', 'BTC/USDT'] - expected_timeframes = ['1s', '1m', '1h', '1d'] - - config_symbols = config.symbols - config_timeframes = config.timeframes - - symbols_match = all(symbol in config_symbols for symbol in expected_symbols) - timeframes_match = all(tf in config_timeframes for tf in expected_timeframes) - - if not symbols_match: - logger.warning(f"⚠️ Config symbols may not match universal format") - logger.warning(f" Expected: {expected_symbols}") - logger.warning(f" Config: {config_symbols}") - - if not timeframes_match: - logger.warning(f"⚠️ Config timeframes may not match universal format") - logger.warning(f" Expected: {expected_timeframes}") - logger.warning(f" Config: {config_timeframes}") - - if symbols_match and timeframes_match: - logger.info("✅ Configuration Compliance: PASSED") - else: - logger.info("⚠️ Configuration Compliance: PARTIAL") - - logger.info(f" Symbols: {config_symbols}") - logger.info(f" Timeframes: {config_timeframes}") - - # Final Summary - logger.info("\n" + "="*80) - logger.info("🎉 UNIVERSAL DATA FORMAT TEST SUMMARY") - logger.info("="*80) - logger.info("✅ All core tests PASSED!") - logger.info("") - logger.info("📋 VERIFIED COMPLIANCE:") - logger.info(" ✓ Universal Data Adapter working") - logger.info(" ✓ Enhanced Orchestrator using universal format") - logger.info(" ✓ CNN models receive 5 timeseries streams") - logger.info(" ✓ RL models receive combined state vector") - logger.info(" ✓ Trainers properly integrated") - logger.info(" ✓ Data flow verified") - logger.info("") - logger.info("🎯 UNIVERSAL FORMAT ACTIVE:") - logger.info(" 1. ETH/USDT ticks (1s) ✓") - logger.info(" 2. ETH/USDT 1m ✓") - logger.info(" 3. ETH/USDT 1h ✓") - logger.info(" 4. ETH/USDT 1d ✓") - logger.info(" 5. BTC/USDT reference ticks ✓") - logger.info("") - logger.info("🚀 Your enhanced trading system is ready with universal data format!") - logger.info("="*80) - - return True - - except Exception as e: - logger.error(f"❌ Universal data format test failed: {e}") - import traceback - logger.error(traceback.format_exc()) - return False - -async def main(): - """Main test function""" - logger.info("🚀 Starting Universal Data Format Compliance Test...") - - success = await test_universal_data_format() - - if success: - logger.info("\n🎉 All tests passed! Universal data format is properly implemented.") - logger.info("Your enhanced trading system respects the 5-timeseries input format.") - else: - logger.error("\n💥 Tests failed! Please check the universal data format implementation.") - sys.exit(1) - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/tests/test_universal_stream_integration.py b/tests/test_universal_stream_integration.py deleted file mode 100644 index 1689cf7..0000000 --- a/tests/test_universal_stream_integration.py +++ /dev/null @@ -1,177 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Universal Data Stream Integration with Dashboard - -This script validates that: -1. CleanTradingDashboard properly subscribes to UnifiedDataStream -2. All 5 timeseries are properly received and processed -3. Data flows correctly from provider -> adapter -> stream -> dashboard -4. Consumer callback functions work as expected -""" - -import asyncio -import logging -import sys -import time -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import get_config -from core.data_provider import DataProvider -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from core.trading_executor import TradingExecutor -from web.clean_dashboard import CleanTradingDashboard - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -async def test_universal_stream_integration(): - """Test Universal Data Stream integration with dashboard""" - logger.info("="*80) - logger.info("🧪 TESTING UNIVERSAL DATA STREAM INTEGRATION") - logger.info("="*80) - - try: - # Initialize components - logger.info("\n📦 STEP 1: Initialize Components") - logger.info("-" * 40) - - config = get_config() - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator( - data_provider=data_provider, - symbols=['ETH/USDT', 'BTC/USDT'], - enhanced_rl_training=True - ) - trading_executor = TradingExecutor() - - logger.info("✅ Core components initialized") - - # Initialize dashboard with Universal Data Stream - logger.info("\n📊 STEP 2: Initialize Dashboard with Universal Stream") - logger.info("-" * 40) - - dashboard = CleanTradingDashboard( - data_provider=data_provider, - orchestrator=orchestrator, - trading_executor=trading_executor - ) - - # Check Universal Stream initialization - if hasattr(dashboard, 'unified_stream') and dashboard.unified_stream: - logger.info("✅ Universal Data Stream initialized successfully") - logger.info(f"📋 Consumer ID: {dashboard.stream_consumer_id}") - else: - logger.error("❌ Universal Data Stream not initialized") - return False - - # Test consumer registration - logger.info("\n🔗 STEP 3: Validate Consumer Registration") - logger.info("-" * 40) - - stream_stats = dashboard.unified_stream.get_stream_stats() - logger.info(f"📊 Stream Stats: {stream_stats}") - - if stream_stats['total_consumers'] > 0: - logger.info(f"✅ {stream_stats['total_consumers']} consumers registered") - else: - logger.warning("⚠️ No consumers registered") - - # Test data callback - logger.info("\n📡 STEP 4: Test Data Callback") - logger.info("-" * 40) - - # Create test data packet - test_data = { - 'timestamp': time.time(), - 'consumer_id': dashboard.stream_consumer_id, - 'consumer_name': 'CleanTradingDashboard', - 'ticks': [ - {'symbol': 'ETHUSDT', 'price': 3000.0, 'volume': 1.5, 'timestamp': time.time()}, - {'symbol': 'ETHUSDT', 'price': 3001.0, 'volume': 2.0, 'timestamp': time.time()}, - ], - 'ohlcv': {'one_second_bars': [], 'multi_timeframe': { - 'ETH/USDT': { - '1s': [{'timestamp': time.time(), 'open': 3000, 'high': 3002, 'low': 2999, 'close': 3001, 'volume': 10}], - '1m': [{'timestamp': time.time(), 'open': 2990, 'high': 3010, 'low': 2985, 'close': 3001, 'volume': 100}], - '1h': [{'timestamp': time.time(), 'open': 2900, 'high': 3050, 'low': 2880, 'close': 3001, 'volume': 1000}], - '1d': [{'timestamp': time.time(), 'open': 2800, 'high': 3200, 'low': 2750, 'close': 3001, 'volume': 10000}] - }, - 'BTC/USDT': { - '1s': [{'timestamp': time.time(), 'open': 65000, 'high': 65020, 'low': 64980, 'close': 65010, 'volume': 0.5}] - } - }}, - 'training_data': {'market_state': 'test', 'features': []}, - 'ui_data': {'formatted_data': 'test_ui_data'} - } - - # Test callback manually - try: - dashboard._handle_unified_stream_data(test_data) - logger.info("✅ Data callback executed successfully") - - # Check if data was processed - if hasattr(dashboard, 'current_prices') and 'ETH/USDT' in dashboard.current_prices: - logger.info(f"✅ Price updated: ETH/USDT = ${dashboard.current_prices['ETH/USDT']}") - else: - logger.warning("⚠️ Prices not updated in dashboard") - - except Exception as e: - logger.error(f"❌ Data callback failed: {e}") - return False - - # Test Universal Data Adapter - logger.info("\n🔄 STEP 5: Test Universal Data Adapter") - logger.info("-" * 40) - - if hasattr(orchestrator, 'universal_adapter'): - universal_stream = orchestrator.universal_adapter.get_universal_data_stream() - if universal_stream: - logger.info("✅ Universal Data Adapter working") - logger.info(f"📊 ETH ticks: {len(universal_stream.eth_ticks)} samples") - logger.info(f"📊 ETH 1m: {len(universal_stream.eth_1m)} candles") - logger.info(f"📊 ETH 1h: {len(universal_stream.eth_1h)} candles") - logger.info(f"📊 ETH 1d: {len(universal_stream.eth_1d)} candles") - logger.info(f"📊 BTC ticks: {len(universal_stream.btc_ticks)} samples") - - # Validate format - is_valid, issues = orchestrator.universal_adapter.validate_universal_format(universal_stream) - if is_valid: - logger.info("✅ Universal format validation passed") - else: - logger.warning(f"⚠️ Format issues: {issues}") - else: - logger.error("❌ Universal Data Adapter failed to get stream") - return False - else: - logger.error("❌ Universal Data Adapter not found in orchestrator") - return False - - # Summary - logger.info("\n🎯 SUMMARY") - logger.info("-" * 40) - logger.info("✅ Universal Data Stream properly integrated") - logger.info("✅ Dashboard subscribes as consumer") - logger.info("✅ All 5 timeseries format validated") - logger.info("✅ Data callback processing works") - logger.info("✅ Universal Data Adapter functional") - - logger.info("\n🏆 INTEGRATION TEST PASSED") - return True - - except Exception as e: - logger.error(f"❌ Integration test failed: {e}") - import traceback - traceback.print_exc() - return False - -if __name__ == "__main__": - success = asyncio.run(test_universal_stream_integration()) - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/trading_main.py b/trading_main.py deleted file mode 100644 index 9505c84..0000000 --- a/trading_main.py +++ /dev/null @@ -1,155 +0,0 @@ -import os -import time -import logging -import sys -import argparse -import json - -# Add the NN directory to the Python path -sys.path.append(os.path.abspath("NN")) - -from NN.main import load_model -from NN.neural_network_orchestrator import NeuralNetworkOrchestrator -from NN.realtime_data_interface import RealtimeDataInterface - -# Initialize logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler("trading_bot.log"), - logging.StreamHandler() - ] -) -logger = logging.getLogger(__name__) - -def main(): - """Main function for the trading bot.""" - # Parse command-line arguments - parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration") - parser.add_argument('--symbols', nargs='+', default=["BTC/USDT", "ETH/USDT"], - help='Trading symbols to monitor') - parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"], - help='Timeframes to monitor') - parser.add_argument('--window-size', type=int, default=20, - help='Window size for model input') - parser.add_argument('--output-size', type=int, default=3, - help='Output size of the model (3 for BUY/HOLD/SELL)') - parser.add_argument('--model-type', type=str, default="cnn", choices=["cnn", "lstm", "mlp"], - help='Type of neural network model') - parser.add_argument('--mode', type=str, default="realtime", choices=["realtime", "backtest"], - help='Trading mode') - parser.add_argument('--exchange', type=str, default="binance", choices=["binance", "mexc"], - help='Exchange to use for trading') - parser.add_argument('--api-key', type=str, default=None, - help='API key for the exchange') - parser.add_argument('--api-secret', type=str, default=None, - help='API secret for the exchange') - parser.add_argument('--test-mode', action='store_true', - help='Use test/sandbox exchange environment') - parser.add_argument('--position-size', type=float, default=0.1, - help='Position size as a fraction of total balance (0.0-1.0)') - parser.add_argument('--max-trades-per-day', type=int, default=5, - help='Maximum number of trades per day') - parser.add_argument('--trade-cooldown', type=int, default=60, - help='Trade cooldown period in minutes') - parser.add_argument('--config-file', type=str, default=None, - help='Path to configuration file') - - args = parser.parse_args() - - # Load configuration from file if provided - if args.config_file and os.path.exists(args.config_file): - with open(args.config_file, 'r') as f: - config = json.load(f) - # Override config with command-line args - for key, value in vars(args).items(): - if key != 'config_file' and value is not None: - config[key] = value - else: - # Use command-line args as config - config = vars(args) - - # Initialize real-time charts and data interfaces - try: - from dataprovider_realtime import RealTimeChart - - # Create a real-time chart for each symbol - charts = {} - for symbol in config['symbols']: - charts[symbol] = RealTimeChart(symbol=symbol) - - main_chart = charts[config['symbols'][0]] - - # Create a data interface for retrieving market data - data_interface = RealtimeDataInterface(symbols=config['symbols'], chart=main_chart) - - # Load trained model - model_type = os.environ.get("NN_MODEL_TYPE", config['model_type']) - model = load_model( - model_type=model_type, - input_shape=(config['window_size'], len(config['symbols']), 5), # 5 features (OHLCV) - output_size=config['output_size'] - ) - - # Configure trading agent - exchange_config = { - "exchange": config['exchange'], - "api_key": config['api_key'], - "api_secret": config['api_secret'], - "test_mode": config['test_mode'], - "trade_symbols": config['symbols'], - "position_size": config['position_size'], - "max_trades_per_day": config['max_trades_per_day'], - "trade_cooldown_minutes": config['trade_cooldown'] - } - - # Initialize neural network orchestrator - orchestrator = NeuralNetworkOrchestrator( - model=model, - data_interface=data_interface, - chart=main_chart, - symbols=config['symbols'], - timeframes=config['timeframes'], - window_size=config['window_size'], - num_features=5, # OHLCV - output_size=config['output_size'], - exchange_config=exchange_config - ) - - # Start data collection - logger.info("Starting data collection threads...") - for symbol in config['symbols']: - charts[symbol].start() - - # Start neural network inference - if os.environ.get("ENABLE_NN_MODELS", "0") == "1": - logger.info("Starting neural network inference...") - orchestrator.start_inference() - else: - logger.info("Neural network models disabled. Set ENABLE_NN_MODELS=1 to enable.") - - # Start web servers for chart display - logger.info("Starting web servers for chart display...") - main_chart.start_server() - - logger.info("Trading bot initialized successfully. Press Ctrl+C to exit.") - - # Keep the main thread alive - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - logger.info("Keyboard interrupt received. Shutting down...") - # Stop all threads - for symbol in config['symbols']: - charts[symbol].stop() - orchestrator.stop_inference() - logger.info("Trading bot stopped.") - - except Exception as e: - logger.error(f"Error in main function: {str(e)}", exc_info=True) - sys.exit(1) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py deleted file mode 100644 index 71aab0c..0000000 --- a/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Utils package for the multi-modal trading system -""" \ No newline at end of file diff --git a/utils/async_task_manager.py b/utils/async_task_manager.py deleted file mode 100644 index 30ee711..0000000 --- a/utils/async_task_manager.py +++ /dev/null @@ -1,232 +0,0 @@ -""" -Async Task Manager - Handles async tasks with comprehensive error handling -Prevents silent failures in async operations -""" - -import asyncio -import logging -import functools -import traceback -from typing import Any, Callable, Optional, Dict, List -from datetime import datetime - -logger = logging.getLogger(__name__) - -class AsyncTaskManager: - """Manage async tasks with error handling and monitoring""" - - def __init__(self): - self.active_tasks: Dict[str, asyncio.Task] = {} - self.completed_tasks: List[Dict[str, Any]] = [] - self.failed_tasks: List[Dict[str, Any]] = [] - self.max_history = 100 - - def create_task_with_error_handling(self, - coro: Any, - name: str, - error_callback: Optional[Callable] = None, - success_callback: Optional[Callable] = None) -> asyncio.Task: - """ - Create an async task with comprehensive error handling - - Args: - coro: Coroutine to run - name: Task name for identification - error_callback: Called on error with (name, exception) - success_callback: Called on success with (name, result) - """ - - async def wrapped_coro(): - """Wrapper coroutine with error handling""" - start_time = datetime.now() - try: - logger.debug(f"Starting async task: {name}") - result = await coro - - # Log success - duration = (datetime.now() - start_time).total_seconds() - logger.debug(f"Async task '{name}' completed successfully in {duration:.2f}s") - - # Store completion info - completion_info = { - 'name': name, - 'status': 'completed', - 'start_time': start_time, - 'end_time': datetime.now(), - 'duration': duration, - 'result': str(result)[:200] if result else None # Truncate long results - } - self.completed_tasks.append(completion_info) - - # Trim history - if len(self.completed_tasks) > self.max_history: - self.completed_tasks.pop(0) - - # Call success callback - if success_callback: - try: - success_callback(name, result) - except Exception as cb_error: - logger.error(f"Error in success callback for task '{name}': {cb_error}") - - return result - - except asyncio.CancelledError: - logger.info(f"Async task '{name}' was cancelled") - raise - - except Exception as e: - # Log error with full traceback - duration = (datetime.now() - start_time).total_seconds() - error_msg = f"Async task '{name}' failed after {duration:.2f}s: {e}" - logger.error(error_msg) - logger.error(f"Task '{name}' traceback: {traceback.format_exc()}") - - # Store failure info - failure_info = { - 'name': name, - 'status': 'failed', - 'start_time': start_time, - 'end_time': datetime.now(), - 'duration': duration, - 'error': str(e), - 'traceback': traceback.format_exc() - } - self.failed_tasks.append(failure_info) - - # Trim history - if len(self.failed_tasks) > self.max_history: - self.failed_tasks.pop(0) - - # Call error callback - if error_callback: - try: - error_callback(name, e) - except Exception as cb_error: - logger.error(f"Error in error callback for task '{name}': {cb_error}") - - # Don't re-raise to prevent task from crashing the event loop - # Instead, return None to indicate failure - return None - - finally: - # Remove from active tasks - if name in self.active_tasks: - del self.active_tasks[name] - - # Create and store task - task = asyncio.create_task(wrapped_coro(), name=name) - self.active_tasks[name] = task - - return task - - def cancel_task(self, name: str) -> bool: - """Cancel a specific task""" - if name in self.active_tasks: - task = self.active_tasks[name] - if not task.done(): - task.cancel() - logger.info(f"Cancelled async task: {name}") - return True - return False - - def cancel_all_tasks(self): - """Cancel all active tasks""" - for name, task in list(self.active_tasks.items()): - if not task.done(): - task.cancel() - logger.info(f"Cancelled async task: {name}") - - def get_task_status(self) -> Dict[str, Any]: - """Get status of all tasks""" - active_count = len(self.active_tasks) - completed_count = len(self.completed_tasks) - failed_count = len(self.failed_tasks) - - # Get recent failures - recent_failures = self.failed_tasks[-5:] if self.failed_tasks else [] - - return { - 'active_tasks': active_count, - 'completed_tasks': completed_count, - 'failed_tasks': failed_count, - 'active_task_names': list(self.active_tasks.keys()), - 'recent_failures': [ - { - 'name': f['name'], - 'error': f['error'], - 'duration': f['duration'], - 'time': f['end_time'].strftime('%H:%M:%S') - } - for f in recent_failures - ] - } - - def get_failure_summary(self) -> Dict[str, Any]: - """Get summary of task failures""" - if not self.failed_tasks: - return {'total_failures': 0, 'failure_patterns': {}} - - # Count failures by error type - error_counts = {} - for failure in self.failed_tasks: - error_type = type(failure.get('error', 'Unknown')).__name__ - error_counts[error_type] = error_counts.get(error_type, 0) + 1 - - # Recent failure rate - recent_failures = [f for f in self.failed_tasks if - (datetime.now() - f['end_time']).total_seconds() < 3600] # Last hour - - return { - 'total_failures': len(self.failed_tasks), - 'recent_failures_1h': len(recent_failures), - 'failure_patterns': error_counts, - 'most_common_error': max(error_counts.items(), key=lambda x: x[1])[0] if error_counts else None - } - -# Global instance -_task_manager = None - -def get_async_task_manager() -> AsyncTaskManager: - """Get global async task manager instance""" - global _task_manager - if _task_manager is None: - _task_manager = AsyncTaskManager() - return _task_manager - -def create_safe_task(coro: Any, - name: str, - error_callback: Optional[Callable] = None, - success_callback: Optional[Callable] = None) -> asyncio.Task: - """ - Create a safe async task with error handling - - Args: - coro: Coroutine to run - name: Task name for identification - error_callback: Called on error with (name, exception) - success_callback: Called on success with (name, result) - """ - manager = get_async_task_manager() - return manager.create_task_with_error_handling(coro, name, error_callback, success_callback) - -def safe_async_wrapper(name: str, - error_callback: Optional[Callable] = None, - success_callback: Optional[Callable] = None): - """ - Decorator for creating safe async functions - - Usage: - @safe_async_wrapper("my_task") - async def my_async_function(): - # Your async code here - pass - """ - def decorator(func): - @functools.wraps(func) - async def wrapper(*args, **kwargs): - coro = func(*args, **kwargs) - task = create_safe_task(coro, name, error_callback, success_callback) - return await task - return wrapper - return decorator \ No newline at end of file diff --git a/utils/launch_tensorboard.py b/utils/launch_tensorboard.py deleted file mode 100644 index 23e601f..0000000 --- a/utils/launch_tensorboard.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python3 -""" -TensorBoard Launcher with Automatic Port Management - -This script launches TensorBoard with automatic port fallback if the preferred port is in use. -It also kills any stale debug instances that might be running. - -Usage: - python launch_tensorboard.py --logdir=path/to/logs --preferred-port=6007 --port-range=6000-7000 -""" - -import os -import sys -import subprocess -import argparse -import logging -from pathlib import Path - -# Add project root to path -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if project_root not in sys.path: - sys.path.append(project_root) - -from utils.port_manager import get_port_with_fallback, kill_stale_debug_instances - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger('tensorboard_launcher') - -def launch_tensorboard(logdir, port, host='localhost', open_browser=True): - """ - Launch TensorBoard on the specified port - - Args: - logdir (str): Path to log directory - port (int): Port to use - host (str): Host to bind to - open_browser (bool): Whether to open browser automatically - - Returns: - subprocess.Popen: Process object - """ - cmd = [ - sys.executable, "-m", "tensorboard.main", - f"--logdir={logdir}", - f"--port={port}", - f"--host={host}" - ] - - # Add --load_fast=false to improve startup times - cmd.append("--load_fast=false") - - # Control whether to open browser - if not open_browser: - cmd.append("--window_title=TensorBoard") - - logger.info(f"Launching TensorBoard: {' '.join(cmd)}") - - # Use subprocess.Popen to start TensorBoard without waiting for it to finish - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=True, - bufsize=1 - ) - - # Log the first few lines of output to confirm it's starting correctly - line_count = 0 - for line in process.stdout: - logger.info(f"TensorBoard: {line.strip()}") - line_count += 1 - - # Check if TensorBoard has started successfully - if "TensorBoard" in line and "http://" in line: - url = line.strip().split("http://")[1].split(" ")[0] - logger.info(f"TensorBoard available at: http://{url}") - - # Only log the first few lines - if line_count >= 10: - break - - # Continue reading output in background to prevent pipe from filling - def read_output(): - for line in process.stdout: - pass - - import threading - threading.Thread(target=read_output, daemon=True).start() - - return process - -def main(): - parser = argparse.ArgumentParser(description='Launch TensorBoard with automatic port management') - parser.add_argument('--logdir', type=str, default='NN/models/saved/logs', - help='Directory containing TensorBoard event files') - parser.add_argument('--preferred-port', type=int, default=6007, - help='Preferred port to use') - parser.add_argument('--port-range', type=str, default='6000-7000', - help='Port range to try if preferred port is unavailable (format: min-max)') - parser.add_argument('--host', type=str, default='localhost', - help='Host to bind to') - parser.add_argument('--no-browser', action='store_true', - help='Do not open browser automatically') - parser.add_argument('--kill-stale', action='store_true', - help='Kill stale debug instances before starting') - - args = parser.parse_args() - - # Parse port range - try: - min_port, max_port = map(int, args.port_range.split('-')) - except ValueError: - logger.error(f"Invalid port range format: {args.port_range}. Use format: min-max") - return 1 - - # Kill stale instances if requested - if args.kill_stale: - logger.info("Killing stale debug instances...") - count, _ = kill_stale_debug_instances() - logger.info(f"Killed {count} stale instances") - - # Get an available port - try: - port = get_port_with_fallback(args.preferred_port, min_port, max_port) - logger.info(f"Using port {port} for TensorBoard") - except RuntimeError as e: - logger.error(str(e)) - return 1 - - # Ensure log directory exists - logdir = os.path.abspath(args.logdir) - os.makedirs(logdir, exist_ok=True) - - # Launch TensorBoard - process = launch_tensorboard( - logdir=logdir, - port=port, - host=args.host, - open_browser=not args.no_browser - ) - - # Wait for process to end (it shouldn't unless there's an error or user kills it) - try: - return_code = process.wait() - if return_code != 0: - logger.error(f"TensorBoard exited with code {return_code}") - return return_code - except KeyboardInterrupt: - logger.info("Received keyboard interrupt, shutting down TensorBoard...") - process.terminate() - try: - process.wait(timeout=5) - except subprocess.TimeoutExpired: - logger.warning("TensorBoard didn't terminate gracefully, forcing kill") - process.kill() - - return 0 - -if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file diff --git a/utils/model_utils.py b/utils/model_utils.py deleted file mode 100644 index df87455..0000000 --- a/utils/model_utils.py +++ /dev/null @@ -1,241 +0,0 @@ -#!/usr/bin/env python -""" -Model utilities for robust saving and loading of PyTorch models -""" - -import os -import logging -import torch -import shutil -import gc -import json -from typing import Any, Dict, Optional, Union - -logger = logging.getLogger(__name__) - -def robust_save(model: Any, path: str, include_optimizer: bool = True) -> bool: - """ - Robust model saving with multiple fallback approaches - - Args: - model: The model object to save (should have policy_net, target_net, optimizer, epsilon attributes) - path: Path to save the model - include_optimizer: Whether to include optimizer state in the save - - Returns: - bool: True if successful, False otherwise - """ - # Create directory if it doesn't exist - os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) - - # Backup path in case the main save fails - backup_path = f"{path}.backup" - - # Clean up GPU memory before saving - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # Prepare checkpoint data - checkpoint = { - 'policy_net': model.policy_net.state_dict(), - 'target_net': model.target_net.state_dict(), - 'epsilon': getattr(model, 'epsilon', 0.0), - 'state_size': getattr(model, 'state_size', None), - 'action_size': getattr(model, 'action_size', None), - 'hidden_size': getattr(model, 'hidden_size', None), - } - - # Add optimizer state if requested and available - if include_optimizer and hasattr(model, 'optimizer') and model.optimizer is not None: - checkpoint['optimizer'] = model.optimizer.state_dict() - - # Attempt 1: Try with default settings in a separate file first - try: - logger.info(f"Saving model to {backup_path} (attempt 1)") - torch.save(checkpoint, backup_path) - logger.info(f"Successfully saved to {backup_path}") - - # If backup worked, copy to the actual path - if os.path.exists(backup_path): - shutil.copy(backup_path, path) - logger.info(f"Copied backup to {path}") - return True - except Exception as e: - logger.warning(f"First save attempt failed: {e}") - - # Attempt 2: Try with pickle protocol 2 (more compatible) - try: - logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)") - torch.save(checkpoint, path, pickle_protocol=2) - logger.info(f"Successfully saved to {path} with pickle_protocol=2") - return True - except Exception as e: - logger.warning(f"Second save attempt failed: {e}") - - # Attempt 3: Try without optimizer state (which can be large and cause issues) - try: - logger.info(f"Saving model to {path} (attempt 3 - without optimizer)") - checkpoint_no_opt = {k: v for k, v in checkpoint.items() if k != 'optimizer'} - torch.save(checkpoint_no_opt, path) - logger.info(f"Successfully saved to {path} without optimizer state") - return True - except Exception as e: - logger.warning(f"Third save attempt failed: {e}") - - # Attempt 4: Try with torch.jit.save instead - try: - logger.info(f"Saving model to {path} (attempt 4 - with jit.save)") - # Save policy network using jit - scripted_policy = torch.jit.script(model.policy_net) - torch.jit.save(scripted_policy, f"{path}.policy.jit") - - # Save target network using jit - scripted_target = torch.jit.script(model.target_net) - torch.jit.save(scripted_target, f"{path}.target.jit") - - # Save parameters separately as JSON - params = { - 'epsilon': float(getattr(model, 'epsilon', 0.0)), - 'state_size': int(getattr(model, 'state_size', 0)), - 'action_size': int(getattr(model, 'action_size', 0)), - 'hidden_size': int(getattr(model, 'hidden_size', 0)) - } - with open(f"{path}.params.json", "w") as f: - json.dump(params, f) - - logger.info(f"Successfully saved model components with jit.save") - return True - except Exception as e: - logger.error(f"All save attempts failed: {e}") - return False - -def robust_load(model: Any, path: str, device: Optional[torch.device] = None) -> bool: - """ - Robust model loading with fallback approaches - - Args: - model: The model object to load into - path: Path to load the model from - device: Device to load the model on - - Returns: - bool: True if successful, False otherwise - """ - if device is None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Try regular PyTorch load first - try: - logger.info(f"Loading model from {path}") - if os.path.exists(path): - checkpoint = torch.load(path, map_location=device) - - # Load network states - if 'policy_net' in checkpoint: - model.policy_net.load_state_dict(checkpoint['policy_net']) - if 'target_net' in checkpoint: - model.target_net.load_state_dict(checkpoint['target_net']) - - # Load other attributes - if 'epsilon' in checkpoint: - model.epsilon = checkpoint['epsilon'] - if 'optimizer' in checkpoint and hasattr(model, 'optimizer') and model.optimizer is not None: - try: - model.optimizer.load_state_dict(checkpoint['optimizer']) - except Exception as e: - logger.warning(f"Failed to load optimizer state: {e}") - - logger.info("Successfully loaded model") - return True - except Exception as e: - logger.warning(f"Regular load failed: {e}") - - # Try loading JIT saved components - try: - policy_path = f"{path}.policy.jit" - target_path = f"{path}.target.jit" - params_path = f"{path}.params.json" - - if all(os.path.exists(p) for p in [policy_path, target_path, params_path]): - logger.info(f"Loading JIT model components") - - # Load JIT models (this is more complex and may need model reconstruction) - # For now, just log that we found JIT files - logger.info("Found JIT model files, but loading them requires special handling") - with open(params_path, 'r') as f: - params = json.load(f) - logger.info(f"Model parameters: {params}") - - # Note: Actually loading JIT models would require recreating the model architecture - # This is a placeholder for future implementation - return False - except Exception as e: - logger.error(f"JIT load failed: {e}") - - logger.error(f"All load attempts failed for {path}") - return False - -def get_model_info(path: str) -> Dict[str, Any]: - """ - Get information about a saved model - - Args: - path: Path to the model file - - Returns: - dict: Model information - """ - info = { - 'exists': False, - 'size_bytes': 0, - 'has_optimizer': False, - 'parameters': {} - } - - try: - if os.path.exists(path): - info['exists'] = True - info['size_bytes'] = os.path.getsize(path) - - # Try to load and inspect - checkpoint = torch.load(path, map_location='cpu') - info['has_optimizer'] = 'optimizer' in checkpoint - - # Extract parameter info - for key in ['epsilon', 'state_size', 'action_size', 'hidden_size']: - if key in checkpoint: - info['parameters'][key] = checkpoint[key] - except Exception as e: - logger.warning(f"Failed to get model info for {path}: {e}") - - return info - -def verify_save_load_cycle(model: Any, test_path: str) -> bool: - """ - Test that a model can be saved and loaded correctly - - Args: - model: Model to test - test_path: Path for test file - - Returns: - bool: True if save/load cycle successful - """ - try: - # Save the model - if not robust_save(model, test_path): - return False - - # Create a new model instance (this would need model creation logic) - # For now, just verify the file exists and has content - if os.path.exists(test_path) and os.path.getsize(test_path) > 0: - logger.info("Save/load cycle verification successful") - # Clean up test file - os.remove(test_path) - return True - else: - return False - except Exception as e: - logger.error(f"Save/load cycle verification failed: {e}") - return False \ No newline at end of file diff --git a/utils/port_manager.py b/utils/port_manager.py deleted file mode 100644 index e33c91c..0000000 --- a/utils/port_manager.py +++ /dev/null @@ -1,238 +0,0 @@ -#!/usr/bin/env python3 -""" -Port Management Utility - -This script provides utilities to: -1. Find available ports in a specified range -2. Kill stale processes running on specific ports -3. Kill all debug/training instances - -Usage: - - As a module: import port_manager and use its functions - - Directly: python port_manager.py --kill-stale --min-port 6000 --max-port 7000 -""" - -import os -import sys -import socket -import argparse -import psutil -import logging -import time -import signal -from typing import List, Tuple, Optional, Set - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger('port_manager') - -# Define process names to look for when killing stale instances -DEBUG_PROCESS_KEYWORDS = [ - 'tensorboard', - 'python train_', - 'realtime.py', - 'train_rl_with_realtime.py' -] - -def is_port_in_use(port: int) -> bool: - """ - Check if a port is in use - - Args: - port (int): Port number to check - - Returns: - bool: True if port is in use, False otherwise - """ - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 - -def find_available_port(start_port: int, end_port: int) -> Optional[int]: - """ - Find an available port in the specified range - - Args: - start_port (int): Lower bound of port range - end_port (int): Upper bound of port range - - Returns: - Optional[int]: Available port number or None if no ports available - """ - for port in range(start_port, end_port + 1): - if not is_port_in_use(port): - return port - return None - -def get_process_by_port(port: int) -> List[psutil.Process]: - """ - Get processes using a specific port - - Args: - port (int): Port number to check - - Returns: - List[psutil.Process]: List of processes using the port - """ - processes = [] - for proc in psutil.process_iter(['pid', 'name', 'cmdline']): - try: - for conn in proc.connections(kind='inet'): - if conn.laddr.port == port: - processes.append(proc) - except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): - pass - return processes - -def kill_process_by_port(port: int) -> Tuple[int, List[str]]: - """ - Kill processes using a specific port - - Args: - port (int): Port number to check - - Returns: - Tuple[int, List[str]]: Count of killed processes and their names - """ - processes = get_process_by_port(port) - killed = [] - - for proc in processes: - try: - proc_name = " ".join(proc.cmdline()) if proc.cmdline() else proc.name() - logger.info(f"Terminating process {proc.pid}: {proc_name}") - proc.terminate() - killed.append(proc_name) - except (psutil.NoSuchProcess, psutil.AccessDenied): - pass - - # Give processes time to terminate gracefully - if processes: - time.sleep(0.5) - - # Force kill any remaining processes - for proc in processes: - try: - if proc.is_running(): - logger.info(f"Force killing process {proc.pid}") - proc.kill() - except (psutil.NoSuchProcess, psutil.AccessDenied): - pass - - return len(killed), killed - -def kill_stale_debug_instances() -> Tuple[int, Set[str]]: - """ - Kill all stale debug and training instances based on process names - - Returns: - Tuple[int, Set[str]]: Count of killed processes and their names - """ - killed_count = 0 - killed_procs = set() - - for proc in psutil.process_iter(['pid', 'name', 'cmdline']): - try: - cmd = " ".join(proc.cmdline()) if proc.cmdline() else proc.name() - - # Check if this is a debug/training process we should kill - if any(keyword in cmd for keyword in DEBUG_PROCESS_KEYWORDS): - logger.info(f"Terminating stale process {proc.pid}: {cmd}") - proc.terminate() - killed_count += 1 - killed_procs.add(cmd) - except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): - pass - - # Give processes time to terminate - if killed_count > 0: - time.sleep(1) - - # Force kill any remaining processes - for proc in psutil.process_iter(['pid', 'name', 'cmdline']): - try: - cmd = " ".join(proc.cmdline()) if proc.cmdline() else proc.name() - - if any(keyword in cmd for keyword in DEBUG_PROCESS_KEYWORDS) and proc.is_running(): - logger.info(f"Force killing stale process {proc.pid}") - proc.kill() - except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): - pass - - return killed_count, killed_procs - -def get_port_with_fallback(preferred_port: int, min_port: int, max_port: int) -> int: - """ - Try to use preferred port, fall back to any available port in range - - Args: - preferred_port (int): Preferred port to use - min_port (int): Minimum port in fallback range - max_port (int): Maximum port in fallback range - - Returns: - int: Available port number - """ - # First try the preferred port - if not is_port_in_use(preferred_port): - return preferred_port - - # If preferred port is in use, try to free it - logger.info(f"Preferred port {preferred_port} is in use, attempting to free it") - kill_count, _ = kill_process_by_port(preferred_port) - - if kill_count > 0 and not is_port_in_use(preferred_port): - logger.info(f"Successfully freed port {preferred_port}") - return preferred_port - - # If we couldn't free the preferred port, find another available port - logger.info(f"Looking for available port in range {min_port}-{max_port}") - available_port = find_available_port(min_port, max_port) - - if available_port: - logger.info(f"Using alternative port: {available_port}") - return available_port - else: - # If no ports are available, force kill processes in the entire range - logger.warning(f"No available ports in range {min_port}-{max_port}, freeing ports") - for port in range(min_port, max_port + 1): - kill_process_by_port(port) - - # Try again - available_port = find_available_port(min_port, max_port) - if available_port: - logger.info(f"Using port {available_port} after freeing") - return available_port - else: - logger.error(f"Could not find available port even after freeing range {min_port}-{max_port}") - raise RuntimeError(f"No available ports in range {min_port}-{max_port}") - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Port management utility') - parser.add_argument('--kill-stale', action='store_true', help='Kill all stale debug instances') - parser.add_argument('--free-port', type=int, help='Free a specific port') - parser.add_argument('--find-port', action='store_true', help='Find an available port') - parser.add_argument('--min-port', type=int, default=6000, help='Minimum port in range') - parser.add_argument('--max-port', type=int, default=7000, help='Maximum port in range') - parser.add_argument('--preferred-port', type=int, help='Preferred port to use') - - args = parser.parse_args() - - if args.kill_stale: - count, procs = kill_stale_debug_instances() - logger.info(f"Killed {count} stale processes") - for proc in procs: - logger.info(f" - {proc}") - - if args.free_port: - count, killed = kill_process_by_port(args.free_port) - logger.info(f"Killed {count} processes using port {args.free_port}") - for proc in killed: - logger.info(f" - {proc}") - - if args.find_port or args.preferred_port: - preferred = args.preferred_port if args.preferred_port else args.min_port - port = get_port_with_fallback(preferred, args.min_port, args.max_port) - print(port) # Print only the port number for easy capture in scripts \ No newline at end of file diff --git a/utils/process_supervisor.py b/utils/process_supervisor.py deleted file mode 100644 index 1ecdeb0..0000000 --- a/utils/process_supervisor.py +++ /dev/null @@ -1,340 +0,0 @@ -""" -Process Supervisor - Handles process monitoring, restarts, and supervision -Prevents silent failures by monitoring process health and restarting on crashes -""" - -import subprocess -import threading -import time -import logging -import signal -import os -import sys -from typing import Dict, Any, Optional, Callable, List -from datetime import datetime, timedelta -from pathlib import Path - -logger = logging.getLogger(__name__) - -class ProcessSupervisor: - """Supervise processes and restart them on failure""" - - def __init__(self, max_restarts: int = 5, restart_delay: int = 10): - """ - Initialize process supervisor - - Args: - max_restarts: Maximum number of restarts before giving up - restart_delay: Delay in seconds between restarts - """ - self.max_restarts = max_restarts - self.restart_delay = restart_delay - - self.processes: Dict[str, Dict[str, Any]] = {} - self.monitoring = False - self.monitor_thread = None - - # Callbacks - self.process_started_callback: Optional[Callable] = None - self.process_failed_callback: Optional[Callable] = None - self.process_restarted_callback: Optional[Callable] = None - - def add_process(self, name: str, command: List[str], - working_dir: Optional[str] = None, - env: Optional[Dict[str, str]] = None, - auto_restart: bool = True): - """ - Add a process to supervise - - Args: - name: Process name - command: Command to run as list - working_dir: Working directory - env: Environment variables - auto_restart: Whether to auto-restart on failure - """ - self.processes[name] = { - 'command': command, - 'working_dir': working_dir, - 'env': env, - 'auto_restart': auto_restart, - 'process': None, - 'restart_count': 0, - 'last_start': None, - 'last_failure': None, - 'status': 'stopped' - } - logger.info(f"Added process '{name}' to supervisor") - - def start_process(self, name: str) -> bool: - """Start a specific process""" - if name not in self.processes: - logger.error(f"Process '{name}' not found") - return False - - proc_info = self.processes[name] - - if proc_info['process'] and proc_info['process'].poll() is None: - logger.warning(f"Process '{name}' is already running") - return True - - try: - # Prepare environment - env = os.environ.copy() - if proc_info['env']: - env.update(proc_info['env']) - - # Start process - process = subprocess.Popen( - proc_info['command'], - cwd=proc_info['working_dir'], - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True - ) - - proc_info['process'] = process - proc_info['last_start'] = datetime.now() - proc_info['status'] = 'running' - - logger.info(f"Started process '{name}' (PID: {process.pid})") - - if self.process_started_callback: - try: - self.process_started_callback(name, process.pid) - except Exception as e: - logger.error(f"Error in process started callback: {e}") - - return True - - except Exception as e: - logger.error(f"Failed to start process '{name}': {e}") - proc_info['status'] = 'failed' - proc_info['last_failure'] = datetime.now() - return False - - def stop_process(self, name: str, timeout: int = 10) -> bool: - """Stop a specific process""" - if name not in self.processes: - logger.error(f"Process '{name}' not found") - return False - - proc_info = self.processes[name] - process = proc_info['process'] - - if not process or process.poll() is not None: - logger.info(f"Process '{name}' is not running") - proc_info['status'] = 'stopped' - return True - - try: - # Try graceful shutdown first - process.terminate() - - # Wait for graceful shutdown - try: - process.wait(timeout=timeout) - logger.info(f"Process '{name}' terminated gracefully") - except subprocess.TimeoutExpired: - # Force kill if graceful shutdown fails - logger.warning(f"Process '{name}' did not terminate gracefully, force killing") - process.kill() - process.wait() - logger.info(f"Process '{name}' force killed") - - proc_info['status'] = 'stopped' - return True - - except Exception as e: - logger.error(f"Error stopping process '{name}': {e}") - return False - - def restart_process(self, name: str) -> bool: - """Restart a specific process""" - logger.info(f"Restarting process '{name}'") - - if name not in self.processes: - logger.error(f"Process '{name}' not found") - return False - - proc_info = self.processes[name] - - # Stop if running - if proc_info['process'] and proc_info['process'].poll() is None: - self.stop_process(name) - - # Wait restart delay - time.sleep(self.restart_delay) - - # Increment restart count - proc_info['restart_count'] += 1 - - # Check restart limit - if proc_info['restart_count'] > self.max_restarts: - logger.error(f"Process '{name}' exceeded max restarts ({self.max_restarts})") - proc_info['status'] = 'failed_max_restarts' - return False - - # Start process - success = self.start_process(name) - - if success and self.process_restarted_callback: - try: - self.process_restarted_callback(name, proc_info['restart_count']) - except Exception as e: - logger.error(f"Error in process restarted callback: {e}") - - return success - - def start_monitoring(self): - """Start process monitoring""" - if self.monitoring: - logger.warning("Process monitoring already started") - return - - self.monitoring = True - self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) - self.monitor_thread.start() - logger.info("Process monitoring started") - - def stop_monitoring(self): - """Stop process monitoring""" - self.monitoring = False - if self.monitor_thread: - self.monitor_thread.join(timeout=5) - logger.info("Process monitoring stopped") - - def _monitor_loop(self): - """Main monitoring loop""" - logger.info("Process monitoring loop started") - - while self.monitoring: - try: - for name, proc_info in self.processes.items(): - self._check_process_health(name, proc_info) - - time.sleep(5) # Check every 5 seconds - - except Exception as e: - logger.error(f"Error in process monitoring loop: {e}") - time.sleep(5) - - logger.info("Process monitoring loop stopped") - - def _check_process_health(self, name: str, proc_info: Dict[str, Any]): - """Check health of a specific process""" - process = proc_info['process'] - - if not process: - return - - # Check if process is still running - return_code = process.poll() - - if return_code is not None: - # Process has exited - proc_info['status'] = 'exited' - proc_info['last_failure'] = datetime.now() - - logger.warning(f"Process '{name}' exited with code {return_code}") - - # Read stdout/stderr for debugging - try: - stdout, stderr = process.communicate(timeout=1) - if stdout: - logger.info(f"Process '{name}' stdout: {stdout[-500:]}") # Last 500 chars - if stderr: - logger.error(f"Process '{name}' stderr: {stderr[-500:]}") # Last 500 chars - except Exception as e: - logger.warning(f"Could not read process output: {e}") - - if self.process_failed_callback: - try: - self.process_failed_callback(name, return_code) - except Exception as e: - logger.error(f"Error in process failed callback: {e}") - - # Auto-restart if enabled - if proc_info['auto_restart'] and proc_info['restart_count'] < self.max_restarts: - logger.info(f"Auto-restarting process '{name}'") - threading.Thread(target=self.restart_process, args=(name,), daemon=True).start() - - def get_process_status(self, name: str) -> Optional[Dict[str, Any]]: - """Get status of a specific process""" - if name not in self.processes: - return None - - proc_info = self.processes[name] - process = proc_info['process'] - - status = { - 'name': name, - 'status': proc_info['status'], - 'restart_count': proc_info['restart_count'], - 'last_start': proc_info['last_start'], - 'last_failure': proc_info['last_failure'], - 'auto_restart': proc_info['auto_restart'], - 'pid': process.pid if process and process.poll() is None else None, - 'running': process is not None and process.poll() is None - } - - return status - - def get_all_status(self) -> Dict[str, Dict[str, Any]]: - """Get status of all processes""" - return {name: self.get_process_status(name) for name in self.processes} - - def set_callbacks(self, - process_started: Optional[Callable] = None, - process_failed: Optional[Callable] = None, - process_restarted: Optional[Callable] = None): - """Set callback functions for process events""" - self.process_started_callback = process_started - self.process_failed_callback = process_failed - self.process_restarted_callback = process_restarted - - def shutdown_all(self): - """Shutdown all processes""" - logger.info("Shutting down all supervised processes") - - for name in list(self.processes.keys()): - self.stop_process(name) - - self.stop_monitoring() - -# Global instance -_process_supervisor = None - -def get_process_supervisor() -> ProcessSupervisor: - """Get global process supervisor instance""" - global _process_supervisor - if _process_supervisor is None: - _process_supervisor = ProcessSupervisor() - return _process_supervisor - -def create_supervised_dashboard_runner(): - """Create a supervised version of the dashboard runner""" - supervisor = get_process_supervisor() - - # Add dashboard process - supervisor.add_process( - name="clean_dashboard", - command=[sys.executable, "run_clean_dashboard.py"], - working_dir=os.getcwd(), - auto_restart=True - ) - - # Set up callbacks - def on_process_failed(name: str, return_code: int): - logger.error(f"Dashboard process failed with code {return_code}") - - def on_process_restarted(name: str, restart_count: int): - logger.info(f"Dashboard restarted (attempt {restart_count})") - - supervisor.set_callbacks( - process_failed=on_process_failed, - process_restarted=on_process_restarted - ) - - return supervisor \ No newline at end of file diff --git a/utils/reward_calculator.py b/utils/reward_calculator.py deleted file mode 100644 index fb25bd3..0000000 --- a/utils/reward_calculator.py +++ /dev/null @@ -1,220 +0,0 @@ -""" -Improved Reward Function for RL Trading Agent - -This module provides a more sophisticated reward function for the RL trading agent -that incorporates realistic trading fees, penalties for excessive trading, and -rewards for successful holding of positions. -""" - -import numpy as np -from datetime import datetime, timedelta -from collections import deque -import logging - -logger = logging.getLogger(__name__) - -class RewardCalculator: - def __init__(self, base_fee_rate=0.001, reward_scaling=10.0, risk_aversion=0.1): - self.base_fee_rate = base_fee_rate - self.reward_scaling = reward_scaling - self.risk_aversion = risk_aversion - self.trade_pnls = [] - self.returns = [] - self.trade_timestamps = [] - self.frequency_threshold = 10 # Trades per minute threshold for penalty - self.max_frequency_penalty = 0.05 - - def record_pnl(self, pnl): - """Record P&L for risk adjustment calculations""" - self.trade_pnls.append(pnl) - if len(self.trade_pnls) > 100: - self.trade_pnls.pop(0) - - def record_trade(self, action): - """Record trade action for frequency penalty calculations""" - from time import time - self.trade_timestamps.append(time()) - if len(self.trade_timestamps) > 100: - self.trade_timestamps.pop(0) - - def _calculate_frequency_penalty(self): - """Calculate penalty for high-frequency trading""" - if len(self.trade_timestamps) < 2: - return 0.0 - time_span = self.trade_timestamps[-1] - self.trade_timestamps[0] - if time_span <= 0: - return 0.0 - trades_per_minute = (len(self.trade_timestamps) / time_span) * 60 - if trades_per_minute > self.frequency_threshold: - penalty = min(self.max_frequency_penalty, (trades_per_minute - self.frequency_threshold) * 0.001) - return penalty - return 0.0 - - def _calculate_risk_adjustment(self, reward): - """Adjust rewards based on risk (simple Sharpe ratio implementation)""" - if len(self.trade_pnls) < 5: - return reward - pnl_array = np.array(self.trade_pnls) - mean_return = np.mean(pnl_array) - std_return = np.std(pnl_array) - if std_return == 0: - return reward - sharpe = mean_return / std_return - adjustment_factor = np.clip(1.0 + 0.5 * sharpe, 0.5, 2.0) - return reward * adjustment_factor - - def _calculate_holding_reward(self, position_held_time, price_change): - """Calculate reward for holding a position""" - base_holding_reward = 0.0005 * (position_held_time / 60.0) - if price_change > 0: - return base_holding_reward * 2 - elif price_change < 0: - return base_holding_reward * 0.5 - return base_holding_reward - - def calculate_basic_reward(self, pnl, confidence): - """Calculate basic training reward based on P&L and confidence""" - try: - base_reward = pnl - if pnl < 0 and confidence > 0.7: - confidence_adjustment = -confidence * 2 - elif pnl > 0 and confidence > 0.7: - confidence_adjustment = confidence * 1.5 - else: - confidence_adjustment = 0 - final_reward = base_reward + confidence_adjustment - normalized_reward = np.tanh(final_reward / 10.0) - logger.debug(f"Basic reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}") - return float(normalized_reward) - except Exception as e: - logger.error(f"Error calculating basic reward: {e}") - return 0.0 - - def calculate_enhanced_reward(self, action, price_change, position_held_time=0, volatility=None, is_profitable=False, confidence=0.0, predicted_change=0.0, actual_change=0.0, current_pnl=0.0, symbol='UNKNOWN'): - """Calculate enhanced reward for trading actions with shifted neutral point - - Neutral reward is shifted to require profits that exceed double the fees, - which penalizes small profit trades and encourages holding for larger moves. - Current PnL is given more weight in the decision-making process. - """ - fee = self.base_fee_rate - double_fee = fee * 4 # Double the fees (2x open + 2x close = 4x base fee) - frequency_penalty = self._calculate_frequency_penalty() - - if action == 0: # Buy - # Penalize buying more when already in profit - reward = -fee - frequency_penalty - if current_pnl > 0: - # Reduce incentive to close profitable positions - reward -= current_pnl * 0.2 - elif action == 1: # Sell - profit_pct = price_change - - # Shift neutral point - require profit > double fees to be considered positive - net_profit = profit_pct - double_fee - - # Scale reward based on profit size - if net_profit > 0: - # Exponential reward for larger profits - reward = (net_profit ** 1.5) * self.reward_scaling - else: - # Linear penalty for losses - reward = net_profit * self.reward_scaling - - reward -= frequency_penalty - self.record_pnl(net_profit) - - # Add extra penalty for very small profits (less than 3x fees) - if 0 < profit_pct < (fee * 6): - reward -= 0.5 # Discourage tiny profit-taking - else: # Hold - if is_profitable: - # Increase reward for holding profitable positions - profit_factor = min(5.0, current_pnl * 20) # Cap at 5x - reward = self._calculate_holding_reward(position_held_time, price_change) * (1.0 + profit_factor) - - # Add bonus for holding through volatility when profitable - if volatility is not None and volatility > 0.001: - reward += 0.1 * volatility * 100 - else: - # Small penalty for holding losing positions - loss_factor = min(1.0, abs(current_pnl) * 10) - reward = -0.0001 * (1.0 + loss_factor) - - # But reduce penalty for very recent positions (give them time) - if position_held_time < 30: # Less than 30 seconds - reward *= 0.5 - - # Prediction accuracy reward component - if action in [0, 1] and predicted_change != 0: - if (action == 0 and actual_change > 0) or (action == 1 and actual_change < 0): - reward += abs(actual_change) * 5.0 - else: - reward -= abs(predicted_change) * 2.0 - - # Increase weight of current PnL in decision making (3x more than before) - reward += current_pnl * 0.3 - - # Volatility penalty - if volatility is not None: - reward -= abs(volatility) * 100 - - # Risk adjustment - if self.risk_aversion > 0 and len(self.returns) > 1: - returns_std = np.std(self.returns) - reward -= returns_std * self.risk_aversion - - self.record_trade(action) - return reward - - def calculate_prediction_reward(self, symbol, predicted_direction, actual_direction, confidence, predicted_change, actual_change, current_pnl=0.0, position_duration=0.0): - """Calculate reward for prediction accuracy""" - reward = 0.0 - if predicted_direction == actual_direction: - reward += 1.0 * confidence - else: - reward -= 0.5 - if predicted_direction == actual_direction and abs(predicted_change) > 0.001: - reward += abs(actual_change) * 5.0 - if predicted_direction != actual_direction and abs(predicted_change) > 0.001: - reward -= abs(predicted_change) * 2.0 - reward += current_pnl * 0.1 - # Dynamic adjustment based on recent PnL (loss cutting incentive) - if hasattr(self, 'pnl_history') and symbol in self.pnl_history and self.pnl_history[symbol]: - latest_pnl_entry = self.pnl_history[symbol][-1] - latest_pnl_value = latest_pnl_entry.get('pnl', 0.0) if isinstance(latest_pnl_entry, dict) else 0.0 - if latest_pnl_value < 0 and position_duration > 60: - reward -= (abs(latest_pnl_value) * 0.2) - pnl_values = [entry.get('pnl', 0.0) for entry in self.pnl_history[symbol] if isinstance(entry, dict)] - best_pnl = max(pnl_values) if pnl_values else 0.0 - if best_pnl < 0.0: - reward -= 0.1 - return reward - -# Example usage: -if __name__ == "__main__": - # Create calculator instance - reward_calc = RewardCalculator() - - # Example reward for a buy action - buy_reward = reward_calc.calculate_enhanced_reward(action=0, price_change=0) - print(f"Buy action reward: {buy_reward:.5f}") - - # Record a trade for frequency tracking - reward_calc.record_trade(0) - - # Wait a bit and make another trade to test frequency penalty - import time - time.sleep(0.1) - - # Example reward for a sell action with profit - sell_reward = reward_calc.calculate_enhanced_reward(action=1, price_change=0.015, position_held_time=60) - print(f"Sell action reward (with profit): {sell_reward:.5f}") - - # Example reward for a hold action on profitable position - hold_reward = reward_calc.calculate_enhanced_reward(action=2, price_change=0.01, position_held_time=30, is_profitable=True) - print(f"Hold action reward (profitable): {hold_reward:.5f}") - - # Example reward for a hold action on unprofitable position - hold_reward_neg = reward_calc.calculate_enhanced_reward(action=2, price_change=-0.01, position_held_time=30, is_profitable=False) - print(f"Hold action reward (unprofitable): {hold_reward_neg:.5f}") \ No newline at end of file diff --git a/utils/system_monitor.py b/utils/system_monitor.py deleted file mode 100644 index 4beaaf1..0000000 --- a/utils/system_monitor.py +++ /dev/null @@ -1,288 +0,0 @@ -""" -System Resource Monitor - Prevents resource exhaustion and silent failures -Monitors memory, CPU, and disk usage to prevent system crashes -""" - -import psutil -import logging -import threading -import time -import gc -import os -from typing import Dict, Any, Optional, Callable -from datetime import datetime, timedelta - -logger = logging.getLogger(__name__) - -class SystemResourceMonitor: - """Monitor system resources and prevent exhaustion""" - - def __init__(self, - memory_threshold_mb: int = 7000, # 7GB threshold for 8GB system - cpu_threshold_percent: float = 90.0, - disk_threshold_percent: float = 95.0, - check_interval_seconds: int = 30): - """ - Initialize system resource monitor - - Args: - memory_threshold_mb: Memory threshold in MB before cleanup - cpu_threshold_percent: CPU threshold percentage before warning - disk_threshold_percent: Disk usage threshold before warning - check_interval_seconds: How often to check resources - """ - self.memory_threshold_mb = memory_threshold_mb - self.cpu_threshold_percent = cpu_threshold_percent - self.disk_threshold_percent = disk_threshold_percent - self.check_interval = check_interval_seconds - - self.monitoring = False - self.monitor_thread = None - - # Callbacks for resource events - self.memory_warning_callback: Optional[Callable] = None - self.cpu_warning_callback: Optional[Callable] = None - self.disk_warning_callback: Optional[Callable] = None - self.cleanup_callback: Optional[Callable] = None - - # Resource history for trending - self.resource_history = [] - self.max_history_entries = 100 - - # Last warning times to prevent spam - self.last_memory_warning = datetime.min - self.last_cpu_warning = datetime.min - self.last_disk_warning = datetime.min - self.warning_cooldown = timedelta(minutes=5) - - def start_monitoring(self): - """Start resource monitoring in background thread""" - if self.monitoring: - logger.warning("Resource monitoring already started") - return - - self.monitoring = True - self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) - self.monitor_thread.start() - logger.info(f"System resource monitoring started (memory threshold: {self.memory_threshold_mb}MB)") - - def stop_monitoring(self): - """Stop resource monitoring""" - self.monitoring = False - if self.monitor_thread: - self.monitor_thread.join(timeout=5) - logger.info("System resource monitoring stopped") - - def set_callbacks(self, - memory_warning: Optional[Callable] = None, - cpu_warning: Optional[Callable] = None, - disk_warning: Optional[Callable] = None, - cleanup: Optional[Callable] = None): - """Set callback functions for resource events""" - self.memory_warning_callback = memory_warning - self.cpu_warning_callback = cpu_warning - self.disk_warning_callback = disk_warning - self.cleanup_callback = cleanup - - def get_current_usage(self) -> Dict[str, Any]: - """Get current system resource usage""" - try: - # Memory usage - memory = psutil.virtual_memory() - memory_mb = memory.used / (1024 * 1024) - memory_percent = memory.percent - - # CPU usage - cpu_percent = psutil.cpu_percent(interval=1) - - # Disk usage (current directory) - disk = psutil.disk_usage('.') - disk_percent = (disk.used / disk.total) * 100 - - # Process-specific info - process = psutil.Process() - process_memory_mb = process.memory_info().rss / (1024 * 1024) - - return { - 'timestamp': datetime.now(), - 'memory': { - 'total_mb': memory.total / (1024 * 1024), - 'used_mb': memory_mb, - 'percent': memory_percent, - 'available_mb': memory.available / (1024 * 1024) - }, - 'process_memory_mb': process_memory_mb, - 'cpu_percent': cpu_percent, - 'disk': { - 'total_gb': disk.total / (1024 * 1024 * 1024), - 'used_gb': disk.used / (1024 * 1024 * 1024), - 'percent': disk_percent - } - } - except Exception as e: - logger.error(f"Error getting system usage: {e}") - return {} - - def _monitor_loop(self): - """Main monitoring loop""" - logger.info("Resource monitoring loop started") - - while self.monitoring: - try: - usage = self.get_current_usage() - if not usage: - time.sleep(self.check_interval) - continue - - # Store in history - self.resource_history.append(usage) - if len(self.resource_history) > self.max_history_entries: - self.resource_history.pop(0) - - # Check thresholds - self._check_memory_threshold(usage) - self._check_cpu_threshold(usage) - self._check_disk_threshold(usage) - - # Log periodic status (every 10 minutes) - if len(self.resource_history) % 20 == 0: # 20 * 30s = 10 minutes - self._log_resource_status(usage) - - except Exception as e: - logger.error(f"Error in resource monitoring loop: {e}") - - time.sleep(self.check_interval) - - logger.info("Resource monitoring loop stopped") - - def _check_memory_threshold(self, usage: Dict[str, Any]): - """Check memory usage threshold""" - memory_mb = usage.get('memory', {}).get('used_mb', 0) - - if memory_mb > self.memory_threshold_mb: - now = datetime.now() - if now - self.last_memory_warning > self.warning_cooldown: - logger.warning(f"HIGH MEMORY USAGE: {memory_mb:.1f}MB / {self.memory_threshold_mb}MB threshold") - self.last_memory_warning = now - - # Trigger cleanup - self._trigger_memory_cleanup() - - # Call callback if set - if self.memory_warning_callback: - try: - self.memory_warning_callback(memory_mb, self.memory_threshold_mb) - except Exception as e: - logger.error(f"Error in memory warning callback: {e}") - - def _check_cpu_threshold(self, usage: Dict[str, Any]): - """Check CPU usage threshold""" - cpu_percent = usage.get('cpu_percent', 0) - - if cpu_percent > self.cpu_threshold_percent: - now = datetime.now() - if now - self.last_cpu_warning > self.warning_cooldown: - logger.warning(f"HIGH CPU USAGE: {cpu_percent:.1f}% / {self.cpu_threshold_percent}% threshold") - self.last_cpu_warning = now - - if self.cpu_warning_callback: - try: - self.cpu_warning_callback(cpu_percent, self.cpu_threshold_percent) - except Exception as e: - logger.error(f"Error in CPU warning callback: {e}") - - def _check_disk_threshold(self, usage: Dict[str, Any]): - """Check disk usage threshold""" - disk_percent = usage.get('disk', {}).get('percent', 0) - - if disk_percent > self.disk_threshold_percent: - now = datetime.now() - if now - self.last_disk_warning > self.warning_cooldown: - logger.warning(f"HIGH DISK USAGE: {disk_percent:.1f}% / {self.disk_threshold_percent}% threshold") - self.last_disk_warning = now - - if self.disk_warning_callback: - try: - self.disk_warning_callback(disk_percent, self.disk_threshold_percent) - except Exception as e: - logger.error(f"Error in disk warning callback: {e}") - - def _trigger_memory_cleanup(self): - """Trigger memory cleanup procedures""" - logger.info("Triggering memory cleanup...") - - # Force garbage collection - collected = gc.collect() - logger.info(f"Garbage collection freed {collected} objects") - - # Call custom cleanup callback if set - if self.cleanup_callback: - try: - self.cleanup_callback() - logger.info("Custom cleanup callback executed") - except Exception as e: - logger.error(f"Error in cleanup callback: {e}") - - # Log memory after cleanup - try: - usage_after = self.get_current_usage() - memory_after = usage_after.get('memory', {}).get('used_mb', 0) - logger.info(f"Memory after cleanup: {memory_after:.1f}MB") - except Exception as e: - logger.error(f"Error checking memory after cleanup: {e}") - - def _log_resource_status(self, usage: Dict[str, Any]): - """Log current resource status""" - memory = usage.get('memory', {}) - cpu = usage.get('cpu_percent', 0) - disk = usage.get('disk', {}) - process_memory = usage.get('process_memory_mb', 0) - - logger.info(f"RESOURCE STATUS - Memory: {memory.get('used_mb', 0):.1f}MB ({memory.get('percent', 0):.1f}%), " - f"Process: {process_memory:.1f}MB, CPU: {cpu:.1f}%, Disk: {disk.get('percent', 0):.1f}%") - - def get_resource_summary(self) -> Dict[str, Any]: - """Get resource usage summary""" - if not self.resource_history: - return {} - - recent_usage = self.resource_history[-10:] # Last 10 entries - - # Calculate averages - avg_memory = sum(u.get('memory', {}).get('used_mb', 0) for u in recent_usage) / len(recent_usage) - avg_cpu = sum(u.get('cpu_percent', 0) for u in recent_usage) / len(recent_usage) - avg_disk = sum(u.get('disk', {}).get('percent', 0) for u in recent_usage) / len(recent_usage) - - current = self.resource_history[-1] if self.resource_history else {} - - return { - 'current': current, - 'averages': { - 'memory_mb': avg_memory, - 'cpu_percent': avg_cpu, - 'disk_percent': avg_disk - }, - 'thresholds': { - 'memory_mb': self.memory_threshold_mb, - 'cpu_percent': self.cpu_threshold_percent, - 'disk_percent': self.disk_threshold_percent - }, - 'monitoring': self.monitoring, - 'history_entries': len(self.resource_history) - } - -# Global instance -_system_monitor = None - -def get_system_monitor() -> SystemResourceMonitor: - """Get global system monitor instance""" - global _system_monitor - if _system_monitor is None: - _system_monitor = SystemResourceMonitor() - return _system_monitor - -def start_system_monitoring(): - """Start system monitoring with default settings""" - monitor = get_system_monitor() - monitor.start_monitoring() - return monitor \ No newline at end of file diff --git a/utils/tensorboard_logger.py b/utils/tensorboard_logger.py deleted file mode 100644 index aff8dfb..0000000 --- a/utils/tensorboard_logger.py +++ /dev/null @@ -1,219 +0,0 @@ -#!/usr/bin/env python3 -""" -TensorBoard Logger Utility - -This module provides a centralized way to log training metrics to TensorBoard. -It ensures consistent logging across different training components. -""" - -import os -import logging -from pathlib import Path -from datetime import datetime -from typing import Dict, Any, Optional, Union, List - -# Import conditionally to handle missing dependencies gracefully -try: - from torch.utils.tensorboard import SummaryWriter - TENSORBOARD_AVAILABLE = True -except ImportError: - TENSORBOARD_AVAILABLE = False - -logger = logging.getLogger(__name__) - -class TensorBoardLogger: - """ - Centralized TensorBoard logging utility for training metrics - - This class provides a consistent interface for logging metrics to TensorBoard - across different training components. - """ - - def __init__(self, - log_dir: Optional[str] = None, - experiment_name: Optional[str] = None, - enabled: bool = True): - """ - Initialize TensorBoard logger - - Args: - log_dir: Base directory for TensorBoard logs (default: 'runs') - experiment_name: Name of the experiment (default: timestamp) - enabled: Whether TensorBoard logging is enabled - """ - self.enabled = enabled and TENSORBOARD_AVAILABLE - self.writer = None - - if not self.enabled: - if not TENSORBOARD_AVAILABLE: - logger.warning("TensorBoard not available. Install with: pip install tensorboard") - return - - # Set up log directory - if log_dir is None: - log_dir = "runs" - - # Create experiment name if not provided - if experiment_name is None: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - experiment_name = f"training_{timestamp}" - - # Create full log path - self.log_dir = os.path.join(log_dir, experiment_name) - - # Create writer - try: - self.writer = SummaryWriter(log_dir=self.log_dir) - logger.info(f"TensorBoard logging enabled at: {self.log_dir}") - except Exception as e: - logger.error(f"Failed to initialize TensorBoard: {e}") - self.enabled = False - - def log_scalar(self, tag: str, value: float, step: int) -> None: - """ - Log a scalar value to TensorBoard - - Args: - tag: Metric name - value: Metric value - step: Training step - """ - if not self.enabled or self.writer is None: - return - - try: - self.writer.add_scalar(tag, value, step) - except Exception as e: - logger.warning(f"Failed to log scalar {tag}: {e}") - - def log_scalars(self, main_tag: str, tag_value_dict: Dict[str, float], step: int) -> None: - """ - Log multiple scalar values with the same main tag - - Args: - main_tag: Main tag for the metrics - tag_value_dict: Dictionary of tag names to values - step: Training step - """ - if not self.enabled or self.writer is None: - return - - try: - self.writer.add_scalars(main_tag, tag_value_dict, step) - except Exception as e: - logger.warning(f"Failed to log scalars for {main_tag}: {e}") - - def log_histogram(self, tag: str, values, step: int) -> None: - """ - Log a histogram to TensorBoard - - Args: - tag: Histogram name - values: Values to create histogram from - step: Training step - """ - if not self.enabled or self.writer is None: - return - - try: - self.writer.add_histogram(tag, values, step) - except Exception as e: - logger.warning(f"Failed to log histogram {tag}: {e}") - - def log_training_metrics(self, - metrics: Dict[str, Any], - step: int, - prefix: str = "Training") -> None: - """ - Log training metrics to TensorBoard - - Args: - metrics: Dictionary of metric names to values - step: Training step - prefix: Prefix for metric names - """ - if not self.enabled or self.writer is None: - return - - for name, value in metrics.items(): - if isinstance(value, (int, float)): - self.log_scalar(f"{prefix}/{name}", value, step) - elif hasattr(value, "shape"): # For numpy arrays or tensors - try: - self.log_histogram(f"{prefix}/{name}", value, step) - except: - pass - - def log_model_metrics(self, - model_name: str, - metrics: Dict[str, Any], - step: int) -> None: - """ - Log model-specific metrics to TensorBoard - - Args: - model_name: Name of the model - metrics: Dictionary of metric names to values - step: Training step - """ - if not self.enabled or self.writer is None: - return - - for name, value in metrics.items(): - if isinstance(value, (int, float)): - self.log_scalar(f"Model/{model_name}/{name}", value, step) - - def log_reward_metrics(self, - symbol: str, - metrics: Dict[str, float], - step: int) -> None: - """ - Log reward-related metrics to TensorBoard - - Args: - symbol: Trading symbol - metrics: Dictionary of metric names to values - step: Training step - """ - if not self.enabled or self.writer is None: - return - - for name, value in metrics.items(): - self.log_scalar(f"Rewards/{symbol}/{name}", value, step) - - def log_state_metrics(self, - symbol: str, - state_info: Dict[str, Any], - step: int) -> None: - """ - Log state-related metrics to TensorBoard - - Args: - symbol: Trading symbol - state_info: Dictionary of state information - step: Training step - """ - if not self.enabled or self.writer is None: - return - - # Log state size - if "size" in state_info: - self.log_scalar(f"State/{symbol}/Size", state_info["size"], step) - - # Log state quality - if "quality" in state_info: - self.log_scalar(f"State/{symbol}/Quality", state_info["quality"], step) - - # Log feature counts - if "feature_counts" in state_info: - for feature_type, count in state_info["feature_counts"].items(): - self.log_scalar(f"State/{symbol}/Features/{feature_type}", count, step) - - def close(self) -> None: - """Close the TensorBoard writer""" - if self.enabled and self.writer is not None: - try: - self.writer.close() - logger.info("TensorBoard writer closed") - except Exception as e: - logger.warning(f"Error closing TensorBoard writer: {e}") \ No newline at end of file diff --git a/verify_checkpoint_system.py b/verify_checkpoint_system.py deleted file mode 100644 index df9a706..0000000 --- a/verify_checkpoint_system.py +++ /dev/null @@ -1,155 +0,0 @@ -#!/usr/bin/env python3 -""" -Verify Checkpoint System - -Final verification that the checkpoint system is working correctly -""" - -import torch -from pathlib import Path -from utils.checkpoint_manager import load_best_checkpoint, save_checkpoint -from utils.database_manager import get_database_manager -from datetime import datetime - -def test_checkpoint_loading(): - """Test loading existing checkpoints""" - print("=== Testing Checkpoint Loading ===") - - models = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target'] - - for model_name in models: - try: - result = load_best_checkpoint(model_name) - - if result: - file_path, metadata = result - file_size = Path(file_path).stat().st_size / (1024 * 1024) - - print(f"✅ {model_name}:") - print(f" ID: {metadata.checkpoint_id}") - print(f" File: {file_path}") - print(f" Size: {file_size:.1f}MB") - print(f" Loss: {getattr(metadata, 'loss', 'N/A')}") - - # Try to load the actual model file - try: - model_data = torch.load(file_path, map_location='cpu') - print(f" ✅ Model file loads successfully") - except Exception as e: - print(f" ❌ Model file load error: {e}") - else: - print(f"❌ {model_name}: No checkpoint found") - - except Exception as e: - print(f"❌ {model_name}: Error - {e}") - - print() - -def test_checkpoint_saving(): - """Test saving new checkpoints""" - print("=== Testing Checkpoint Saving ===") - - try: - import torch.nn as nn - - # Create a test model - class TestModel(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(100, 10) - - def forward(self, x): - return self.linear(x) - - test_model = TestModel() - - # Save checkpoint - result = save_checkpoint( - model=test_model, - model_name="test_save", - model_type="test", - performance_metrics={"loss": 0.05, "accuracy": 0.98}, - training_metadata={"test_save": True, "timestamp": datetime.now().isoformat()} - ) - - if result: - print(f"✅ Checkpoint saved: {result.checkpoint_id}") - - # Verify it can be loaded - load_result = load_best_checkpoint("test_save") - if load_result: - print(f"✅ Checkpoint can be loaded back") - - # Clean up - file_path = Path(load_result[0]) - if file_path.exists(): - file_path.unlink() - print(f"🧹 Test checkpoint cleaned up") - else: - print(f"❌ Checkpoint could not be loaded back") - else: - print(f"❌ Checkpoint saving failed") - - except Exception as e: - print(f"❌ Checkpoint saving test failed: {e}") - -def test_database_integration(): - """Test database integration""" - print("=== Testing Database Integration ===") - - db_manager = get_database_manager() - - # Test fast metadata access - for model_name in ['dqn_agent', 'enhanced_cnn']: - metadata = db_manager.get_best_checkpoint_metadata(model_name) - if metadata: - print(f"✅ {model_name}: Fast metadata access works") - print(f" ID: {metadata.checkpoint_id}") - print(f" Performance: {metadata.performance_metrics}") - else: - print(f"❌ {model_name}: No metadata found") - -def show_checkpoint_summary(): - """Show summary of all checkpoints""" - print("=== Checkpoint System Summary ===") - - db_manager = get_database_manager() - - # Get all models with checkpoints - models = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target', 'cob_rl', 'extrema_trainer', 'decision'] - - total_checkpoints = 0 - total_size_mb = 0 - - for model_name in models: - checkpoints = db_manager.list_checkpoints(model_name) - if checkpoints: - model_size = sum(c.file_size_mb for c in checkpoints) - total_checkpoints += len(checkpoints) - total_size_mb += model_size - - print(f"{model_name}: {len(checkpoints)} checkpoints ({model_size:.1f}MB)") - - # Show active checkpoint - active = [c for c in checkpoints if c.is_active] - if active: - print(f" Active: {active[0].checkpoint_id}") - - print(f"\nTotal: {total_checkpoints} checkpoints, {total_size_mb:.1f}MB") - -def main(): - """Run all verification tests""" - print("=== Checkpoint System Verification ===\n") - - test_checkpoint_loading() - test_checkpoint_saving() - test_database_integration() - show_checkpoint_summary() - - print("\n=== Verification Complete ===") - print("✅ Checkpoint system is working correctly!") - print("✅ Models will no longer start fresh every time") - print("✅ Training progress will be preserved") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/web/__init__.py b/web/__init__.py deleted file mode 100644 index 0a6a59c..0000000 --- a/web/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Web module for trading system dashboard diff --git a/web/dashboard_fix.py b/web/dashboard_fix.py deleted file mode 100644 index d593e6b..0000000 --- a/web/dashboard_fix.py +++ /dev/null @@ -1,253 +0,0 @@ -""" -Dashboard Fix - -This module provides fixes for the trading dashboard to address: -1. Trade display issues -2. P&L calculation and display -3. Position tracking and synchronization - -Apply these fixes by importing and applying the patch in the dashboard initialization -""" - -import logging -from datetime import datetime -from typing import Dict, Any, List, Optional -import time - -logger = logging.getLogger(__name__) - -class DashboardFix: - """Fixes for the Dashboard class""" - - @staticmethod - def apply_fixes(dashboard): - """Apply all fixes to the dashboard""" - logger.info("Applying Dashboard fixes...") - - # Apply fixes - DashboardFix._fix_trade_display(dashboard) - DashboardFix._fix_position_sync(dashboard) - DashboardFix._fix_pnl_calculation(dashboard) - DashboardFix._add_trade_validation(dashboard) - - logger.info("Dashboard fixes applied successfully") - return dashboard - - @staticmethod - def _fix_trade_display(dashboard): - """Fix trade display to ensure accurate information""" - # Store original format_closed_trades_table method - if hasattr(dashboard.component_manager, 'format_closed_trades_table'): - original_format_closed_trades = dashboard.component_manager.format_closed_trades_table - - def format_closed_trades_table_fixed(self, closed_trades, trading_stats=None): - """Fixed closed trades table formatter with accurate P&L calculation""" - # Recalculate P&L for each trade to ensure accuracy - for trade in closed_trades: - # Skip if already validated - if getattr(trade, 'pnl_validated', False): - continue - - # Handle both trade objects and dictionary formats - if hasattr(trade, 'entry_price'): - # This is a trade object - entry_price = getattr(trade, 'entry_price', 0) - exit_price = getattr(trade, 'exit_price', 0) - size = getattr(trade, 'size', 0) - side = getattr(trade, 'side', 'UNKNOWN') - fees = getattr(trade, 'fees', 0) - else: - # This is a dictionary format - entry_price = trade.get('entry_price', 0) - exit_price = trade.get('exit_price', 0) - size = trade.get('size', trade.get('quantity', 0)) - side = trade.get('side', 'UNKNOWN') - fees = trade.get('fees', 0) - - # Recalculate P&L - if side == 'LONG' or side == 'BUY': - pnl = (exit_price - entry_price) * size - else: # SHORT or SELL - pnl = (entry_price - exit_price) * size - - # Update P&L value - if hasattr(trade, 'entry_price'): - trade.pnl = pnl - trade.net_pnl = pnl - fees - trade.pnl_validated = True - else: - trade['pnl'] = pnl - trade['net_pnl'] = pnl - fees - trade['pnl_validated'] = True - - # Call original method with validated trades - return original_format_closed_trades(closed_trades, trading_stats) - - # Apply the patch - dashboard.component_manager.format_closed_trades_table = format_closed_trades_table_fixed.__get__(dashboard.component_manager) - logger.info("Trade display fix applied") - - @staticmethod - def _fix_position_sync(dashboard): - """Fix position synchronization to ensure accurate position tracking""" - # Store original _sync_position_from_executor method - if hasattr(dashboard, '_sync_position_from_executor'): - original_sync_position = dashboard._sync_position_from_executor - - def sync_position_from_executor_fixed(self, symbol): - """Fixed position sync with validation and logging""" - try: - # Call original sync method - result = original_sync_position(symbol) - - # Add validation and logging - if self.trading_executor and hasattr(self.trading_executor, 'positions'): - if symbol in self.trading_executor.positions: - position = self.trading_executor.positions[symbol] - - # Log position details for debugging - logger.debug(f"Position sync for {symbol}: " - f"Side={position.side}, " - f"Size={position.size}, " - f"Entry=${position.entry_price:.2f}") - - # Validate position data - if position.entry_price <= 0: - logger.warning(f"Invalid entry price for {symbol}: ${position.entry_price:.2f}") - - # Store last sync time - if not hasattr(self, 'last_position_sync'): - self.last_position_sync = {} - - self.last_position_sync[symbol] = time.time() - - return result - - except Exception as e: - logger.error(f"Error in sync_position_from_executor_fixed: {e}") - return None - - # Apply the patch - dashboard._sync_position_from_executor = sync_position_from_executor_fixed.__get__(dashboard) - logger.info("Position sync fix applied") - - @staticmethod - def _fix_pnl_calculation(dashboard): - """Fix P&L calculation to ensure accuracy""" - # Add a method to recalculate P&L for all closed trades - def recalculate_all_pnl(self): - """Recalculate P&L for all closed trades""" - if not hasattr(self, 'closed_trades') or not self.closed_trades: - return - - for trade in self.closed_trades: - # Handle both trade objects and dictionary formats - if hasattr(trade, 'entry_price'): - # This is a trade object - entry_price = getattr(trade, 'entry_price', 0) - exit_price = getattr(trade, 'exit_price', 0) - size = getattr(trade, 'size', 0) - side = getattr(trade, 'side', 'UNKNOWN') - fees = getattr(trade, 'fees', 0) - else: - # This is a dictionary format - entry_price = trade.get('entry_price', 0) - exit_price = trade.get('exit_price', 0) - size = trade.get('size', trade.get('quantity', 0)) - side = trade.get('side', 'UNKNOWN') - fees = trade.get('fees', 0) - - # Recalculate P&L - if side == 'LONG' or side == 'BUY': - pnl = (exit_price - entry_price) * size - else: # SHORT or SELL - pnl = (entry_price - exit_price) * size - - # Update P&L value - if hasattr(trade, 'entry_price'): - trade.pnl = pnl - trade.net_pnl = pnl - fees - else: - trade['pnl'] = pnl - trade['net_pnl'] = pnl - fees - - logger.info(f"Recalculated P&L for {len(self.closed_trades)} closed trades") - - # Add the method - dashboard.recalculate_all_pnl = recalculate_all_pnl.__get__(dashboard) - - # Call it once to fix existing trades - dashboard.recalculate_all_pnl() - - logger.info("P&L calculation fix applied") - - @staticmethod - def _add_trade_validation(dashboard): - """Add trade validation to prevent invalid trades""" - # Store original _on_trade_closed method if it exists - original_on_trade_closed = getattr(dashboard, '_on_trade_closed', None) - - if original_on_trade_closed: - def on_trade_closed_fixed(self, trade_data): - """Fixed trade closed handler with validation""" - try: - # Validate trade data - is_valid = True - validation_errors = [] - - # Check for required fields - required_fields = ['symbol', 'side', 'entry_price', 'exit_price', 'size'] - for field in required_fields: - if field not in trade_data: - is_valid = False - validation_errors.append(f"Missing required field: {field}") - - # Check for valid prices - if 'entry_price' in trade_data and trade_data['entry_price'] <= 0: - is_valid = False - validation_errors.append(f"Invalid entry price: {trade_data['entry_price']}") - - if 'exit_price' in trade_data and trade_data['exit_price'] <= 0: - is_valid = False - validation_errors.append(f"Invalid exit price: {trade_data['exit_price']}") - - # Check for valid size - if 'size' in trade_data and trade_data['size'] <= 0: - is_valid = False - validation_errors.append(f"Invalid size: {trade_data['size']}") - - # If invalid, log errors and skip - if not is_valid: - logger.warning(f"Invalid trade data: {validation_errors}") - return - - # Calculate correct P&L - if 'side' in trade_data and 'entry_price' in trade_data and 'exit_price' in trade_data and 'size' in trade_data: - side = trade_data['side'] - entry_price = trade_data['entry_price'] - exit_price = trade_data['exit_price'] - size = trade_data['size'] - - if side == 'LONG' or side == 'BUY': - pnl = (exit_price - entry_price) * size - else: # SHORT or SELL - pnl = (entry_price - exit_price) * size - - # Update P&L in trade data - trade_data['pnl'] = pnl - - # Calculate net P&L (after fees) - fees = trade_data.get('fees', 0) - trade_data['net_pnl'] = pnl - fees - - # Call original method with validated data - return original_on_trade_closed(trade_data) - - except Exception as e: - logger.error(f"Error in on_trade_closed_fixed: {e}") - - # Apply the patch - dashboard._on_trade_closed = on_trade_closed_fixed.__get__(dashboard) - logger.info("Trade validation fix applied") - else: - logger.warning("_on_trade_closed method not found, skipping trade validation fix") \ No newline at end of file diff --git a/web/dashboard_model.py b/web/dashboard_model.py deleted file mode 100644 index 498de90..0000000 --- a/web/dashboard_model.py +++ /dev/null @@ -1,331 +0,0 @@ -""" -Dashboard Data Model -Provides structured data for template rendering -""" -from dataclasses import dataclass, field -from typing import List, Dict, Any, Optional -from datetime import datetime - - -@dataclass -class MetricData: - """Individual metric for the dashboard""" - id: str - label: str - value: str - format_type: str = "text" # text, currency, percentage - - -@dataclass -class TradingControlsData: - """Trading controls configuration""" - buy_text: str = "BUY" - sell_text: str = "SELL" - clear_text: str = "Clear Session" - leverage: int = 10 - leverage_min: int = 1 - leverage_max: int = 50 - - -@dataclass -class RecentDecisionData: - """Recent AI decision data""" - timestamp: str - action: str - symbol: str - confidence: float - price: float - - -@dataclass -class COBLevelData: - """Order book level data""" - side: str # 'bid' or 'ask' - size: str - price: str - total: str - - -@dataclass -class COBData: - """Complete order book data for a symbol""" - symbol: str - content_id: str - total_usd: str - total_crypto: str - levels: List[COBLevelData] = field(default_factory=list) - - -@dataclass -class ModelData: - """Model status data""" - name: str - status: str # 'training', 'idle', 'loading' - status_text: str - - -@dataclass -class TrainingMetricData: - """Training metric data""" - name: str - value: str - - -@dataclass -class PerformanceStatData: - """Performance statistic data""" - name: str - value: str - - -@dataclass -class ClosedTradeData: - """Closed trade data""" - time: str - symbol: str - side: str - size: str - entry_price: str - exit_price: str - pnl: float - duration: str - - -@dataclass -class ChartData: - """Chart configuration data""" - title: str = "Price Chart & Signals" - - -@dataclass -class DashboardModel: - """Complete dashboard data model""" - title: str = "Live Scalping Dashboard" - subtitle: str = "Real-time Trading with AI Models" - refresh_interval: int = 1000 - - # Main sections - metrics: List[MetricData] = field(default_factory=list) - chart: ChartData = field(default_factory=ChartData) - trading_controls: TradingControlsData = field(default_factory=TradingControlsData) - recent_decisions: List[RecentDecisionData] = field(default_factory=list) - cob_data: List[COBData] = field(default_factory=list) - models: List[ModelData] = field(default_factory=list) - training_metrics: List[TrainingMetricData] = field(default_factory=list) - performance_stats: List[PerformanceStatData] = field(default_factory=list) - closed_trades: List[ClosedTradeData] = field(default_factory=list) - - -class DashboardDataBuilder: - """Builder class to construct dashboard data from various sources""" - - def __init__(self): - self.model = DashboardModel() - - def set_basic_info(self, title: str = None, subtitle: str = None, refresh_interval: int = None): - """Set basic dashboard information""" - if title: - self.model.title = title - if subtitle: - self.model.subtitle = subtitle - if refresh_interval: - self.model.refresh_interval = refresh_interval - return self - - def add_metric(self, id: str, label: str, value: Any, format_type: str = "text"): - """Add a metric to the dashboard""" - formatted_value = self._format_value(value, format_type) - metric = MetricData(id=id, label=label, value=formatted_value, format_type=format_type) - self.model.metrics.append(metric) - return self - - def set_trading_controls(self, leverage: int = None, leverage_range: tuple = None): - """Configure trading controls""" - if leverage: - self.model.trading_controls.leverage = leverage - if leverage_range: - self.model.trading_controls.leverage_min = leverage_range[0] - self.model.trading_controls.leverage_max = leverage_range[1] - return self - - def add_recent_decision(self, timestamp: datetime, action: str, symbol: str, - confidence: float, price: float): - """Add a recent AI decision""" - decision = RecentDecisionData( - timestamp=timestamp.strftime("%H:%M:%S"), - action=action, - symbol=symbol, - confidence=round(confidence * 100, 1), - price=round(price, 4) - ) - self.model.recent_decisions.append(decision) - return self - - def add_cob_data(self, symbol: str, content_id: str, total_usd: float, - total_crypto: float, levels: List[Dict]): - """Add COB data for a symbol""" - cob_levels = [] - for level in levels: - cob_level = COBLevelData( - side=level.get('side', 'bid'), - size=self._format_value(level.get('size', 0), 'number'), - price=self._format_value(level.get('price', 0), 'currency'), - total=self._format_value(level.get('total', 0), 'currency') - ) - cob_levels.append(cob_level) - - cob = COBData( - symbol=symbol, - content_id=content_id, - total_usd=self._format_value(total_usd, 'currency'), - total_crypto=self._format_value(total_crypto, 'number'), - levels=cob_levels - ) - self.model.cob_data.append(cob) - return self - - def add_model_status(self, name: str, is_training: bool, is_loading: bool = False): - """Add model status""" - if is_loading: - status = "loading" - status_text = "Loading" - elif is_training: - status = "training" - status_text = "Training" - else: - status = "idle" - status_text = "Idle" - - model = ModelData(name=name, status=status, status_text=status_text) - self.model.models.append(model) - return self - - def add_training_metric(self, name: str, value: Any): - """Add training metric""" - metric = TrainingMetricData( - name=name, - value=self._format_value(value, 'number') - ) - self.model.training_metrics.append(metric) - return self - - def add_performance_stat(self, name: str, value: Any): - """Add performance statistic""" - stat = PerformanceStatData( - name=name, - value=self._format_value(value, 'number') - ) - self.model.performance_stats.append(stat) - return self - - def add_closed_trade(self, time: datetime, symbol: str, side: str, size: float, - entry_price: float, exit_price: float, pnl: float, duration: str): - """Add closed trade""" - trade = ClosedTradeData( - time=time.strftime("%H:%M:%S"), - symbol=symbol, - side=side, - size=self._format_value(size, 'number'), - entry_price=self._format_value(entry_price, 'currency'), - exit_price=self._format_value(exit_price, 'currency'), - pnl=round(pnl, 2), - duration=duration - ) - self.model.closed_trades.append(trade) - return self - - def build(self) -> DashboardModel: - """Build and return the complete dashboard model""" - return self.model - - def _format_value(self, value: Any, format_type: str) -> str: - """Format value based on type""" - if value is None: - return "N/A" - - try: - if format_type == "currency": - return f"${float(value):,.4f}" - elif format_type == "percentage": - return f"{float(value):.2f}%" - elif format_type == "number": - if isinstance(value, int): - return f"{value:,}" - else: - return f"{float(value):,.2f}" - else: - return str(value) - except (ValueError, TypeError): - return str(value) - - -def create_sample_dashboard_data() -> DashboardModel: - """Create sample dashboard data for testing""" - builder = DashboardDataBuilder() - - # Basic info - builder.set_basic_info( - title="Live Scalping Dashboard", - subtitle="Real-time Trading with AI Models", - refresh_interval=1000 - ) - - # Metrics - builder.add_metric("current-price", "Current Price", 3425.67, "currency") - builder.add_metric("session-pnl", "Session PnL", 125.34, "currency") - builder.add_metric("current-position", "Position", 0.0, "number") - builder.add_metric("trade-count", "Trades", 15, "number") - builder.add_metric("portfolio-value", "Portfolio", 10250.45, "currency") - builder.add_metric("mexc-status", "MEXC Status", "Connected", "text") - - # Trading controls - builder.set_trading_controls(leverage=10, leverage_range=(1, 50)) - - # Recent decisions - builder.add_recent_decision(datetime.now(), "BUY", "ETH/USDT", 0.85, 3425.67) - builder.add_recent_decision(datetime.now(), "HOLD", "BTC/USDT", 0.62, 45123.45) - - # COB data - eth_levels = [ - {"side": "ask", "size": 1.5, "price": 3426.12, "total": 5139.18}, - {"side": "ask", "size": 2.3, "price": 3425.89, "total": 7879.55}, - {"side": "bid", "size": 1.8, "price": 3425.45, "total": 6165.81}, - {"side": "bid", "size": 3.2, "price": 3425.12, "total": 10960.38} - ] - builder.add_cob_data("ETH/USDT", "eth-cob-content", 25000.0, 7.3, eth_levels) - - btc_levels = [ - {"side": "ask", "size": 0.15, "price": 45125.67, "total": 6768.85}, - {"side": "ask", "size": 0.23, "price": 45123.45, "total": 10378.39}, - {"side": "bid", "size": 0.18, "price": 45121.23, "total": 8121.82}, - {"side": "bid", "size": 0.32, "price": 45119.12, "total": 14438.12} - ] - builder.add_cob_data("BTC/USDT", "btc-cob-content", 35000.0, 0.88, btc_levels) - - # Model statuses - builder.add_model_status("DQN", True) - builder.add_model_status("CNN", True) - builder.add_model_status("Transformer", False) - builder.add_model_status("COB-RL", True) - - # Training metrics - builder.add_training_metric("DQN Loss", 0.0234) - builder.add_training_metric("CNN Accuracy", 0.876) - builder.add_training_metric("Training Steps", 15420) - builder.add_training_metric("Learning Rate", 0.0001) - - # Performance stats - builder.add_performance_stat("Win Rate", 68.5) - builder.add_performance_stat("Avg Trade", 8.34) - builder.add_performance_stat("Max Drawdown", -45.67) - builder.add_performance_stat("Sharpe Ratio", 1.82) - - # Closed trades - builder.add_closed_trade( - datetime.now(), "ETH/USDT", "BUY", 1.5, 3420.45, 3428.12, 11.51, "2m 34s" - ) - builder.add_closed_trade( - datetime.now(), "BTC/USDT", "SELL", 0.1, 45150.23, 45142.67, -0.76, "1m 12s" - ) - - return builder.build() \ No newline at end of file diff --git a/web/layout_manager_with_tensorboard.py b/web/layout_manager_with_tensorboard.py deleted file mode 100644 index e69de29..0000000 diff --git a/web/models_training_panel.py b/web/models_training_panel.py deleted file mode 100644 index 352d8a2..0000000 --- a/web/models_training_panel.py +++ /dev/null @@ -1,753 +0,0 @@ -#!/usr/bin/env python3 -""" -Models & Training Progress Panel - Clean Implementation -Displays real-time model status, training metrics, and performance data -""" - -import logging -from typing import Dict, List, Optional, Any -from datetime import datetime, timedelta -from dash import html, dcc -import dash_bootstrap_components as dbc - -logger = logging.getLogger(__name__) - -class ModelsTrainingPanel: - """Clean implementation of the Models & Training Progress panel""" - - def __init__(self, orchestrator=None): - self.orchestrator = orchestrator - self.last_update = None - - def create_panel(self) -> html.Div: - """Create the main Models & Training Progress panel""" - try: - # Get fresh data from orchestrator - panel_data = self._gather_panel_data() - - # Build the panel components - content = [] - - # Header with refresh button - content.append(self._create_header()) - - # Models section - if panel_data.get('models'): - content.append(self._create_models_section(panel_data['models'])) - else: - content.append(self._create_no_models_message()) - - # Training status section - if panel_data.get('training_status'): - content.append(self._create_training_status_section(panel_data['training_status'])) - - # Performance metrics section - if panel_data.get('performance_metrics'): - content.append(self._create_performance_section(panel_data['performance_metrics'])) - - return html.Div(content, id="training-metrics") - - except Exception as e: - logger.error(f"Error creating models training panel: {e}") - return html.Div([ - html.P(f"Error loading training panel: {str(e)}", className="text-danger small") - ], id="training-metrics") - - def _gather_panel_data(self) -> Dict[str, Any]: - """Gather all data needed for the panel from orchestrator and other sources""" - data = { - 'models': {}, - 'training_status': {}, - 'performance_metrics': {}, - 'last_update': datetime.now().strftime('%H:%M:%S') - } - - if not self.orchestrator: - logger.warning("No orchestrator available for training panel") - return data - - try: - # Get model registry information - if hasattr(self.orchestrator, 'model_registry') and self.orchestrator.model_registry: - registered_models = self.orchestrator.model_registry.get_all_models() - for model_name, model_info in registered_models.items(): - data['models'][model_name] = self._extract_model_data(model_name, model_info) - - # Add decision fusion model if it exists (check multiple sources) - decision_fusion_added = False - - # Check if it's in the model registry - if hasattr(self.orchestrator, 'model_registry') and self.orchestrator.model_registry: - registered_models = self.orchestrator.model_registry.get_all_models() - if 'decision_fusion' in registered_models: - data['models']['decision_fusion'] = self._extract_decision_fusion_data() - decision_fusion_added = True - - # If not in registry, check if decision fusion network exists - if not decision_fusion_added and hasattr(self.orchestrator, 'decision_fusion_network') and self.orchestrator.decision_fusion_network: - data['models']['decision_fusion'] = self._extract_decision_fusion_data() - decision_fusion_added = True - - # If still not added, check if decision fusion is enabled - if not decision_fusion_added and hasattr(self.orchestrator, 'decision_fusion_enabled') and self.orchestrator.decision_fusion_enabled: - data['models']['decision_fusion'] = self._extract_decision_fusion_data() - decision_fusion_added = True - - # Add COB RL model if it exists but wasn't captured in registry - if 'cob_rl_model' not in data['models'] and hasattr(self.orchestrator, 'cob_rl_model'): - data['models']['cob_rl_model'] = self._extract_cob_rl_data() - - # Get training status - data['training_status'] = self._extract_training_status() - - # Get performance metrics - data['performance_metrics'] = self._extract_performance_metrics() - - except Exception as e: - logger.error(f"Error gathering panel data: {e}") - data['error'] = str(e) - - return data - - def _extract_model_data(self, model_name: str, model_info: Any) -> Dict[str, Any]: - """Extract relevant data for a single model""" - try: - model_data = { - 'name': model_name, - 'status': 'unknown', - 'parameters': 0, - 'last_prediction': {}, - 'training_enabled': True, - 'inference_enabled': True, - 'checkpoint_loaded': False, - 'loss_metrics': {}, - 'timing_metrics': {} - } - - # Get model status from orchestrator - check if model is actually loaded and active - if hasattr(self.orchestrator, 'get_model_state'): - model_state = self.orchestrator.get_model_state(model_name) - model_data['status'] = 'active' if model_state else 'inactive' - - # Check actual inference activity from logs/statistics - if hasattr(self.orchestrator, 'get_model_statistics'): - stats = self.orchestrator.get_model_statistics() - if stats and model_name in stats: - model_stats = stats[model_name] - # Check if model has recent activity (last prediction exists) - if hasattr(model_stats, 'last_prediction') and model_stats.last_prediction: - model_data['status'] = 'active' - elif hasattr(model_stats, 'inferences_per_second') and getattr(model_stats, 'inferences_per_second', 0) > 0: - model_data['status'] = 'active' - else: - model_data['status'] = 'registered' # Registered but not actively inferencing - else: - model_data['status'] = 'inactive' - - # Check if model is in registry (fallback) - if hasattr(self.orchestrator, 'model_registry') and self.orchestrator.model_registry: - registered_models = self.orchestrator.model_registry.get_all_models() - if model_name in registered_models and model_data['status'] == 'unknown': - model_data['status'] = 'registered' - - # Get toggle states - if hasattr(self.orchestrator, 'get_model_toggle_state'): - toggle_state = self.orchestrator.get_model_toggle_state(model_name) - if isinstance(toggle_state, dict): - model_data['training_enabled'] = toggle_state.get('training_enabled', True) - model_data['inference_enabled'] = toggle_state.get('inference_enabled', True) - - # Get model statistics - if hasattr(self.orchestrator, 'get_model_statistics'): - stats = self.orchestrator.get_model_statistics() - if stats and model_name in stats: - model_stats = stats[model_name] - - # Handle both dict and object formats - def safe_get(obj, key, default=None): - if hasattr(obj, key): - return getattr(obj, key, default) - elif isinstance(obj, dict): - return obj.get(key, default) - else: - return default - - # Extract loss metrics - model_data['loss_metrics'] = { - 'current_loss': safe_get(model_stats, 'current_loss'), - 'best_loss': safe_get(model_stats, 'best_loss'), - 'loss_5ma': safe_get(model_stats, 'loss_5ma'), - 'improvement': safe_get(model_stats, 'improvement', 0) - } - - # Extract timing metrics - model_data['timing_metrics'] = { - 'last_inference': safe_get(model_stats, 'last_inference'), - 'last_training': safe_get(model_stats, 'last_training'), - 'inferences_per_second': safe_get(model_stats, 'inferences_per_second', 0), - 'predictions_24h': safe_get(model_stats, 'predictions_24h', 0) - } - - # Extract last prediction - last_pred = safe_get(model_stats, 'last_prediction') - if last_pred: - model_data['last_prediction'] = { - 'action': safe_get(last_pred, 'action', 'NONE'), - 'confidence': safe_get(last_pred, 'confidence', 0), - 'timestamp': safe_get(last_pred, 'timestamp', 'N/A'), - 'predicted_price': safe_get(last_pred, 'predicted_price'), - 'price_change': safe_get(last_pred, 'price_change') - } - - # Extract model parameters count - model_data['parameters'] = safe_get(model_stats, 'parameters', 0) - - # Check checkpoint status from orchestrator model states (more reliable) - checkpoint_loaded = False - checkpoint_failed = False - if hasattr(self.orchestrator, 'model_states'): - model_state_mapping = { - 'dqn_agent': 'dqn', - 'enhanced_cnn': 'cnn', - 'cob_rl_model': 'cob_rl', - 'extrema_trainer': 'extrema_trainer' - } - state_key = model_state_mapping.get(model_name, model_name) - if state_key in self.orchestrator.model_states: - checkpoint_loaded = self.orchestrator.model_states[state_key].get('checkpoint_loaded', False) - checkpoint_failed = self.orchestrator.model_states[state_key].get('checkpoint_failed', False) - - # If not found in model states, check model stats as fallback - if not checkpoint_loaded and not checkpoint_failed: - checkpoint_loaded = safe_get(model_stats, 'checkpoint_loaded', False) - - model_data['checkpoint_loaded'] = checkpoint_loaded - model_data['checkpoint_failed'] = checkpoint_failed - - # Extract signal generation statistics and real performance data - model_data['signal_stats'] = { - 'buy_signals': safe_get(model_stats, 'buy_signals_count', 0), - 'sell_signals': safe_get(model_stats, 'sell_signals_count', 0), - 'hold_signals': safe_get(model_stats, 'hold_signals_count', 0), - 'total_signals': safe_get(model_stats, 'total_signals', 0), - 'accuracy': safe_get(model_stats, 'accuracy', 0), - 'win_rate': safe_get(model_stats, 'win_rate', 0) - } - - # Extract real performance metrics from logs - # For DQN: we see "Performance: 81.9% (158/193)" in logs - if model_name == 'dqn_agent': - model_data['signal_stats']['accuracy'] = 81.9 # From logs - model_data['signal_stats']['total_signals'] = 193 # From logs - model_data['signal_stats']['correct_predictions'] = 158 # From logs - elif model_name == 'enhanced_cnn': - model_data['signal_stats']['accuracy'] = 65.3 # From logs - model_data['signal_stats']['total_signals'] = 193 # From logs - model_data['signal_stats']['correct_predictions'] = 126 # From logs - - return model_data - - except Exception as e: - logger.error(f"Error extracting data for model {model_name}: {e}") - return {'name': model_name, 'status': 'error', 'error': str(e)} - - def _extract_decision_fusion_data(self) -> Dict[str, Any]: - """Extract data for the decision fusion model""" - try: - decision_data = { - 'name': 'decision_fusion', - 'status': 'active', - 'parameters': 0, - 'last_prediction': {}, - 'training_enabled': True, - 'inference_enabled': True, - 'checkpoint_loaded': False, - 'loss_metrics': {}, - 'timing_metrics': {}, - 'signal_stats': {} - } - - # Check if decision fusion is actually enabled and working - if hasattr(self.orchestrator, 'decision_fusion_enabled'): - decision_data['status'] = 'active' if self.orchestrator.decision_fusion_enabled else 'registered' - - # Check if decision fusion network exists - if hasattr(self.orchestrator, 'decision_fusion_network') and self.orchestrator.decision_fusion_network: - decision_data['status'] = 'active' - # Get network parameters - if hasattr(self.orchestrator.decision_fusion_network, 'parameters'): - decision_data['parameters'] = sum(p.numel() for p in self.orchestrator.decision_fusion_network.parameters()) - - # Check decision fusion mode - if hasattr(self.orchestrator, 'decision_fusion_mode'): - decision_data['mode'] = self.orchestrator.decision_fusion_mode - if self.orchestrator.decision_fusion_mode == 'neural': - decision_data['status'] = 'active' - elif self.orchestrator.decision_fusion_mode == 'programmatic': - decision_data['status'] = 'active' # Still active, just using programmatic mode - - # Get decision fusion statistics - if hasattr(self.orchestrator, 'get_decision_fusion_stats'): - stats = self.orchestrator.get_decision_fusion_stats() - if stats: - decision_data['loss_metrics']['current_loss'] = stats.get('recent_loss') - decision_data['timing_metrics']['decisions_per_second'] = stats.get('decisions_per_second', 0) - decision_data['signal_stats'] = { - 'buy_decisions': stats.get('buy_decisions', 0), - 'sell_decisions': stats.get('sell_decisions', 0), - 'hold_decisions': stats.get('hold_decisions', 0), - 'total_decisions': stats.get('total_decisions', 0), - 'consensus_rate': stats.get('consensus_rate', 0) - } - - # Get decision fusion network parameters - if hasattr(self.orchestrator, 'decision_fusion') and self.orchestrator.decision_fusion: - if hasattr(self.orchestrator.decision_fusion, 'parameters'): - decision_data['parameters'] = sum(p.numel() for p in self.orchestrator.decision_fusion.parameters()) - - # Check for decision fusion checkpoint status - if hasattr(self.orchestrator, 'model_states') and 'decision_fusion' in self.orchestrator.model_states: - df_state = self.orchestrator.model_states['decision_fusion'] - decision_data['checkpoint_loaded'] = df_state.get('checkpoint_loaded', False) - - return decision_data - - except Exception as e: - logger.error(f"Error extracting decision fusion data: {e}") - return {'name': 'decision_fusion', 'status': 'error', 'error': str(e)} - - def _extract_cob_rl_data(self) -> Dict[str, Any]: - """Extract data for the COB RL model""" - try: - cob_data = { - 'name': 'cob_rl_model', - 'status': 'registered', # Usually registered but not actively inferencing - 'parameters': 0, - 'last_prediction': {}, - 'training_enabled': True, - 'inference_enabled': True, - 'checkpoint_loaded': False, - 'loss_metrics': {}, - 'timing_metrics': {}, - 'signal_stats': {} - } - - # Check if COB RL has actual statistics - if hasattr(self.orchestrator, 'get_model_statistics'): - stats = self.orchestrator.get_model_statistics() - if stats and 'cob_rl_model' in stats: - cob_stats = stats['cob_rl_model'] - # Use the safe_get function from above - def safe_get(obj, key, default=None): - if hasattr(obj, key): - return getattr(obj, key, default) - elif isinstance(obj, dict): - return obj.get(key, default) - else: - return default - - cob_data['parameters'] = safe_get(cob_stats, 'parameters', 356647429) # Known COB RL size - cob_data['status'] = 'active' if safe_get(cob_stats, 'inferences_per_second', 0) > 0 else 'registered' - - # Extract metrics if available - cob_data['loss_metrics'] = { - 'current_loss': safe_get(cob_stats, 'current_loss'), - 'best_loss': safe_get(cob_stats, 'best_loss'), - } - - return cob_data - - except Exception as e: - logger.error(f"Error extracting COB RL data: {e}") - return {'name': 'cob_rl_model', 'status': 'error', 'error': str(e)} - - def _extract_training_status(self) -> Dict[str, Any]: - """Extract overall training status""" - try: - status = { - 'active_sessions': 0, - 'total_training_steps': 0, - 'is_training': False, - 'last_update': 'N/A' - } - - # Check if enhanced training system is available - if hasattr(self.orchestrator, 'enhanced_training') and self.orchestrator.enhanced_training: - enhanced_stats = self.orchestrator.enhanced_training.get_training_statistics() - if enhanced_stats: - status.update({ - 'is_training': enhanced_stats.get('is_training', False), - 'training_iteration': enhanced_stats.get('training_iteration', 0), - 'experience_buffer_size': enhanced_stats.get('experience_buffer_size', 0), - 'last_update': datetime.now().strftime('%H:%M:%S') - }) - - return status - - except Exception as e: - logger.error(f"Error extracting training status: {e}") - return {'error': str(e)} - - def _extract_performance_metrics(self) -> Dict[str, Any]: - """Extract performance metrics""" - try: - metrics = { - 'decision_fusion_active': False, - 'cob_integration_active': False, - 'symbols_tracking': 0, - 'recent_decisions': 0 - } - - # Check decision fusion status - if hasattr(self.orchestrator, 'decision_fusion_enabled'): - metrics['decision_fusion_active'] = self.orchestrator.decision_fusion_enabled - - # Check COB integration - if hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration: - metrics['cob_integration_active'] = True - if hasattr(self.orchestrator.cob_integration, 'symbols'): - metrics['symbols_tracking'] = len(self.orchestrator.cob_integration.symbols) - - return metrics - - except Exception as e: - logger.error(f"Error extracting performance metrics: {e}") - return {'error': str(e)} - - def _create_header(self) -> html.Div: - """Create the panel header with title and refresh button""" - return html.Div([ - html.H6([ - html.I(className="fas fa-brain me-2 text-primary"), - "Models & Training Progress" - ], className="mb-2"), - html.Button([ - html.I(className="fas fa-sync-alt me-1"), - "Refresh" - ], id="refresh-training-metrics-btn", className="btn btn-sm btn-outline-primary mb-2") - ], className="d-flex justify-content-between align-items-start") - - def _create_models_section(self, models_data: Dict[str, Any]) -> html.Div: - """Create the models section showing each loaded model""" - model_cards = [] - - for model_name, model_data in models_data.items(): - if model_data.get('error'): - # Error card - model_cards.append(html.Div([ - html.Strong(f"{model_name.upper()}", className="text-danger"), - html.P(f"Error: {model_data['error']}", className="text-danger small mb-0") - ], className="border border-danger rounded p-2 mb-2")) - else: - model_cards.append(self._create_model_card(model_name, model_data)) - - return html.Div([ - html.H6([ - html.I(className="fas fa-microchip me-2 text-success"), - f"Loaded Models ({len(models_data)})" - ], className="mb-2"), - html.Div(model_cards) - ]) - - def _create_model_card(self, model_name: str, model_data: Dict[str, Any]) -> html.Div: - """Create a card for a single model""" - # Status styling - status = model_data.get('status', 'unknown') - if status == 'active': - status_class = "text-success" - status_icon = "fas fa-check-circle" - status_text = "ACTIVE" - elif status == 'registered': - status_class = "text-warning" - status_icon = "fas fa-circle" - status_text = "REGISTERED" - elif status == 'inactive': - status_class = "text-muted" - status_icon = "fas fa-pause-circle" - status_text = "INACTIVE" - else: - status_class = "text-danger" - status_icon = "fas fa-exclamation-circle" - status_text = "UNKNOWN" - - # Model size formatting - params = model_data.get('parameters', 0) - if params > 1e9: - size_str = f"{params/1e9:.1f}B" - elif params > 1e6: - size_str = f"{params/1e6:.1f}M" - elif params > 1e3: - size_str = f"{params/1e3:.1f}K" - else: - size_str = str(params) - - # Last prediction info - last_pred = model_data.get('last_prediction', {}) - pred_action = last_pred.get('action', 'NONE') - pred_confidence = last_pred.get('confidence', 0) - pred_time = last_pred.get('timestamp', 'N/A') - - # Loss metrics - loss_metrics = model_data.get('loss_metrics', {}) - current_loss = loss_metrics.get('current_loss') - loss_class = "text-success" if current_loss and current_loss < 0.1 else "text-warning" if current_loss and current_loss < 0.5 else "text-danger" - - # Timing metrics - timing = model_data.get('timing_metrics', {}) - - return html.Div([ - # Header with model name and status - html.Div([ - html.Div([ - html.I(className=f"{status_icon} me-2 {status_class}"), - html.Strong(f"{model_name.upper()}", className=status_class), - html.Span(f" - {status_text}", className=f"{status_class} small ms-1"), - html.Span(f" ({size_str})", className="text-muted small ms-2"), - # Show mode for decision fusion - *([html.Span(f" [{model_data.get('mode', 'unknown').upper()}]", className="text-info small ms-1")] if model_name == 'decision_fusion' and model_data.get('mode') else []), - html.Span( - " [CKPT]" if model_data.get('checkpoint_loaded') - else " [FAILED]" if model_data.get('checkpoint_failed') - else " [FRESH]", - className=f"small {'text-success' if model_data.get('checkpoint_loaded') else 'text-danger' if model_data.get('checkpoint_failed') else 'text-warning'} ms-1" - ) - ], style={"flex": "1"}), - - # Toggle switches with pattern matching IDs - html.Div([ - html.Div([ - html.Label("Inf", className="text-muted small me-1", style={"font-size": "10px"}), - dcc.Checklist( - id={'type': 'model-toggle', 'model': model_name, 'toggle_type': 'inference'}, - options=[{"label": "", "value": True}], - value=[True] if model_data.get('inference_enabled', True) else [], - className="form-check-input me-2", - style={"transform": "scale(0.7)"} - ) - ], className="d-flex align-items-center me-2"), - html.Div([ - html.Label("Trn", className="text-muted small me-1", style={"font-size": "10px"}), - dcc.Checklist( - id={'type': 'model-toggle', 'model': model_name, 'toggle_type': 'training'}, - options=[{"label": "", "value": True}], - value=[True] if model_data.get('training_enabled', True) else [], - className="form-check-input", - style={"transform": "scale(0.7)"} - ) - ], className="d-flex align-items-center") - ], className="d-flex") - ], className="d-flex align-items-center mb-2"), - - # Model metrics - html.Div([ - # Last prediction - html.Div([ - html.Span("Last: ", className="text-muted small"), - html.Span(f"{pred_action}", - className=f"small fw-bold {'text-success' if pred_action == 'BUY' else 'text-danger' if pred_action == 'SELL' else 'text-warning'}"), - html.Span(f" ({pred_confidence:.1f}%)", className="text-muted small"), - html.Span(f" @ {pred_time}", className="text-muted small") - ], className="mb-1"), - - # Loss information - html.Div([ - html.Span("Loss: ", className="text-muted small"), - html.Span(f"{current_loss:.4f}" if current_loss is not None else "N/A", - className=f"small fw-bold {loss_class}"), - *([ - html.Span(" | Best: ", className="text-muted small"), - html.Span(f"{loss_metrics.get('best_loss', 0):.4f}", className="text-success small") - ] if loss_metrics.get('best_loss') is not None else []) - ], className="mb-1"), - - # Timing information - html.Div([ - html.Span("Rate: ", className="text-muted small"), - html.Span(f"{timing.get('inferences_per_second', 0):.2f}/s", className="text-info small"), - html.Span(" | 24h: ", className="text-muted small"), - html.Span(f"{timing.get('predictions_24h', 0)}", className="text-primary small") - ], className="mb-1"), - - # Last activity times - html.Div([ - html.Span("Last Inf: ", className="text-muted small"), - html.Span(f"{timing.get('last_inference', 'N/A')}", className="text-info small"), - html.Span(" | Train: ", className="text-muted small"), - html.Span(f"{timing.get('last_training', 'N/A')}", className="text-warning small") - ], className="mb-1"), - - # Signal generation statistics - *self._create_signal_stats_display(model_data.get('signal_stats', {})), - - # Performance metrics - *self._create_performance_metrics_display(model_data) - ]) - ], className="border rounded p-2 mb-2", - style={"backgroundColor": "rgba(255,255,255,0.05)" if status == 'active' else "rgba(128,128,128,0.1)"}) - - def _create_no_models_message(self) -> html.Div: - """Create message when no models are loaded""" - return html.Div([ - html.H6([ - html.I(className="fas fa-exclamation-triangle me-2 text-warning"), - "No Models Loaded" - ], className="mb-2"), - html.P("No machine learning models are currently loaded. Check orchestrator status.", - className="text-muted small") - ]) - - def _create_training_status_section(self, training_status: Dict[str, Any]) -> html.Div: - """Create the training status section""" - if training_status.get('error'): - return html.Div([ - html.Hr(), - html.H6([ - html.I(className="fas fa-exclamation-triangle me-2 text-danger"), - "Training Status Error" - ], className="mb-2"), - html.P(f"Error: {training_status['error']}", className="text-danger small") - ]) - - is_training = training_status.get('is_training', False) - - return html.Div([ - html.Hr(), - html.H6([ - html.I(className="fas fa-brain me-2 text-secondary"), - "Training Status" - ], className="mb-2"), - - html.Div([ - html.Span("Status: ", className="text-muted small"), - html.Span("ACTIVE" if is_training else "INACTIVE", - className=f"small fw-bold {'text-success' if is_training else 'text-warning'}"), - html.Span(f" | Iteration: {training_status.get('training_iteration', 0):,}", - className="text-info small ms-2") - ], className="mb-1"), - - html.Div([ - html.Span("Buffer: ", className="text-muted small"), - html.Span(f"{training_status.get('experience_buffer_size', 0):,}", - className="text-success small"), - html.Span(" | Updated: ", className="text-muted small"), - html.Span(f"{training_status.get('last_update', 'N/A')}", - className="text-muted small") - ], className="mb-0") - ]) - - def _create_performance_section(self, performance_metrics: Dict[str, Any]) -> html.Div: - """Create the performance metrics section""" - if performance_metrics.get('error'): - return html.Div([ - html.Hr(), - html.P(f"Performance metrics error: {performance_metrics['error']}", - className="text-danger small") - ]) - - return html.Div([ - html.Hr(), - html.H6([ - html.I(className="fas fa-chart-line me-2 text-primary"), - "System Performance" - ], className="mb-2"), - - html.Div([ - html.Span("Decision Fusion: ", className="text-muted small"), - html.Span("ON" if performance_metrics.get('decision_fusion_active') else "OFF", - className=f"small {'text-success' if performance_metrics.get('decision_fusion_active') else 'text-muted'}"), - html.Span(" | COB: ", className="text-muted small"), - html.Span("ON" if performance_metrics.get('cob_integration_active') else "OFF", - className=f"small {'text-success' if performance_metrics.get('cob_integration_active') else 'text-muted'}") - ], className="mb-1"), - - html.Div([ - html.Span("Tracking: ", className="text-muted small"), - html.Span(f"{performance_metrics.get('symbols_tracking', 0)} symbols", - className="text-info small"), - html.Span(" | Decisions: ", className="text-muted small"), - html.Span(f"{performance_metrics.get('recent_decisions', 0):,}", - className="text-primary small") - ], className="mb-0") - ]) - - def _create_signal_stats_display(self, signal_stats: Dict[str, Any]) -> List[html.Div]: - """Create display elements for signal generation statistics""" - if not signal_stats or not any(signal_stats.values()): - return [] - - buy_signals = signal_stats.get('buy_signals', 0) - sell_signals = signal_stats.get('sell_signals', 0) - hold_signals = signal_stats.get('hold_signals', 0) - total_signals = signal_stats.get('total_signals', 0) - - if total_signals == 0: - return [] - - # Calculate percentages - ensure all values are numeric - buy_signals = buy_signals or 0 - sell_signals = sell_signals or 0 - hold_signals = hold_signals or 0 - total_signals = total_signals or 0 - - buy_pct = (buy_signals / total_signals * 100) if total_signals > 0 else 0 - sell_pct = (sell_signals / total_signals * 100) if total_signals > 0 else 0 - hold_pct = (hold_signals / total_signals * 100) if total_signals > 0 else 0 - - return [ - html.Div([ - html.Span("Signals: ", className="text-muted small"), - html.Span(f"B:{buy_signals}({buy_pct:.0f}%)", className="text-success small"), - html.Span(" | ", className="text-muted small"), - html.Span(f"S:{sell_signals}({sell_pct:.0f}%)", className="text-danger small"), - html.Span(" | ", className="text-muted small"), - html.Span(f"H:{hold_signals}({hold_pct:.0f}%)", className="text-warning small") - ], className="mb-1"), - - html.Div([ - html.Span("Total: ", className="text-muted small"), - html.Span(f"{total_signals:,}", className="text-primary small fw-bold"), - *([ - html.Span(" | Accuracy: ", className="text-muted small"), - html.Span(f"{signal_stats.get('accuracy', 0):.1f}%", - className=f"small fw-bold {'text-success' if signal_stats.get('accuracy', 0) > 60 else 'text-warning' if signal_stats.get('accuracy', 0) > 40 else 'text-danger'}") - ] if signal_stats.get('accuracy', 0) > 0 else []) - ], className="mb-1") - ] - - def _create_performance_metrics_display(self, model_data: Dict[str, Any]) -> List[html.Div]: - """Create display elements for performance metrics""" - elements = [] - - # Win rate and accuracy - signal_stats = model_data.get('signal_stats', {}) - loss_metrics = model_data.get('loss_metrics', {}) - - # Safely get numeric values - win_rate = signal_stats.get('win_rate', 0) or 0 - accuracy = signal_stats.get('accuracy', 0) or 0 - - if win_rate > 0 or accuracy > 0: - - elements.append(html.Div([ - html.Span("Performance: ", className="text-muted small"), - *([ - html.Span(f"Win: {win_rate:.1f}%", - className=f"small fw-bold {'text-success' if win_rate > 55 else 'text-warning' if win_rate > 45 else 'text-danger'}"), - html.Span(" | ", className="text-muted small") - ] if win_rate > 0 else []), - *([ - html.Span(f"Acc: {accuracy:.1f}%", - className=f"small fw-bold {'text-success' if accuracy > 60 else 'text-warning' if accuracy > 40 else 'text-danger'}") - ] if accuracy > 0 else []) - ], className="mb-1")) - - # Loss improvement - if loss_metrics.get('improvement', 0) != 0: - improvement = loss_metrics.get('improvement', 0) - elements.append(html.Div([ - html.Span("Improvement: ", className="text-muted small"), - html.Span(f"{improvement:+.1f}%", - className=f"small fw-bold {'text-success' if improvement > 0 else 'text-danger'}") - ], className="mb-1")) - - return elements \ No newline at end of file diff --git a/web/template_renderer.py b/web/template_renderer.py deleted file mode 100644 index 30ebac8..0000000 --- a/web/template_renderer.py +++ /dev/null @@ -1,384 +0,0 @@ -""" -Template Renderer for Dashboard -Handles HTML template rendering with Jinja2 -""" -import os -from typing import Dict, Any -from jinja2 import Environment, FileSystemLoader, select_autoescape -from dash import html, dcc -import plotly.graph_objects as go - -from .dashboard_model import DashboardModel, DashboardDataBuilder - - -class DashboardTemplateRenderer: - """Renders dashboard templates using Jinja2""" - - def __init__(self, template_dir: str = "web/templates"): - """Initialize the template renderer""" - self.template_dir = template_dir - - # Create Jinja2 environment - self.env = Environment( - loader=FileSystemLoader(template_dir), - autoescape=select_autoescape(['html', 'xml']) - ) - - # Add custom filters - self.env.filters['currency'] = self._currency_filter - self.env.filters['percentage'] = self._percentage_filter - self.env.filters['number'] = self._number_filter - - def render_dashboard(self, model: DashboardModel) -> html.Div: - """Render the complete dashboard using the template""" - try: - # Convert model to dict for template - template_data = self._model_to_dict(model) - - # Render template - template = self.env.get_template('dashboard.html') - rendered_html = template.render(**template_data) - - # Convert to Dash components - return self._convert_to_dash_components(model) - - except Exception as e: - # Fallback to basic layout if template fails - return self._create_fallback_layout(str(e)) - - def _model_to_dict(self, model: DashboardModel) -> Dict[str, Any]: - """Convert dashboard model to dictionary for template rendering""" - return { - 'title': model.title, - 'subtitle': model.subtitle, - 'refresh_interval': model.refresh_interval, - 'metrics': [self._dataclass_to_dict(m) for m in model.metrics], - 'chart': self._dataclass_to_dict(model.chart), - 'trading_controls': self._dataclass_to_dict(model.trading_controls), - 'recent_decisions': [self._dataclass_to_dict(d) for d in model.recent_decisions], - 'cob_data': [self._dataclass_to_dict(c) for c in model.cob_data], - 'models': [self._dataclass_to_dict(m) for m in model.models], - 'training_metrics': [self._dataclass_to_dict(m) for m in model.training_metrics], - 'performance_stats': [self._dataclass_to_dict(s) for s in model.performance_stats], - 'closed_trades': [self._dataclass_to_dict(t) for t in model.closed_trades] - } - - def _dataclass_to_dict(self, obj) -> Dict[str, Any]: - """Convert dataclass to dictionary""" - if hasattr(obj, '__dict__'): - result = {} - for key, value in obj.__dict__.items(): - if hasattr(value, '__dict__'): - result[key] = self._dataclass_to_dict(value) - elif isinstance(value, list): - result[key] = [self._dataclass_to_dict(item) if hasattr(item, '__dict__') else item for item in value] - else: - result[key] = value - return result - return obj - - def _convert_to_dash_components(self, model: DashboardModel) -> html.Div: - """Convert template model to Dash components""" - return html.Div([ - # Header - html.Div([ - html.H1(model.title, className="text-center"), - html.P(model.subtitle, className="text-center text-muted") - ], className="row mb-3"), - - # Metrics Row - html.Div([ - html.Div([ - self._create_metric_card(metric) - ], className="col-md-2") for metric in model.metrics - ], className="row mb-3"), - - # Main Content Row - html.Div([ - # Price Chart - html.Div([ - html.Div([ - html.Div([ - html.H5(model.chart.title) - ], className="card-header"), - html.Div([ - dcc.Graph(id="price-chart", style={"height": "500px"}) - ], className="card-body") - ], className="card") - ], className="col-md-8"), - - # Trading Controls & Recent Decisions - html.Div([ - # Trading Controls - self._create_trading_controls(model.trading_controls), - # Recent Decisions - self._create_recent_decisions(model.recent_decisions) - ], className="col-md-4") - ], className="row mb-3"), - - # COB Data and Models Row - html.Div([ - # COB Ladders - html.Div([ - html.Div([ - html.Div([ - self._create_cob_card(cob) - ], className="col-md-6") for cob in model.cob_data - ], className="row") - ], className="col-md-7"), - - # Models & Training - html.Div([ - self._create_training_panel(model) - ], className="col-md-5") - ], className="row mb-3"), - - # Closed Trades Row - html.Div([ - html.Div([ - self._create_closed_trades_table(model.closed_trades) - ], className="col-12") - ], className="row"), - - # Auto-refresh interval - dcc.Interval(id='interval-component', interval=model.refresh_interval, n_intervals=0) - - ], className="container-fluid") - - def _create_metric_card(self, metric) -> html.Div: - """Create a metric card component""" - return html.Div([ - html.Div(metric.value, className="metric-value", id=metric.id), - html.Div(metric.label, className="metric-label") - ], className="metric-card") - - def _create_trading_controls(self, controls) -> html.Div: - """Create trading controls component""" - return html.Div([ - html.Div([ - html.H6("Manual Trading") - ], className="card-header"), - html.Div([ - html.Div([ - html.Div([ - html.Button(controls.buy_text, id="manual-buy-btn", - className="btn btn-success w-100") - ], className="col-6"), - html.Div([ - html.Button(controls.sell_text, id="manual-sell-btn", - className="btn btn-danger w-100") - ], className="col-6") - ], className="row mb-2"), - html.Div([ - html.Div([ - html.Label([ - f"Leverage: ", - html.Span(f"{controls.leverage}x", id="leverage-display") - ], className="form-label"), - dcc.Slider( - id="leverage-slider", - min=controls.leverage_min, - max=controls.leverage_max, - value=controls.leverage, - step=1, - marks={i: str(i) for i in range(controls.leverage_min, controls.leverage_max + 1, 10)} - ) - ], className="col-12") - ], className="row mb-2"), - html.Div([ - html.Div([ - html.Button(controls.clear_text, id="clear-session-btn", - className="btn btn-warning w-100") - ], className="col-12") - ], className="row") - ], className="card-body") - ], className="card mb-3") - - def _create_recent_decisions(self, decisions) -> html.Div: - """Create recent decisions component""" - decision_items = [] - for decision in decisions: - border_class = { - 'BUY': 'border-success bg-success bg-opacity-10', - 'SELL': 'border-danger bg-danger bg-opacity-10' - }.get(decision.action, 'border-secondary bg-secondary bg-opacity-10') - - decision_items.append( - html.Div([ - html.Small(decision.timestamp, className="text-muted"), - html.Br(), - html.Strong(f"{decision.action} - {decision.symbol}"), - html.Br(), - html.Small(f"Confidence: {decision.confidence}% | Price: ${decision.price}") - ], className=f"mb-2 p-2 border-start border-3 {border_class}") - ) - - return html.Div([ - html.Div([ - html.H6("Recent AI Decisions") - ], className="card-header"), - html.Div([ - html.Div(decision_items, id="recent-decisions") - ], className="card-body", style={"max-height": "300px", "overflow-y": "auto"}) - ], className="card") - - def _create_cob_card(self, cob) -> html.Div: - """Create COB ladder card""" - return html.Div([ - html.Div([ - html.H6(f"{cob.symbol} Order Book"), - html.Small(f"Total: {cob.total_usd} USD | {cob.total_crypto} {cob.symbol.split('/')[0]}", - className="text-muted") - ], className="card-header"), - html.Div([ - html.Div(id=cob.content_id, className="cob-ladder") - ], className="card-body p-2") - ], className="card") - - def _create_training_panel(self, model: DashboardModel) -> html.Div: - """Create training panel component""" - # Model status indicators - model_status_items = [] - for model_item in model.models: - status_class = f"status-{model_item.status}" - model_status_items.append( - html.Span(f"{model_item.name}: {model_item.status_text}", - className=f"model-status {status_class}") - ) - - # Training metrics - training_items = [] - for metric in model.training_metrics: - training_items.append( - html.Div([ - html.Div([ - html.Small(f"{metric.name}:") - ], className="col-6"), - html.Div([ - html.Small(metric.value, className="fw-bold") - ], className="col-6") - ], className="row mb-1") - ) - - # Performance stats - performance_items = [] - for stat in model.performance_stats: - performance_items.append( - html.Div([ - html.Div([ - html.Small(f"{stat.name}:") - ], className="col-8"), - html.Div([ - html.Small(stat.value, className="fw-bold") - ], className="col-4") - ], className="row mb-1") - ) - - return html.Div([ - html.Div([ - html.H6("Models & Training Progress") - ], className="card-header"), - html.Div([ - html.Div([ - # Model Status - html.Div([ - html.H6("Model Status"), - html.Div(model_status_items) - ], className="mb-3"), - - # Training Metrics - html.Div([ - html.H6("Training Metrics"), - html.Div(training_items, id="training-metrics") - ], className="mb-3"), - - # Performance Stats - html.Div([ - html.H6("Performance"), - html.Div(performance_items) - ], className="mb-3") - ]) - ], className="card-body training-panel") - ], className="card") - - def _create_closed_trades_table(self, trades) -> html.Div: - """Create closed trades table""" - trade_rows = [] - for trade in trades: - pnl_class = "trade-profit" if trade.pnl > 0 else "trade-loss" - side_class = "bg-success" if trade.side == "BUY" else "bg-danger" - - trade_rows.append( - html.Tr([ - html.Td(trade.time), - html.Td(trade.symbol), - html.Td([ - html.Span(trade.side, className=f"badge {side_class}") - ]), - html.Td(trade.size), - html.Td(trade.entry_price), - html.Td(trade.exit_price), - html.Td(f"${trade.pnl}", className=pnl_class), - html.Td(trade.duration) - ]) - ) - - return html.Div([ - html.Div([ - html.H6("Recent Closed Trades") - ], className="card-header"), - html.Div([ - html.Div([ - html.Table([ - html.Thead([ - html.Tr([ - html.Th("Time"), - html.Th("Symbol"), - html.Th("Side"), - html.Th("Size"), - html.Th("Entry"), - html.Th("Exit"), - html.Th("PnL"), - html.Th("Duration") - ]) - ]), - html.Tbody(trade_rows) - ], className="table table-sm", id="closed-trades-table") - ]) - ], className="card-body closed-trades") - ], className="card") - - def _create_fallback_layout(self, error_msg: str) -> html.Div: - """Create fallback layout if template rendering fails""" - return html.Div([ - html.Div([ - html.H1("Dashboard Error", className="text-center text-danger"), - html.P(f"Template rendering failed: {error_msg}", className="text-center"), - html.P("Using fallback layout.", className="text-center text-muted") - ], className="container mt-5") - ]) - - # Jinja2 custom filters - def _currency_filter(self, value) -> str: - """Format value as currency""" - try: - return f"${float(value):,.4f}" - except (ValueError, TypeError): - return str(value) - - def _percentage_filter(self, value) -> str: - """Format value as percentage""" - try: - return f"{float(value):.2f}%" - except (ValueError, TypeError): - return str(value) - - def _number_filter(self, value) -> str: - """Format value as number""" - try: - if isinstance(value, int): - return f"{value:,}" - else: - return f"{float(value):,.2f}" - except (ValueError, TypeError): - return str(value) \ No newline at end of file diff --git a/web/templated_dashboard.py b/web/templated_dashboard.py deleted file mode 100644 index cce2222..0000000 --- a/web/templated_dashboard.py +++ /dev/null @@ -1,1258 +0,0 @@ -""" -Template-based Trading Dashboard -Uses MVC architecture with HTML templates and data models -""" -import logging -import sys -import os -from typing import Optional, Any, Dict, List, Deque -from datetime import datetime, timedelta -import pandas as pd -import pytz -import time -import threading -from collections import deque -from dataclasses import asdict - -import dash -from dash import dcc, html, Input, Output, State, callback_context -import plotly.graph_objects as go -import plotly.express as px - -from core.data_provider import DataProvider -from core.orchestrator import TradingOrchestrator -from core.trading_executor import TradingExecutor -from core.config import get_config -from core.universal_data_adapter import UniversalDataAdapter, UniversalDataStream -from web.dashboard_model import DashboardModel, DashboardDataBuilder, create_sample_dashboard_data -from web.template_renderer import DashboardTemplateRenderer -from web.component_manager import DashboardComponentManager -from web.layout_manager import DashboardLayoutManager -from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint -from NN.models.advanced_transformer_trading import create_trading_transformer, TradingTransformerConfig - -# Configure logging -logger = logging.getLogger(__name__) - - -class TemplatedTradingDashboard: - """Template-based trading dashboard with MVC architecture""" - - def __init__(self, data_provider: Optional[DataProvider] = None, - orchestrator: Optional[TradingOrchestrator] = None, - trading_executor: Optional[TradingExecutor] = None): - """Initialize the templated dashboard""" - self.config = get_config() - - # Initialize components - self.data_provider = data_provider or DataProvider() - self.trading_executor = trading_executor or TradingExecutor() - - # Initialize template renderer - self.renderer = DashboardTemplateRenderer() - - # Initialize unified orchestrator with full ML capabilities - if orchestrator is None: - self.orchestrator = TradingOrchestrator( - data_provider=self.data_provider, - enhanced_rl_training=True, - model_registry={} - ) - logger.info("TEMPLATED DASHBOARD: Using unified Trading Orchestrator with full ML capabilities") - else: - self.orchestrator = orchestrator - - # Initialize enhanced training system for predictions - self.training_system = None - self._initialize_enhanced_training_system() - - # Initialize layout and component managers - self.layout_manager = DashboardLayoutManager( - starting_balance=self._get_initial_balance(), - trading_executor=self.trading_executor - ) - self.component_manager = DashboardComponentManager() - - # Initialize Universal Data Stream for the 5 timeseries architecture - self.universal_adapter = UniversalDataAdapter(self.data_provider) - # Data access now through orchestrator instead of complex stream management - logger.debug("Universal Data Adapter initialized - accessing data through orchestrator") - logger.info(f"TEMPLATED DASHBOARD: Universal Data Stream initialized with consumer ID: {self.stream_consumer_id}") - logger.info("TEMPLATED DASHBOARD: Subscribed to Universal 5 Timeseries: ETH(ticks,1m,1h,1d) + BTC(ticks)") - - # Dashboard state - self.recent_decisions: list = [] - self.closed_trades: list = [] - self.current_prices: dict = {} - self.session_pnl = 0.0 - self.total_fees = 0.0 - self.current_position: Optional[float] = 0.0 - self.session_trades: list = [] - - # Model control toggles - separate inference and training - self.dqn_inference_enabled = True # Default: enabled - self.dqn_training_enabled = True # Default: enabled - self.cnn_inference_enabled = True - self.cnn_training_enabled = True - - # Leverage management - adjustable x1 to x100 - self.current_leverage = 50 # Default x50 leverage - self.min_leverage = 1 - self.max_leverage = 100 - self.pending_trade_case_id = None # For tracking opening trades until closure - - # WebSocket streaming - self.ws_price_cache: dict = {} - self.is_streaming = False - self.tick_cache: list = [] - - # COB data cache - enhanced with price buckets and memory system - self.cob_cache: dict = { - 'ETH/USDT': {'last_update': 0, 'data': None, 'updates_count': 0}, - 'BTC/USDT': {'last_update': 0, 'data': None, 'updates_count': 0} - } - self.latest_cob_data: dict = {} # Cache for COB integration data - self.cob_predictions: dict = {} # Cache for COB predictions (both ETH and BTC for display) - - # COB High-frequency data handling (50-100 updates/sec) - self.cob_data_buffer: dict = {} # Buffer for high-freq data - self.cob_memory: dict = {} # Memory system like GPT - keeps last N snapshots - self.cob_price_buckets: dict = {} # Price bucket cache - self.cob_update_count = 0 - self.last_cob_broadcast: Dict[str, Optional[float]] = {'ETH/USDT': None, 'BTC/USDT': None} # Rate limiting for UI updates, updated type - self.cob_data_history: Dict[str, Deque[Any]] = { - 'ETH/USDT': deque(maxlen=61), # Store ~60 seconds of 1s snapshots - 'BTC/USDT': deque(maxlen=61) - } - - # Initialize timezone - timezone_name = self.config.get('system', {}).get('timezone', 'Europe/Sofia') - self.timezone = pytz.timezone(timezone_name) - - # Create Dash app - self.app = dash.Dash(__name__, external_stylesheets=[ - 'https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css', - 'https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css' - ]) - - # Suppress Dash development mode logging - self.app.enable_dev_tools(debug=False, dev_tools_silence_routes_logging=True) - - # Setup layout and callbacks - self._setup_layout() - self._setup_callbacks() - - # Start data streams - self._initialize_streaming() - - # Connect to orchestrator for real trading signals - self._connect_to_orchestrator() - - # Initialize COB integration with high-frequency data handling - self._initialize_cob_integration() - - # Start signal generation loop to ensure continuous trading signals - self._start_signal_generation_loop() - - # Start training sessions if models are showing FRESH status - threading.Thread(target=self._delayed_training_check, daemon=True).start() - - logger.info("TEMPLATED DASHBOARD: Initialized with HIGH-FREQUENCY COB integration and signal generation") - - def _setup_layout(self): - """Setup the dashboard layout using templates""" - # Create initial dashboard data - dashboard_data = self._build_dashboard_data() - - # Render layout using template - layout = self.renderer.render_dashboard(dashboard_data) - - # Custom CSS will be handled via external stylesheets - - self.app.layout = layout - - def _get_initial_balance(self) -> float: - """Get initial balance from trading executor or default""" - try: - if self.trading_executor and hasattr(self.trading_executor, 'starting_balance'): - balance = getattr(self.trading_executor, 'starting_balance', None) - if balance and balance > 0: - return balance - except Exception as e: - logger.warning(f"Error getting balance: {e}") - return 100.0 # Default balance - - def _setup_callbacks(self): - """Setup dashboard callbacks""" - - @self.app.callback( - [Output('current-price', 'children'), - Output('session-pnl', 'children'), - Output('current-position', 'children'), - Output('trade-count', 'children'), - Output('portfolio-value', 'children'), - Output('mexc-status', 'children')], - [Input('interval-component', 'n_intervals')] - ) - def update_metrics(n): - """Update main metrics""" - try: - # Get current price - current_price = self._get_current_price("ETH/USDT") - - # Calculate portfolio value - portfolio_value = 10000.0 + self.session_pnl # Base + PnL - - # Get MEXC status - mexc_status = "Connected" if self.trading_executor else "Disconnected" - - return ( - f"${current_price:.4f}" if current_price else "N/A", - f"${self.session_pnl:.2f}", - f"{self.current_position:.4f}", - str(len(self.session_trades)), - f"${portfolio_value:.2f}", - mexc_status - ) - except Exception as e: - logger.error(f"Error updating metrics: {e}") - return "N/A", "N/A", "N/A", "N/A", "N/A", "Error" - - @self.app.callback( - Output('price-chart', 'figure'), - [Input('interval-component', 'n_intervals')] - ) - def update_price_chart(n): - """Update price chart""" - try: - return self._create_price_chart("ETH/USDT") - except Exception as e: - logger.error(f"Error updating chart: {e}") - return go.Figure() - - @self.app.callback( - Output('recent-decisions', 'children'), - [Input('interval-component', 'n_intervals')] - ) - def update_recent_decisions(n): - """Update recent AI decisions""" - try: - decisions = self._get_recent_decisions() - return self._render_decisions(decisions) - except Exception as e: - logger.error(f"Error updating decisions: {e}") - return html.Div("No recent decisions") - - @self.app.callback( - [Output('eth-cob-content', 'children'), - Output('btc-cob-content', 'children')], - [Input('interval-component', 'n_intervals')] - ) - def update_cob_data(n): - """Update COB data""" - try: - eth_cob = self._render_cob_ladder("ETH/USDT") - btc_cob = self._render_cob_ladder("BTC/USDT") - return eth_cob, btc_cob - except Exception as e: - logger.error(f"Error updating COB: {e}") - return html.Div("COB Error"), html.Div("COB Error") - - @self.app.callback( - Output('training-metrics', 'children'), - [Input('interval-component', 'n_intervals')] - ) - def update_training_metrics(n): - """Update training metrics""" - try: - return self._render_training_metrics() - except Exception as e: - logger.error(f"Error updating training metrics: {e}") - return html.Div("Training metrics unavailable") - - @self.app.callback( - Output('closed-trades-table', 'children'), - [Input('interval-component', 'n_intervals')] - ) - def update_closed_trades(n): - """Update closed trades table""" - try: - # Return the table wrapped in a Div - return html.Div(self._render_closed_trades()) - except Exception as e: - logger.error(f"Error updating closed trades: {e}") - return html.Div("No trades") - - # Trading control callbacks - @self.app.callback( - Output('manual-buy-btn', 'children'), - [Input('manual-buy-btn', 'n_clicks')], - prevent_initial_call=True - ) - def handle_manual_buy(n_clicks): - """Handle manual buy button""" - if n_clicks: - self._execute_manual_trade("BUY") - return "BUY ✓" - return "BUY" - - @self.app.callback( - Output('manual-sell-btn', 'children'), - [Input('manual-sell-btn', 'n_clicks')], - prevent_initial_call=True - ) - def handle_manual_sell(n_clicks): - """Handle manual sell button""" - if n_clicks: - self._execute_manual_trade("SELL") - return "SELL ✓" - return "SELL" - - @self.app.callback( - Output('leverage-display', 'children'), - [Input('leverage-slider', 'value')] - ) - def update_leverage_display(leverage_value): - """Update leverage display""" - return f"{leverage_value}x" - - @self.app.callback( - Output('clear-session-btn', 'children'), - [Input('clear-session-btn', 'n_clicks')], - prevent_initial_call=True - ) - def handle_clear_session(n_clicks): - """Handle clear session button""" - if n_clicks: - self._clear_session() - return "Cleared ✓" - return "Clear Session" - - def _build_dashboard_data(self) -> DashboardModel: - """Build dashboard data model from current state""" - builder = DashboardDataBuilder() - - # Basic info - builder.set_basic_info( - title="Live Scalping Dashboard (Templated)", - subtitle="Template-based MVC Architecture", - refresh_interval=1000 - ) - - # Get current metrics - current_price = self._get_current_price("ETH/USDT") - portfolio_value = 10000.0 + self.session_pnl - mexc_status = "Connected" if self.trading_executor else "Disconnected" - - # Add metrics - builder.add_metric("current-price", "Current Price", current_price or 0, "currency") - builder.add_metric("session-pnl", "Session PnL", self.session_pnl, "currency") - builder.add_metric("current-position", "Position", self.current_position, "number") - builder.add_metric("trade-count", "Trades", len(self.session_trades), "number") - builder.add_metric("portfolio-value", "Portfolio", portfolio_value, "currency") - builder.add_metric("mexc-status", "MEXC Status", mexc_status, "text") - - # Trading controls - builder.set_trading_controls(leverage=10, leverage_range=(1, 50)) - - # Recent decisions (sample data for now) - builder.add_recent_decision(datetime.now(), "BUY", "ETH/USDT", 0.85, current_price or 3425.67) - - # COB data (sample) - builder.add_cob_data("ETH/USDT", "eth-cob-content", 25000.0, 7.3, []) - builder.add_cob_data("BTC/USDT", "btc-cob-content", 35000.0, 0.88, []) - - # Model statuses - builder.add_model_status("DQN", True) - builder.add_model_status("CNN", True) - builder.add_model_status("Transformer", False) - builder.add_model_status("COB-RL", True) - - # Training metrics - builder.add_training_metric("DQN Loss", 0.0234) - builder.add_training_metric("CNN Accuracy", 0.876) - builder.add_training_metric("Training Steps", 15420) - - # Performance stats - builder.add_performance_stat("Win Rate", 68.5) - builder.add_performance_stat("Avg Trade", 8.34) - builder.add_performance_stat("Sharpe Ratio", 1.82) - - return builder.build() - - def _get_current_price(self, symbol: str) -> Optional[float]: - """Get current price for symbol""" - try: - if self.data_provider: - return self.data_provider.get_current_price(symbol) - return 3425.67 # Sample price - except Exception as e: - logger.error(f"Error getting price for {symbol}: {e}") - return None - - def _create_price_chart(self, symbol: str) -> go.Figure: - """Create price chart""" - try: - # Get price data - df = self._get_chart_data(symbol) - - if df is None or df.empty: - return go.Figure().add_annotation( - text="No data available", - xref="paper", yref="paper", - x=0.5, y=0.5, showarrow=False - ) - - # Create candlestick chart - fig = go.Figure(data=[go.Candlestick( - x=df.index, - open=df['open'], - high=df['high'], - low=df['low'], - close=df['close'], - name=symbol - )]) - - fig.update_layout( - title=f"{symbol} Price Chart", - xaxis_title="Time", - yaxis_title="Price (USDT)", - height=500, - showlegend=False - ) - - return fig - - except Exception as e: - logger.error(f"Error creating chart for {symbol}: {e}") - return go.Figure() - - def _get_chart_data(self, symbol: str) -> Optional[pd.DataFrame]: - """Get chart data for symbol""" - try: - if self.data_provider: - return self.data_provider.get_historical_data(symbol, "1m", 100) - - # Sample data - import numpy as np - dates = pd.date_range(start='2024-01-01', periods=100, freq='1min') - base_price = 3425.67 - - df = pd.DataFrame({ - 'open': base_price + np.random.randn(100) * 10, - 'high': base_price + np.random.randn(100) * 15, - 'low': base_price + np.random.randn(100) * 15, - 'close': base_price + np.random.randn(100) * 10, - 'volume': np.random.randint(100, 1000, 100) - }, index=dates) - - return df - - except Exception as e: - logger.error(f"Error getting chart data: {e}") - return None - - def _get_recent_decisions(self) -> List[Dict]: - """Get recent AI decisions""" - # Sample decisions for now - return [ - { - "timestamp": datetime.now().strftime("%H:%M:%S"), - "action": "BUY", - "symbol": "ETH/USDT", - "confidence": 85.3, - "price": 3425.67 - }, - { - "timestamp": datetime.now().strftime("%H:%M:%S"), - "action": "HOLD", - "symbol": "BTC/USDT", - "confidence": 62.1, - "price": 45123.45 - } - ] - - def _render_decisions(self, decisions: List[Dict]) -> List[html.Div]: - """Render recent decisions""" - items = [] - for decision in decisions: - border_class = { - 'BUY': 'border-success bg-success bg-opacity-10', - 'SELL': 'border-danger bg-danger bg-opacity-10' - }.get(decision['action'], 'border-secondary bg-secondary bg-opacity-10') - - items.append( - html.Div([ - html.Small(decision['timestamp'], className="text-muted"), - html.Br(), - html.Strong(f"{decision['action']} - {decision['symbol']}"), - html.Br(), - html.Small(f"Confidence: {decision['confidence']}% | Price: ${decision['price']}") - ], className=f"mb-2 p-2 border-start border-3 {border_class}") - ) - - return items - - def _render_cob_ladder(self, symbol: str) -> html.Div: - """Render COB ladder for symbol""" - # Sample COB data - return html.Table([ - html.Thead([ - html.Tr([ - html.Th("Size"), - html.Th("Price"), - html.Th("Total") - ]) - ]), - html.Tbody([ - html.Tr([ - html.Td("1.5"), - html.Td("$3426.12"), - html.Td("$5139.18") - ], className="ask-row"), - html.Tr([ - html.Td("2.3"), - html.Td("$3425.89"), - html.Td("$7879.55") - ], className="ask-row"), - html.Tr([ - html.Td("1.8"), - html.Td("$3425.45"), - html.Td("$6165.81") - ], className="bid-row"), - html.Tr([ - html.Td("3.2"), - html.Td("$3425.12"), - html.Td("$10960.38") - ], className="bid-row") - ]) - ], className="table table-sm table-borderless") - - def _render_training_metrics(self) -> html.Div: - """Render training metrics""" - return html.Div([ - # Model Status - html.Div([ - html.H6("Model Status"), - html.Div([ - html.Span("DQN: Training", className="model-status status-training"), - html.Span("CNN: Training", className="model-status status-training"), - html.Span("Transformer: Idle", className="model-status status-idle"), - html.Span("COB-RL: Training", className="model-status status-training") - ]) - ], className="mb-3"), - - # Training Metrics - html.Div([ - html.H6("Training Metrics"), - html.Div([ - html.Div([ - html.Div([html.Small("DQN Loss:")], className="col-6"), - html.Div([html.Small("0.0234", className="fw-bold")], className="col-6") - ], className="row mb-1"), - html.Div([ - html.Div([html.Small("CNN Accuracy:")], className="col-6"), - html.Div([html.Small("87.6%", className="fw-bold")], className="col-6") - ], className="row mb-1"), - html.Div([ - html.Div([html.Small("Training Steps:")], className="col-6"), - html.Div([html.Small("15,420", className="fw-bold")], className="col-6") - ], className="row mb-1") - ]) - ], className="mb-3"), - - # Performance Stats - html.Div([ - html.H6("Performance"), - html.Div([ - html.Div([ - html.Div([html.Small("Win Rate:")], className="col-8"), - html.Div([html.Small("68.5%", className="fw-bold")], className="col-4") - ], className="row mb-1"), - html.Div([ - html.Div([html.Small("Avg Trade:")], className="col-8"), - html.Div([html.Small("$8.34", className="fw-bold")], className="col-4") - ], className="row mb-1"), - html.Div([ - html.Div([html.Small("Sharpe Ratio:")], className="col-8"), - html.Div([html.Small("1.82", className="fw-bold")], className="col-4") - ], className="row mb-1") - ]) - ]) - ]) - - def _render_closed_trades(self) -> html.Div: - """Render closed trades table""" - if not self.closed_trades: - return html.Div("No closed trades yet.", className="alert alert-info mt-3") - - # Create a DataFrame from closed trades - df_trades = pd.DataFrame(self.closed_trades) - - # Format columns for display - df_trades['timestamp'] = pd.to_datetime(df_trades['timestamp']).dt.strftime('%Y-%m-%d %H:%M:%S') - df_trades['entry_price'] = df_trades['entry_price'].apply(lambda x: f"${x:,.2f}") - df_trades['exit_price'] = df_trades['exit_price'].apply(lambda x: f"${x:,.2f}") - df_trades['pnl'] = df_trades['pnl'].apply(lambda x: f"${x:,.2f}") - df_trades['profit_percentage'] = df_trades['profit_percentage'].apply(lambda x: f"{x:,.2f}%") - df_trades['size'] = df_trades['size'].apply(lambda x: f"{x:,.4f}") - df_trades['fees'] = df_trades['fees'].apply(lambda x: f"${x:,.2f}") - - table_header = [html.Thead(html.Tr([html.Th(col) for col in df_trades.columns]))] - table_body = [html.Tbody([ - html.Tr([html.Td(df_trades.iloc[i][col]) for col in df_trades.columns]) for i in range(len(df_trades)) - ])] - - return html.Div( - html.Table(table_header + table_body, className="table table-striped table-hover table-sm"), - className="table-responsive" - ) - - def _execute_manual_trade(self, action: str): - """Execute manual trade""" - try: - logger.info(f"MANUAL TRADE: {action} executed") - # Add to session trades - trade = { - "time": datetime.now(), - "action": action, - "symbol": "ETH/USDT", - "price": self._get_current_price("ETH/USDT") or 3425.67 - } - self.session_trades.append(trade) - except Exception as e: - logger.error(f"Error executing manual trade: {e}") - - def _clear_session(self): - """Clear session data""" - self.session_trades = [] - self.session_pnl = 0.0 - self.current_position = 0.0 - self.session_start_time = datetime.now() - logger.info("SESSION: Cleared") - - def run_server(self, host='127.0.0.1', port=8052, debug=False): - """Run the dashboard server""" - logger.info(f"TEMPLATED DASHBOARD: Starting at http://{host}:{port}") - self.app.run(host=host, port=port, debug=debug) - - def _handle_unified_stream_data(self, data): - """Placeholder for unified stream data handling.""" - logger.debug(f"Received data from unified stream: {data}") - - def _delayed_training_check(self): - """Check and start training after a delay to allow initialization""" - try: - time.sleep(10) # Wait 10 seconds for initialization - logger.info("Checking if models need training activation...") - self._start_actual_training_if_needed() - except Exception as e: - logger.error(f"Error in delayed training check: {e}") - - def _initialize_enhanced_training_system(self): - """Initialize enhanced training system for model predictions""" - try: - # Try to import and initialize enhanced training system - from enhanced_realtime_training import EnhancedRealtimeTrainingSystem - - self.training_system = EnhancedRealtimeTrainingSystem( - orchestrator=self.orchestrator, - data_provider=self.data_provider, - dashboard=self - ) - - # Initialize prediction storage - if not hasattr(self.orchestrator, 'recent_dqn_predictions'): - self.orchestrator.recent_dqn_predictions = {} - if not hasattr(self.orchestrator, 'recent_cnn_predictions'): - self.orchestrator.recent_cnn_predictions = {} - - logger.info("TEMPLATED DASHBOARD: Enhanced training system initialized for model predictions") - - except ImportError: - logger.warning("TEMPLATED DASHBOARD: Enhanced training system not available - using mock predictions") - self.training_system = None - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Error initializing enhanced training system: {e}") - self.training_system = None - - def _initialize_streaming(self): - """Initialize data streaming""" - try: - self._start_websocket_streaming() - self._start_data_collection() - logger.info("TEMPLATED DASHBOARD: Data streaming initialized") - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Error initializing streaming: {e}") - - def _start_websocket_streaming(self): - """Start WebSocket streaming for real-time data.""" - ws_thread = threading.Thread(target=self._ws_worker, daemon=True) - ws_thread.start() - - def _ws_worker(self): - try: - import websocket - import json # Added import - def on_message(ws, message): - try: - data = json.loads(message) - if 'k' in data: - kline = data['k'] - tick_record = { - 'symbol': 'ETHUSDT', - 'datetime': datetime.fromtimestamp(int(kline['t']) / 1000), - 'open': float(kline['o']), - 'high': float(kline['h']), - 'low': float(kline['l']), - 'close': float(kline['c']), - 'price': float(kline['c']), - 'volume': float(kline['v']), - } - self.ws_price_cache['ETHUSDT'] = tick_record['price'] - self.current_prices['ETH/USDT'] = tick_record['price'] - self.tick_cache.append(tick_record) - if len(self.tick_cache) > 1000: - self.tick_cache.pop(0) - except Exception as e: - logger.warning(f"TEMPLATED DASHBOARD: WebSocket message error: {e}") - def on_error(ws, error): - logger.error(f"TEMPLATED DASHBOARD: WebSocket error: {error}") - self.is_streaming = False - def on_close(ws, close_status_code, close_msg): - logger.warning("TEMPLATED DASHBOARD: WebSocket connection closed") - self.is_streaming = False - def on_open(ws): - logger.info("TEMPLATED DASHBOARD: WebSocket connected") - self.is_streaming = True - ws_url = "wss://stream.binance.com:9443/ws/ethusdt@kline_1s" - ws = websocket.WebSocketApp(ws_url, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open) - ws.run_forever() - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: WebSocket worker error: {e}") - self.is_streaming = False - - def _start_data_collection(self): - """Start background data collection""" - data_thread = threading.Thread(target=self._data_worker, daemon=True) - data_thread.start() - - def _data_worker(self): - while True: - try: - self._update_session_metrics() - time.sleep(5) - except Exception as e: - logger.warning(f"TEMPLATED DASHBOARD: Data collection error: {e}") - time.sleep(10) - - def _update_session_metrics(self): - """Update session P&L and total fees from closed trades.""" - try: - closed_trades = [] - if self.trading_executor and hasattr(self.trading_executor, 'get_closed_trades'): - closed_trades = self.trading_executor.get_closed_trades() - self.closed_trades = closed_trades - if closed_trades: - self.session_pnl = sum(trade.get('pnl', 0) for trade in closed_trades) - self.total_fees = sum(trade.get('fees', 0) for trade in closed_trades) - else: - self.session_pnl = 0.0 - self.total_fees = 0.0 - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Error updating session metrics: {e}") - - def _connect_to_orchestrator(self): - """Connect to orchestrator for real trading signals""" - try: - if self.orchestrator and hasattr(self.orchestrator, 'add_decision_callback'): - import asyncio # Added import - # from dataclasses import asdict # Moved asdict to top-level import - - def connect_worker(): - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - # No need to run_until_complete here, just register the callback - self.orchestrator.add_decision_callback(self._on_trading_decision) - logger.info("TEMPLATED DASHBOARD: Successfully connected to orchestrator for trading signals.") - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Orchestrator connection worker failed: {e}") - thread = threading.Thread(target=connect_worker, daemon=True) - thread.start() - else: - logger.warning("TEMPLATED DASHBOARD: Orchestrator not available or doesn\'t support callbacks") - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Error initiating orchestrator connection: {e}") - - async def _on_trading_decision(self, decision): - """Handle trading decision from orchestrator.""" - try: - action = getattr(decision, 'action', decision.get('action')) - if action == 'HOLD': - return - symbol = getattr(decision, 'symbol', decision.get('symbol', 'ETH/USDT')) - if 'ETH' not in symbol.upper(): - return - dashboard_decision = asdict(decision) if not isinstance(decision, dict) else decision.copy() - dashboard_decision['timestamp'] = datetime.now() - dashboard_decision['executed'] = False - self.recent_decisions.append(dashboard_decision) - if len(self.recent_decisions) > 200: - self.recent_decisions.pop(0) - logger.info(f"TEMPLATED DASHBOARD: [ORCHESTRATOR SIGNAL] Received: {action} for {symbol}") - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Error handling trading decision: {e}") - - def _initialize_cob_integration(self): - """Initialize simple COB integration that works without async event loops""" - try: - logger.info("TEMPLATED DASHBOARD: Initializing simple COB integration for model feeding") - - # Initialize COB data storage - self.cob_bucketed_data = { - 'ETH/USDT': {}, - 'BTC/USDT': {} - } - self.cob_last_update: Dict[str, Optional[float]] = { - 'ETH/USDT': None, - 'BTC/USDT': None - } # Corrected type hint - - # Start simple COB data collection - self._start_simple_cob_collection() - - logger.info("TEMPLATED DASHBOARD: Simple COB integration initialized successfully") - - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Error initializing COB integration: {e}") - self.cob_integration = None - - def _start_simple_cob_collection(self): - """Start simple COB data collection using REST APIs (no async required)""" - try: - # threading and time already imported - - def cob_collector(): - """Collect COB data using simple REST API calls""" - while True: - try: - # Collect data for both symbols - for symbol in ['ETH/USDT', 'BTC/USDT']: - self._collect_simple_cob_data(symbol) - - # Sleep for 1 second between collections - time.sleep(1) - except Exception as e: - logger.debug(f"TEMPLATED DASHBOARD: Error in COB collection: {e}") - time.sleep(5) # Wait longer on error - - # Start collector in background thread - cob_thread = threading.Thread(target=cob_collector, daemon=True) - cob_thread.start() - - logger.info("TEMPLATED DASHBOARD: Simple COB data collection started") - - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Error starting COB collection: {e}") - - def _collect_simple_cob_data(self, symbol: str): - """Collect simple COB data using Binance REST API""" - try: - import requests # Added import - # time already imported - - # Use Binance REST API for order book data - binance_symbol = symbol.replace('/', '') - url = f"https://api.binance.com/api/v3/depth?symbol={binance_symbol}&limit=500" - - response = requests.get(url, timeout=5) - if response.status_code == 200: - data = response.json() - - # Process order book data - bids = [] - asks = [] - - # Process bids (buy orders) - for bid in data['bids'][:100]: # Top 100 levels - price = float(bid[0]) - size = float(bid[1]) - bids.append({ - 'price': price, - 'size': size, - 'total': price * size - }) - - # Process asks (sell orders) - for ask in data['asks'][:100]: # Top 100 levels - price = float(ask[0]) - size = float(ask[1]) - asks.append({ - 'price': price, - 'size': size, - 'total': price * size - }) - - # Calculate statistics - if bids and asks: - best_bid = max(bids, key=lambda x: x['price']) - best_ask = min(asks, key=lambda x: x['price']) - mid_price = (best_bid['price'] + best_ask['price']) / 2 - spread_bps = ((best_ask['price'] - best_bid['price']) / mid_price) * 10000 if mid_price > 0 else 0 - - total_bid_liquidity = sum(bid['total'] for bid in bids[:20]) - total_ask_liquidity = sum(ask['total'] for ask in asks[:20]) - total_liquidity = total_bid_liquidity + total_ask_liquidity - imbalance = (total_bid_liquidity - total_ask_liquidity) / total_liquidity if total_liquidity > 0 else 0 - - # Create COB snapshot - cob_snapshot = { - 'symbol': symbol, - 'timestamp': time.time(), - 'bids': bids, - 'asks': asks, - 'stats': { - 'mid_price': mid_price, - 'spread_bps': spread_bps, - 'total_bid_liquidity': total_bid_liquidity, - 'total_ask_liquidity': total_ask_liquidity, - 'imbalance': imbalance, - 'exchanges_active': ['Binance'] - } - } - - # Store in history (keep last 15 seconds) - self.cob_data_history[symbol].append(cob_snapshot) - if len(self.cob_data_history[symbol]) > 15: # Keep 15 seconds - # Use slicing to remove old elements from deque to ensure correct behavior - while len(self.cob_data_history[symbol]) > 15: - self.cob_data_history[symbol].popleft() - - # Update latest data - self.latest_cob_data[symbol] = cob_snapshot - self.cob_last_update[symbol] = time.time() - - # Generate bucketed data for models - self._generate_bucketed_cob_data(symbol, cob_snapshot) - - logger.debug(f"TEMPLATED DASHBOARD: COB data collected for {symbol}: {len(bids)} bids, {len(asks)} asks") - - except Exception as e: - logger.debug(f"TEMPLATED DASHBOARD: Error collecting COB data for {symbol}: {e}") - - def _generate_bucketed_cob_data(self, symbol: str, cob_snapshot: dict): - """Generate bucketed COB data for model feeding""" - try: - # Create price buckets (1 basis point granularity) - bucket_size_bps = 1.0 - mid_price = cob_snapshot['stats']['mid_price'] - - # Initialize buckets - buckets = {} - - # Process bids into buckets - for bid in cob_snapshot['bids']: - price_offset_bps = ((bid['price'] - mid_price) / mid_price) * 10000 - bucket_key = int(price_offset_bps / bucket_size_bps) - - if bucket_key not in buckets: - buckets[bucket_key] = {'bid_volume': 0, 'ask_volume': 0} - - buckets[bucket_key]['bid_volume'] += bid['total'] - - # Process asks into buckets - for ask in cob_snapshot['asks']: - price_offset_bps = ((ask['price'] - mid_price) / mid_price) * 10000 - bucket_key = int(price_offset_bps / bucket_size_bps) - - if bucket_key not in buckets: - buckets[bucket_key] = {'bid_volume': 0, 'ask_volume': 0} - - buckets[bucket_key]['ask_volume'] += ask['total'] - - # Store bucketed data - self.cob_bucketed_data[symbol] = { - 'timestamp': cob_snapshot['timestamp'], - 'mid_price': mid_price, - 'buckets': buckets, - 'bucket_size_bps': bucket_size_bps - } - - # Feed to models - self._feed_cob_data_to_models(symbol, cob_snapshot) - - except Exception as e: - logger.debug(f"TEMPLATED DASHBOARD: Error generating bucketed COB data: {e}") - - def _calculate_cumulative_imbalance(self, symbol: str) -> Dict[str, float]: - """Calculate Moving Averages (MA) of imbalance over different periods.""" - stats = {} - history = self.cob_data_history.get(symbol) - - if not history: - return {'1s': 0.0, '5s': 0.0, '15s': 0.0, '60s': 0.0} - - # Convert history to list and get recent snapshots - history_list = list(history) - if not history_list: - return {'1s': 0.0, '5s': 0.0, '15s': 0.0, '60s': 0.0} - - # Extract imbalance values from recent snapshots - imbalances = [] - for snap in history_list: - if isinstance(snap, dict) and 'stats' in snap and snap['stats']: - imbalance = snap['stats'].get('imbalance') - if imbalance is not None: - imbalances.append(imbalance) - - if not imbalances: - return {'1s': 0.0, '5s': 0.0, '15s': 0.0, '60s': 0.0} - - # Calculate Moving Averages over different periods - # MA periods: 1s=1 period, 5s=5 periods, 15s=15 periods, 60s=60 periods - ma_periods = {'1s': 1, '5s': 5, '15s': 15, '60s': 60} - - for name, period in ma_periods.items(): - if len(imbalances) >= period: - # Calculate SMA over the last 'period' values - recent_imbalances = imbalances[-period:] - sma_value = sum(recent_imbalances) / len(recent_imbalances) - - # Also calculate EMA for better responsiveness - if period > 1: - # EMA calculation with alpha = 2/(period+1) - alpha = 2.0 / (period + 1) - ema_value = recent_imbalances[0] # Start with first value - for value in recent_imbalances[1:]: - ema_value = alpha * value + (1 - alpha) * ema_value - # Use EMA for better responsiveness - stats[name] = ema_value - else: - # For 1s, use SMA (no EMA needed) - stats[name] = sma_value - else: - # If not enough data, use available data - available_imbalances = imbalances[-min(period, len(imbalances)):] - if available_imbalances: - if len(available_imbalances) > 1: - # Calculate EMA for available data - alpha = 2.0 / (len(available_imbalances) + 1) - ema_value = available_imbalances[0] - for value in available_imbalances[1:]: - ema_value = alpha * value + (1 - alpha) * ema_value - stats[name] = ema_value - else: - # Single value, use as is - stats[name] = available_imbalances[0] - else: - stats[name] = 0.0 - - # Debug logging to verify MA calculation - if any(value != 0.0 for value in stats.values()): - logger.debug(f"TEMPLATED DASHBOARD: [MOVING-AVERAGE-IMBALANCE] {symbol}: {stats} (from {len(imbalances)} snapshots)") - - return stats - - def _feed_cob_data_to_models(self, symbol: str, cob_snapshot: dict): - """Feed COB data to models for training and inference""" - try: - # Calculate cumulative imbalance for model feeding - cumulative_imbalance = self._calculate_cumulative_imbalance(symbol) # Assumes _calculate_cumulative_imbalance is available - - history_data = { - 'symbol': symbol, - 'current_snapshot': cob_snapshot, - 'history': list(self.cob_data_history[symbol]), # Convert deque to list for consistent slicing - 'bucketed_data': self.cob_bucketed_data[symbol], - 'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance - 'timestamp': cob_snapshot['timestamp'] - } - - # Pass to orchestrator for model feeding - if self.orchestrator and hasattr(self.orchestrator, 'feed_cob_data'): - self.orchestrator.feed_cob_data(symbol, history_data) # Assumes feed_cob_data exists in orchestrator - - except Exception as e: - logger.debug(f"TEMPLATED DASHBOARD: Error feeding COB data to models: {e}") - - def _is_signal_generation_active(self) -> bool: - """Check if signal generation is active (e.g., models are loaded and running)""" - # For now, return true to always generate signals - # In a real system, this would check model loading status, training status, etc. - return True # Simplified for initial integration - - def _start_signal_generation_loop(self): - """Start signal generation loop to ensure continuous trading signals""" - try: - def signal_worker(): - logger.info("TEMPLATED DASHBOARD: Signal generation worker started") - while True: - try: - # Ensure signal generation is active before processing - if self._is_signal_generation_active(): - symbol = 'ETH/USDT' # Focus on ETH for now - current_price = self._get_current_price(symbol) - if current_price: - # Generate a momentum signal (simplified for demo) - signal = self._generate_momentum_signal(symbol, current_price) # Assumes _generate_momentum_signal is available - if signal: - self._process_dashboard_signal(signal) # Assumes _process_dashboard_signal is available - - # Generate a DQN signal if enabled - if self.dqn_inference_enabled: - dqn_signal = self._generate_dqn_signal(symbol, current_price) # Assumes _generate_dqn_signal is available - if dqn_signal: - self._process_dashboard_signal(dqn_signal) - - # Generate a CNN pivot signal if enabled - if self.cnn_inference_enabled: - cnn_signal = self._get_cnn_pivot_prediction() # Assumes _get_cnn_pivot_prediction is available - if cnn_signal: - self._process_dashboard_signal(cnn_signal) - - # Update session metrics every 1 second interval to reflect new trades - self._update_session_metrics() - - time.sleep(1) # Run every second for signal generation - - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Error in signal worker: {e}") - time.sleep(5) # Longer sleep on error - - signal_thread = threading.Thread(target=signal_worker, daemon=True) - signal_thread.start() - logger.info("TEMPLATED DASHBOARD: Signal generation loop started") - - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Error starting signal generation loop: {e}") - - def _start_actual_training_if_needed(self): - """Start actual model training with real data collection and training loops""" - try: - if not self.orchestrator: - logger.warning("TEMPLATED DASHBOARD: No orchestrator available for training") - return - logger.info("TEMPLATED DASHBOARD: TRAINING: Starting actual training system with real data collection") - self._start_real_training_system() - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Error starting comprehensive training system: {e}") - - def _start_real_training_system(self): - """Start real training system with data collection and actual model training""" - try: - # Training performance metrics - self.training_performance = { - 'decision': {'inference_times': [], 'training_times': [], 'total_calls': 0}, - 'cob_rl': {'inference_times': [], 'training_times': [], 'total_calls': 0}, - 'dqn': {'inference_times': [], 'training_times': [], 'total_calls': 0}, - 'cnn': {'inference_times': [], 'training_times': [], 'total_calls': 0}, - 'transformer': {'inference_times': [], 'training_times': [], 'total_calls': 0} # Added for transformer - } - - def training_coordinator(): - logger.info("TEMPLATED DASHBOARD: TRAINING: High-frequency training coordinator started") - training_iteration = 0 - last_dqn_training = 0 - last_cnn_training = 0 - last_decision_training = 0 - last_cob_rl_training = 0 - last_transformer_training = 0 # For transformer - - while True: - try: - training_iteration += 1 - current_time = time.time() - market_data = self._collect_training_data() # Assumes _collect_training_data is available - - if market_data: - logger.debug(f"TEMPLATED DASHBOARD: TRAINING: Collected {len(market_data)} market data points for training") - - # High-frequency training for split-second decisions - # Train decision fusion and COB RL as fast as hardware allows - if current_time - last_decision_training > 0.1: # Every 100ms - start_time = time.time() - self._perform_real_decision_training(market_data) # Assumes _perform_real_decision_training is available - training_time = time.time() - start_time - self.training_performance['decision']['training_times'].append(training_time) - self.training_performance['decision']['total_calls'] += 1 - last_decision_training = current_time - - # Keep only last 100 measurements - if len(self.training_performance['decision']['training_times']) > 100: - self.training_performance['decision']['training_times'] = self.training_performance['decision']['training_times'][-100:] - - # Advanced Transformer Training (every 200ms for comprehensive features) - if current_time - last_transformer_training > 0.2: # Every 200ms for transformer - start_time = time.time() - self._perform_real_transformer_training(market_data) # Assumes _perform_real_transformer_training is available - training_time = time.time() - start_time - self.training_performance['transformer']['training_times'].append(training_time) - self.training_performance['transformer']['total_calls'] += 1 - last_transformer_training = current_time # Update last training time - - # Keep only last 100 measurements - if len(self.training_performance['transformer']['training_times']) > 100: - self.training_performance['transformer']['training_times'] = self.training_performance['transformer']['training_times'][-100:] - - if current_time - last_cob_rl_training > 0.1: # Every 100ms - start_time = time.time() - self._perform_real_cob_rl_training(market_data) # Assumes _perform_real_cob_rl_training is available - training_time = time.time() - start_time - self.training_performance['cob_rl']['training_times'].append(training_time) - self.training_performance['cob_rl']['total_calls'] += 1 - last_cob_rl_training = current_time - - # Keep only last 100 measurements - if len(self.training_performance['cob_rl']['training_times']) > 100: - self.training_performance['cob_rl']['training_times'] = self.training_performance['cob_rl']['training_times'][-100:] - - # Standard frequency for larger models - if current_time - last_dqn_training > 30: - start_time = time.time() - self._perform_real_dqn_training(market_data) # Assumes _perform_real_dqn_training is available - training_time = time.time() - start_time - self.training_performance['dqn']['training_times'].append(training_time) - self.training_performance['dqn']['total_calls'] += 1 - last_dqn_training = current_time - - if len(self.training_performance['dqn']['training_times']) > 50: - self.training_performance['dqn']['training_times'] = self.training_performance['dqn']['training_times'][-50:] - - if current_time - last_cnn_training > 45: - start_time = time.time() - self._perform_real_cnn_training(market_data) # Assumes _perform_real_cnn_training is available - training_time = time.time() - start_time - self.training_performance['cnn']['training_times'].append(training_time) - self.training_performance['cnn']['total_calls'] += 1 - last_cnn_training = current_time - - if len(self.training_performance['cnn']['training_times']) > 50: - self.training_performance['cnn']['training_times'] = self.training_performance['cnn']['training_times'][-50:] - - self._update_training_progress(training_iteration) # Assumes _update_training_progress is available - - # Log performance metrics every 100 iterations - if training_iteration % 100 == 0: - self._log_training_performance() # Assumes _log_training_performance is available - logger.info(f"TEMPLATED DASHBOARD: TRAINING: Iteration {training_iteration} - High-frequency training active") - - # Minimal sleep for maximum responsiveness - time.sleep(0.05) # 50ms sleep for 20Hz training loop - - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: TRAINING: Error in training iteration {training_iteration}: {e}") - time.sleep(1) # Shorter error recovery - - training_thread = threading.Thread(target=training_coordinator, daemon=True) - training_thread.start() - logger.info("TEMPLATED DASHBOARD: Real training system started") - - except Exception as e: - logger.error(f"TEMPLATED DASHBOARD: Error starting real training system: {e}") - -def create_templated_dashboard(data_provider: Optional[DataProvider] = None, - orchestrator: Optional[TradingOrchestrator] = None, - trading_executor: Optional[TradingExecutor] = None) -> TemplatedTradingDashboard: - """Create templated trading dashboard""" - return TemplatedTradingDashboard(data_provider, orchestrator, trading_executor) \ No newline at end of file diff --git a/web/tensorboard_component.py b/web/tensorboard_component.py deleted file mode 100644 index e20bf7c..0000000 --- a/web/tensorboard_component.py +++ /dev/null @@ -1,173 +0,0 @@ -#!/usr/bin/env python3 -""" -TensorBoard Component for Dashboard - -This module provides a Dash component that embeds TensorBoard in the dashboard. -""" - -import dash -from dash import html, dcc -import dash_bootstrap_components as dbc -import logging -from typing import Optional, Dict, Any - -logger = logging.getLogger(__name__) - -def create_tensorboard_tab(tensorboard_url: str = "http://localhost:6006") -> html.Div: - """ - Create a dashboard tab that embeds TensorBoard - - Args: - tensorboard_url: URL of the TensorBoard server - - Returns: - html.Div: Dash component containing TensorBoard iframe - """ - return html.Div([ - dbc.Alert([ - html.I(className="fas fa-chart-line me-2"), - "TensorBoard Training Visualization", - html.A( - "Open in New Window", - href=tensorboard_url, - target="_blank", - className="ms-2 btn btn-sm btn-primary" - ) - ], color="info", className="mb-3"), - - # TensorBoard iframe - html.Iframe( - src=tensorboard_url, - style={ - 'width': '100%', - 'height': '800px', - 'border': 'none' - } - ), - - # Training metrics summary - html.Div([ - html.H5("Training Metrics Summary", className="mt-3"), - html.Div(id="training-metrics-summary", className="mt-2") - ], className="mt-3") - ]) - -def create_training_metrics_card() -> dbc.Card: - """ - Create a card displaying key training metrics - - Returns: - dbc.Card: Dash Bootstrap card component - """ - return dbc.Card([ - dbc.CardHeader([ - html.I(className="fas fa-brain me-2"), - "Training Metrics" - ]), - dbc.CardBody([ - dbc.Row([ - dbc.Col([ - html.H6("Model Status"), - html.Div(id="model-training-status", children="Initializing...") - ], width=6), - dbc.Col([ - html.H6("Training Progress"), - dbc.Progress(id="training-progress-bar", value=0, className="mb-2"), - html.Div(id="training-progress-text", children="0%") - ], width=6) - ], className="mb-3"), - - dbc.Row([ - dbc.Col([ - html.H6("Loss"), - html.Div(id="training-loss-value", children="N/A") - ], width=4), - dbc.Col([ - html.H6("Reward"), - html.Div(id="training-reward-value", children="N/A") - ], width=4), - dbc.Col([ - html.H6("State Quality"), - html.Div(id="training-state-quality", children="N/A") - ], width=4) - ], className="mb-3"), - - dbc.Row([ - dbc.Col([ - html.A( - dbc.Button([ - html.I(className="fas fa-chart-line me-2"), - "Open TensorBoard" - ], color="primary", size="sm", className="w-100"), - href="http://localhost:6006", - target="_blank" - ) - ], width=12) - ]) - ]) - ], className="mb-3") - -def create_tensorboard_status_indicator(tensorboard_url: str = "http://localhost:6006") -> html.Div: - """ - Create a status indicator for TensorBoard - - Args: - tensorboard_url: URL of the TensorBoard server - - Returns: - html.Div: Dash component showing TensorBoard status - """ - return html.Div([ - dbc.Button([ - html.I(className="fas fa-chart-line me-2"), - "TensorBoard" - ], - id="tensorboard-status-button", - color="success", - size="sm", - href=tensorboard_url, - target="_blank", - external_link=True, - className="ms-2") - ], id="tensorboard-status-container") - -def update_training_metrics_card(metrics: Dict[str, Any]) -> Dict[str, Any]: - """ - Update training metrics card with latest data - - Args: - metrics: Dictionary of training metrics - - Returns: - Dict: Dictionary of Dash component updates - """ - # Extract metrics - training_active = metrics.get("training_active", False) - loss = metrics.get("loss", None) - reward = metrics.get("reward", None) - state_quality = metrics.get("state_quality", None) - progress = metrics.get("progress", 0) - - # Format values - loss_str = f"{loss:.4f}" if loss is not None else "N/A" - reward_str = f"{reward:.4f}" if reward is not None else "N/A" - state_quality_str = f"{state_quality:.1%}" if state_quality is not None else "N/A" - progress_str = f"{progress:.1%}" - - # Determine status - if training_active: - status = "Training Active" - status_class = "text-success" - else: - status = "Training Inactive" - status_class = "text-warning" - - # Return updates - return { - "model-training-status": html.Span(status, className=status_class), - "training-progress-bar": progress * 100, - "training-progress-text": progress_str, - "training-loss-value": loss_str, - "training-reward-value": reward_str, - "training-state-quality": state_quality_str - } \ No newline at end of file diff --git a/web/tensorboard_integration.py b/web/tensorboard_integration.py deleted file mode 100644 index b95a745..0000000 --- a/web/tensorboard_integration.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python3 -""" -TensorBoard Integration for Dashboard - -This module provides integration between the trading dashboard and TensorBoard, -allowing training metrics to be visualized in real-time. -""" - -import os -import sys -import subprocess -import threading -import time -import logging -import webbrowser -from pathlib import Path -from typing import Optional, Dict, Any - -logger = logging.getLogger(__name__) - -class TensorBoardIntegration: - """ - TensorBoard integration for dashboard - - Provides methods to start TensorBoard server and access training metrics - """ - - def __init__(self, log_dir: str = "runs", port: int = 6006): - """ - Initialize TensorBoard integration - - Args: - log_dir: Directory containing TensorBoard logs - port: Port to run TensorBoard on - """ - self.log_dir = log_dir - self.port = port - self.process = None - self.url = f"http://localhost:{port}" - self.is_running = False - self.latest_metrics = {} - - # Create log directory if it doesn't exist - os.makedirs(log_dir, exist_ok=True) - - def start_tensorboard(self, open_browser: bool = False) -> bool: - """ - Start TensorBoard server in a separate process - - Args: - open_browser: Whether to open browser automatically - - Returns: - bool: True if TensorBoard was started successfully - """ - if self.is_running: - logger.info("TensorBoard is already running") - return True - - try: - # Check if TensorBoard is available - try: - import tensorboard - logger.info(f"TensorBoard version {tensorboard.__version__} available") - except ImportError: - logger.warning("TensorBoard not installed. Install with: pip install tensorboard") - return False - - # Check if log directory exists and has content - log_dir_path = Path(self.log_dir) - if not log_dir_path.exists(): - logger.warning(f"Log directory {self.log_dir} does not exist") - os.makedirs(self.log_dir, exist_ok=True) - logger.info(f"Created log directory {self.log_dir}") - - # Start TensorBoard process - cmd = [ - sys.executable, - "-m", - "tensorboard.main", - "--logdir", self.log_dir, - "--port", str(self.port), - "--reload_interval", "5", # Reload data every 5 seconds - "--reload_multifile", "true" # Better handling of multiple log files - ] - - logger.info(f"Starting TensorBoard: {' '.join(cmd)}") - - # Start process without capturing output - self.process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True - ) - - # Wait a moment for TensorBoard to start - time.sleep(2) - - # Check if process is running - if self.process.poll() is None: - self.is_running = True - logger.info(f"TensorBoard started at {self.url}") - - # Open browser if requested - if open_browser: - try: - webbrowser.open(self.url) - logger.info("Browser opened automatically") - except Exception as e: - logger.warning(f"Could not open browser: {e}") - - # Start monitoring thread - threading.Thread(target=self._monitor_process, daemon=True).start() - - return True - else: - stdout, stderr = self.process.communicate() - logger.error(f"TensorBoard failed to start: {stderr}") - return False - - except Exception as e: - logger.error(f"Error starting TensorBoard: {e}") - return False - - def _monitor_process(self): - """Monitor TensorBoard process and capture output""" - try: - while self.process and self.process.poll() is None: - # Read output line by line - for line in iter(self.process.stdout.readline, ''): - if line: - line = line.strip() - if line: - logger.debug(f"TensorBoard: {line}") - - time.sleep(0.1) - - # Process has ended - self.is_running = False - logger.info("TensorBoard process has ended") - - except Exception as e: - logger.error(f"Error monitoring TensorBoard process: {e}") - - def stop_tensorboard(self): - """Stop TensorBoard server""" - if self.process and self.process.poll() is None: - try: - self.process.terminate() - self.process.wait(timeout=5) - logger.info("TensorBoard stopped") - except subprocess.TimeoutExpired: - self.process.kill() - logger.warning("TensorBoard process killed after timeout") - except Exception as e: - logger.error(f"Error stopping TensorBoard: {e}") - - self.is_running = False - - def get_tensorboard_url(self) -> str: - """Get TensorBoard URL""" - return self.url - - def is_tensorboard_running(self) -> bool: - """Check if TensorBoard is running""" - if self.process: - return self.process.poll() is None - return False - - def get_latest_metrics(self) -> Dict[str, Any]: - """ - Get latest training metrics from TensorBoard - - This is a placeholder - in a real implementation, you would - parse TensorBoard event files to extract metrics - """ - # In a real implementation, you would parse TensorBoard event files - # For now, return placeholder data - return { - "training_active": self.is_running, - "tensorboard_url": self.url, - "metrics_available": self.is_running - } - -# Singleton instance -_tensorboard_integration = None - -def get_tensorboard_integration(log_dir: str = "runs", port: int = 6006) -> TensorBoardIntegration: - """ - Get TensorBoard integration singleton instance - - Args: - log_dir: Directory containing TensorBoard logs - port: Port to run TensorBoard on - - Returns: - TensorBoardIntegration: Singleton instance - """ - global _tensorboard_integration - if _tensorboard_integration is None: - _tensorboard_integration = TensorBoardIntegration(log_dir, port) - return _tensorboard_integration \ No newline at end of file