BIG CLEANUP

This commit is contained in:
Dobromir Popov
2025-08-08 14:58:55 +03:00
parent e39e9ee95a
commit 2b0d2679c6
162 changed files with 455 additions and 42814 deletions

56
CLEANUP_TODO.md Normal file
View File

@ -0,0 +1,56 @@
Cleanup run summary:
- Deleted files: 183
- NN\__init__.py
- NN\models\__init__.py
- NN\models\cnn_model.py
- NN\models\transformer_model.py
- NN\start_tensorboard.py
- NN\training\enhanced_rl_training_integration.py
- NN\training\example_checkpoint_usage.py
- NN\training\integrate_checkpoint_management.py
- NN\utils\__init__.py
- NN\utils\data_interface.py
- NN\utils\multi_data_interface.py
- NN\utils\realtime_analyzer.py
- NN\utils\signal_interpreter.py
- NN\utils\trading_env.py
- _dev\cleanup_models_now.py
- _tools\build_keep_set.py
- apply_trading_fixes.py
- apply_trading_fixes_to_main.py
- audit_training_system.py
- balance_trading_signals.py
- check_live_trading.py
- check_mexc_symbols.py
- cleanup_checkpoint_db.py
- cleanup_checkpoints.py
- core\__init__.py
- core\api_rate_limiter.py
- core\async_handler.py
- core\bookmap_data_provider.py
- core\bookmap_integration.py
- core\cnn_monitor.py
- core\cnn_training_pipeline.py
- core\config_sync.py
- core\enhanced_cnn_adapter.py
- core\enhanced_cob_websocket.py
- core\enhanced_orchestrator.py
- core\enhanced_training_integration.py
- core\exchanges\__init__.py
- core\exchanges\binance_interface.py
- core\exchanges\bybit\debug\test_bybit_balance.py
- core\exchanges\bybit_interface.py
- core\exchanges\bybit_rest_client.py
- core\exchanges\deribit_interface.py
- core\exchanges\mexc\debug\final_mexc_order_test.py
- core\exchanges\mexc\debug\fix_mexc_orders.py
- core\exchanges\mexc\debug\fix_mexc_orders_v2.py
- core\exchanges\mexc\debug\fix_mexc_orders_v3.py
- core\exchanges\mexc\debug\test_mexc_interface_debug.py
- core\exchanges\mexc\debug\test_mexc_order_signature.py
- core\exchanges\mexc\debug\test_mexc_order_signature_v2.py
- core\exchanges\mexc\debug\test_mexc_signature_debug.py
... and 133 more
- Removed test directories: 1
- tests
- Kept (excluded): 1

184
DELETE_CANDIDATES.txt Normal file
View File

@ -0,0 +1,184 @@
NN\__init__.py
NN\models\__init__.py
NN\models\cnn_model.py
NN\models\transformer_model.py
NN\start_tensorboard.py
NN\training\enhanced_realtime_training.py
NN\training\enhanced_rl_training_integration.py
NN\training\example_checkpoint_usage.py
NN\training\integrate_checkpoint_management.py
NN\utils\__init__.py
NN\utils\data_interface.py
NN\utils\multi_data_interface.py
NN\utils\realtime_analyzer.py
NN\utils\signal_interpreter.py
NN\utils\trading_env.py
_dev\cleanup_models_now.py
_tools\build_keep_set.py
apply_trading_fixes.py
apply_trading_fixes_to_main.py
audit_training_system.py
balance_trading_signals.py
check_live_trading.py
check_mexc_symbols.py
cleanup_checkpoint_db.py
cleanup_checkpoints.py
core\__init__.py
core\api_rate_limiter.py
core\async_handler.py
core\bookmap_data_provider.py
core\bookmap_integration.py
core\cnn_monitor.py
core\cnn_training_pipeline.py
core\config_sync.py
core\enhanced_cnn_adapter.py
core\enhanced_cob_websocket.py
core\enhanced_orchestrator.py
core\enhanced_training_integration.py
core\exchanges\__init__.py
core\exchanges\binance_interface.py
core\exchanges\bybit\debug\test_bybit_balance.py
core\exchanges\bybit_interface.py
core\exchanges\bybit_rest_client.py
core\exchanges\deribit_interface.py
core\exchanges\mexc\debug\final_mexc_order_test.py
core\exchanges\mexc\debug\fix_mexc_orders.py
core\exchanges\mexc\debug\fix_mexc_orders_v2.py
core\exchanges\mexc\debug\fix_mexc_orders_v3.py
core\exchanges\mexc\debug\test_mexc_interface_debug.py
core\exchanges\mexc\debug\test_mexc_order_signature.py
core\exchanges\mexc\debug\test_mexc_order_signature_v2.py
core\exchanges\mexc\debug\test_mexc_signature_debug.py
core\exchanges\mexc\debug\test_small_mexc_order.py
core\exchanges\mexc\test_live_trading.py
core\exchanges\mexc_interface.py
core\exchanges\trading_agent_test.py
core\mexc_webclient\__init__.py
core\mexc_webclient\auto_browser.py
core\mexc_webclient\browser_automation.py
core\mexc_webclient\mexc_futures_client.py
core\mexc_webclient\session_manager.py
core\mexc_webclient\test_mexc_futures_webclient.py
core\model_output_manager.py
core\negative_case_trainer.py
core\nn_decision_fusion.py
core\prediction_tracker.py
core\realtime_tick_processor.py
core\retrospective_trainer.py
core\rl_training_pipeline.py
core\robust_cob_provider.py
core\shared_cob_service.py
core\shared_data_manager.py
core\tick_aggregator.py
core\trading_action.py
core\trading_executor_fix.py
core\training_data_collector.py
core\williams_market_structure.py
dataprovider_realtime.py
debug\test_fixed_issues.py
debug\test_trading_fixes.py
debug\trade_audit.py
debug_training_methods.py
docs\exchanges\bybit\examples.py
example_usage_simplified_data_provider.py
kill_stale_processes.py
launch_training.py
main.py
main_clean.py
migrate_existing_models.py
model_manager.py
position_sync_enhancement.py
read_logs.py
reset_db_manager.py
reset_models_and_fix_mapping.py
run_clean_dashboard.py
run_continuous_training.py
run_crash_safe_dashboard.py
run_enhanced_rl_training.py
run_enhanced_training_dashboard.py
run_integrated_rl_cob_dashboard.py
run_mexc_browser.py
run_optimized_cob_system.py
run_realtime_rl_cob_trader.py
run_simple_dashboard.py
run_stable_dashboard.py
run_templated_dashboard.py
run_tensorboard.py
run_tests.py
scripts\kill_stale_processes.py
scripts\restart_dashboard_with_learning.py
scripts\restart_main_overnight.py
setup_mexc_browser.py
start_monitoring.py
start_overnight_training.py
system_stability_audit.py
test_build_base_data_performance.py
test_bybit_eth_futures.py
test_bybit_eth_futures_fixed.py
test_bybit_eth_live.py
test_bybit_public_api.py
test_cache_fix.py
test_cnn_integration.py
test_cob_dashboard.py
test_cob_data_quality.py
test_cob_websocket_only.py
test_continuous_cnn_training.py
test_dashboard_data_flow.py
test_dashboard_performance.py
test_data_integration.py
test_data_provider_integration.py
test_db_migration.py
test_deribit_integration.py
test_device_fix.py
test_device_training_fix.py
test_enhanced_cnn_adapter.py
test_enhanced_cob_websocket.py
test_enhanced_data_provider_websocket.py
test_enhanced_inference_logging.py
test_enhanced_training_integration.py
test_enhanced_training_simple.py
test_fifo_queues.py
test_hold_position_fix.py
test_imbalance_calculation.py
test_improved_data_integration.py
test_integrated_standardized_provider.py
test_leverage_fix.py
test_massive_dqn.py
test_mexc_order_fix.py
test_model_output_manager.py
test_model_registry.py
test_model_statistics.py
test_model_stats.py
test_model_training.py
test_orchestrator_fix.py
test_order_sync_and_fees.py
test_position_based_rewards.py
test_profitability_reward_system.py
test_training_data_collection.py
test_training_fixes.py
test_websocket_cob_data.py
tests\cob\test_cob_comparison.py
tests\cob\test_cob_data_stability.py
tests\test_training.py
tests\test_training_integration.py
tests\test_training_status.py
tests\test_universal_data_format.py
tests\test_universal_stream_integration.py
trading_main.py
utils\__init__.py
utils\async_task_manager.py
utils\launch_tensorboard.py
utils\model_utils.py
utils\port_manager.py
utils\process_supervisor.py
utils\reward_calculator.py
utils\system_monitor.py
utils\tensorboard_logger.py
utils\text_logger.py
verify_checkpoint_system.py
web\__init__.py
web\dashboard_fix.py
web\dashboard_model.py
web\layout_manager_with_tensorboard.py
web\tensorboard_component.py
web\tensorboard_integration.py

84
DEPENDENCY_TREE.md Normal file
View File

@ -0,0 +1,84 @@
Dependency tree from dashboards (module -> deps):
- NN\models\advanced_transformer_trading.py
- NN\models\cob_rl_model.py
- models\__init__.py
- NN\models\dqn_agent.py
- utils\checkpoint_manager.py
- utils\training_integration.py
- NN\models\enhanced_cnn.py
- NN\models\model_interfaces.py
- NN\models\standardized_cnn.py
- core\data_models.py
- core\cob_integration.py
- core\config.py
- safe_logging.py
- core\data_models.py
- core\data_provider.py
- utils\cache_manager.py
- utils\timezone_utils.py
- core\exchanges\exchange_factory.py
- core\exchanges\exchange_interface.py
- core\extrema_trainer.py
- utils\checkpoint_manager.py
- utils\training_integration.py
- core\multi_exchange_cob_provider.py
- core\orchestrator.py
- NN\models\advanced_transformer_trading.py
- NN\models\cob_rl_model.py
- NN\models\dqn_agent.py
- NN\models\enhanced_cnn.py
- NN\models\model_interfaces.py
- NN\models\standardized_cnn.py
- core\data_models.py
- core\extrema_trainer.py
- enhanced_realtime_training.py
- models\__init__.py
- utils\checkpoint_manager.py
- utils\database_manager.py
- utils\inference_logger.py
- core\overnight_training_coordinator.py
- core\realtime_rl_cob_trader.py
- NN\models\cob_rl_model.py
- core\trading_executor.py
- utils\checkpoint_manager.py
- core\standardized_data_provider.py
- core\trade_data_manager.py
- core\trading_executor.py
- core\data_provider.py
- core\exchanges\exchange_factory.py
- core\exchanges\exchange_interface.py
- core\training_integration.py
- core\universal_data_adapter.py
- enhanced_realtime_training.py
- models\__init__.py
- safe_logging.py
- utils\cache_manager.py
- utils\checkpoint_manager.py
- utils\database_manager.py
- utils\inference_logger.py
- utils\timezone_utils.py
- utils\training_integration.py
- web\clean_dashboard.py
- NN\models\advanced_transformer_trading.py
- NN\models\standardized_cnn.py
- core\cob_integration.py
- core\config.py
- core\data_models.py
- core\data_provider.py
- core\multi_exchange_cob_provider.py
- core\orchestrator.py
- core\overnight_training_coordinator.py
- core\realtime_rl_cob_trader.py
- core\standardized_data_provider.py
- core\trade_data_manager.py
- core\trading_executor.py
- core\training_integration.py
- core\universal_data_adapter.py
- utils\checkpoint_manager.py
- utils\timezone_utils.py
- web\component_manager.py
- web\layout_manager.py
- web\cob_realtime_dashboard.py
- core\cob_integration.py
- web\component_manager.py
- web\layout_manager.py

35
KEEP_SET.txt Normal file
View File

@ -0,0 +1,35 @@
NN\models\advanced_transformer_trading.py
NN\models\cob_rl_model.py
NN\models\dqn_agent.py
NN\models\enhanced_cnn.py
NN\models\model_interfaces.py
NN\models\standardized_cnn.py
core\cob_integration.py
core\config.py
core\data_models.py
core\data_provider.py
core\exchanges\exchange_factory.py
core\exchanges\exchange_interface.py
core\extrema_trainer.py
core\multi_exchange_cob_provider.py
core\orchestrator.py
core\overnight_training_coordinator.py
core\realtime_rl_cob_trader.py
core\standardized_data_provider.py
core\trade_data_manager.py
core\trading_executor.py
core\training_integration.py
core\universal_data_adapter.py
enhanced_realtime_training.py
models\__init__.py
safe_logging.py
utils\cache_manager.py
utils\checkpoint_manager.py
utils\database_manager.py
utils\inference_logger.py
utils\timezone_utils.py
utils\training_integration.py
web\clean_dashboard.py
web\cob_realtime_dashboard.py
web\component_manager.py
web\layout_manager.py

View File

@ -1,16 +0,0 @@
"""
Neural Network Trading System
============================
A comprehensive neural network trading system that uses deep learning models
to analyze cryptocurrency price data and generate trading signals.
The system consists of:
1. Data Interface: Connects to realtime trading data
2. CNN Model: Deep convolutional neural network for feature extraction
3. Transformer Model: Processes high-level features for improved pattern recognition
4. MoE: Mixture of Experts model that combines multiple neural networks
"""
__version__ = '0.1.0'
__author__ = 'Gogo2 Project'

View File

@ -1,27 +0,0 @@
"""
Neural Network Models
====================
This package contains the neural network models used in the trading system:
- CNN Model: Deep convolutional neural network for feature extraction
- DQN Agent: Deep Q-Network for reinforcement learning
- COB RL Model: Specialized RL model for order book data
- Advanced Transformer: High-performance transformer for trading
PyTorch implementation only.
"""
# Import core models
from NN.models.dqn_agent import DQNAgent
from NN.models.cob_rl_model import COBRLModelInterface
from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig
from NN.models.standardized_cnn import StandardizedCNN # Use the unified CNN model
# Import model interfaces
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
# Export the unified StandardizedCNN as CNNModel for compatibility
CNNModel = StandardizedCNN
__all__ = ['CNNModel', 'StandardizedCNN', 'DQNAgent', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']

View File

@ -1,201 +0,0 @@
# """
# Legacy CNN Model Compatibility Layer
# This module provides compatibility redirects to the unified StandardizedCNN model.
# All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired
# in favor of the StandardizedCNN architecture.
# """
# import logging
# import warnings
# from typing import Tuple, Dict, Any, Optional
# import torch
# import numpy as np
# # Import the standardized CNN model
# from .standardized_cnn import StandardizedCNN
# logger = logging.getLogger(__name__)
# # Compatibility aliases and wrappers
# class EnhancedCNNModel:
# """Legacy compatibility wrapper - redirects to StandardizedCNN"""
# def __init__(self, *args, **kwargs):
# warnings.warn(
# "EnhancedCNNModel is deprecated. Use StandardizedCNN instead.",
# DeprecationWarning,
# stacklevel=2
# )
# # Create StandardizedCNN with default parameters
# self.standardized_cnn = StandardizedCNN()
# logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN")
# def __getattr__(self, name):
# """Delegate all method calls to StandardizedCNN"""
# return getattr(self.standardized_cnn, name)
# class CNNModelTrainer:
# """Legacy compatibility wrapper for CNN training"""
# def __init__(self, model=None, *args, **kwargs):
# warnings.warn(
# "CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.",
# DeprecationWarning,
# stacklevel=2
# )
# if isinstance(model, EnhancedCNNModel):
# self.model = model.standardized_cnn
# else:
# self.model = StandardizedCNN()
# logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()")
# def train_step(self, x, y, *args, **kwargs):
# """Legacy train step wrapper"""
# try:
# # Convert to BaseDataInput format if needed
# if hasattr(x, 'get_feature_vector'):
# # Already BaseDataInput
# base_input = x
# else:
# # Create mock BaseDataInput for legacy compatibility
# from core.data_models import BaseDataInput
# base_input = BaseDataInput()
# # Set mock feature vector
# if isinstance(x, torch.Tensor):
# feature_vector = x.flatten().cpu().numpy()
# else:
# feature_vector = np.array(x).flatten()
# # Pad or truncate to expected size
# expected_size = self.model.expected_feature_dim
# if len(feature_vector) < expected_size:
# padding = np.zeros(expected_size - len(feature_vector))
# feature_vector = np.concatenate([feature_vector, padding])
# else:
# feature_vector = feature_vector[:expected_size]
# base_input._feature_vector = feature_vector
# # Convert target to string format
# if isinstance(y, torch.Tensor):
# y_val = y.item() if y.numel() == 1 else y.argmax().item()
# else:
# y_val = int(y) if np.isscalar(y) else int(np.argmax(y))
# target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'}
# target = target_map.get(y_val, 'HOLD')
# # Use StandardizedCNN training
# optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
# loss = self.model.train_step([base_input], [target], optimizer)
# return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5}
# except Exception as e:
# logger.error(f"Legacy train_step error: {e}")
# return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5}
# # class CNNModel:
# # """Legacy compatibility wrapper for CNN model interface"""
# # def __init__(self, input_shape=(900, 50), output_size=3, model_path=None):
# # warnings.warn(
# # "CNNModel is deprecated. Use StandardizedCNN directly.",
# # DeprecationWarning,
# # stacklevel=2
# # )
# # self.input_shape = input_shape
# # self.output_size = output_size
# # self.standardized_cnn = StandardizedCNN()
# # self.trainer = CNNModelTrainer(self.standardized_cnn)
# # logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN")
# # def build_model(self, **kwargs):
# # """Legacy build method - no-op for StandardizedCNN"""
# # return self
# # def predict(self, X):
# # """Legacy predict method"""
# # try:
# # # Convert input to BaseDataInput
# # from core.data_models import BaseDataInput
# # base_input = BaseDataInput()
# # if isinstance(X, np.ndarray):
# # feature_vector = X.flatten()
# # else:
# # feature_vector = np.array(X).flatten()
# # # Pad or truncate to expected size
# # expected_size = self.standardized_cnn.expected_feature_dim
# # if len(feature_vector) < expected_size:
# # padding = np.zeros(expected_size - len(feature_vector))
# # feature_vector = np.concatenate([feature_vector, padding])
# # else:
# # feature_vector = feature_vector[:expected_size]
# # base_input._feature_vector = feature_vector
# # # Get prediction from StandardizedCNN
# # result = self.standardized_cnn.predict_from_base_input(base_input)
# # # Convert to legacy format
# # action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
# # pred_class = np.array([action_map.get(result.predictions['action'], 2)])
# # pred_proba = np.array([result.predictions['action_probabilities']])
# # return pred_class, pred_proba
# # except Exception as e:
# # logger.error(f"Legacy predict error: {e}")
# # # Return safe defaults
# # pred_class = np.array([2]) # HOLD
# # pred_proba = np.array([[0.33, 0.33, 0.34]])
# # return pred_class, pred_proba
# # def fit(self, X, y, **kwargs):
# # """Legacy fit method"""
# # try:
# # return self.trainer.train_step(X, y)
# # except Exception as e:
# # logger.error(f"Legacy fit error: {e}")
# # return self
# # def save(self, filepath: str):
# # """Legacy save method"""
# # try:
# # torch.save(self.standardized_cnn.state_dict(), filepath)
# # logger.info(f"StandardizedCNN saved to {filepath}")
# # except Exception as e:
# # logger.error(f"Error saving model: {e}")
# def create_enhanced_cnn_model(input_size: int = 60,
# feature_dim: int = 50,
# output_size: int = 3,
# base_channels: int = 256,
# device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]:
# """Legacy compatibility function - returns StandardizedCNN"""
# warnings.warn(
# "create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.",
# DeprecationWarning,
# stacklevel=2
# )
# model = StandardizedCNN()
# trainer = CNNModelTrainer(model)
# logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly")
# return model, trainer
# # Export compatibility symbols
# __all__ = [
# 'EnhancedCNNModel',
# 'CNNModelTrainer',
# # 'CNNModel',
# 'create_enhanced_cnn_model'
# ]

View File

@ -1,821 +0,0 @@
"""
Transformer Neural Network for timeseries analysis
This module implements a Transformer model with attention mechanisms for cryptocurrency price analysis.
It also includes a Mixture of Experts model that combines predictions from multiple models.
"""
import os
import logging
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (
Input, Dense, Dropout, BatchNormalization,
Concatenate, Layer, LayerNormalization, MultiHeadAttention,
Add, GlobalAveragePooling1D, Conv1D, Reshape
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import datetime
import json
logger = logging.getLogger(__name__)
class TransformerBlock(Layer):
"""
Transformer block implementation with multi-head attention and feed-forward networks.
"""
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
super(TransformerBlock, self).__init__()
self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.ffn = tf.keras.Sequential([
Dense(ff_dim, activation="relu"),
Dense(embed_dim),
])
self.layernorm1 = LayerNormalization(epsilon=1e-6)
self.layernorm2 = LayerNormalization(epsilon=1e-6)
self.dropout1 = Dropout(rate)
self.dropout2 = Dropout(rate)
def call(self, inputs, training=False):
attn_output = self.att(inputs, inputs)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out1 + ffn_output)
def get_config(self):
config = super().get_config()
config.update({
'att': self.att,
'ffn': self.ffn,
'layernorm1': self.layernorm1,
'layernorm2': self.layernorm2,
'dropout1': self.dropout1,
'dropout2': self.dropout2
})
return config
class PositionalEncoding(Layer):
"""
Positional encoding layer to add position information to input embeddings.
"""
def __init__(self, position, d_model):
super(PositionalEncoding, self).__init__()
self.position = position
self.d_model = d_model
self.pos_encoding = self.positional_encoding(position, d_model)
def get_angles(self, position, i, d_model):
angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
return position * angles
def positional_encoding(self, position, d_model):
angle_rads = self.get_angles(
position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
d_model=d_model
)
# Apply sin to even indices in the array
sines = tf.math.sin(angle_rads[:, 0::2])
# Apply cos to odd indices in the array
cosines = tf.math.cos(angle_rads[:, 1::2])
pos_encoding = tf.concat([sines, cosines], axis=-1)
pos_encoding = pos_encoding[tf.newaxis, ...]
return tf.cast(pos_encoding, tf.float32)
def call(self, inputs):
return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
def get_config(self):
config = super().get_config()
config.update({
'position': self.position,
'd_model': self.d_model,
'pos_encoding': self.pos_encoding
})
return config
class TransformerModel:
"""
Transformer Neural Network for time series analysis.
This model uses self-attention mechanisms to capture relationships between
different time points in the input data.
"""
def __init__(self, ts_input_shape=(20, 5), feature_input_shape=64, output_size=1, model_dir="NN/models/saved"):
"""
Initialize the Transformer model.
Args:
ts_input_shape (tuple): Shape of time series input data (sequence_length, features)
feature_input_shape (int): Shape of additional feature input (e.g., from CNN)
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
model_dir (str): Directory to save trained models
"""
self.ts_input_shape = ts_input_shape
self.feature_input_shape = feature_input_shape
self.output_size = output_size
self.model_dir = model_dir
self.model = None
self.history = None
# Create model directory if it doesn't exist
os.makedirs(self.model_dir, exist_ok=True)
logger.info(f"Initialized Transformer model with TS input shape {ts_input_shape}, "
f"feature input shape {feature_input_shape}, and output size {output_size}")
def build_model(self, embed_dim=32, num_heads=4, ff_dim=64, num_transformer_blocks=2, dropout_rate=0.1, learning_rate=0.001):
"""
Build the Transformer model architecture.
Args:
embed_dim (int): Embedding dimension for transformer
num_heads (int): Number of attention heads
ff_dim (int): Hidden dimension of the feed forward network
num_transformer_blocks (int): Number of transformer blocks
dropout_rate (float): Dropout rate for regularization
learning_rate (float): Learning rate for Adam optimizer
Returns:
The compiled model
"""
# Time series input
ts_inputs = Input(shape=self.ts_input_shape, name="ts_input")
# Additional feature input (e.g., from CNN)
feature_inputs = Input(shape=(self.feature_input_shape,), name="feature_input")
# Process time series with transformer
# First, project the input to the embedding dimension
x = Conv1D(embed_dim, 1, activation="relu")(ts_inputs)
# Add positional encoding
x = PositionalEncoding(self.ts_input_shape[0], embed_dim)(x)
# Add transformer blocks
for _ in range(num_transformer_blocks):
x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate)(x)
# Global pooling to get a single vector representation
x = GlobalAveragePooling1D()(x)
x = Dropout(dropout_rate)(x)
# Combine with additional features
combined = Concatenate()([x, feature_inputs])
# Dense layers for final classification/regression
x = Dense(64, activation="relu")(combined)
x = BatchNormalization()(x)
x = Dropout(dropout_rate)(x)
# Output layer
if self.output_size == 1:
# Binary classification (up/down)
outputs = Dense(1, activation='sigmoid', name='output')(x)
loss = 'binary_crossentropy'
metrics = ['accuracy']
elif self.output_size == 3:
# Multi-class classification (buy/hold/sell)
outputs = Dense(3, activation='softmax', name='output')(x)
loss = 'categorical_crossentropy'
metrics = ['accuracy']
else:
# Regression
outputs = Dense(self.output_size, activation='linear', name='output')(x)
loss = 'mse'
metrics = ['mae']
# Create and compile model
self.model = Model(inputs=[ts_inputs, feature_inputs], outputs=outputs)
# Compile with Adam optimizer
self.model.compile(
optimizer=Adam(learning_rate=learning_rate),
loss=loss,
metrics=metrics
)
# Log model summary
self.model.summary(print_fn=lambda x: logger.info(x))
return self.model
def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
callbacks=None, class_weights=None):
"""
Train the Transformer model on the provided data.
Args:
X_ts (numpy.ndarray): Time series input features
X_features (numpy.ndarray): Additional input features
y (numpy.ndarray): Target labels
batch_size (int): Batch size
epochs (int): Number of epochs
validation_split (float): Fraction of data to use for validation
callbacks (list): List of Keras callbacks
class_weights (dict): Class weights for imbalanced datasets
Returns:
History object containing training metrics
"""
if self.model is None:
self.build_model()
# Default callbacks if none provided
if callbacks is None:
# Create a timestamp for model checkpoints
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
callbacks = [
EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True
),
ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5,
min_lr=1e-6
),
ModelCheckpoint(
filepath=os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5"),
monitor='val_loss',
save_best_only=True
)
]
# Check if y needs to be one-hot encoded for multi-class
if self.output_size == 3 and len(y.shape) == 1:
y = tf.keras.utils.to_categorical(y, num_classes=3)
# Train the model
logger.info(f"Training Transformer model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
self.history = self.model.fit(
[X_ts, X_features], y,
batch_size=batch_size,
epochs=epochs,
validation_split=validation_split,
callbacks=callbacks,
class_weight=class_weights,
verbose=2
)
# Save the trained model
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(self.model_dir, f"transformer_model_final_{timestamp}.h5")
self.model.save(model_path)
logger.info(f"Model saved to {model_path}")
# Save training history
history_path = os.path.join(self.model_dir, f"transformer_model_history_{timestamp}.json")
with open(history_path, 'w') as f:
# Convert numpy values to Python native types for JSON serialization
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
json.dump(history_dict, f, indent=2)
return self.history
def evaluate(self, X_ts, X_features, y):
"""
Evaluate the model on test data.
Args:
X_ts (numpy.ndarray): Time series input features
X_features (numpy.ndarray): Additional input features
y (numpy.ndarray): Target labels
Returns:
dict: Evaluation metrics
"""
if self.model is None:
raise ValueError("Model has not been built or trained yet")
# Convert y to one-hot encoding for multi-class
if self.output_size == 3 and len(y.shape) == 1:
y = tf.keras.utils.to_categorical(y, num_classes=3)
# Evaluate model
logger.info(f"Evaluating Transformer model on {len(X_ts)} samples")
eval_results = self.model.evaluate([X_ts, X_features], y, verbose=0)
metrics = {}
for metric, value in zip(self.model.metrics_names, eval_results):
metrics[metric] = value
logger.info(f"{metric}: {value:.4f}")
return metrics
def predict(self, X_ts, X_features=None):
"""
Make predictions on new data.
Args:
X_ts (numpy.ndarray): Time series input features
X_features (numpy.ndarray): Additional input features
Returns:
tuple: (y_pred, y_proba) where:
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
y_proba is the class probability
"""
if self.model is None:
raise ValueError("Model has not been built or trained yet")
# Ensure X_ts has the right shape
if len(X_ts.shape) == 2:
# Single sample, add batch dimension
X_ts = np.expand_dims(X_ts, axis=0)
# Ensure X_features has the right shape
if X_features is None:
# Extract features from time series data if no external features provided
X_features = self._extract_features_from_timeseries(X_ts)
elif len(X_features.shape) == 1:
# Single sample, add batch dimension
X_features = np.expand_dims(X_features, axis=0)
def _extract_features_from_timeseries(self, X_ts: np.ndarray) -> np.ndarray:
"""Extract meaningful features from time series data instead of using dummy zeros"""
try:
batch_size = X_ts.shape[0]
features = []
for i in range(batch_size):
sample = X_ts[i] # Shape: (timesteps, features)
# Extract statistical features from each feature dimension
sample_features = []
for feature_idx in range(sample.shape[1]):
feature_data = sample[:, feature_idx]
# Basic statistical features
sample_features.extend([
np.mean(feature_data), # Mean
np.std(feature_data), # Standard deviation
np.min(feature_data), # Minimum
np.max(feature_data), # Maximum
np.percentile(feature_data, 25), # 25th percentile
np.percentile(feature_data, 75), # 75th percentile
])
# Trend features
if len(feature_data) > 1:
# Linear trend (slope)
x = np.arange(len(feature_data))
slope = np.polyfit(x, feature_data, 1)[0]
sample_features.append(slope)
# Rate of change
rate_of_change = (feature_data[-1] - feature_data[0]) / feature_data[0] if feature_data[0] != 0 else 0
sample_features.append(rate_of_change)
else:
sample_features.extend([0.0, 0.0])
# Pad or truncate to expected feature size
while len(sample_features) < self.feature_input_shape:
sample_features.append(0.0)
sample_features = sample_features[:self.feature_input_shape]
features.append(sample_features)
return np.array(features, dtype=np.float32)
except Exception as e:
logger.error(f"Error extracting features from time series: {e}")
# Fallback to zeros if extraction fails
return np.zeros((X_ts.shape[0], self.feature_input_shape), dtype=np.float32)
# Get predictions
y_proba = self.model.predict([X_ts, X_features])
# Process based on output type
if self.output_size == 1:
# Binary classification
y_pred = (y_proba > 0.5).astype(int).flatten()
return y_pred, y_proba.flatten()
elif self.output_size == 3:
# Multi-class classification
y_pred = np.argmax(y_proba, axis=1)
return y_pred, y_proba
else:
# Regression
return y_proba, y_proba
def save(self, filepath=None):
"""
Save the model to disk.
Args:
filepath (str): Path to save the model
Returns:
str: Path where the model was saved
"""
if self.model is None:
raise ValueError("Model has not been built yet")
if filepath is None:
# Create a default filepath with timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5")
self.model.save(filepath)
logger.info(f"Model saved to {filepath}")
return filepath
def load(self, filepath):
"""
Load a saved model from disk.
Args:
filepath (str): Path to the saved model
Returns:
The loaded model
"""
# Register custom layers
custom_objects = {
'TransformerBlock': TransformerBlock,
'PositionalEncoding': PositionalEncoding
}
self.model = load_model(filepath, custom_objects=custom_objects)
logger.info(f"Model loaded from {filepath}")
return self.model
def plot_training_history(self):
"""
Plot training history (loss and metrics).
Returns:
str: Path to the saved plot
"""
if self.history is None:
raise ValueError("Model has not been trained yet")
plt.figure(figsize=(12, 5))
# Plot loss
plt.subplot(1, 2, 1)
plt.plot(self.history.history['loss'], label='Training Loss')
if 'val_loss' in self.history.history:
plt.plot(self.history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
# Plot accuracy
plt.subplot(1, 2, 2)
if 'accuracy' in self.history.history:
plt.plot(self.history.history['accuracy'], label='Training Accuracy')
if 'val_accuracy' in self.history.history:
plt.plot(self.history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
elif 'mae' in self.history.history:
plt.plot(self.history.history['mae'], label='Training MAE')
if 'val_mae' in self.history.history:
plt.plot(self.history.history['val_mae'], label='Validation MAE')
plt.title('Model MAE')
plt.ylabel('MAE')
plt.xlabel('Epoch')
plt.legend()
plt.tight_layout()
# Save figure
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
fig_path = os.path.join(self.model_dir, f"transformer_training_history_{timestamp}.png")
plt.savefig(fig_path)
plt.close()
logger.info(f"Training history plot saved to {fig_path}")
return fig_path
class MixtureOfExpertsModel:
"""
Mixture of Experts (MoE) model.
This model combines predictions from multiple expert models (such as CNN and Transformer)
using a weighted ensemble approach.
"""
def __init__(self, output_size=1, model_dir="NN/models/saved"):
"""
Initialize the MoE model.
Args:
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
model_dir (str): Directory to save trained models
"""
self.output_size = output_size
self.model_dir = model_dir
self.model = None
self.history = None
self.experts = {}
# Create model directory if it doesn't exist
os.makedirs(self.model_dir, exist_ok=True)
logger.info(f"Initialized Mixture of Experts model with output size {output_size}")
def add_expert(self, name, model):
"""
Add an expert model to the MoE.
Args:
name (str): Name of the expert model
model: The expert model instance
Returns:
None
"""
self.experts[name] = model
logger.info(f"Added expert model '{name}' to MoE")
def build_model(self, ts_input_shape=(20, 5), expert_weights=None, learning_rate=0.001):
"""
Build the MoE model by combining expert models.
Args:
ts_input_shape (tuple): Shape of time series input data
expert_weights (dict): Weights for each expert model
learning_rate (float): Learning rate for Adam optimizer
Returns:
The compiled model
"""
# Time series input
ts_inputs = Input(shape=ts_input_shape, name="ts_input")
# Additional feature input (from CNN)
feature_inputs = Input(shape=(64,), name="feature_input") # Default size for features
# Process with each expert model
expert_outputs = []
expert_names = []
for name, expert in self.experts.items():
# Skip if expert model is not valid or doesn't have a call/predict method
if expert is None:
logger.warning(f"Expert model '{name}' is None, skipping")
continue
try:
# Different handling based on model type
if name == 'cnn':
# CNN model takes only time series input
expert_output = expert(ts_inputs)
expert_outputs.append(expert_output)
expert_names.append(name)
elif name == 'transformer':
# Transformer model takes both time series and feature inputs
expert_output = expert([ts_inputs, feature_inputs])
expert_outputs.append(expert_output)
expert_names.append(name)
else:
logger.warning(f"Unknown expert model type: {name}")
except Exception as e:
logger.error(f"Error adding expert '{name}': {str(e)}")
if not expert_outputs:
logger.error("No valid expert models found")
return None
# Use expert weighting
if expert_weights is None:
# Equal weighting
weights = [1.0 / len(expert_outputs)] * len(expert_outputs)
else:
# User-provided weights
weights = [expert_weights.get(name, 1.0 / len(expert_outputs)) for name in expert_names]
# Normalize weights
weights = [w / sum(weights) for w in weights]
# Combine expert outputs using weighted average
if len(expert_outputs) == 1:
# Only one expert, use its output directly
combined_output = expert_outputs[0]
else:
# Multiple experts, compute weighted average
weighted_outputs = [output * weight for output, weight in zip(expert_outputs, weights)]
combined_output = Add()(weighted_outputs)
# Create the MoE model
moe_model = Model(inputs=[ts_inputs, feature_inputs], outputs=combined_output)
# Compile the model
if self.output_size == 1:
# Binary classification
moe_model.compile(
optimizer=Adam(learning_rate=learning_rate),
loss='binary_crossentropy',
metrics=['accuracy']
)
elif self.output_size == 3:
# Multi-class classification for BUY/HOLD/SELL
moe_model.compile(
optimizer=Adam(learning_rate=learning_rate),
loss='categorical_crossentropy',
metrics=['accuracy']
)
else:
# Regression
moe_model.compile(
optimizer=Adam(learning_rate=learning_rate),
loss='mse',
metrics=['mae']
)
self.model = moe_model
# Log model summary
self.model.summary(print_fn=lambda x: logger.info(x))
logger.info(f"Built MoE model with weights: {weights}")
return self.model
def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
callbacks=None, class_weights=None):
"""
Train the MoE model on the provided data.
Args:
X_ts (numpy.ndarray): Time series input features
X_features (numpy.ndarray): Additional input features
y (numpy.ndarray): Target labels
batch_size (int): Batch size
epochs (int): Number of epochs
validation_split (float): Fraction of data to use for validation
callbacks (list): List of Keras callbacks
class_weights (dict): Class weights for imbalanced datasets
Returns:
History object containing training metrics
"""
if self.model is None:
logger.error("MoE model has not been built yet")
return None
# Default callbacks if none provided
if callbacks is None:
# Create a timestamp for model checkpoints
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
callbacks = [
EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True
),
ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5,
min_lr=1e-6
),
ModelCheckpoint(
filepath=os.path.join(self.model_dir, f"moe_model_{timestamp}.h5"),
monitor='val_loss',
save_best_only=True
)
]
# Check if y needs to be one-hot encoded for multi-class
if self.output_size == 3 and len(y.shape) == 1:
y = tf.keras.utils.to_categorical(y, num_classes=3)
# Train the model
logger.info(f"Training MoE model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
self.history = self.model.fit(
[X_ts, X_features], y,
batch_size=batch_size,
epochs=epochs,
validation_split=validation_split,
callbacks=callbacks,
class_weight=class_weights,
verbose=2
)
# Save the trained model
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(self.model_dir, f"moe_model_final_{timestamp}.h5")
self.model.save(model_path)
logger.info(f"Model saved to {model_path}")
# Save training history
history_path = os.path.join(self.model_dir, f"moe_model_history_{timestamp}.json")
with open(history_path, 'w') as f:
# Convert numpy values to Python native types for JSON serialization
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
json.dump(history_dict, f, indent=2)
return self.history
def predict(self, X_ts, X_features=None):
"""
Make predictions on new data.
Args:
X_ts (numpy.ndarray): Time series input features
X_features (numpy.ndarray): Additional input features
Returns:
tuple: (y_pred, y_proba) where:
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
y_proba is the class probability
"""
if self.model is None:
raise ValueError("Model has not been built or trained yet")
# Ensure X_ts has the right shape
if len(X_ts.shape) == 2:
# Single sample, add batch dimension
X_ts = np.expand_dims(X_ts, axis=0)
# Ensure X_features has the right shape
if X_features is None:
# Create dummy features with zeros
X_features = np.zeros((X_ts.shape[0], 64)) # Default size
elif len(X_features.shape) == 1:
# Single sample, add batch dimension
X_features = np.expand_dims(X_features, axis=0)
# Get predictions
y_proba = self.model.predict([X_ts, X_features])
# Process based on output type
if self.output_size == 1:
# Binary classification
y_pred = (y_proba > 0.5).astype(int).flatten()
return y_pred, y_proba.flatten()
elif self.output_size == 3:
# Multi-class classification
y_pred = np.argmax(y_proba, axis=1)
return y_pred, y_proba
else:
# Regression
return y_proba, y_proba
def save(self, filepath=None):
"""
Save the model to disk.
Args:
filepath (str): Path to save the model
Returns:
str: Path where the model was saved
"""
if self.model is None:
raise ValueError("Model has not been built yet")
if filepath is None:
# Create a default filepath with timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(self.model_dir, f"moe_model_{timestamp}.h5")
self.model.save(filepath)
logger.info(f"Model saved to {filepath}")
return filepath
def load(self, filepath):
"""
Load a saved model from disk.
Args:
filepath (str): Path to the saved model
Returns:
The loaded model
"""
# Register custom layers
custom_objects = {
'TransformerBlock': TransformerBlock,
'PositionalEncoding': PositionalEncoding
}
self.model = load_model(filepath, custom_objects=custom_objects)
logger.info(f"Model loaded from {filepath}")
return self.model
# Example usage:
if __name__ == "__main__":
# This would be a complete implementation in a real system
print("Transformer and MoE models defined, but not implemented here.")

View File

@ -1,88 +0,0 @@
#!/usr/bin/env python
"""
Start TensorBoard for monitoring neural network training
"""
import os
import sys
import subprocess
import webbrowser
from time import sleep
def start_tensorboard(logdir="NN/models/saved/logs", port=6006, open_browser=True):
"""
Start TensorBoard in a subprocess
Args:
logdir: Directory containing TensorBoard logs
port: Port to run TensorBoard on
open_browser: Whether to open a browser automatically
"""
# Make sure the log directory exists
os.makedirs(logdir, exist_ok=True)
# Create command
cmd = [
sys.executable,
"-m",
"tensorboard.main",
f"--logdir={logdir}",
f"--port={port}",
"--bind_all"
]
print(f"Starting TensorBoard with logs from {logdir} on port {port}")
print(f"Command: {' '.join(cmd)}")
# Start TensorBoard in a subprocess
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True
)
# Wait for TensorBoard to start up
for line in process.stdout:
print(line.strip())
if "TensorBoard" in line and "http://" in line:
# TensorBoard is running, extract the URL
url = None
for part in line.split():
if part.startswith(("http://", "https://")):
url = part
break
# Open browser if requested and URL found
if open_browser and url:
print(f"Opening TensorBoard in browser: {url}")
webbrowser.open(url)
break
# Return the process for the caller to manage
return process
if __name__ == "__main__":
import argparse
# Parse command line arguments
parser = argparse.ArgumentParser(description="Start TensorBoard for NN training visualization")
parser.add_argument("--logdir", default="NN/models/saved/logs", help="Directory containing TensorBoard logs")
parser.add_argument("--port", type=int, default=6006, help="Port to run TensorBoard on")
parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically")
args = parser.parse_args()
# Start TensorBoard
process = start_tensorboard(args.logdir, args.port, not args.no_browser)
try:
# Keep the script running until Ctrl+C
print("TensorBoard is running. Press Ctrl+C to stop.")
while True:
sleep(1)
except KeyboardInterrupt:
print("Stopping TensorBoard...")
process.terminate()
process.wait()

View File

@ -1,490 +0,0 @@
#!/usr/bin/env python3
"""
Enhanced RL Training Integration - Comprehensive Fix
This script addresses the critical RL training audit issues:
1. MASSIVE INPUT DATA GAP (99.25% Missing) - Implements full 13,400 feature state
2. Disconnected Training Pipeline - Provides proper data flow integration
3. Missing Enhanced State Builder - Connects orchestrator to dashboard
4. Reward Calculation Issues - Ensures enhanced pivot-based rewards
5. Williams Market Structure Integration - Proper feature extraction
6. Real-time Data Integration - Live market data to RL
Usage:
python enhanced_rl_training_integration.py
"""
import os
import sys
import asyncio
import logging
import numpy as np
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Any
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging, get_config
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from core.trading_executor import TradingExecutor
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
from utils.tensorboard_logger import TensorBoardLogger
logger = logging.getLogger(__name__)
class EnhancedRLTrainingIntegrator:
"""
Comprehensive RL Training Integrator
Fixes all audit issues by ensuring proper data flow and feature completeness.
"""
def __init__(self):
"""Initialize the enhanced RL training integrator"""
# Setup logging
setup_logging()
logger.info("=" * 70)
logger.info("ENHANCED RL TRAINING INTEGRATION - COMPREHENSIVE FIX")
logger.info("=" * 70)
# Get configuration
self.config = get_config()
# Initialize core components
self.data_provider = DataProvider()
self.enhanced_orchestrator = None
self.trading_executor = TradingExecutor()
self.dashboard = None
# Training metrics
self.training_stats = {
'total_episodes': 0,
'successful_state_builds': 0,
'enhanced_reward_calculations': 0,
'comprehensive_features_used': 0,
'pivot_features_extracted': 0,
'cob_features_available': 0
}
# Initialize TensorBoard logger
experiment_name = f"enhanced_rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.tb_logger = TensorBoardLogger(
log_dir="runs",
experiment_name=experiment_name,
enabled=True
)
logger.info(f"TensorBoard logging enabled for experiment: {experiment_name}")
logger.info("Enhanced RL Training Integrator initialized")
async def start_integration(self):
"""Start the comprehensive RL training integration"""
try:
logger.info("Starting comprehensive RL training integration...")
# 1. Initialize Enhanced Orchestrator with comprehensive features
await self._initialize_enhanced_orchestrator()
# 2. Create enhanced dashboard with proper connections
await self._create_enhanced_dashboard()
# 3. Verify comprehensive state building
await self._verify_comprehensive_state_building()
# 4. Test enhanced reward calculation
await self._test_enhanced_reward_calculation()
# 5. Validate Williams market structure integration
await self._validate_williams_integration()
# 6. Start live training with comprehensive features
await self._start_live_comprehensive_training()
logger.info("=" * 70)
logger.info("COMPREHENSIVE RL TRAINING INTEGRATION COMPLETE")
logger.info("=" * 70)
self._log_integration_stats()
except Exception as e:
logger.error(f"Error in RL training integration: {e}")
import traceback
logger.error(traceback.format_exc())
async def _initialize_enhanced_orchestrator(self):
"""Initialize enhanced orchestrator with comprehensive RL capabilities"""
try:
logger.info("[STEP 1] Initializing Enhanced Orchestrator...")
# Create enhanced orchestrator with RL training enabled
self.enhanced_orchestrator = EnhancedTradingOrchestrator(
data_provider=self.data_provider,
symbols=['ETH/USDT', 'BTC/USDT'],
enhanced_rl_training=True,
model_registry={} # Will be populated as needed
)
# Start COB integration for real-time market microstructure
await self.enhanced_orchestrator.start_cob_integration()
# Start real-time processing
await self.enhanced_orchestrator.start_realtime_processing()
logger.info("[SUCCESS] Enhanced Orchestrator initialized with:")
logger.info(" - Comprehensive RL state building: ENABLED")
logger.info(" - Enhanced pivot-based rewards: ENABLED")
logger.info(" - COB integration: ENABLED")
logger.info(" - Williams market structure: ENABLED")
logger.info(" - Real-time tick processing: ENABLED")
except Exception as e:
logger.error(f"Error initializing enhanced orchestrator: {e}")
raise
async def _create_enhanced_dashboard(self):
"""Create dashboard with enhanced orchestrator connections"""
try:
logger.info("[STEP 2] Creating Enhanced Dashboard...")
# Create trading dashboard with enhanced orchestrator
self.dashboard = TradingDashboard(
data_provider=self.data_provider,
orchestrator=self.enhanced_orchestrator, # Use enhanced orchestrator
trading_executor=self.trading_executor
)
# Verify enhanced connections
has_comprehensive_state_builder = hasattr(self.dashboard.orchestrator, 'build_comprehensive_rl_state')
has_enhanced_reward_calc = hasattr(self.dashboard.orchestrator, 'calculate_enhanced_pivot_reward')
has_symbol_correlation = hasattr(self.dashboard.orchestrator, '_get_symbol_correlation')
logger.info("[SUCCESS] Enhanced Dashboard created with:")
logger.info(f" - Comprehensive state builder: {'AVAILABLE' if has_comprehensive_state_builder else 'MISSING'}")
logger.info(f" - Enhanced reward calculation: {'AVAILABLE' if has_enhanced_reward_calc else 'MISSING'}")
logger.info(f" - Symbol correlation analysis: {'AVAILABLE' if has_symbol_correlation else 'MISSING'}")
if not all([has_comprehensive_state_builder, has_enhanced_reward_calc, has_symbol_correlation]):
logger.warning("Some enhanced features are missing - this will cause fallbacks to basic training")
else:
logger.info(" - ALL ENHANCED FEATURES AVAILABLE!")
except Exception as e:
logger.error(f"Error creating enhanced dashboard: {e}")
raise
async def _verify_comprehensive_state_building(self):
"""Verify that comprehensive RL state building works correctly"""
try:
logger.info("[STEP 3] Verifying Comprehensive State Building...")
# Test comprehensive state building for ETH
eth_state = self.enhanced_orchestrator.build_comprehensive_rl_state('ETH/USDT')
if eth_state is not None:
logger.info(f"[SUCCESS] ETH comprehensive state built: {len(eth_state)} features")
# Verify feature count
if len(eth_state) == 13400:
logger.info(" - PERFECT: Exactly 13,400 features as required!")
self.training_stats['comprehensive_features_used'] += 1
else:
logger.warning(f" - MISMATCH: Expected 13,400 features, got {len(eth_state)}")
# Analyze feature distribution
self._analyze_state_features(eth_state)
self.training_stats['successful_state_builds'] += 1
else:
logger.error(" - FAILED: Comprehensive state building returned None")
# Test for BTC reference
btc_state = self.enhanced_orchestrator.build_comprehensive_rl_state('BTC/USDT')
if btc_state is not None:
logger.info(f"[SUCCESS] BTC reference state built: {len(btc_state)} features")
self.training_stats['successful_state_builds'] += 1
except Exception as e:
logger.error(f"Error verifying comprehensive state building: {e}")
def _analyze_state_features(self, state_vector: np.ndarray):
"""Analyze the comprehensive state feature distribution"""
try:
# Calculate feature statistics
non_zero_features = np.count_nonzero(state_vector)
zero_features = len(state_vector) - non_zero_features
feature_mean = np.mean(state_vector)
feature_std = np.std(state_vector)
feature_min = np.min(state_vector)
feature_max = np.max(state_vector)
logger.info(" - Feature Analysis:")
logger.info(f" * Non-zero features: {non_zero_features:,} ({non_zero_features/len(state_vector)*100:.1f}%)")
logger.info(f" * Zero features: {zero_features:,} ({zero_features/len(state_vector)*100:.1f}%)")
logger.info(f" * Mean: {feature_mean:.6f}")
logger.info(f" * Std: {feature_std:.6f}")
logger.info(f" * Range: [{feature_min:.6f}, {feature_max:.6f}]")
# Log feature statistics to TensorBoard
step = self.training_stats['total_episodes']
self.tb_logger.log_scalars('Features/Distribution', {
'non_zero_percentage': non_zero_features/len(state_vector)*100,
'mean': feature_mean,
'std': feature_std,
'min': feature_min,
'max': feature_max
}, step)
# Log feature histogram to TensorBoard
self.tb_logger.log_histogram('Features/Values', state_vector, step)
# Check if features are properly distributed
if non_zero_features > len(state_vector) * 0.1: # At least 10% non-zero
logger.info(" * GOOD: Features are well distributed")
else:
logger.warning(" * WARNING: Too many zero features - data may be incomplete")
except Exception as e:
logger.warning(f"Error analyzing state features: {e}")
async def _test_enhanced_reward_calculation(self):
"""Test enhanced pivot-based reward calculation"""
try:
logger.info("[STEP 4] Testing Enhanced Reward Calculation...")
# Create mock trade data for testing
trade_decision = {
'action': 'BUY',
'confidence': 0.75,
'price': 2500.0,
'timestamp': datetime.now()
}
trade_outcome = {
'net_pnl': 50.0,
'exit_price': 2550.0,
'duration': timedelta(minutes=15)
}
# Get market data for reward calculation
market_data = {
'volatility': 0.03,
'order_flow_direction': 'bullish',
'order_flow_strength': 0.8
}
# Test enhanced reward calculation
if hasattr(self.enhanced_orchestrator, 'calculate_enhanced_pivot_reward'):
enhanced_reward = self.enhanced_orchestrator.calculate_enhanced_pivot_reward(
trade_decision, market_data, trade_outcome
)
logger.info(f"[SUCCESS] Enhanced reward calculated: {enhanced_reward:.3f}")
logger.info(" - Enhanced pivot-based reward system: WORKING")
self.training_stats['enhanced_reward_calculations'] += 1
# Log reward metrics to TensorBoard
step = self.training_stats['enhanced_reward_calculations']
self.tb_logger.log_scalar('Rewards/Enhanced', enhanced_reward, step)
# Log reward components to TensorBoard
self.tb_logger.log_scalars('Rewards/Components', {
'pnl_component': trade_outcome['net_pnl'],
'confidence': trade_decision['confidence'],
'volatility': market_data['volatility'],
'order_flow_strength': market_data['order_flow_strength']
}, step)
else:
logger.error(" - FAILED: Enhanced reward calculation method not available")
except Exception as e:
logger.error(f"Error testing enhanced reward calculation: {e}")
async def _validate_williams_integration(self):
"""Validate Williams market structure integration"""
try:
logger.info("[STEP 5] Validating Williams Market Structure Integration...")
# Test Williams pivot feature extraction
try:
from training.williams_market_structure import extract_pivot_features, analyze_pivot_context
# Get test market data
df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=100)
if df is not None and not df.empty:
# Test pivot feature extraction
pivot_features = extract_pivot_features(df)
if pivot_features is not None:
logger.info(f"[SUCCESS] Williams pivot features extracted: {len(pivot_features)} features")
self.training_stats['pivot_features_extracted'] += 1
# Test pivot context analysis
market_data = {'ohlcv_data': df}
pivot_context = analyze_pivot_context(
market_data, datetime.now(), 'BUY'
)
if pivot_context is not None:
logger.info("[SUCCESS] Williams pivot context analysis: WORKING")
logger.info(f" - Near pivot: {pivot_context.get('near_pivot', False)}")
logger.info(f" - Pivot strength: {pivot_context.get('pivot_strength', 0):.3f}")
else:
logger.warning(" - Williams pivot context analysis returned None")
else:
logger.warning(" - Williams pivot feature extraction returned None")
else:
logger.warning(" - No market data available for Williams testing")
except ImportError:
logger.error(" - Williams market structure module not available")
except Exception as e:
logger.error(f" - Error in Williams integration: {e}")
except Exception as e:
logger.error(f"Error validating Williams integration: {e}")
async def _start_live_comprehensive_training(self):
"""Start live training with comprehensive feature integration"""
try:
logger.info("[STEP 6] Starting Live Comprehensive Training...")
# Run a few training iterations to verify integration
for iteration in range(5):
logger.info(f"Training iteration {iteration + 1}/5")
# Make coordinated decisions using enhanced orchestrator
decisions = await self.enhanced_orchestrator.make_coordinated_decisions()
# Track iteration metrics for TensorBoard
iteration_metrics = {
'decisions_count': len(decisions),
'confidence_avg': 0.0,
'state_size_avg': 0.0,
'successful_states': 0
}
# Process each decision
for symbol, decision in decisions.items():
if decision:
logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
# Track confidence for TensorBoard
iteration_metrics['confidence_avg'] += decision.confidence
# Build comprehensive state for this decision
comprehensive_state = self.enhanced_orchestrator.build_comprehensive_rl_state(symbol)
if comprehensive_state is not None:
state_size = len(comprehensive_state)
logger.info(f" - Comprehensive state: {state_size} features")
self.training_stats['total_episodes'] += 1
# Track state size for TensorBoard
iteration_metrics['state_size_avg'] += state_size
iteration_metrics['successful_states'] += 1
# Log individual state metrics to TensorBoard
self.tb_logger.log_state_metrics(
symbol=symbol,
state_info={
'size': state_size,
'quality': 1.0 if state_size == 13400 else 0.8,
'feature_counts': {
'total': state_size,
'non_zero': np.count_nonzero(comprehensive_state)
}
},
step=self.training_stats['total_episodes']
)
else:
logger.warning(f" - Failed to build comprehensive state for {symbol}")
# Calculate averages for TensorBoard
if decisions:
iteration_metrics['confidence_avg'] /= len(decisions)
if iteration_metrics['successful_states'] > 0:
iteration_metrics['state_size_avg'] /= iteration_metrics['successful_states']
# Log iteration metrics to TensorBoard
self.tb_logger.log_scalars('Training/Iteration', {
'iteration': iteration + 1,
'decisions_count': iteration_metrics['decisions_count'],
'confidence_avg': iteration_metrics['confidence_avg'],
'state_size_avg': iteration_metrics['state_size_avg'],
'successful_states': iteration_metrics['successful_states']
}, iteration + 1)
# Wait between iterations
await asyncio.sleep(2)
logger.info("[SUCCESS] Live comprehensive training demonstration complete")
except Exception as e:
logger.error(f"Error in live comprehensive training: {e}")
def _log_integration_stats(self):
"""Log comprehensive integration statistics"""
logger.info("INTEGRATION STATISTICS:")
logger.info(f" - Total training episodes: {self.training_stats['total_episodes']}")
logger.info(f" - Successful state builds: {self.training_stats['successful_state_builds']}")
logger.info(f" - Enhanced reward calculations: {self.training_stats['enhanced_reward_calculations']}")
logger.info(f" - Comprehensive features used: {self.training_stats['comprehensive_features_used']}")
logger.info(f" - Pivot features extracted: {self.training_stats['pivot_features_extracted']}")
# Calculate success rates
state_success_rate = 0
if self.training_stats['total_episodes'] > 0:
state_success_rate = self.training_stats['successful_state_builds'] / self.training_stats['total_episodes'] * 100
logger.info(f" - State building success rate: {state_success_rate:.1f}%")
# Log final statistics to TensorBoard
self.tb_logger.log_scalars('Integration/Statistics', {
'total_episodes': self.training_stats['total_episodes'],
'successful_state_builds': self.training_stats['successful_state_builds'],
'enhanced_reward_calculations': self.training_stats['enhanced_reward_calculations'],
'comprehensive_features_used': self.training_stats['comprehensive_features_used'],
'pivot_features_extracted': self.training_stats['pivot_features_extracted'],
'state_success_rate': state_success_rate
}, 0) # Use step 0 for final summary stats
# Integration status
if self.training_stats['comprehensive_features_used'] > 0:
logger.info("STATUS: COMPREHENSIVE RL TRAINING INTEGRATION SUCCESSFUL! ✅")
logger.info("The system is now using the full 13,400 feature comprehensive state.")
# Log success status to TensorBoard
self.tb_logger.log_scalar('Integration/Success', 1.0, 0)
else:
logger.warning("STATUS: Integration partially successful - some fallbacks may occur")
# Log partial success status to TensorBoard
self.tb_logger.log_scalar('Integration/Success', 0.5, 0)
async def main():
"""Main entry point"""
try:
# Create and run the enhanced RL training integrator
integrator = EnhancedRLTrainingIntegrator()
await integrator.start_integration()
logger.info("Enhanced RL training integration completed successfully!")
return 0
except KeyboardInterrupt:
logger.info("Integration interrupted by user")
return 0
except Exception as e:
logger.error(f"Fatal error in integration: {e}")
import traceback
logger.error(traceback.format_exc())
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)

View File

@ -1,148 +0,0 @@
#!/usr/bin/env python3
"""
Example: Using the Checkpoint Management System
"""
import logging
import torch
import torch.nn as nn
import numpy as np
from datetime import datetime
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint, get_checkpoint_manager
from utils.training_integration import get_training_integration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ExampleCNN(nn.Module):
def __init__(self, input_channels=5, num_classes=3):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
return self.fc(x)
def example_cnn_training():
logger.info("=== CNN Training Example ===")
model = ExampleCNN()
training_integration = get_training_integration()
for epoch in range(5): # Simulate 5 epochs
# Simulate training metrics
train_loss = 2.0 - (epoch * 0.15) + np.random.normal(0, 0.1)
train_acc = 0.3 + (epoch * 0.06) + np.random.normal(0, 0.02)
val_loss = train_loss + np.random.normal(0, 0.05)
val_acc = train_acc - 0.05 + np.random.normal(0, 0.02)
# Clamp values to realistic ranges
train_acc = max(0.0, min(1.0, train_acc))
val_acc = max(0.0, min(1.0, val_acc))
train_loss = max(0.1, train_loss)
val_loss = max(0.1, val_loss)
logger.info(f"Epoch {epoch+1}: train_acc={train_acc:.3f}, val_acc={val_acc:.3f}")
# Save checkpoint
saved = training_integration.save_cnn_checkpoint(
cnn_model=model,
model_name="example_cnn",
epoch=epoch + 1,
train_accuracy=train_acc,
val_accuracy=val_acc,
train_loss=train_loss,
val_loss=val_loss,
training_time_hours=0.1 * (epoch + 1)
)
if saved:
logger.info(f" Checkpoint saved for epoch {epoch+1}")
else:
logger.info(f" Checkpoint not saved (performance not improved)")
# Load the best checkpoint
logger.info("\\nLoading best checkpoint...")
best_result = load_best_checkpoint("example_cnn")
if best_result:
file_path, metadata = best_result
logger.info(f"Best checkpoint: {metadata.checkpoint_id}")
logger.info(f"Performance score: {metadata.performance_score:.4f}")
def example_manual_checkpoint():
logger.info("\\n=== Manual Checkpoint Example ===")
model = nn.Linear(10, 3)
performance_metrics = {
'accuracy': 0.85,
'val_accuracy': 0.82,
'loss': 0.45,
'val_loss': 0.48
}
training_metadata = {
'epoch': 25,
'training_time_hours': 2.5,
'total_parameters': sum(p.numel() for p in model.parameters())
}
logger.info("Saving checkpoint manually...")
metadata = save_checkpoint(
model=model,
model_name="example_manual",
model_type="cnn",
performance_metrics=performance_metrics,
training_metadata=training_metadata,
force_save=True
)
if metadata:
logger.info(f" Manual checkpoint saved: {metadata.checkpoint_id}")
logger.info(f" Performance score: {metadata.performance_score:.4f}")
def show_checkpoint_stats():
logger.info("\\n=== Checkpoint Statistics ===")
checkpoint_manager = get_checkpoint_manager()
stats = checkpoint_manager.get_checkpoint_stats()
logger.info(f"Total models: {stats['total_models']}")
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
for model_name, model_stats in stats['models'].items():
logger.info(f"\\n{model_name}:")
logger.info(f" Checkpoints: {model_stats['checkpoint_count']}")
logger.info(f" Size: {model_stats['total_size_mb']:.2f} MB")
logger.info(f" Best performance: {model_stats['best_performance']:.4f}")
def main():
logger.info(" Checkpoint Management System Examples")
logger.info("=" * 50)
try:
example_cnn_training()
example_manual_checkpoint()
show_checkpoint_stats()
logger.info("\\n All examples completed successfully!")
logger.info("\\nTo use in your training:")
logger.info("1. Import: from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint")
logger.info("2. Or use: from utils.training_integration import get_training_integration")
logger.info("3. Save checkpoints during training with performance metrics")
logger.info("4. Load best checkpoints for inference or continued training")
except Exception as e:
logger.error(f"Error in examples: {e}")
raise
if __name__ == "__main__":
main()

View File

@ -1,517 +0,0 @@
#!/usr/bin/env python3
"""
Comprehensive Checkpoint Management Integration
This script demonstrates how to integrate the checkpoint management system
across all training pipelines in the gogo2 project.
Features:
- DQN Agent training with automatic checkpointing
- CNN Model training with checkpoint management
- ExtremaTrainer with checkpoint persistence
- NegativeCaseTrainer with checkpoint integration
- Unified training orchestration with checkpoint coordination
"""
import asyncio
import logging
import time
import signal
import sys
import numpy as np
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, List
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('logs/checkpoint_integration.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Import checkpoint management
from utils.checkpoint_manager import get_checkpoint_manager, get_checkpoint_stats
from utils.training_integration import get_training_integration
# Import training components
from NN.models.dqn_agent import DQNAgent
from NN.models.standardized_cnn import StandardizedCNN
from core.extrema_trainer import ExtremaTrainer
from core.negative_case_trainer import NegativeCaseTrainer
from core.data_provider import DataProvider
from core.config import get_config
class CheckpointIntegratedTrainingSystem:
"""Unified training system with comprehensive checkpoint management"""
def __init__(self):
"""Initialize the checkpoint-integrated training system"""
self.config = get_config()
self.running = False
# Checkpoint management
self.checkpoint_manager = get_checkpoint_manager()
self.training_integration = get_training_integration()
# Data provider
self.data_provider = DataProvider(
symbols=['ETH/USDT', 'BTC/USDT'],
timeframes=['1s', '1m', '1h', '1d']
)
# Training components with checkpoint management
self.dqn_agent = None
self.cnn_trainer = None
self.extrema_trainer = None
self.negative_case_trainer = None
# Training statistics
self.training_stats = {
'start_time': None,
'total_training_sessions': 0,
'checkpoints_saved': 0,
'models_loaded': 0,
'best_performances': {}
}
logger.info("Checkpoint-Integrated Training System initialized")
async def initialize_components(self):
"""Initialize all training components with checkpoint management"""
try:
logger.info("Initializing training components with checkpoint management...")
# Initialize data provider
await self.data_provider.start_real_time_streaming()
logger.info("Data provider streaming started")
# Initialize DQN Agent with checkpoint management
logger.info("Initializing DQN Agent with checkpoints...")
self.dqn_agent = DQNAgent(
state_shape=(100,), # Example state shape
n_actions=3,
model_name="integrated_dqn_agent",
enable_checkpoints=True
)
logger.info("✅ DQN Agent initialized with checkpoint management")
# Initialize StandardizedCNN Model with checkpoint management
logger.info("Initializing StandardizedCNN Model with checkpoints...")
self.cnn_model = StandardizedCNN(model_name="integrated_cnn_model")
logger.info("✅ StandardizedCNN Model initialized with checkpoint management")
# Initialize ExtremaTrainer with checkpoint management
logger.info("Initializing ExtremaTrainer with checkpoints...")
self.extrema_trainer = ExtremaTrainer(
data_provider=self.data_provider,
symbols=['ETH/USDT', 'BTC/USDT'],
model_name="integrated_extrema_trainer",
enable_checkpoints=True
)
await self.extrema_trainer.initialize_context_data()
logger.info("✅ ExtremaTrainer initialized with checkpoint management")
# Initialize NegativeCaseTrainer with checkpoint management
logger.info("Initializing NegativeCaseTrainer with checkpoints...")
self.negative_case_trainer = NegativeCaseTrainer(
model_name="integrated_negative_case_trainer",
enable_checkpoints=True
)
logger.info("✅ NegativeCaseTrainer initialized with checkpoint management")
# Load existing checkpoints for all components
self.training_stats['models_loaded'] = await self._load_all_checkpoints()
logger.info("All training components initialized successfully")
except Exception as e:
logger.error(f"Error initializing components: {e}")
raise
async def _load_all_checkpoints(self) -> int:
"""Load checkpoints for all training components"""
loaded_count = 0
try:
# DQN Agent checkpoint loading is handled in __init__
if hasattr(self.dqn_agent, 'episode_count') and self.dqn_agent.episode_count > 0:
loaded_count += 1
logger.info(f"DQN Agent resumed from episode {self.dqn_agent.episode_count}")
# CNN Trainer checkpoint loading is handled in __init__
if hasattr(self.cnn_trainer, 'epoch_count') and self.cnn_trainer.epoch_count > 0:
loaded_count += 1
logger.info(f"CNN Trainer resumed from epoch {self.cnn_trainer.epoch_count}")
# ExtremaTrainer checkpoint loading is handled in __init__
if hasattr(self.extrema_trainer, 'training_session_count') and self.extrema_trainer.training_session_count > 0:
loaded_count += 1
logger.info(f"ExtremaTrainer resumed from session {self.extrema_trainer.training_session_count}")
# NegativeCaseTrainer checkpoint loading is handled in __init__
if hasattr(self.negative_case_trainer, 'training_session_count') and self.negative_case_trainer.training_session_count > 0:
loaded_count += 1
logger.info(f"NegativeCaseTrainer resumed from session {self.negative_case_trainer.training_session_count}")
return loaded_count
except Exception as e:
logger.error(f"Error loading checkpoints: {e}")
return 0
async def run_integrated_training_loop(self):
"""Run the integrated training loop with checkpoint coordination"""
logger.info("Starting integrated training loop with checkpoint management...")
self.running = True
self.training_stats['start_time'] = datetime.now()
training_cycle = 0
try:
while self.running:
training_cycle += 1
cycle_start = time.time()
logger.info(f"=== Training Cycle {training_cycle} ===")
# DQN Training
dqn_results = await self._train_dqn_agent()
# CNN Training
cnn_results = await self._train_cnn_model()
# Extrema Detection Training
extrema_results = await self._train_extrema_detector()
# Negative Case Training (runs in background)
negative_results = await self._process_negative_cases()
# Coordinate checkpoint saving
await self._coordinate_checkpoint_saving(
dqn_results, cnn_results, extrema_results, negative_results
)
# Update statistics
self.training_stats['total_training_sessions'] += 1
# Log cycle summary
cycle_duration = time.time() - cycle_start
logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
# Wait before next cycle
await asyncio.sleep(60) # 1-minute cycles
except KeyboardInterrupt:
logger.info("Training interrupted by user")
except Exception as e:
logger.error(f"Error in training loop: {e}")
finally:
await self.shutdown()
async def _train_dqn_agent(self) -> Dict[str, Any]:
"""Train DQN agent with automatic checkpointing"""
try:
if not self.dqn_agent:
return {'status': 'skipped', 'reason': 'no_agent'}
# Simulate DQN training episode
episode_reward = 0.0
# Add some training experiences (simulate real training)
for _ in range(10): # Simulate 10 training steps
state = np.random.randn(100).astype(np.float32)
action = np.random.randint(0, 3)
reward = np.random.randn() * 0.1
next_state = np.random.randn(100).astype(np.float32)
done = np.random.random() < 0.1
self.dqn_agent.remember(state, action, reward, next_state, done)
episode_reward += reward
# Train if enough experiences
loss = 0.0
if len(self.dqn_agent.memory) >= self.dqn_agent.batch_size:
loss = self.dqn_agent.replay()
# Save checkpoint (automatic based on performance)
checkpoint_saved = self.dqn_agent.save_checkpoint(episode_reward)
if checkpoint_saved:
self.training_stats['checkpoints_saved'] += 1
return {
'status': 'completed',
'episode_reward': episode_reward,
'loss': loss,
'checkpoint_saved': checkpoint_saved,
'episode': self.dqn_agent.episode_count
}
except Exception as e:
logger.error(f"Error training DQN agent: {e}")
return {'status': 'error', 'error': str(e)}
async def _train_cnn_model(self) -> Dict[str, Any]:
"""Train CNN model with automatic checkpointing"""
try:
if not self.cnn_trainer:
return {'status': 'skipped', 'reason': 'no_trainer'}
# Simulate CNN training step
import torch
import numpy as np
batch_size = 32
input_size = 60
feature_dim = 50
# Generate synthetic training data
x = torch.randn(batch_size, input_size, feature_dim)
y = torch.randint(0, 3, (batch_size,))
# Training step
results = self.cnn_trainer.train_step(x, y)
# Simulate validation
val_x = torch.randn(16, input_size, feature_dim)
val_y = torch.randint(0, 3, (16,))
val_results = self.cnn_trainer.train_step(val_x, val_y)
# Save checkpoint (automatic based on performance)
checkpoint_saved = self.cnn_trainer.save_checkpoint(
train_accuracy=results.get('accuracy', 0.5),
val_accuracy=val_results.get('accuracy', 0.5),
train_loss=results.get('total_loss', 1.0),
val_loss=val_results.get('total_loss', 1.0)
)
if checkpoint_saved:
self.training_stats['checkpoints_saved'] += 1
return {
'status': 'completed',
'train_accuracy': results.get('accuracy', 0.5),
'val_accuracy': val_results.get('accuracy', 0.5),
'train_loss': results.get('total_loss', 1.0),
'val_loss': val_results.get('total_loss', 1.0),
'checkpoint_saved': checkpoint_saved,
'epoch': self.cnn_trainer.epoch_count
}
except Exception as e:
logger.error(f"Error training CNN model: {e}")
return {'status': 'error', 'error': str(e)}
async def _train_extrema_detector(self) -> Dict[str, Any]:
"""Train extrema detector with automatic checkpointing"""
try:
if not self.extrema_trainer:
return {'status': 'skipped', 'reason': 'no_trainer'}
# Update context data and detect extrema
update_results = self.extrema_trainer.update_context_data()
# Get training data
extrema_data = self.extrema_trainer.get_extrema_training_data(count=10)
# Simulate training accuracy improvement
if extrema_data:
self.extrema_trainer.training_stats['total_extrema_detected'] += len(extrema_data)
self.extrema_trainer.training_stats['successful_predictions'] += len(extrema_data) // 2
self.extrema_trainer.training_stats['failed_predictions'] += len(extrema_data) // 2
# Save checkpoint (automatic based on performance)
checkpoint_saved = self.extrema_trainer.save_checkpoint()
if checkpoint_saved:
self.training_stats['checkpoints_saved'] += 1
return {
'status': 'completed',
'extrema_detected': len(extrema_data),
'context_updates': sum(1 for success in update_results.values() if success),
'checkpoint_saved': checkpoint_saved,
'session': self.extrema_trainer.training_session_count
}
except Exception as e:
logger.error(f"Error training extrema detector: {e}")
return {'status': 'error', 'error': str(e)}
async def _process_negative_cases(self) -> Dict[str, Any]:
"""Process negative cases with automatic checkpointing"""
try:
if not self.negative_case_trainer:
return {'status': 'skipped', 'reason': 'no_trainer'}
# Simulate adding a negative case
if np.random.random() < 0.1: # 10% chance of negative case
trade_info = {
'symbol': 'ETH/USDT',
'action': 'BUY',
'price': 2000.0,
'pnl': -50.0, # Loss
'value': 1000.0,
'confidence': 0.7,
'timestamp': datetime.now()
}
market_data = {
'exit_price': 1950.0,
'state_before': {},
'state_after': {},
'tick_data': [],
'technical_indicators': {}
}
case_id = self.negative_case_trainer.add_losing_trade(trade_info, market_data)
# Simulate loss improvement
loss_improvement = np.random.random() * 0.1
# Save checkpoint (automatic based on performance)
checkpoint_saved = self.negative_case_trainer.save_checkpoint(loss_improvement)
if checkpoint_saved:
self.training_stats['checkpoints_saved'] += 1
return {
'status': 'completed',
'case_added': case_id,
'loss_improvement': loss_improvement,
'checkpoint_saved': checkpoint_saved,
'session': self.negative_case_trainer.training_session_count
}
else:
return {'status': 'no_cases'}
except Exception as e:
logger.error(f"Error processing negative cases: {e}")
return {'status': 'error', 'error': str(e)}
async def _coordinate_checkpoint_saving(self, dqn_results: Dict, cnn_results: Dict,
extrema_results: Dict, negative_results: Dict):
"""Coordinate checkpoint saving across all components"""
try:
# Count successful checkpoints
checkpoints_saved = sum([
dqn_results.get('checkpoint_saved', False),
cnn_results.get('checkpoint_saved', False),
extrema_results.get('checkpoint_saved', False),
negative_results.get('checkpoint_saved', False)
])
if checkpoints_saved > 0:
logger.info(f"Saved {checkpoints_saved} checkpoints this cycle")
# Update best performances
if 'episode_reward' in dqn_results:
current_best = self.training_stats['best_performances'].get('dqn_reward', float('-inf'))
if dqn_results['episode_reward'] > current_best:
self.training_stats['best_performances']['dqn_reward'] = dqn_results['episode_reward']
if 'val_accuracy' in cnn_results:
current_best = self.training_stats['best_performances'].get('cnn_accuracy', 0.0)
if cnn_results['val_accuracy'] > current_best:
self.training_stats['best_performances']['cnn_accuracy'] = cnn_results['val_accuracy']
# Log checkpoint statistics every 10 cycles
if self.training_stats['total_training_sessions'] % 10 == 0:
await self._log_checkpoint_statistics()
except Exception as e:
logger.error(f"Error coordinating checkpoint saving: {e}")
async def _log_checkpoint_statistics(self):
"""Log comprehensive checkpoint statistics"""
try:
stats = get_checkpoint_stats()
logger.info("=== Checkpoint Statistics ===")
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
logger.info(f"Models managed: {len(stats['models'])}")
for model_name, model_stats in stats['models'].items():
logger.info(f" {model_name}: {model_stats['checkpoint_count']} checkpoints, "
f"{model_stats['total_size_mb']:.2f} MB, "
f"best: {model_stats['best_performance']:.4f}")
logger.info(f"Training sessions: {self.training_stats['total_training_sessions']}")
logger.info(f"Checkpoints saved: {self.training_stats['checkpoints_saved']}")
logger.info(f"Best performances: {self.training_stats['best_performances']}")
except Exception as e:
logger.error(f"Error logging checkpoint statistics: {e}")
async def shutdown(self):
"""Shutdown the training system and save final checkpoints"""
logger.info("Shutting down checkpoint-integrated training system...")
self.running = False
try:
# Force save checkpoints for all components
if self.dqn_agent:
self.dqn_agent.save_checkpoint(0.0, force_save=True)
if self.cnn_trainer:
self.cnn_trainer.save_checkpoint(0.0, 0.0, 0.0, 0.0, force_save=True)
if self.extrema_trainer:
self.extrema_trainer.save_checkpoint(force_save=True)
if self.negative_case_trainer:
self.negative_case_trainer.save_checkpoint(force_save=True)
# Final statistics
await self._log_checkpoint_statistics()
logger.info("Checkpoint-integrated training system shutdown complete")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
async def main():
"""Main function to run the checkpoint-integrated training system"""
logger.info("🚀 Starting Checkpoint-Integrated Training System")
# Create and initialize the training system
training_system = CheckpointIntegratedTrainingSystem()
# Setup signal handlers for graceful shutdown
def signal_handler(signum, frame):
logger.info("Received shutdown signal")
asyncio.create_task(training_system.shutdown())
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
# Initialize components
await training_system.initialize_components()
# Run the integrated training loop
await training_system.run_integrated_training_loop()
except Exception as e:
logger.error(f"Error in main: {e}")
raise
finally:
await training_system.shutdown()
logger.info("✅ Checkpoint management integration complete!")
logger.info("All training pipelines now support automatic checkpointing")
if __name__ == "__main__":
# Ensure logs directory exists
Path("logs").mkdir(exist_ok=True)
# Run the checkpoint-integrated training system
asyncio.run(main())

View File

@ -1,13 +0,0 @@
"""
Neural Network Utilities
======================
This package contains utility functions and classes used in the neural network trading system:
- Data Interface: Connects to realtime trading data and processes it for the neural network models
"""
from .data_interface import DataInterface
from .trading_env import TradingEnvironment
from .signal_interpreter import SignalInterpreter
__all__ = ['DataInterface', 'TradingEnvironment', 'SignalInterpreter']

View File

@ -1,123 +0,0 @@
"""
Enhanced Data Interface with additional NN trading parameters
"""
from typing import List, Optional, Tuple
import numpy as np
import pandas as pd
from datetime import datetime
from .data_interface import DataInterface
class MultiDataInterface(DataInterface):
"""
Enhanced data interface that supports window_size and output_size parameters
for neural network trading models.
"""
def __init__(self, symbol: str,
timeframes: List[str],
window_size: int = 20,
output_size: int = 3,
data_dir: str = "NN/data"):
"""
Initialize with window_size and output_size for NN predictions.
"""
super().__init__(symbol, timeframes, data_dir)
self.window_size = window_size
self.output_size = output_size
self.scalers = {} # Store scalers for each timeframe
self.min_window_threshold = 100 # Minimum candles needed for training
def get_feature_count(self) -> int:
"""
Get number of features (OHLCV) for NN input.
"""
return 5 # open, high, low, close, volume
def prepare_training_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Prepare training data with windowed sequences"""
# Get historical data for primary timeframe
primary_tf = self.timeframes[0]
df = self.get_historical_data(timeframe=primary_tf,
n_candles=self.min_window_threshold + 1000)
if df is None or len(df) < self.min_window_threshold:
raise ValueError(f"Insufficient data for training. Need at least {self.min_window_threshold} candles")
# Prepare OHLCV sequences
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
# Create sequences and labels
X = []
y = []
for i in range(len(ohlcv) - self.window_size - self.output_size):
# Input sequence
seq = ohlcv[i:i+self.window_size]
X.append(seq)
# Output target (price movement direction)
close_prices = ohlcv[i+self.window_size:i+self.window_size+self.output_size, 3] # Close prices
price_changes = np.diff(close_prices)
if self.output_size == 1:
# Binary classification (up/down)
label = 1 if price_changes[0] > 0 else 0
elif self.output_size == 3:
# 3-class classification (buy/hold/sell)
if price_changes[0] > 0.002: # Significant rise
label = 0 # Buy
elif price_changes[0] < -0.002: # Significant drop
label = 2 # Sell
else:
label = 1 # Hold
else:
raise ValueError(f"Unsupported output_size: {self.output_size}")
y.append(label)
# Convert to numpy arrays
X = np.array(X)
y = np.array(y)
# Split into train/validation (80/20)
split_idx = int(0.8 * len(X))
X_train, y_train = X[:split_idx], y[:split_idx]
X_val, y_val = X[split_idx:], y[split_idx:]
return X_train, y_train, X_val, y_val
def prepare_prediction_data(self) -> np.ndarray:
"""Prepare most recent window for predictions"""
primary_tf = self.timeframes[0]
df = self.get_historical_data(timeframe=primary_tf,
n_candles=self.window_size,
use_cache=False)
if df is None or len(df) < self.window_size:
raise ValueError(f"Need at least {self.window_size} candles for prediction")
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values[-self.window_size:]
return np.array([ohlcv]) # Add batch dimension
def process_predictions(self, predictions: np.ndarray):
"""Convert prediction probabilities to trading signals"""
signals = []
for pred in predictions:
if self.output_size == 1:
signal = "BUY" if pred[0] > 0.5 else "SELL"
confidence = np.abs(pred[0] - 0.5) * 2 # Convert to 0-1 scale
elif self.output_size == 3:
action_idx = np.argmax(pred)
signal = ["BUY", "HOLD", "SELL"][action_idx]
confidence = pred[action_idx]
else:
signal = "HOLD"
confidence = 0.0
signals.append({
'action': signal,
'confidence': confidence,
'timestamp': datetime.now().isoformat()
})
return signals

View File

@ -1,364 +0,0 @@
"""
Realtime Analyzer for Neural Network Trading System
This module implements real-time analysis of market data using trained neural network models.
"""
import logging
import time
import numpy as np
from threading import Thread
from queue import Queue
from datetime import datetime
import asyncio
import websockets
import json
import os
import pandas as pd
from collections import deque
logger = logging.getLogger(__name__)
class RealtimeAnalyzer:
"""
Handles real-time analysis of market data using trained neural network models.
Features:
- Connects to real-time data sources (websockets)
- Processes tick data into multiple timeframes (1s, 1m, 1h, 1d)
- Uses trained models to analyze all timeframes
- Generates trading signals
- Manages risk and position sizing
- Logs all trading decisions
"""
def __init__(self, data_interface, model, symbol="BTC/USDT", timeframes=None):
"""
Initialize the realtime analyzer.
Args:
data_interface (DataInterface): Preconfigured data interface
model: Trained neural network model
symbol (str): Trading pair symbol
timeframes (list): List of timeframes to monitor (default: ['1s', '1m', '1h', '1d'])
"""
self.data_interface = data_interface
self.model = model
self.symbol = symbol
self.timeframes = timeframes or ['1s', '1m', '1h', '1d']
self.running = False
self.data_queue = Queue()
self.prediction_interval = 10 # Seconds between predictions
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
self.ws = None
self.tick_storage = deque(maxlen=10000) # Store up to 10,000 ticks
self.candle_cache = {
'1s': deque(maxlen=5000),
'1m': deque(maxlen=5000),
'1h': deque(maxlen=5000),
'1d': deque(maxlen=5000)
}
logger.info(f"RealtimeAnalyzer initialized for {symbol} with timeframes: {self.timeframes}")
def start(self):
"""Start the realtime analysis process."""
if self.running:
logger.warning("Realtime analyzer already running")
return
self.running = True
# Start WebSocket connection thread
self.ws_thread = Thread(target=self._run_websocket, daemon=True)
self.ws_thread.start()
# Start data processing thread
self.processing_thread = Thread(target=self._process_data, daemon=True)
self.processing_thread.start()
# Start analysis thread
self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
self.analysis_thread.start()
logger.info("Realtime analysis started")
def stop(self):
"""Stop the realtime analysis process."""
self.running = False
if self.ws:
asyncio.run(self.ws.close())
if hasattr(self, 'ws_thread'):
self.ws_thread.join(timeout=1)
if hasattr(self, 'processing_thread'):
self.processing_thread.join(timeout=1)
if hasattr(self, 'analysis_thread'):
self.analysis_thread.join(timeout=1)
logger.info("Realtime analysis stopped")
def _run_websocket(self):
"""Thread function for running WebSocket connection."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._connect_websocket())
async def _connect_websocket(self):
"""Connect to WebSocket and receive data."""
while self.running:
try:
logger.info(f"Connecting to WebSocket: {self.ws_url}")
async with websockets.connect(self.ws_url) as ws:
self.ws = ws
logger.info("WebSocket connected")
while self.running:
try:
message = await ws.recv()
data = json.loads(message)
if 'e' in data and data['e'] == 'trade':
tick = {
'timestamp': data['T'],
'price': float(data['p']),
'volume': float(data['q']),
'symbol': self.symbol
}
self.tick_storage.append(tick)
self.data_queue.put(tick)
except websockets.exceptions.ConnectionClosed:
logger.warning("WebSocket connection closed")
break
except Exception as e:
logger.error(f"Error receiving WebSocket message: {str(e)}")
time.sleep(1)
except Exception as e:
logger.error(f"WebSocket connection error: {str(e)}")
time.sleep(5) # Wait before reconnecting
def _process_data(self):
"""Process incoming tick data into candles for all timeframes."""
logger.info("Starting data processing thread")
while self.running:
try:
# Process any new ticks
while not self.data_queue.empty():
tick = self.data_queue.get()
# Convert timestamp to datetime
timestamp = datetime.fromtimestamp(tick['timestamp'] / 1000)
# Process for each timeframe
for timeframe in self.timeframes:
interval = self._get_interval_seconds(timeframe)
if interval is None:
continue
# Round timestamp to nearest candle interval
candle_ts = int(tick['timestamp'] // (interval * 1000)) * (interval * 1000)
# Get or create candle for this timeframe
if not self.candle_cache[timeframe]:
# First candle for this timeframe
candle = {
'timestamp': candle_ts,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['volume']
}
self.candle_cache[timeframe].append(candle)
else:
# Update existing candle
last_candle = self.candle_cache[timeframe][-1]
if last_candle['timestamp'] == candle_ts:
# Update current candle
last_candle['high'] = max(last_candle['high'], tick['price'])
last_candle['low'] = min(last_candle['low'], tick['price'])
last_candle['close'] = tick['price']
last_candle['volume'] += tick['volume']
else:
# New candle
candle = {
'timestamp': candle_ts,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['volume']
}
self.candle_cache[timeframe].append(candle)
time.sleep(0.1)
except Exception as e:
logger.error(f"Error in data processing: {str(e)}")
time.sleep(1)
def _get_interval_seconds(self, timeframe):
"""Convert timeframe string to seconds."""
intervals = {
'1s': 1,
'1m': 60,
'1h': 3600,
'1d': 86400
}
return intervals.get(timeframe)
def _analyze_data(self):
"""Thread function for analyzing data and generating signals."""
logger.info("Starting analysis thread")
last_prediction_time = 0
while self.running:
try:
current_time = time.time()
# Only make predictions at the specified interval
if current_time - last_prediction_time < self.prediction_interval:
time.sleep(0.1)
continue
# Prepare input data from all timeframes
input_data = {}
valid = True
for timeframe in self.timeframes:
if not self.candle_cache[timeframe]:
logger.warning(f"No data available for timeframe {timeframe}")
valid = False
break
# Get last N candles for this timeframe
candles = list(self.candle_cache[timeframe])[-self.data_interface.window_size:]
# Convert to numpy array
ohlcv = np.array([
[c['open'], c['high'], c['low'], c['close'], c['volume']]
for c in candles
])
# Normalize data
ohlcv_normalized = (ohlcv - ohlcv.mean(axis=0)) / (ohlcv.std(axis=0) + 1e-8)
input_data[timeframe] = ohlcv_normalized
if not valid:
time.sleep(0.1)
continue
# Make prediction using the model
try:
prediction = self.model.predict(input_data)
# Get latest timestamp from 1s timeframe
latest_ts = self.candle_cache['1s'][-1]['timestamp'] if self.candle_cache['1s'] else int(time.time() * 1000)
# Process prediction
self._process_prediction(
prediction=prediction,
timeframe='multi',
timestamp=latest_ts
)
last_prediction_time = current_time
except Exception as e:
logger.error(f"Error making prediction: {str(e)}")
time.sleep(0.1)
except Exception as e:
logger.error(f"Error in analysis: {str(e)}")
time.sleep(1)
def _process_prediction(self, prediction, timeframe, timestamp):
"""
Process model prediction and generate trading signals.
Args:
prediction: Model prediction output
timeframe (str): Timeframe the prediction is for ('multi' for combined)
timestamp: Timestamp of the prediction (ms)
"""
# Convert prediction to trading signal
signal, confidence = self._prediction_to_signal(prediction)
# Convert timestamp to datetime
try:
dt = datetime.fromtimestamp(timestamp / 1000)
except:
dt = datetime.now()
# Log the signal with all timeframes
logger.info(
f"Signal generated - Timeframes: {', '.join(self.timeframes)}, "
f"Timestamp: {dt}, "
f"Signal: {signal} (Confidence: {confidence:.2f})"
)
# In a real implementation, we would execute trades here
# For now, we'll just log the signals
def _prediction_to_signal(self, prediction):
"""
Convert model prediction to trading signal and confidence.
Args:
prediction: Model prediction output (can be dict for multi-timeframe)
Returns:
tuple: (signal, confidence) where signal is BUY/SELL/HOLD,
confidence is probability (0-1)
"""
if isinstance(prediction, dict):
# Multi-timeframe prediction - combine signals
signals = []
confidences = []
for tf, pred in prediction.items():
if len(pred.shape) == 1:
# Binary classification
signal = "BUY" if pred[0] > 0.5 else "SELL"
confidence = pred[0] if signal == "BUY" else 1 - pred[0]
else:
# Multi-class
class_idx = np.argmax(pred)
signal = ["SELL", "HOLD", "BUY"][class_idx]
confidence = pred[class_idx]
signals.append(signal)
confidences.append(confidence)
# Simple voting system - count BUY/SELL signals
buy_count = signals.count("BUY")
sell_count = signals.count("SELL")
if buy_count > sell_count:
final_signal = "BUY"
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "BUY"])
elif sell_count > buy_count:
final_signal = "SELL"
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "SELL"])
else:
final_signal = "HOLD"
final_confidence = np.mean(confidences)
return final_signal, final_confidence
else:
# Single prediction
if len(prediction.shape) == 1:
# Binary classification
signal = "BUY" if prediction[0] > 0.5 else "SELL"
confidence = prediction[0] if signal == "BUY" else 1 - prediction[0]
else:
# Multi-class
class_idx = np.argmax(prediction)
signal = ["SELL", "HOLD", "BUY"][class_idx]
confidence = prediction[class_idx]
return signal, confidence

View File

@ -1,98 +0,0 @@
#!/usr/bin/env python3
"""
Immediate Model Cleanup Script
This script will clean up all existing model files and prepare the system
for fresh training with the new model management system.
"""
import logging
import sys
from model_manager import ModelManager
def main():
"""Run the model cleanup"""
# Configure logging for better output
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
print("=" * 60)
print("GOGO2 MODEL CLEANUP SYSTEM")
print("=" * 60)
print()
print("This script will:")
print("1. Delete ALL existing model files (.pt, .pth)")
print("2. Remove ALL checkpoint directories")
print("3. Clear model backup directories")
print("4. Reset the model registry")
print("5. Create clean directory structure")
print()
print("WARNING: This action cannot be undone!")
print()
# Calculate current space usage first
try:
manager = ModelManager()
storage_stats = manager.get_storage_stats()
print(f"Current storage usage:")
print(f"- Models: {storage_stats['total_models']}")
print(f"- Size: {storage_stats['actual_size_mb']:.1f}MB")
print()
except Exception as e:
print(f"Error checking current storage: {e}")
print()
# Ask for confirmation
print("Type 'CLEANUP' to proceed with the cleanup:")
user_input = input("> ").strip()
if user_input != "CLEANUP":
print("Cleanup cancelled. No changes made.")
return
print()
print("Starting cleanup...")
print("-" * 40)
try:
# Create manager and run cleanup
manager = ModelManager()
cleanup_result = manager.cleanup_all_existing_models(confirm=True)
print()
print("=" * 60)
print("CLEANUP COMPLETE")
print("=" * 60)
print(f"Files deleted: {cleanup_result['deleted_files']}")
print(f"Space freed: {cleanup_result['freed_space_mb']:.1f} MB")
print(f"Directories cleaned: {len(cleanup_result['deleted_directories'])}")
if cleanup_result['errors']:
print(f"Errors encountered: {len(cleanup_result['errors'])}")
print("Errors:")
for error in cleanup_result['errors'][:5]: # Show first 5 errors
print(f" - {error}")
if len(cleanup_result['errors']) > 5:
print(f" ... and {len(cleanup_result['errors']) - 5} more")
print()
print("System is now ready for fresh model training!")
print("The following directories have been created:")
print("- models/best_models/")
print("- models/cnn/")
print("- models/rl/")
print("- models/checkpoints/")
print("- NN/models/saved/")
print()
print("New models will be automatically managed by the ModelManager.")
except Exception as e:
print(f"Error during cleanup: {e}")
logging.exception("Cleanup failed")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,74 @@
import os
import shutil
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
EXCLUDE_PREFIXES = (
'COBY' + os.sep, # Do not touch COBY subsystem
)
EXCLUDE_FILES = {
'NN' + os.sep + 'training' + os.sep + 'enhanced_realtime_training.py',
}
delete_list_path = ROOT / 'DELETE_CANDIDATES.txt'
deleted_files: list[str] = []
kept_files: list[str] = []
if delete_list_path.exists():
for line in delete_list_path.read_text(encoding='utf-8').splitlines():
rel = line.strip()
if not rel:
continue
# Skip excluded prefixes
if any(rel.startswith(p) for p in EXCLUDE_PREFIXES):
kept_files.append(rel)
continue
# Skip explicitly excluded files
if rel in EXCLUDE_FILES:
kept_files.append(rel)
continue
fp = ROOT / rel
if fp.exists() and fp.is_file():
try:
fp.unlink()
deleted_files.append(rel)
except Exception:
kept_files.append(rel)
# Remove tests directories outside COBY
removed_dirs: list[str] = []
for d in ROOT.rglob('tests'):
try:
rel = str(d.relative_to(ROOT))
except Exception:
continue
if any(rel.startswith(p) for p in EXCLUDE_PREFIXES):
continue
if d.is_dir():
try:
shutil.rmtree(d)
removed_dirs.append(rel)
except Exception:
pass
# Write cleanup log / todo
log_lines = []
log_lines.append('Cleanup run summary:')
log_lines.append(f'- Deleted files: {len(deleted_files)}')
for x in deleted_files[:50]:
log_lines.append(f' - {x}')
if len(deleted_files) > 50:
log_lines.append(f' ... and {len(deleted_files)-50} more')
log_lines.append(f'- Removed test directories: {len(removed_dirs)}')
for x in removed_dirs[:50]:
log_lines.append(f' - {x}')
log_lines.append(f'- Kept (excluded): {len(kept_files)}')
(ROOT / 'CLEANUP_TODO.md').write_text('\n'.join(log_lines), encoding='utf-8')
print(f'Deleted files: {len(deleted_files)}')
print(f'Removed test dirs: {len(removed_dirs)}')
print(f'Kept (excluded): {len(kept_files)}')

View File

@ -1,193 +0,0 @@
#!/usr/bin/env python3
"""
Apply Trading System Fixes
This script applies fixes to the trading system to address:
1. Duplicate entry prices
2. P&L calculation issues
3. Position tracking problems
4. Trade display issues
Usage:
python apply_trading_fixes.py
"""
import os
import sys
import logging
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('logs/trading_fixes.log')
]
)
logger = logging.getLogger(__name__)
def apply_fixes():
"""Apply all fixes to the trading system"""
logger.info("=" * 70)
logger.info("APPLYING TRADING SYSTEM FIXES")
logger.info("=" * 70)
# Import fixes
try:
from core.trading_executor_fix import TradingExecutorFix
from web.dashboard_fix import DashboardFix
logger.info("Fix modules imported successfully")
except ImportError as e:
logger.error(f"Error importing fix modules: {e}")
return False
# Apply fixes to trading executor
try:
# Import trading executor
from core.trading_executor import TradingExecutor
# Create a test instance to apply fixes
test_executor = TradingExecutor()
# Apply fixes
TradingExecutorFix.apply_fixes(test_executor)
logger.info("Trading executor fixes applied successfully to test instance")
# Verify fixes
if hasattr(test_executor, 'price_cache_timestamp'):
logger.info("✅ Price caching fix verified")
else:
logger.warning("❌ Price caching fix not verified")
if hasattr(test_executor, 'trade_cooldown_seconds'):
logger.info("✅ Trade cooldown fix verified")
else:
logger.warning("❌ Trade cooldown fix not verified")
if hasattr(test_executor, '_check_trade_cooldown'):
logger.info("✅ Trade cooldown check method verified")
else:
logger.warning("❌ Trade cooldown check method not verified")
except Exception as e:
logger.error(f"Error applying trading executor fixes: {e}")
import traceback
logger.error(traceback.format_exc())
# Create patch for main.py
try:
main_patch = """
# Apply trading system fixes
try:
from core.trading_executor_fix import TradingExecutorFix
from web.dashboard_fix import DashboardFix
# Apply fixes to trading executor
if trading_executor:
TradingExecutorFix.apply_fixes(trading_executor)
logger.info("✅ Trading executor fixes applied")
# Apply fixes to dashboard
if 'dashboard' in locals() and dashboard:
DashboardFix.apply_fixes(dashboard)
logger.info("✅ Dashboard fixes applied")
logger.info("Trading system fixes applied successfully")
except Exception as e:
logger.warning(f"Error applying trading system fixes: {e}")
"""
# Write patch instructions
with open('patch_instructions.txt', 'w') as f:
f.write("""
TRADING SYSTEM FIX INSTRUCTIONS
==============================
To apply the fixes to your trading system, follow these steps:
1. Add the following code to main.py just before the dashboard.run_server() call:
```python
# Apply trading system fixes
try:
from core.trading_executor_fix import TradingExecutorFix
from web.dashboard_fix import DashboardFix
# Apply fixes to trading executor
if trading_executor:
TradingExecutorFix.apply_fixes(trading_executor)
logger.info("✅ Trading executor fixes applied")
# Apply fixes to dashboard
if 'dashboard' in locals() and dashboard:
DashboardFix.apply_fixes(dashboard)
logger.info("✅ Dashboard fixes applied")
logger.info("Trading system fixes applied successfully")
except Exception as e:
logger.warning(f"Error applying trading system fixes: {e}")
```
2. Add the following code to web/clean_dashboard.py in the __init__ method, just before the run_server method:
```python
# Apply dashboard fixes if available
try:
from web.dashboard_fix import DashboardFix
DashboardFix.apply_fixes(self)
logger.info("✅ Dashboard fixes applied during initialization")
except ImportError:
logger.warning("Dashboard fixes not available")
```
3. Run the system with the fixes applied:
```
python main.py
```
4. Monitor the logs for any issues with the fixes.
These fixes address:
- Duplicate entry prices
- P&L calculation issues
- Position tracking problems
- Trade display issues
- Rapid consecutive trades
""")
logger.info("Patch instructions written to patch_instructions.txt")
except Exception as e:
logger.error(f"Error creating patch: {e}")
logger.info("=" * 70)
logger.info("TRADING SYSTEM FIXES READY TO APPLY")
logger.info("See patch_instructions.txt for instructions")
logger.info("=" * 70)
return True
if __name__ == "__main__":
# Create logs directory if it doesn't exist
os.makedirs('logs', exist_ok=True)
# Apply fixes
success = apply_fixes()
if success:
print("\nTrading system fixes ready to apply!")
print("See patch_instructions.txt for instructions")
sys.exit(0)
else:
print("\nError preparing trading system fixes")
sys.exit(1)

View File

@ -1,218 +0,0 @@
#!/usr/bin/env python3
"""
Apply Trading System Fixes to Main.py
This script applies the trading system fixes directly to main.py
to address the issues with duplicate entry prices and P&L calculation.
Usage:
python apply_trading_fixes_to_main.py
"""
import os
import sys
import logging
import re
from pathlib import Path
import shutil
from datetime import datetime
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('logs/apply_fixes.log')
]
)
logger = logging.getLogger(__name__)
def backup_file(file_path):
"""Create a backup of a file"""
try:
backup_path = f"{file_path}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
shutil.copy2(file_path, backup_path)
logger.info(f"Created backup: {backup_path}")
return True
except Exception as e:
logger.error(f"Error creating backup of {file_path}: {e}")
return False
def apply_fixes_to_main():
"""Apply fixes to main.py"""
main_py_path = "main.py"
if not os.path.exists(main_py_path):
logger.error(f"File {main_py_path} not found")
return False
# Create backup
if not backup_file(main_py_path):
logger.error("Failed to create backup, aborting")
return False
try:
# Read main.py
with open(main_py_path, 'r') as f:
content = f.read()
# Find the position to insert the fixes
# Look for the line before dashboard.run_server()
run_server_pattern = r"dashboard\.run_server\("
match = re.search(run_server_pattern, content)
if not match:
logger.error("Could not find dashboard.run_server() call in main.py")
return False
# Find the position to insert the fixes (before the run_server call)
insert_pos = content.rfind("\n", 0, match.start())
if insert_pos == -1:
logger.error("Could not find insertion point in main.py")
return False
# Prepare the fixes to insert
fixes_code = """
# Apply trading system fixes
try:
from core.trading_executor_fix import TradingExecutorFix
from web.dashboard_fix import DashboardFix
# Apply fixes to trading executor
if trading_executor:
TradingExecutorFix.apply_fixes(trading_executor)
logger.info("✅ Trading executor fixes applied")
# Apply fixes to dashboard
if 'dashboard' in locals() and dashboard:
DashboardFix.apply_fixes(dashboard)
logger.info("✅ Dashboard fixes applied")
logger.info("Trading system fixes applied successfully")
except Exception as e:
logger.warning(f"Error applying trading system fixes: {e}")
"""
# Insert the fixes
new_content = content[:insert_pos] + fixes_code + content[insert_pos:]
# Write the modified content back to main.py
with open(main_py_path, 'w') as f:
f.write(new_content)
logger.info(f"Successfully applied fixes to {main_py_path}")
return True
except Exception as e:
logger.error(f"Error applying fixes to {main_py_path}: {e}")
return False
def apply_fixes_to_dashboard():
"""Apply fixes to web/clean_dashboard.py"""
dashboard_py_path = "web/clean_dashboard.py"
if not os.path.exists(dashboard_py_path):
logger.error(f"File {dashboard_py_path} not found")
return False
# Create backup
if not backup_file(dashboard_py_path):
logger.error("Failed to create backup, aborting")
return False
try:
# Read dashboard.py
with open(dashboard_py_path, 'r') as f:
content = f.read()
# Find the position to insert the fixes
# Look for the __init__ method
init_pattern = r"def __init__\(self,"
match = re.search(init_pattern, content)
if not match:
logger.error("Could not find __init__ method in dashboard.py")
return False
# Find the end of the __init__ method
init_end_pattern = r"logger\.debug\(.*\)"
init_end_matches = list(re.finditer(init_end_pattern, content[match.end():]))
if not init_end_matches:
logger.error("Could not find end of __init__ method in dashboard.py")
return False
# Get the last logger.debug line in the __init__ method
last_debug_match = init_end_matches[-1]
insert_pos = match.end() + last_debug_match.end()
# Prepare the fixes to insert
fixes_code = """
# Apply dashboard fixes if available
try:
from web.dashboard_fix import DashboardFix
DashboardFix.apply_fixes(self)
logger.info("✅ Dashboard fixes applied during initialization")
except ImportError:
logger.warning("Dashboard fixes not available")
"""
# Insert the fixes
new_content = content[:insert_pos] + fixes_code + content[insert_pos:]
# Write the modified content back to dashboard.py
with open(dashboard_py_path, 'w') as f:
f.write(new_content)
logger.info(f"Successfully applied fixes to {dashboard_py_path}")
return True
except Exception as e:
logger.error(f"Error applying fixes to {dashboard_py_path}: {e}")
return False
def main():
"""Main entry point"""
logger.info("=" * 70)
logger.info("APPLYING TRADING SYSTEM FIXES TO MAIN.PY")
logger.info("=" * 70)
# Create logs directory if it doesn't exist
os.makedirs('logs', exist_ok=True)
# Apply fixes to main.py
main_success = apply_fixes_to_main()
# Apply fixes to dashboard.py
dashboard_success = apply_fixes_to_dashboard()
if main_success and dashboard_success:
logger.info("=" * 70)
logger.info("TRADING SYSTEM FIXES APPLIED SUCCESSFULLY")
logger.info("=" * 70)
logger.info("The following issues have been fixed:")
logger.info("1. Duplicate entry prices")
logger.info("2. P&L calculation issues")
logger.info("3. Position tracking problems")
logger.info("4. Trade display issues")
logger.info("5. Rapid consecutive trades")
logger.info("=" * 70)
logger.info("You can now run the trading system with the fixes applied:")
logger.info("python main.py")
logger.info("=" * 70)
return 0
else:
logger.error("=" * 70)
logger.error("FAILED TO APPLY SOME FIXES")
logger.error("=" * 70)
logger.error("Please check the logs for details")
logger.error("=" * 70)
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,189 +0,0 @@
#!/usr/bin/env python3
"""
Balance Trading Signals - Analyze and fix SHORT signal bias
This script analyzes the trading signals from the orchestrator and adjusts
the model weights to balance BUY and SELL signals.
"""
import os
import sys
import logging
import json
from pathlib import Path
from datetime import datetime
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config, setup_logging
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
# Setup logging
setup_logging()
logger = logging.getLogger(__name__)
def analyze_trading_signals():
"""Analyze trading signals from the orchestrator"""
logger.info("Analyzing trading signals...")
# Initialize components
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
# Get recent decisions
symbols = orchestrator.symbols
all_decisions = {}
for symbol in symbols:
decisions = orchestrator.get_recent_decisions(symbol)
all_decisions[symbol] = decisions
# Count actions
action_counts = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
for decision in decisions:
action_counts[decision.action] += 1
total_decisions = sum(action_counts.values())
if total_decisions > 0:
buy_percent = action_counts['BUY'] / total_decisions * 100
sell_percent = action_counts['SELL'] / total_decisions * 100
hold_percent = action_counts['HOLD'] / total_decisions * 100
logger.info(f"Symbol: {symbol}")
logger.info(f" Total decisions: {total_decisions}")
logger.info(f" BUY: {action_counts['BUY']} ({buy_percent:.1f}%)")
logger.info(f" SELL: {action_counts['SELL']} ({sell_percent:.1f}%)")
logger.info(f" HOLD: {action_counts['HOLD']} ({hold_percent:.1f}%)")
# Check for bias
if sell_percent > buy_percent * 2: # If SELL signals are more than twice BUY signals
logger.warning(f" SELL bias detected: {sell_percent:.1f}% vs {buy_percent:.1f}%")
# Adjust model weights to balance signals
logger.info(" Adjusting model weights to balance signals...")
# Get current model weights
model_weights = orchestrator.model_weights
logger.info(f" Current model weights: {model_weights}")
# Identify models with SELL bias
model_predictions = {}
for model_name in model_weights:
model_predictions[model_name] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
# Analyze recent decisions to identify biased models
for decision in decisions:
reasoning = decision.reasoning
if 'models_used' in reasoning:
for model_name in reasoning['models_used']:
if model_name in model_predictions:
model_predictions[model_name][decision.action] += 1
# Calculate bias for each model
model_bias = {}
for model_name, actions in model_predictions.items():
total = sum(actions.values())
if total > 0:
buy_pct = actions['BUY'] / total * 100
sell_pct = actions['SELL'] / total * 100
# Calculate bias score (-100 to 100, negative = SELL bias, positive = BUY bias)
bias_score = buy_pct - sell_pct
model_bias[model_name] = bias_score
logger.info(f" Model {model_name}: Bias score = {bias_score:.1f} (BUY: {buy_pct:.1f}%, SELL: {sell_pct:.1f}%)")
# Adjust weights based on bias
adjusted_weights = {}
for model_name, weight in model_weights.items():
if model_name in model_bias:
bias = model_bias[model_name]
# If model has strong SELL bias, reduce its weight
if bias < -30: # Strong SELL bias
adjusted_weights[model_name] = max(0.05, weight * 0.7) # Reduce weight by 30%
logger.info(f" Reducing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} due to SELL bias")
# If model has BUY bias, increase its weight to balance
elif bias > 10: # BUY bias
adjusted_weights[model_name] = min(0.5, weight * 1.3) # Increase weight by 30%
logger.info(f" Increasing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} to balance SELL bias")
else:
adjusted_weights[model_name] = weight
else:
adjusted_weights[model_name] = weight
# Save adjusted weights
save_adjusted_weights(adjusted_weights)
logger.info(f" Adjusted weights: {adjusted_weights}")
logger.info(" Weights saved to 'adjusted_model_weights.json'")
# Recommend next steps
logger.info("\nRecommended actions:")
logger.info("1. Update the model weights in the orchestrator")
logger.info("2. Monitor trading signals for balance")
logger.info("3. Consider retraining models with balanced data")
def save_adjusted_weights(weights):
"""Save adjusted weights to a file"""
output = {
'timestamp': datetime.now().isoformat(),
'weights': weights,
'notes': 'Adjusted to balance BUY/SELL signals'
}
with open('adjusted_model_weights.json', 'w') as f:
json.dump(output, f, indent=2)
def apply_balanced_weights():
"""Apply balanced weights to the orchestrator"""
try:
# Check if weights file exists
if not os.path.exists('adjusted_model_weights.json'):
logger.error("Adjusted weights file not found. Run analyze_trading_signals() first.")
return False
# Load adjusted weights
with open('adjusted_model_weights.json', 'r') as f:
data = json.load(f)
weights = data.get('weights', {})
if not weights:
logger.error("No weights found in the file.")
return False
logger.info(f"Loaded adjusted weights: {weights}")
# Initialize components
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
# Apply weights
for model_name, weight in weights.items():
if model_name in orchestrator.model_weights:
orchestrator.model_weights[model_name] = weight
# Save updated weights
orchestrator._save_orchestrator_state()
logger.info("Applied balanced weights to orchestrator.")
logger.info("Restart the trading system for changes to take effect.")
return True
except Exception as e:
logger.error(f"Error applying balanced weights: {e}")
return False
if __name__ == "__main__":
logger.info("=" * 70)
logger.info("TRADING SIGNAL BALANCE ANALYZER")
logger.info("=" * 70)
if len(sys.argv) > 1 and sys.argv[1] == 'apply':
apply_balanced_weights()
else:
analyze_trading_signals()

View File

@ -1,163 +0,0 @@
import os
import sys
import logging
import importlib
import asyncio
from dotenv import load_dotenv
from safe_logging import setup_safe_logging
# Configure logging
setup_safe_logging()
logger = logging.getLogger("check_live_trading")
def check_dependencies():
"""Check if all required dependencies are installed"""
required_packages = [
"numpy", "pandas", "matplotlib", "mplfinance", "torch",
"dotenv", "ccxt", "websockets", "tensorboard",
"sklearn", "PIL", "asyncio"
]
missing_packages = []
for package in required_packages:
try:
if package == "dotenv":
importlib.import_module("dotenv")
elif package == "PIL":
importlib.import_module("PIL")
else:
importlib.import_module(package)
logger.info(f"{package} is installed")
except ImportError:
missing_packages.append(package)
logger.error(f"{package} is NOT installed")
if missing_packages:
logger.error(f"Missing packages: {', '.join(missing_packages)}")
logger.info("Install missing packages with: pip install -r requirements.txt")
return False
return True
def check_api_keys():
"""Check if API keys are configured"""
load_dotenv()
api_key = os.getenv('MEXC_API_KEY')
secret_key = os.getenv('MEXC_SECRET_KEY')
if not api_key or api_key == "your_api_key_here" or not secret_key or secret_key == "your_secret_key_here":
logger.error("❌ API keys are not properly configured in .env file")
logger.info("Please update your .env file with valid MEXC API keys")
return False
logger.info("✅ API keys are configured")
return True
def check_model_files():
"""Check if trained model files exist"""
model_files = [
"models/trading_agent_best_pnl.pt",
"models/trading_agent_best_reward.pt",
"models/trading_agent_final.pt"
]
missing_models = []
for model_file in model_files:
if os.path.exists(model_file):
logger.info(f"✅ Model file exists: {model_file}")
else:
missing_models.append(model_file)
logger.error(f"❌ Model file missing: {model_file}")
if missing_models:
logger.warning("Some model files are missing. You need to train the model first.")
return False
return True
async def check_exchange_connection():
"""Test connection to MEXC exchange"""
try:
import ccxt
# Load API keys
load_dotenv()
api_key = os.getenv('MEXC_API_KEY')
secret_key = os.getenv('MEXC_SECRET_KEY')
if api_key == "your_api_key_here" or secret_key == "your_secret_key_here":
logger.warning("⚠️ Using placeholder API keys, skipping exchange connection test")
return False
# Initialize exchange
exchange = ccxt.mexc({
'apiKey': api_key,
'secret': secret_key,
'enableRateLimit': True
})
# Test connection by fetching markets
markets = exchange.fetch_markets()
logger.info(f"✅ Successfully connected to MEXC exchange")
logger.info(f"✅ Found {len(markets)} markets")
return True
except Exception as e:
logger.error(f"❌ Failed to connect to MEXC exchange: {str(e)}")
return False
def check_directories():
"""Check if required directories exist"""
required_dirs = ["models", "runs", "trade_logs"]
for directory in required_dirs:
if not os.path.exists(directory):
logger.info(f"Creating directory: {directory}")
os.makedirs(directory, exist_ok=True)
logger.info("✅ All required directories exist")
return True
async def main():
"""Run all checks"""
logger.info("Running pre-flight checks for live trading...")
checks = [
("Dependencies", check_dependencies()),
("API Keys", check_api_keys()),
("Model Files", check_model_files()),
("Directories", check_directories()),
("Exchange Connection", await check_exchange_connection())
]
# Count failed checks
failed_checks = sum(1 for _, result in checks if not result)
# Print summary
logger.info("\n" + "="*50)
logger.info("LIVE TRADING PRE-FLIGHT CHECK SUMMARY")
logger.info("="*50)
for check_name, result in checks:
status = "✅ PASS" if result else "❌ FAIL"
logger.info(f"{check_name}: {status}")
logger.info("="*50)
if failed_checks == 0:
logger.info("🚀 All checks passed! You're ready for live trading.")
logger.info("\nRun live trading with:")
logger.info("python main.py --mode live --demo true --symbol ETH/USDT --timeframe 1m")
logger.info("\nFor real trading (after updating API keys):")
logger.info("python main.py --mode live --demo false --symbol ETH/USDT --timeframe 1m --leverage 50")
return 0
else:
logger.error(f"{failed_checks} check(s) failed. Please fix the issues before running live trading.")
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)

View File

@ -1,77 +0,0 @@
#!/usr/bin/env python3
"""
Check MEXC Available Trading Symbols
"""
import os
import sys
import logging
# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.trading_executor import TradingExecutor
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def check_mexc_symbols():
"""Check available trading symbols on MEXC"""
try:
logger.info("=== MEXC SYMBOL AVAILABILITY CHECK ===")
# Initialize trading executor
executor = TradingExecutor("config.yaml")
if not executor.exchange:
logger.error("Failed to initialize exchange")
return
# Get all supported symbols
logger.info("Fetching all supported symbols from MEXC...")
supported_symbols = executor.exchange.get_api_symbols()
logger.info(f"Total supported symbols: {len(supported_symbols)}")
# Filter ETH-related symbols
eth_symbols = [s for s in supported_symbols if 'ETH' in s]
logger.info(f"ETH-related symbols ({len(eth_symbols)}):")
for symbol in sorted(eth_symbols):
logger.info(f" {symbol}")
# Filter USDT pairs
usdt_symbols = [s for s in supported_symbols if s.endswith('USDT')]
logger.info(f"USDT pairs ({len(usdt_symbols)}):")
for symbol in sorted(usdt_symbols)[:20]: # Show first 20
logger.info(f" {symbol}")
if len(usdt_symbols) > 20:
logger.info(f" ... and {len(usdt_symbols) - 20} more")
# Filter USDC pairs
usdc_symbols = [s for s in supported_symbols if s.endswith('USDC')]
logger.info(f"USDC pairs ({len(usdc_symbols)}):")
for symbol in sorted(usdc_symbols):
logger.info(f" {symbol}")
# Check specific symbols we're interested in
test_symbols = ['ETHUSDT', 'ETHUSDC', 'BTCUSDT', 'BTCUSDC']
logger.info("Checking specific symbols:")
for symbol in test_symbols:
if symbol in supported_symbols:
logger.info(f"{symbol} - SUPPORTED")
else:
logger.info(f"{symbol} - NOT SUPPORTED")
# Show a sample of all available symbols
logger.info("Sample of all available symbols:")
for symbol in sorted(supported_symbols)[:30]:
logger.info(f" {symbol}")
if len(supported_symbols) > 30:
logger.info(f" ... and {len(supported_symbols) - 30} more")
except Exception as e:
logger.error(f"Error checking MEXC symbols: {e}")
if __name__ == "__main__":
check_mexc_symbols()

View File

@ -1,108 +0,0 @@
#!/usr/bin/env python3
"""
Cleanup Checkpoint Database
Remove invalid database entries and ensure consistency
"""
import logging
from pathlib import Path
from utils.database_manager import get_database_manager
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def cleanup_invalid_checkpoints():
"""Remove database entries for non-existent checkpoint files"""
print("=== Cleaning Up Invalid Checkpoint Entries ===")
db_manager = get_database_manager()
# Get all checkpoints from database
all_models = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target', 'cob_rl', 'extrema_trainer', 'decision']
removed_count = 0
for model_name in all_models:
checkpoints = db_manager.list_checkpoints(model_name)
for checkpoint in checkpoints:
file_path = Path(checkpoint.file_path)
if not file_path.exists():
print(f"Removing invalid entry: {checkpoint.checkpoint_id} -> {checkpoint.file_path}")
# Remove from database by setting as inactive and creating a new active one if needed
try:
# For now, we'll just report - the system will handle missing files gracefully
logger.warning(f"Invalid checkpoint file: {checkpoint.file_path}")
removed_count += 1
except Exception as e:
logger.error(f"Failed to remove invalid checkpoint: {e}")
else:
print(f"Valid checkpoint: {checkpoint.checkpoint_id} -> {checkpoint.file_path}")
print(f"Found {removed_count} invalid checkpoint entries")
def verify_checkpoint_loading():
"""Test that checkpoint loading works correctly"""
print("\n=== Verifying Checkpoint Loading ===")
from utils.checkpoint_manager import load_best_checkpoint
models_to_test = ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target']
for model_name in models_to_test:
try:
result = load_best_checkpoint(model_name)
if result:
file_path, metadata = result
file_exists = Path(file_path).exists()
print(f"{model_name}:")
print(f" ✅ Checkpoint found: {metadata.checkpoint_id}")
print(f" 📁 File exists: {file_exists}")
print(f" 📊 Loss: {getattr(metadata, 'loss', 'N/A')}")
print(f" 💾 Size: {Path(file_path).stat().st_size / (1024*1024):.1f}MB" if file_exists else " 💾 Size: N/A")
else:
print(f"{model_name}: ❌ No valid checkpoint found")
except Exception as e:
print(f"{model_name}: ❌ Error loading checkpoint: {e}")
def test_checkpoint_system_integration():
"""Test integration with the orchestrator"""
print("\n=== Testing Orchestrator Integration ===")
try:
# Test database manager integration
from utils.database_manager import get_database_manager
db_manager = get_database_manager()
# Test fast metadata access
for model_name in ['dqn_agent', 'enhanced_cnn']:
metadata = db_manager.get_best_checkpoint_metadata(model_name)
if metadata:
print(f"{model_name}: ✅ Fast metadata access works")
print(f" ID: {metadata.checkpoint_id}")
print(f" Loss: {metadata.performance_metrics.get('loss', 'N/A')}")
else:
print(f"{model_name}: ❌ No metadata found")
print("\n✅ Checkpoint system is ready for use!")
except Exception as e:
print(f"❌ Integration test failed: {e}")
def main():
"""Main cleanup process"""
cleanup_invalid_checkpoints()
verify_checkpoint_loading()
test_checkpoint_system_integration()
print("\n=== Cleanup Complete ===")
print("The checkpoint system should now work without 'file not found' errors!")
if __name__ == "__main__":
main()

View File

@ -1,186 +0,0 @@
#!/usr/bin/env python3
"""
Checkpoint Cleanup and Migration Script
This script helps clean up existing checkpoints and migrate to the new
checkpoint management system with W&B integration.
"""
import os
import logging
import shutil
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any
import torch
from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class CheckpointCleanup:
def __init__(self):
self.saved_models_dir = Path("NN/models/saved")
self.checkpoint_manager = get_checkpoint_manager()
def analyze_existing_checkpoints(self) -> Dict[str, Any]:
logger.info("Analyzing existing checkpoint files...")
analysis = {
'total_files': 0,
'total_size_mb': 0.0,
'model_types': {},
'file_patterns': {},
'potential_duplicates': []
}
if not self.saved_models_dir.exists():
logger.warning(f"Saved models directory not found: {self.saved_models_dir}")
return analysis
for pt_file in self.saved_models_dir.rglob("*.pt"):
try:
file_size_mb = pt_file.stat().st_size / (1024 * 1024)
analysis['total_files'] += 1
analysis['total_size_mb'] += file_size_mb
filename = pt_file.name
if 'cnn' in filename.lower():
model_type = 'cnn'
elif 'dqn' in filename.lower() or 'rl' in filename.lower():
model_type = 'rl'
elif 'agent' in filename.lower():
model_type = 'rl'
else:
model_type = 'unknown'
if model_type not in analysis['model_types']:
analysis['model_types'][model_type] = {'count': 0, 'size_mb': 0.0}
analysis['model_types'][model_type]['count'] += 1
analysis['model_types'][model_type]['size_mb'] += file_size_mb
base_name = filename.split('_')[0] if '_' in filename else filename.replace('.pt', '')
if base_name not in analysis['file_patterns']:
analysis['file_patterns'][base_name] = []
analysis['file_patterns'][base_name].append({
'path': str(pt_file),
'size_mb': file_size_mb,
'modified': datetime.fromtimestamp(pt_file.stat().st_mtime)
})
except Exception as e:
logger.error(f"Error analyzing {pt_file}: {e}")
for base_name, files in analysis['file_patterns'].items():
if len(files) > 5: # More than 5 files with same base name
analysis['potential_duplicates'].append({
'base_name': base_name,
'count': len(files),
'total_size_mb': sum(f['size_mb'] for f in files),
'files': files
})
logger.info(f"Analysis complete:")
logger.info(f" Total files: {analysis['total_files']}")
logger.info(f" Total size: {analysis['total_size_mb']:.2f} MB")
logger.info(f" Model types: {analysis['model_types']}")
logger.info(f" Potential duplicates: {len(analysis['potential_duplicates'])}")
return analysis
def cleanup_duplicates(self, dry_run: bool = True) -> Dict[str, Any]:
logger.info(f"Starting duplicate cleanup (dry_run={dry_run})...")
cleanup_results = {
'removed': 0,
'kept': 0,
'space_saved_mb': 0.0,
'details': []
}
analysis = self.analyze_existing_checkpoints()
for duplicate_group in analysis['potential_duplicates']:
base_name = duplicate_group['base_name']
files = duplicate_group['files']
# Sort by modification time (newest first)
files.sort(key=lambda x: x['modified'], reverse=True)
logger.info(f"Processing {base_name}: {len(files)} files")
# Keep only the 5 newest files
for i, file_info in enumerate(files):
if i < 5: # Keep first 5 (newest)
cleanup_results['kept'] += 1
cleanup_results['details'].append({
'action': 'kept',
'file': file_info['path']
})
else: # Remove the rest
if not dry_run:
try:
Path(file_info['path']).unlink()
logger.info(f"Removed: {file_info['path']}")
except Exception as e:
logger.error(f"Error removing {file_info['path']}: {e}")
continue
cleanup_results['removed'] += 1
cleanup_results['space_saved_mb'] += file_info['size_mb']
cleanup_results['details'].append({
'action': 'removed',
'file': file_info['path'],
'size_mb': file_info['size_mb']
})
logger.info(f"Cleanup {'simulation' if dry_run else 'complete'}:")
logger.info(f" Kept: {cleanup_results['kept']}")
logger.info(f" Removed: {cleanup_results['removed']}")
logger.info(f" Space saved: {cleanup_results['space_saved_mb']:.2f} MB")
return cleanup_results
def main():
logger.info("=== Checkpoint Cleanup Tool ===")
cleanup = CheckpointCleanup()
# Analyze existing checkpoints
logger.info("\\n1. Analyzing existing checkpoints...")
analysis = cleanup.analyze_existing_checkpoints()
if analysis['total_files'] == 0:
logger.info("No checkpoint files found.")
return
# Show potential space savings
total_duplicates = sum(len(group['files']) - 5 for group in analysis['potential_duplicates'] if len(group['files']) > 5)
if total_duplicates > 0:
logger.info(f"\\nFound {total_duplicates} files that could be cleaned up")
# Dry run first
logger.info("\\n2. Simulating cleanup...")
dry_run_results = cleanup.cleanup_duplicates(dry_run=True)
if dry_run_results['removed'] > 0:
proceed = input(f"\\nProceed with cleanup? Will remove {dry_run_results['removed']} files "
f"and save {dry_run_results['space_saved_mb']:.2f} MB. (y/n): ").lower().strip() == 'y'
if proceed:
logger.info("\\n3. Performing actual cleanup...")
cleanup_results = cleanup.cleanup_duplicates(dry_run=False)
logger.info("\\n=== Cleanup Complete ===")
else:
logger.info("Cleanup cancelled.")
else:
logger.info("No files to remove.")
else:
logger.info("No duplicate files found that need cleanup.")
if __name__ == "__main__":
main()

View File

View File

@ -1,402 +0,0 @@
"""
API Rate Limiter and Error Handler
This module provides robust rate limiting and error handling for API requests,
specifically designed to handle Binance's aggressive rate limiting (HTTP 418 errors)
and other exchange API limitations.
Features:
- Exponential backoff for rate limiting
- IP rotation and proxy support
- Request queuing and throttling
- Error recovery strategies
- Thread-safe operations
"""
import asyncio
import logging
import time
import random
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass, field
from collections import deque
import threading
from concurrent.futures import ThreadPoolExecutor
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
logger = logging.getLogger(__name__)
@dataclass
class RateLimitConfig:
"""Configuration for rate limiting"""
requests_per_second: float = 0.5 # Very conservative for Binance
requests_per_minute: int = 20
requests_per_hour: int = 1000
# Backoff configuration
initial_backoff: float = 1.0
max_backoff: float = 300.0 # 5 minutes max
backoff_multiplier: float = 2.0
# Error handling
max_retries: int = 3
retry_delay: float = 5.0
# IP blocking detection
block_detection_threshold: int = 3 # 3 consecutive 418s = blocked
block_recovery_time: int = 3600 # 1 hour recovery time
@dataclass
class APIEndpoint:
"""API endpoint configuration"""
name: str
base_url: str
rate_limit: RateLimitConfig
last_request_time: float = 0.0
request_count_minute: int = 0
request_count_hour: int = 0
consecutive_errors: int = 0
blocked_until: Optional[datetime] = None
# Request history for rate limiting
request_history: deque = field(default_factory=lambda: deque(maxlen=3600)) # 1 hour history
class APIRateLimiter:
"""Thread-safe API rate limiter with error handling"""
def __init__(self, config: RateLimitConfig = None):
self.config = config or RateLimitConfig()
# Thread safety
self.lock = threading.RLock()
# Endpoint tracking
self.endpoints: Dict[str, APIEndpoint] = {}
# Global rate limiting
self.global_request_history = deque(maxlen=3600)
self.global_blocked_until: Optional[datetime] = None
# Request session with retry strategy
self.session = self._create_session()
# Background cleanup thread
self.cleanup_thread = None
self.is_running = False
logger.info("API Rate Limiter initialized")
logger.info(f"Rate limits: {self.config.requests_per_second}/s, {self.config.requests_per_minute}/m")
def _create_session(self) -> requests.Session:
"""Create requests session with retry strategy"""
session = requests.Session()
# Retry strategy
retry_strategy = Retry(
total=self.config.max_retries,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["HEAD", "GET", "OPTIONS"]
)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
# Headers to appear more legitimate
session.headers.update({
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Accept': 'application/json',
'Accept-Language': 'en-US,en;q=0.9',
'Accept-Encoding': 'gzip, deflate, br',
'Connection': 'keep-alive',
'Upgrade-Insecure-Requests': '1',
})
return session
def register_endpoint(self, name: str, base_url: str, rate_limit: RateLimitConfig = None):
"""Register an API endpoint for rate limiting"""
with self.lock:
self.endpoints[name] = APIEndpoint(
name=name,
base_url=base_url,
rate_limit=rate_limit or self.config
)
logger.info(f"Registered endpoint: {name} -> {base_url}")
def start_background_cleanup(self):
"""Start background cleanup thread"""
if self.is_running:
return
self.is_running = True
self.cleanup_thread = threading.Thread(target=self._cleanup_worker, daemon=True)
self.cleanup_thread.start()
logger.info("Started background cleanup thread")
def stop_background_cleanup(self):
"""Stop background cleanup thread"""
self.is_running = False
if self.cleanup_thread:
self.cleanup_thread.join(timeout=5)
logger.info("Stopped background cleanup thread")
def _cleanup_worker(self):
"""Background worker to clean up old request history"""
while self.is_running:
try:
current_time = time.time()
cutoff_time = current_time - 3600 # 1 hour ago
with self.lock:
# Clean global history
while (self.global_request_history and
self.global_request_history[0] < cutoff_time):
self.global_request_history.popleft()
# Clean endpoint histories
for endpoint in self.endpoints.values():
while (endpoint.request_history and
endpoint.request_history[0] < cutoff_time):
endpoint.request_history.popleft()
# Reset counters
endpoint.request_count_minute = len([
t for t in endpoint.request_history
if t > current_time - 60
])
endpoint.request_count_hour = len(endpoint.request_history)
time.sleep(60) # Clean every minute
except Exception as e:
logger.error(f"Error in cleanup worker: {e}")
time.sleep(30)
def can_make_request(self, endpoint_name: str) -> tuple[bool, float]:
"""
Check if we can make a request to the endpoint
Returns:
(can_make_request, wait_time_seconds)
"""
with self.lock:
current_time = time.time()
# Check global blocking
if self.global_blocked_until and datetime.now() < self.global_blocked_until:
wait_time = (self.global_blocked_until - datetime.now()).total_seconds()
return False, wait_time
# Get endpoint
endpoint = self.endpoints.get(endpoint_name)
if not endpoint:
logger.warning(f"Unknown endpoint: {endpoint_name}")
return False, 60.0
# Check endpoint blocking
if endpoint.blocked_until and datetime.now() < endpoint.blocked_until:
wait_time = (endpoint.blocked_until - datetime.now()).total_seconds()
return False, wait_time
# Check rate limits
config = endpoint.rate_limit
# Per-second rate limit
time_since_last = current_time - endpoint.last_request_time
if time_since_last < (1.0 / config.requests_per_second):
wait_time = (1.0 / config.requests_per_second) - time_since_last
return False, wait_time
# Per-minute rate limit
minute_requests = len([
t for t in endpoint.request_history
if t > current_time - 60
])
if minute_requests >= config.requests_per_minute:
return False, 60.0
# Per-hour rate limit
if len(endpoint.request_history) >= config.requests_per_hour:
return False, 3600.0
return True, 0.0
def make_request(self, endpoint_name: str, url: str, method: str = 'GET',
**kwargs) -> Optional[requests.Response]:
"""
Make a rate-limited request with error handling
Args:
endpoint_name: Name of the registered endpoint
url: Full URL to request
method: HTTP method
**kwargs: Additional arguments for requests
Returns:
Response object or None if failed
"""
with self.lock:
endpoint = self.endpoints.get(endpoint_name)
if not endpoint:
logger.error(f"Unknown endpoint: {endpoint_name}")
return None
# Check if we can make the request
can_request, wait_time = self.can_make_request(endpoint_name)
if not can_request:
logger.debug(f"Rate limited for {endpoint_name}, waiting {wait_time:.2f}s")
time.sleep(min(wait_time, 30)) # Cap wait time
return None
# Record request attempt
current_time = time.time()
endpoint.last_request_time = current_time
endpoint.request_history.append(current_time)
self.global_request_history.append(current_time)
# Add jitter to avoid thundering herd
jitter = random.uniform(0.1, 0.5)
time.sleep(jitter)
# Make the request (outside of lock to avoid blocking other threads)
try:
# Set timeout
kwargs.setdefault('timeout', 10)
# Make request
response = self.session.request(method, url, **kwargs)
# Handle response
with self.lock:
if response.status_code == 200:
# Success - reset error counter
endpoint.consecutive_errors = 0
return response
elif response.status_code == 418:
# Binance "I'm a teapot" - rate limited/blocked
endpoint.consecutive_errors += 1
logger.warning(f"HTTP 418 (rate limited) for {endpoint_name}, consecutive errors: {endpoint.consecutive_errors}")
if endpoint.consecutive_errors >= endpoint.rate_limit.block_detection_threshold:
# We're likely IP blocked
block_time = datetime.now() + timedelta(seconds=endpoint.rate_limit.block_recovery_time)
endpoint.blocked_until = block_time
logger.error(f"Endpoint {endpoint_name} blocked until {block_time}")
return None
elif response.status_code == 429:
# Too many requests
endpoint.consecutive_errors += 1
logger.warning(f"HTTP 429 (too many requests) for {endpoint_name}")
# Implement exponential backoff
backoff_time = min(
endpoint.rate_limit.initial_backoff * (endpoint.rate_limit.backoff_multiplier ** endpoint.consecutive_errors),
endpoint.rate_limit.max_backoff
)
block_time = datetime.now() + timedelta(seconds=backoff_time)
endpoint.blocked_until = block_time
logger.warning(f"Backing off {endpoint_name} for {backoff_time:.2f}s")
return None
else:
# Other error
endpoint.consecutive_errors += 1
logger.warning(f"HTTP {response.status_code} for {endpoint_name}: {response.text[:200]}")
return None
except requests.exceptions.RequestException as e:
with self.lock:
endpoint.consecutive_errors += 1
logger.error(f"Request exception for {endpoint_name}: {e}")
return None
except Exception as e:
with self.lock:
endpoint.consecutive_errors += 1
logger.error(f"Unexpected error for {endpoint_name}: {e}")
return None
def get_endpoint_status(self, endpoint_name: str) -> Dict[str, Any]:
"""Get status information for an endpoint"""
with self.lock:
endpoint = self.endpoints.get(endpoint_name)
if not endpoint:
return {'error': 'Unknown endpoint'}
current_time = time.time()
return {
'name': endpoint.name,
'base_url': endpoint.base_url,
'consecutive_errors': endpoint.consecutive_errors,
'blocked_until': endpoint.blocked_until.isoformat() if endpoint.blocked_until else None,
'requests_last_minute': len([t for t in endpoint.request_history if t > current_time - 60]),
'requests_last_hour': len(endpoint.request_history),
'last_request_time': endpoint.last_request_time,
'can_make_request': self.can_make_request(endpoint_name)[0]
}
def get_all_endpoint_status(self) -> Dict[str, Dict[str, Any]]:
"""Get status for all endpoints"""
return {name: self.get_endpoint_status(name) for name in self.endpoints.keys()}
def reset_endpoint(self, endpoint_name: str):
"""Reset an endpoint's error state"""
with self.lock:
endpoint = self.endpoints.get(endpoint_name)
if endpoint:
endpoint.consecutive_errors = 0
endpoint.blocked_until = None
logger.info(f"Reset endpoint: {endpoint_name}")
def reset_all_endpoints(self):
"""Reset all endpoints' error states"""
with self.lock:
for endpoint in self.endpoints.values():
endpoint.consecutive_errors = 0
endpoint.blocked_until = None
self.global_blocked_until = None
logger.info("Reset all endpoints")
# Global rate limiter instance
_global_rate_limiter = None
def get_rate_limiter() -> APIRateLimiter:
"""Get global rate limiter instance"""
global _global_rate_limiter
if _global_rate_limiter is None:
_global_rate_limiter = APIRateLimiter()
_global_rate_limiter.start_background_cleanup()
# Register common endpoints
_global_rate_limiter.register_endpoint(
'binance_api',
'https://api.binance.com',
RateLimitConfig(
requests_per_second=0.2, # Very conservative
requests_per_minute=10,
requests_per_hour=500
)
)
_global_rate_limiter.register_endpoint(
'mexc_api',
'https://api.mexc.com',
RateLimitConfig(
requests_per_second=0.5,
requests_per_minute=20,
requests_per_hour=1000
)
)
return _global_rate_limiter

View File

@ -1,442 +0,0 @@
"""
Async Handler for UI Stability Fix
Properly handles all async operations in the dashboard with single event loop management,
proper exception handling, and timeout support to prevent async/await errors.
"""
import asyncio
import logging
import threading
import time
from typing import Any, Callable, Coroutine, Dict, Optional, Union
from concurrent.futures import ThreadPoolExecutor
import functools
import weakref
logger = logging.getLogger(__name__)
class AsyncOperationError(Exception):
"""Exception raised for async operation errors"""
pass
class AsyncHandler:
"""
Centralized async operation handler with single event loop management
and proper exception handling for async operations.
"""
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
"""
Initialize the async handler
Args:
loop: Optional event loop to use. If None, creates a new one.
"""
self._loop = loop
self._thread = None
self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="AsyncHandler")
self._running = False
self._callbacks = weakref.WeakSet()
self._timeout_default = 30.0 # Default timeout for operations
# Start the event loop in a separate thread if not provided
if self._loop is None:
self._start_event_loop_thread()
logger.info("AsyncHandler initialized with event loop management")
def _start_event_loop_thread(self):
"""Start the event loop in a separate thread"""
def run_event_loop():
"""Run the event loop in a separate thread"""
try:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._running = True
logger.debug("Event loop started in separate thread")
self._loop.run_forever()
except Exception as e:
logger.error(f"Error in event loop thread: {e}")
finally:
self._running = False
logger.debug("Event loop thread stopped")
self._thread = threading.Thread(target=run_event_loop, daemon=True, name="AsyncHandler-EventLoop")
self._thread.start()
# Wait for the loop to be ready
timeout = 5.0
start_time = time.time()
while not self._running and (time.time() - start_time) < timeout:
time.sleep(0.1)
if not self._running:
raise AsyncOperationError("Failed to start event loop within timeout")
def is_running(self) -> bool:
"""Check if the async handler is running"""
return self._running and self._loop is not None and not self._loop.is_closed()
def run_async_safely(self, coro: Coroutine, timeout: Optional[float] = None) -> Any:
"""
Run an async coroutine safely with proper error handling and timeout
Args:
coro: The coroutine to run
timeout: Timeout in seconds (uses default if None)
Returns:
The result of the coroutine
Raises:
AsyncOperationError: If the operation fails or times out
"""
if not self.is_running():
raise AsyncOperationError("AsyncHandler is not running")
timeout = timeout or self._timeout_default
try:
# Schedule the coroutine on the event loop
future = asyncio.run_coroutine_threadsafe(
asyncio.wait_for(coro, timeout=timeout),
self._loop
)
# Wait for the result with timeout
result = future.result(timeout=timeout + 1.0) # Add buffer to future timeout
logger.debug("Async operation completed successfully")
return result
except asyncio.TimeoutError:
logger.error(f"Async operation timed out after {timeout} seconds")
raise AsyncOperationError(f"Operation timed out after {timeout} seconds")
except Exception as e:
logger.error(f"Async operation failed: {e}")
raise AsyncOperationError(f"Async operation failed: {e}")
def schedule_coroutine(self, coro: Coroutine, callback: Optional[Callable] = None) -> None:
"""
Schedule a coroutine to run asynchronously without waiting for result
Args:
coro: The coroutine to schedule
callback: Optional callback to call with the result
"""
if not self.is_running():
logger.warning("Cannot schedule coroutine: AsyncHandler is not running")
return
async def wrapped_coro():
"""Wrapper to handle exceptions and callbacks"""
try:
result = await coro
if callback:
try:
callback(result)
except Exception as e:
logger.error(f"Error in coroutine callback: {e}")
return result
except Exception as e:
logger.error(f"Error in scheduled coroutine: {e}")
if callback:
try:
callback(None) # Call callback with None on error
except Exception as cb_e:
logger.error(f"Error in error callback: {cb_e}")
try:
asyncio.run_coroutine_threadsafe(wrapped_coro(), self._loop)
logger.debug("Coroutine scheduled successfully")
except Exception as e:
logger.error(f"Failed to schedule coroutine: {e}")
def create_task_safely(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]:
"""
Create an asyncio task safely with proper error handling
Args:
coro: The coroutine to create a task for
name: Optional name for the task
Returns:
The created task or None if failed
"""
if not self.is_running():
logger.warning("Cannot create task: AsyncHandler is not running")
return None
async def create_task():
"""Create the task in the event loop"""
try:
task = asyncio.create_task(coro, name=name)
logger.debug(f"Task created: {name or 'unnamed'}")
return task
except Exception as e:
logger.error(f"Failed to create task {name}: {e}")
return None
try:
future = asyncio.run_coroutine_threadsafe(create_task(), self._loop)
return future.result(timeout=5.0)
except Exception as e:
logger.error(f"Failed to create task {name}: {e}")
return None
async def handle_orchestrator_connection(self, orchestrator) -> bool:
"""
Handle orchestrator connection with proper async patterns
Args:
orchestrator: The orchestrator instance to connect to
Returns:
True if connection successful, False otherwise
"""
try:
logger.info("Connecting to orchestrator...")
# Add decision callback if orchestrator supports it
if hasattr(orchestrator, 'add_decision_callback'):
await orchestrator.add_decision_callback(self._handle_trading_decision)
logger.info("Decision callback added to orchestrator")
# Start COB integration if available
if hasattr(orchestrator, 'start_cob_integration'):
await orchestrator.start_cob_integration()
logger.info("COB integration started")
# Start continuous trading if available
if hasattr(orchestrator, 'start_continuous_trading'):
await orchestrator.start_continuous_trading()
logger.info("Continuous trading started")
logger.info("Successfully connected to orchestrator")
return True
except Exception as e:
logger.error(f"Failed to connect to orchestrator: {e}")
return False
async def handle_cob_integration(self, cob_integration) -> bool:
"""
Handle COB integration startup with proper async patterns
Args:
cob_integration: The COB integration instance
Returns:
True if startup successful, False otherwise
"""
try:
logger.info("Starting COB integration...")
if hasattr(cob_integration, 'start'):
await cob_integration.start()
logger.info("COB integration started successfully")
return True
else:
logger.warning("COB integration does not have start method")
return False
except Exception as e:
logger.error(f"Failed to start COB integration: {e}")
return False
async def _handle_trading_decision(self, decision: Dict[str, Any]) -> None:
"""
Handle trading decision with proper async patterns
Args:
decision: The trading decision dictionary
"""
try:
logger.debug(f"Handling trading decision: {decision.get('action', 'UNKNOWN')}")
# Process the decision (this would be customized based on needs)
# For now, just log it
symbol = decision.get('symbol', 'UNKNOWN')
action = decision.get('action', 'HOLD')
confidence = decision.get('confidence', 0.0)
logger.info(f"Trading decision processed: {action} {symbol} (confidence: {confidence:.2f})")
except Exception as e:
logger.error(f"Error handling trading decision: {e}")
def run_in_executor(self, func: Callable, *args, **kwargs) -> Any:
"""
Run a blocking function in the thread pool executor
Args:
func: The function to run
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
The result of the function
"""
if not self.is_running():
raise AsyncOperationError("AsyncHandler is not running")
try:
# Create a partial function with the arguments
partial_func = functools.partial(func, *args, **kwargs)
# Create a coroutine that runs the function in executor
async def run_in_executor_coro():
return await self._loop.run_in_executor(self._executor, partial_func)
# Run the coroutine
future = asyncio.run_coroutine_threadsafe(run_in_executor_coro(), self._loop)
result = future.result(timeout=self._timeout_default)
logger.debug("Executor function completed successfully")
return result
except Exception as e:
logger.error(f"Error running function in executor: {e}")
raise AsyncOperationError(f"Executor function failed: {e}")
def add_periodic_task(self, coro_func: Callable[[], Coroutine], interval: float, name: Optional[str] = None) -> Optional[asyncio.Task]:
"""
Add a periodic task that runs at specified intervals
Args:
coro_func: Function that returns a coroutine to run periodically
interval: Interval in seconds between runs
name: Optional name for the task
Returns:
The created task or None if failed
"""
async def periodic_runner():
"""Run the coroutine periodically"""
task_name = name or "periodic_task"
logger.info(f"Starting periodic task: {task_name} (interval: {interval}s)")
try:
while True:
try:
coro = coro_func()
await coro
logger.debug(f"Periodic task {task_name} completed")
except Exception as e:
logger.error(f"Error in periodic task {task_name}: {e}")
await asyncio.sleep(interval)
except asyncio.CancelledError:
logger.info(f"Periodic task {task_name} cancelled")
raise
except Exception as e:
logger.error(f"Fatal error in periodic task {task_name}: {e}")
return self.create_task_safely(periodic_runner(), name=f"periodic_{name}")
def stop(self) -> None:
"""Stop the async handler and clean up resources"""
try:
logger.info("Stopping AsyncHandler...")
if self._loop and not self._loop.is_closed():
# Cancel all tasks
if self._loop.is_running():
asyncio.run_coroutine_threadsafe(self._cancel_all_tasks(), self._loop)
# Stop the event loop
self._loop.call_soon_threadsafe(self._loop.stop)
# Shutdown executor
if self._executor:
self._executor.shutdown(wait=True)
# Wait for thread to finish
if self._thread and self._thread.is_alive():
self._thread.join(timeout=5.0)
self._running = False
logger.info("AsyncHandler stopped successfully")
except Exception as e:
logger.error(f"Error stopping AsyncHandler: {e}")
async def _cancel_all_tasks(self) -> None:
"""Cancel all running tasks"""
try:
tasks = [task for task in asyncio.all_tasks(self._loop) if not task.done()]
if tasks:
logger.info(f"Cancelling {len(tasks)} running tasks")
for task in tasks:
task.cancel()
# Wait for tasks to be cancelled
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug("All tasks cancelled")
except Exception as e:
logger.error(f"Error cancelling tasks: {e}")
def __enter__(self):
"""Context manager entry"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit"""
self.stop()
class AsyncContextManager:
"""
Context manager for async operations that ensures proper cleanup
"""
def __init__(self, async_handler: AsyncHandler):
self.async_handler = async_handler
self.active_tasks = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Cancel any active tasks
for task in self.active_tasks:
if not task.done():
task.cancel()
def create_task(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]:
"""Create a task and track it for cleanup"""
task = self.async_handler.create_task_safely(coro, name)
if task:
self.active_tasks.append(task)
return task
def create_async_handler(loop: Optional[asyncio.AbstractEventLoop] = None) -> AsyncHandler:
"""
Factory function to create an AsyncHandler instance
Args:
loop: Optional event loop to use
Returns:
AsyncHandler instance
"""
return AsyncHandler(loop=loop)
def run_async_safely(coro: Coroutine, timeout: Optional[float] = None) -> Any:
"""
Convenience function to run a coroutine safely with a temporary AsyncHandler
Args:
coro: The coroutine to run
timeout: Timeout in seconds
Returns:
The result of the coroutine
"""
with AsyncHandler() as handler:
return handler.run_async_safely(coro, timeout=timeout)

View File

@ -1,952 +0,0 @@
"""
Bookmap Order Book Data Provider
This module integrates with Bookmap to gather:
- Current Order Book (COB) data
- Session Volume Profile (SVP) data
- Order book sweeps and momentum trades detection
- Real-time order size heatmap matrix (last 10 minutes)
- Level 2 market depth analysis
The data is processed and fed to CNN and DQN networks for enhanced trading decisions.
"""
import asyncio
import json
import logging
import time
import websockets
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Callable
from collections import deque, defaultdict
from dataclasses import dataclass
from threading import Thread, Lock
import requests
logger = logging.getLogger(__name__)
@dataclass
class OrderBookLevel:
"""Represents a single order book level"""
price: float
size: float
orders: int
side: str # 'bid' or 'ask'
timestamp: datetime
@dataclass
class OrderBookSnapshot:
"""Complete order book snapshot"""
symbol: str
timestamp: datetime
bids: List[OrderBookLevel]
asks: List[OrderBookLevel]
spread: float
mid_price: float
@dataclass
class VolumeProfileLevel:
"""Volume profile level data"""
price: float
volume: float
buy_volume: float
sell_volume: float
trades_count: int
vwap: float
@dataclass
class OrderFlowSignal:
"""Order flow signal detection"""
timestamp: datetime
signal_type: str # 'sweep', 'absorption', 'iceberg', 'momentum'
price: float
volume: float
confidence: float
description: str
class BookmapDataProvider:
"""
Real-time order book data provider using Bookmap-style analysis
Features:
- Level 2 order book monitoring
- Order flow detection (sweeps, absorptions)
- Volume profile analysis
- Order size heatmap generation
- Market microstructure analysis
"""
def __init__(self, symbols: List[str] = None, depth_levels: int = 20):
"""
Initialize Bookmap data provider
Args:
symbols: List of symbols to monitor
depth_levels: Number of order book levels to track
"""
self.symbols = symbols or ['ETHUSDT', 'BTCUSDT']
self.depth_levels = depth_levels
self.is_streaming = False
# Order book data storage
self.order_books: Dict[str, OrderBookSnapshot] = {}
self.order_book_history: Dict[str, deque] = {}
self.volume_profiles: Dict[str, List[VolumeProfileLevel]] = {}
# Heatmap data (10-minute rolling window)
self.heatmap_window = timedelta(minutes=10)
self.order_heatmaps: Dict[str, deque] = {}
self.price_levels: Dict[str, List[float]] = {}
# Order flow detection
self.flow_signals: Dict[str, deque] = {}
self.sweep_threshold = 0.8 # Minimum confidence for sweep detection
self.absorption_threshold = 0.7 # Minimum confidence for absorption
# Market microstructure metrics
self.bid_ask_spreads: Dict[str, deque] = {}
self.order_book_imbalances: Dict[str, deque] = {}
self.liquidity_metrics: Dict[str, Dict] = {}
# WebSocket connections
self.websocket_tasks: Dict[str, asyncio.Task] = {}
self.data_lock = Lock()
# Callbacks for CNN/DQN integration
self.cnn_callbacks: List[Callable] = []
self.dqn_callbacks: List[Callable] = []
# Performance tracking
self.update_counts = defaultdict(int)
self.last_update_times = {}
# Initialize data structures
for symbol in self.symbols:
self.order_book_history[symbol] = deque(maxlen=1000)
self.order_heatmaps[symbol] = deque(maxlen=600) # 10 min at 1s intervals
self.flow_signals[symbol] = deque(maxlen=500)
self.bid_ask_spreads[symbol] = deque(maxlen=1000)
self.order_book_imbalances[symbol] = deque(maxlen=1000)
self.liquidity_metrics[symbol] = {
'total_bid_size': 0.0,
'total_ask_size': 0.0,
'weighted_mid': 0.0,
'liquidity_ratio': 1.0
}
logger.info(f"BookmapDataProvider initialized for {len(self.symbols)} symbols")
logger.info(f"Tracking {depth_levels} order book levels per side")
def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
"""Add callback for CNN model updates"""
self.cnn_callbacks.append(callback)
logger.info(f"Added CNN callback: {len(self.cnn_callbacks)} total")
def add_dqn_callback(self, callback: Callable[[str, Dict], None]):
"""Add callback for DQN model updates"""
self.dqn_callbacks.append(callback)
logger.info(f"Added DQN callback: {len(self.dqn_callbacks)} total")
async def start_streaming(self):
"""Start real-time order book streaming"""
if self.is_streaming:
logger.warning("Bookmap streaming already active")
return
self.is_streaming = True
logger.info("Starting Bookmap order book streaming")
# Start order book streams for each symbol
for symbol in self.symbols:
# Order book depth stream
depth_task = asyncio.create_task(self._stream_order_book_depth(symbol))
self.websocket_tasks[f"{symbol}_depth"] = depth_task
# Trade stream for order flow analysis
trade_task = asyncio.create_task(self._stream_trades(symbol))
self.websocket_tasks[f"{symbol}_trades"] = trade_task
# Start analysis threads
analysis_task = asyncio.create_task(self._continuous_analysis())
self.websocket_tasks["analysis"] = analysis_task
logger.info(f"Started streaming for {len(self.symbols)} symbols")
async def stop_streaming(self):
"""Stop order book streaming"""
if not self.is_streaming:
return
logger.info("Stopping Bookmap streaming")
self.is_streaming = False
# Cancel all tasks
for name, task in self.websocket_tasks.items():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.websocket_tasks.clear()
logger.info("Bookmap streaming stopped")
async def _stream_order_book_depth(self, symbol: str):
"""Stream order book depth data"""
binance_symbol = symbol.lower()
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@depth20@100ms"
while self.is_streaming:
try:
async with websockets.connect(url) as websocket:
logger.info(f"Order book depth WebSocket connected for {symbol}")
async for message in websocket:
if not self.is_streaming:
break
try:
data = json.loads(message)
await self._process_depth_update(symbol, data)
except Exception as e:
logger.warning(f"Error processing depth for {symbol}: {e}")
except Exception as e:
logger.error(f"Depth WebSocket error for {symbol}: {e}")
if self.is_streaming:
await asyncio.sleep(2)
async def _stream_trades(self, symbol: str):
"""Stream trade data for order flow analysis"""
binance_symbol = symbol.lower()
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade"
while self.is_streaming:
try:
async with websockets.connect(url) as websocket:
logger.info(f"Trade WebSocket connected for {symbol}")
async for message in websocket:
if not self.is_streaming:
break
try:
data = json.loads(message)
await self._process_trade_update(symbol, data)
except Exception as e:
logger.warning(f"Error processing trade for {symbol}: {e}")
except Exception as e:
logger.error(f"Trade WebSocket error for {symbol}: {e}")
if self.is_streaming:
await asyncio.sleep(2)
async def _process_depth_update(self, symbol: str, data: Dict):
"""Process order book depth update"""
try:
timestamp = datetime.now()
# Parse bids and asks
bids = []
asks = []
for bid_data in data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
bids.append(OrderBookLevel(
price=price,
size=size,
orders=1, # Binance doesn't provide order count
side='bid',
timestamp=timestamp
))
for ask_data in data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
asks.append(OrderBookLevel(
price=price,
size=size,
orders=1,
side='ask',
timestamp=timestamp
))
# Sort order book levels
bids.sort(key=lambda x: x.price, reverse=True)
asks.sort(key=lambda x: x.price)
# Calculate spread and mid price
if bids and asks:
best_bid = bids[0].price
best_ask = asks[0].price
spread = best_ask - best_bid
mid_price = (best_bid + best_ask) / 2
else:
spread = 0.0
mid_price = 0.0
# Create order book snapshot
snapshot = OrderBookSnapshot(
symbol=symbol,
timestamp=timestamp,
bids=bids,
asks=asks,
spread=spread,
mid_price=mid_price
)
with self.data_lock:
self.order_books[symbol] = snapshot
self.order_book_history[symbol].append(snapshot)
# Update liquidity metrics
self._update_liquidity_metrics(symbol, snapshot)
# Update order book imbalance
self._calculate_order_book_imbalance(symbol, snapshot)
# Update heatmap data
self._update_order_heatmap(symbol, snapshot)
# Update counters
self.update_counts[f"{symbol}_depth"] += 1
self.last_update_times[f"{symbol}_depth"] = timestamp
except Exception as e:
logger.error(f"Error processing depth update for {symbol}: {e}")
async def _process_trade_update(self, symbol: str, data: Dict):
"""Process trade data for order flow analysis"""
try:
timestamp = datetime.fromtimestamp(int(data['T']) / 1000)
price = float(data['p'])
quantity = float(data['q'])
is_buyer_maker = data['m']
# Analyze for order flow signals
await self._analyze_order_flow(symbol, timestamp, price, quantity, is_buyer_maker)
# Update volume profile
self._update_volume_profile(symbol, price, quantity, is_buyer_maker)
self.update_counts[f"{symbol}_trades"] += 1
except Exception as e:
logger.error(f"Error processing trade for {symbol}: {e}")
def _update_liquidity_metrics(self, symbol: str, snapshot: OrderBookSnapshot):
"""Update liquidity metrics from order book snapshot"""
try:
total_bid_size = sum(level.size for level in snapshot.bids)
total_ask_size = sum(level.size for level in snapshot.asks)
# Calculate weighted mid price
if snapshot.bids and snapshot.asks:
bid_weight = total_bid_size / (total_bid_size + total_ask_size)
ask_weight = total_ask_size / (total_bid_size + total_ask_size)
weighted_mid = (snapshot.bids[0].price * ask_weight +
snapshot.asks[0].price * bid_weight)
else:
weighted_mid = snapshot.mid_price
# Liquidity ratio (bid/ask balance)
if total_ask_size > 0:
liquidity_ratio = total_bid_size / total_ask_size
else:
liquidity_ratio = 1.0
self.liquidity_metrics[symbol] = {
'total_bid_size': total_bid_size,
'total_ask_size': total_ask_size,
'weighted_mid': weighted_mid,
'liquidity_ratio': liquidity_ratio,
'spread_bps': (snapshot.spread / snapshot.mid_price) * 10000 if snapshot.mid_price > 0 else 0
}
except Exception as e:
logger.error(f"Error updating liquidity metrics for {symbol}: {e}")
def _calculate_order_book_imbalance(self, symbol: str, snapshot: OrderBookSnapshot):
"""Calculate order book imbalance ratio"""
try:
if not snapshot.bids or not snapshot.asks:
return
# Calculate imbalance for top N levels
n_levels = min(5, len(snapshot.bids), len(snapshot.asks))
total_bid_size = sum(snapshot.bids[i].size for i in range(n_levels))
total_ask_size = sum(snapshot.asks[i].size for i in range(n_levels))
if total_bid_size + total_ask_size > 0:
imbalance = (total_bid_size - total_ask_size) / (total_bid_size + total_ask_size)
else:
imbalance = 0.0
self.order_book_imbalances[symbol].append({
'timestamp': snapshot.timestamp,
'imbalance': imbalance,
'bid_size': total_bid_size,
'ask_size': total_ask_size
})
except Exception as e:
logger.error(f"Error calculating imbalance for {symbol}: {e}")
def _update_order_heatmap(self, symbol: str, snapshot: OrderBookSnapshot):
"""Update order size heatmap matrix"""
try:
# Create heatmap entry
heatmap_entry = {
'timestamp': snapshot.timestamp,
'mid_price': snapshot.mid_price,
'levels': {}
}
# Add bid levels
for level in snapshot.bids:
price_offset = level.price - snapshot.mid_price
heatmap_entry['levels'][price_offset] = {
'side': 'bid',
'size': level.size,
'price': level.price
}
# Add ask levels
for level in snapshot.asks:
price_offset = level.price - snapshot.mid_price
heatmap_entry['levels'][price_offset] = {
'side': 'ask',
'size': level.size,
'price': level.price
}
self.order_heatmaps[symbol].append(heatmap_entry)
# Clean old entries (keep 10 minutes)
cutoff_time = snapshot.timestamp - self.heatmap_window
while (self.order_heatmaps[symbol] and
self.order_heatmaps[symbol][0]['timestamp'] < cutoff_time):
self.order_heatmaps[symbol].popleft()
except Exception as e:
logger.error(f"Error updating heatmap for {symbol}: {e}")
def _update_volume_profile(self, symbol: str, price: float, quantity: float, is_buyer_maker: bool):
"""Update volume profile with new trade"""
try:
# Initialize if not exists
if symbol not in self.volume_profiles:
self.volume_profiles[symbol] = []
# Find or create price level
price_level = None
for level in self.volume_profiles[symbol]:
if abs(level.price - price) < 0.01: # Price tolerance
price_level = level
break
if not price_level:
price_level = VolumeProfileLevel(
price=price,
volume=0.0,
buy_volume=0.0,
sell_volume=0.0,
trades_count=0,
vwap=price
)
self.volume_profiles[symbol].append(price_level)
# Update volume profile
volume = price * quantity
old_total = price_level.volume
price_level.volume += volume
price_level.trades_count += 1
if is_buyer_maker:
price_level.sell_volume += volume
else:
price_level.buy_volume += volume
# Update VWAP
if price_level.volume > 0:
price_level.vwap = ((price_level.vwap * old_total) + (price * volume)) / price_level.volume
except Exception as e:
logger.error(f"Error updating volume profile for {symbol}: {e}")
async def _analyze_order_flow(self, symbol: str, timestamp: datetime, price: float,
quantity: float, is_buyer_maker: bool):
"""Analyze order flow for sweep and absorption patterns"""
try:
# Get recent order book data
if symbol not in self.order_book_history or not self.order_book_history[symbol]:
return
recent_snapshots = list(self.order_book_history[symbol])[-10:] # Last 10 snapshots
# Check for order book sweeps
sweep_signal = self._detect_order_sweep(symbol, recent_snapshots, price, quantity, is_buyer_maker)
if sweep_signal:
self.flow_signals[symbol].append(sweep_signal)
await self._notify_flow_signal(symbol, sweep_signal)
# Check for absorption patterns
absorption_signal = self._detect_absorption(symbol, recent_snapshots, price, quantity)
if absorption_signal:
self.flow_signals[symbol].append(absorption_signal)
await self._notify_flow_signal(symbol, absorption_signal)
# Check for momentum trades
momentum_signal = self._detect_momentum_trade(symbol, price, quantity, is_buyer_maker)
if momentum_signal:
self.flow_signals[symbol].append(momentum_signal)
await self._notify_flow_signal(symbol, momentum_signal)
except Exception as e:
logger.error(f"Error analyzing order flow for {symbol}: {e}")
def _detect_order_sweep(self, symbol: str, snapshots: List[OrderBookSnapshot],
price: float, quantity: float, is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
"""Detect order book sweep patterns"""
try:
if len(snapshots) < 2:
return None
before_snapshot = snapshots[-2]
after_snapshot = snapshots[-1]
# Check if multiple levels were consumed
if is_buyer_maker: # Sell order, check ask side
levels_consumed = 0
total_consumed_size = 0
for level in before_snapshot.asks[:5]: # Check top 5 levels
if level.price <= price:
levels_consumed += 1
total_consumed_size += level.size
if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
confidence = min(0.9, levels_consumed / 5.0 + 0.3)
return OrderFlowSignal(
timestamp=datetime.now(),
signal_type='sweep',
price=price,
volume=quantity * price,
confidence=confidence,
description=f"Sell sweep: {levels_consumed} levels, {total_consumed_size:.2f} size"
)
else: # Buy order, check bid side
levels_consumed = 0
total_consumed_size = 0
for level in before_snapshot.bids[:5]:
if level.price >= price:
levels_consumed += 1
total_consumed_size += level.size
if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
confidence = min(0.9, levels_consumed / 5.0 + 0.3)
return OrderFlowSignal(
timestamp=datetime.now(),
signal_type='sweep',
price=price,
volume=quantity * price,
confidence=confidence,
description=f"Buy sweep: {levels_consumed} levels, {total_consumed_size:.2f} size"
)
return None
except Exception as e:
logger.error(f"Error detecting sweep for {symbol}: {e}")
return None
def _detect_absorption(self, symbol: str, snapshots: List[OrderBookSnapshot],
price: float, quantity: float) -> Optional[OrderFlowSignal]:
"""Detect absorption patterns where large orders are absorbed without price movement"""
try:
if len(snapshots) < 3:
return None
# Check if large order was absorbed with minimal price impact
volume_threshold = 10000 # $10K minimum for absorption
price_impact_threshold = 0.001 # 0.1% max price impact
trade_value = price * quantity
if trade_value < volume_threshold:
return None
# Calculate price impact
price_before = snapshots[-3].mid_price
price_after = snapshots[-1].mid_price
price_impact = abs(price_after - price_before) / price_before
if price_impact < price_impact_threshold:
confidence = min(0.8, (trade_value / 50000) * 0.5 + 0.3) # Scale with size
return OrderFlowSignal(
timestamp=datetime.now(),
signal_type='absorption',
price=price,
volume=trade_value,
confidence=confidence,
description=f"Absorption: ${trade_value:.0f} with {price_impact*100:.3f}% impact"
)
return None
except Exception as e:
logger.error(f"Error detecting absorption for {symbol}: {e}")
return None
def _detect_momentum_trade(self, symbol: str, price: float, quantity: float,
is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
"""Detect momentum trades based on size and direction"""
try:
trade_value = price * quantity
momentum_threshold = 25000 # $25K minimum for momentum classification
if trade_value < momentum_threshold:
return None
# Calculate confidence based on trade size
confidence = min(0.9, trade_value / 100000 * 0.6 + 0.3)
direction = "sell" if is_buyer_maker else "buy"
return OrderFlowSignal(
timestamp=datetime.now(),
signal_type='momentum',
price=price,
volume=trade_value,
confidence=confidence,
description=f"Large {direction}: ${trade_value:.0f}"
)
except Exception as e:
logger.error(f"Error detecting momentum for {symbol}: {e}")
return None
async def _notify_flow_signal(self, symbol: str, signal: OrderFlowSignal):
"""Notify CNN and DQN models of order flow signals"""
try:
signal_data = {
'signal_type': signal.signal_type,
'price': signal.price,
'volume': signal.volume,
'confidence': signal.confidence,
'timestamp': signal.timestamp,
'description': signal.description
}
# Notify CNN callbacks
for callback in self.cnn_callbacks:
try:
callback(symbol, signal_data)
except Exception as e:
logger.warning(f"Error in CNN callback: {e}")
# Notify DQN callbacks
for callback in self.dqn_callbacks:
try:
callback(symbol, signal_data)
except Exception as e:
logger.warning(f"Error in DQN callback: {e}")
except Exception as e:
logger.error(f"Error notifying flow signal: {e}")
async def _continuous_analysis(self):
"""Continuous analysis of market microstructure"""
while self.is_streaming:
try:
await asyncio.sleep(1) # Analyze every second
for symbol in self.symbols:
# Generate CNN features
cnn_features = self.get_cnn_features(symbol)
if cnn_features is not None:
for callback in self.cnn_callbacks:
try:
callback(symbol, {'features': cnn_features, 'type': 'orderbook'})
except Exception as e:
logger.warning(f"Error in CNN feature callback: {e}")
# Generate DQN state features
dqn_features = self.get_dqn_state_features(symbol)
if dqn_features is not None:
for callback in self.dqn_callbacks:
try:
callback(symbol, {'state': dqn_features, 'type': 'orderbook'})
except Exception as e:
logger.warning(f"Error in DQN state callback: {e}")
except Exception as e:
logger.error(f"Error in continuous analysis: {e}")
await asyncio.sleep(5)
def get_cnn_features(self, symbol: str) -> Optional[np.ndarray]:
"""Generate CNN input features from order book data"""
try:
if symbol not in self.order_books:
return None
snapshot = self.order_books[symbol]
features = []
# Order book features (40 features: 20 levels x 2 sides)
for i in range(min(20, len(snapshot.bids))):
bid = snapshot.bids[i]
features.append(bid.size)
features.append(bid.price - snapshot.mid_price) # Price offset
# Pad if not enough bid levels
while len(features) < 40:
features.extend([0.0, 0.0])
for i in range(min(20, len(snapshot.asks))):
ask = snapshot.asks[i]
features.append(ask.size)
features.append(ask.price - snapshot.mid_price) # Price offset
# Pad if not enough ask levels
while len(features) < 80:
features.extend([0.0, 0.0])
# Liquidity metrics (10 features)
metrics = self.liquidity_metrics.get(symbol, {})
features.extend([
metrics.get('total_bid_size', 0.0),
metrics.get('total_ask_size', 0.0),
metrics.get('liquidity_ratio', 1.0),
metrics.get('spread_bps', 0.0),
snapshot.spread,
metrics.get('weighted_mid', snapshot.mid_price) - snapshot.mid_price,
len(snapshot.bids),
len(snapshot.asks),
snapshot.mid_price,
time.time() % 86400 # Time of day
])
# Order book imbalance features (5 features)
if self.order_book_imbalances[symbol]:
latest_imbalance = self.order_book_imbalances[symbol][-1]
features.extend([
latest_imbalance['imbalance'],
latest_imbalance['bid_size'],
latest_imbalance['ask_size'],
latest_imbalance['bid_size'] + latest_imbalance['ask_size'],
abs(latest_imbalance['imbalance'])
])
else:
features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
# Flow signal features (5 features)
recent_signals = [s for s in self.flow_signals[symbol]
if (datetime.now() - s.timestamp).seconds < 60]
sweep_count = sum(1 for s in recent_signals if s.signal_type == 'sweep')
absorption_count = sum(1 for s in recent_signals if s.signal_type == 'absorption')
momentum_count = sum(1 for s in recent_signals if s.signal_type == 'momentum')
max_confidence = max([s.confidence for s in recent_signals], default=0.0)
total_flow_volume = sum(s.volume for s in recent_signals)
features.extend([
sweep_count,
absorption_count,
momentum_count,
max_confidence,
total_flow_volume
])
return np.array(features, dtype=np.float32)
except Exception as e:
logger.error(f"Error generating CNN features for {symbol}: {e}")
return None
def get_dqn_state_features(self, symbol: str) -> Optional[np.ndarray]:
"""Generate DQN state features from order book data"""
try:
if symbol not in self.order_books:
return None
snapshot = self.order_books[symbol]
state_features = []
# Normalized order book state (20 features)
total_bid_size = sum(level.size for level in snapshot.bids[:10])
total_ask_size = sum(level.size for level in snapshot.asks[:10])
total_size = total_bid_size + total_ask_size
if total_size > 0:
for i in range(min(10, len(snapshot.bids))):
state_features.append(snapshot.bids[i].size / total_size)
# Pad bids
while len(state_features) < 10:
state_features.append(0.0)
for i in range(min(10, len(snapshot.asks))):
state_features.append(snapshot.asks[i].size / total_size)
# Pad asks
while len(state_features) < 20:
state_features.append(0.0)
else:
state_features.extend([0.0] * 20)
# Market state indicators (10 features)
metrics = self.liquidity_metrics.get(symbol, {})
# Normalize spread as percentage
spread_pct = (snapshot.spread / snapshot.mid_price) if snapshot.mid_price > 0 else 0
# Liquidity imbalance
liquidity_ratio = metrics.get('liquidity_ratio', 1.0)
liquidity_imbalance = (liquidity_ratio - 1) / (liquidity_ratio + 1)
# Recent flow signals strength
recent_signals = [s for s in self.flow_signals[symbol]
if (datetime.now() - s.timestamp).seconds < 30]
flow_strength = sum(s.confidence for s in recent_signals) / max(len(recent_signals), 1)
# Price volatility (from recent snapshots)
if len(self.order_book_history[symbol]) >= 10:
recent_prices = [s.mid_price for s in list(self.order_book_history[symbol])[-10:]]
price_volatility = np.std(recent_prices) / np.mean(recent_prices) if recent_prices else 0
else:
price_volatility = 0
state_features.extend([
spread_pct * 10000, # Spread in basis points
liquidity_imbalance,
flow_strength,
price_volatility * 100, # Volatility as percentage
min(len(snapshot.bids), 20) / 20, # Book depth ratio
min(len(snapshot.asks), 20) / 20,
sweep_count / 10 if 'sweep_count' in locals() else 0, # From CNN features
absorption_count / 5 if 'absorption_count' in locals() else 0,
momentum_count / 5 if 'momentum_count' in locals() else 0,
(datetime.now().hour * 60 + datetime.now().minute) / 1440 # Time of day normalized
])
return np.array(state_features, dtype=np.float32)
except Exception as e:
logger.error(f"Error generating DQN features for {symbol}: {e}")
return None
def get_order_heatmap_matrix(self, symbol: str, levels: int = 40) -> Optional[np.ndarray]:
"""Generate order size heatmap matrix for dashboard visualization"""
try:
if symbol not in self.order_heatmaps or not self.order_heatmaps[symbol]:
return None
# Create price levels around current mid price
current_snapshot = self.order_books.get(symbol)
if not current_snapshot:
return None
mid_price = current_snapshot.mid_price
price_step = mid_price * 0.0001 # 1 basis point steps
# Create matrix: time x price levels
time_window = min(600, len(self.order_heatmaps[symbol])) # 10 minutes max
heatmap_matrix = np.zeros((time_window, levels))
# Fill matrix with order sizes
for t, entry in enumerate(list(self.order_heatmaps[symbol])[-time_window:]):
for price_offset, level_data in entry['levels'].items():
# Convert price offset to matrix index
level_idx = int((price_offset + (levels/2) * price_step) / price_step)
if 0 <= level_idx < levels:
size_weight = 1.0 if level_data['side'] == 'bid' else -1.0
heatmap_matrix[t, level_idx] = level_data['size'] * size_weight
return heatmap_matrix
except Exception as e:
logger.error(f"Error generating heatmap matrix for {symbol}: {e}")
return None
def get_volume_profile_data(self, symbol: str) -> Optional[List[Dict]]:
"""Get session volume profile data"""
try:
if symbol not in self.volume_profiles:
return None
profile_data = []
for level in sorted(self.volume_profiles[symbol], key=lambda x: x.price):
profile_data.append({
'price': level.price,
'volume': level.volume,
'buy_volume': level.buy_volume,
'sell_volume': level.sell_volume,
'trades_count': level.trades_count,
'vwap': level.vwap,
'net_volume': level.buy_volume - level.sell_volume
})
return profile_data
except Exception as e:
logger.error(f"Error getting volume profile for {symbol}: {e}")
return None
def get_current_order_book(self, symbol: str) -> Optional[Dict]:
"""Get current order book snapshot"""
try:
if symbol not in self.order_books:
return None
snapshot = self.order_books[symbol]
return {
'timestamp': snapshot.timestamp.isoformat(),
'symbol': symbol,
'mid_price': snapshot.mid_price,
'spread': snapshot.spread,
'bids': [{'price': l.price, 'size': l.size} for l in snapshot.bids[:20]],
'asks': [{'price': l.price, 'size': l.size} for l in snapshot.asks[:20]],
'liquidity_metrics': self.liquidity_metrics.get(symbol, {}),
'recent_signals': [
{
'type': s.signal_type,
'price': s.price,
'volume': s.volume,
'confidence': s.confidence,
'timestamp': s.timestamp.isoformat()
}
for s in list(self.flow_signals[symbol])[-5:] # Last 5 signals
]
}
except Exception as e:
logger.error(f"Error getting order book for {symbol}: {e}")
return None
def get_statistics(self) -> Dict[str, Any]:
"""Get provider statistics"""
return {
'symbols': self.symbols,
'is_streaming': self.is_streaming,
'update_counts': dict(self.update_counts),
'last_update_times': {k: v.isoformat() if isinstance(v, datetime) else v
for k, v in self.last_update_times.items()},
'order_books_active': len(self.order_books),
'flow_signals_total': sum(len(signals) for signals in self.flow_signals.values()),
'cnn_callbacks': len(self.cnn_callbacks),
'dqn_callbacks': len(self.dqn_callbacks),
'websocket_tasks': len(self.websocket_tasks)
}

File diff suppressed because it is too large Load Diff

View File

@ -1,785 +0,0 @@
"""
CNN Training Pipeline with Comprehensive Data Storage and Replay
This module implements a robust CNN training pipeline that:
1. Integrates with the comprehensive training data collection system
2. Stores all backpropagation data for gradient replay
3. Enables retraining on most profitable setups
4. Maintains training episode profitability tracking
5. Supports both real-time and batch training modes
Key Features:
- Integration with TrainingDataCollector for data validation
- Gradient and loss storage for each training step
- Profitable episode prioritization and replay
- Comprehensive training metrics and validation
- Real-time pivot point prediction with outcome tracking
"""
import asyncio
import logging
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass, field
import json
import pickle
from collections import deque, defaultdict
import threading
from concurrent.futures import ThreadPoolExecutor
from .training_data_collector import (
TrainingDataCollector,
TrainingEpisode,
ModelInputPackage,
get_training_data_collector
)
logger = logging.getLogger(__name__)
@dataclass
class CNNTrainingStep:
"""Single CNN training step with complete backpropagation data"""
step_id: str
timestamp: datetime
episode_id: str
# Input data
input_features: torch.Tensor
target_labels: torch.Tensor
# Forward pass results
model_outputs: Dict[str, torch.Tensor]
predictions: Dict[str, Any]
confidence_scores: torch.Tensor
# Loss components
total_loss: float
pivot_prediction_loss: float
confidence_loss: float
regularization_loss: float
# Backpropagation data
gradients: Dict[str, torch.Tensor] # Gradients for each parameter
gradient_norms: Dict[str, float] # Gradient norms for monitoring
# Model state
model_state_dict: Optional[Dict[str, torch.Tensor]] = None
optimizer_state: Optional[Dict[str, Any]] = None
# Training metadata
learning_rate: float = 0.001
batch_size: int = 32
epoch: int = 0
# Profitability tracking
actual_profitability: Optional[float] = None
prediction_accuracy: Optional[float] = None
training_value: float = 0.0 # Value of this training step for replay
@dataclass
class CNNTrainingSession:
"""Complete CNN training session with multiple steps"""
session_id: str
start_timestamp: datetime
end_timestamp: Optional[datetime] = None
# Session configuration
training_mode: str = 'real_time' # 'real_time', 'batch', 'replay'
symbol: str = ''
# Training steps
training_steps: List[CNNTrainingStep] = field(default_factory=list)
# Session metrics
total_steps: int = 0
average_loss: float = 0.0
best_loss: float = float('inf')
convergence_achieved: bool = False
# Profitability metrics
profitable_predictions: int = 0
total_predictions: int = 0
profitability_rate: float = 0.0
# Session value for replay prioritization
session_value: float = 0.0
class CNNPivotPredictor(nn.Module):
"""CNN model for pivot point prediction with comprehensive output"""
def __init__(self,
input_channels: int = 10, # Multiple timeframes
sequence_length: int = 300, # 300 bars
hidden_dim: int = 256,
num_pivot_classes: int = 3, # high, low, none
dropout_rate: float = 0.2):
super(CNNPivotPredictor, self).__init__()
self.input_channels = input_channels
self.sequence_length = sequence_length
self.hidden_dim = hidden_dim
# Convolutional layers for pattern extraction
self.conv_layers = nn.Sequential(
# First conv block
nn.Conv1d(input_channels, 64, kernel_size=7, padding=3),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(dropout_rate),
# Second conv block
nn.Conv1d(64, 128, kernel_size=5, padding=2),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(dropout_rate),
# Third conv block
nn.Conv1d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(dropout_rate),
)
# LSTM for temporal dependencies
self.lstm = nn.LSTM(
input_size=256,
hidden_size=hidden_dim,
num_layers=2,
batch_first=True,
dropout=dropout_rate,
bidirectional=True
)
# Attention mechanism
self.attention = nn.MultiheadAttention(
embed_dim=hidden_dim * 2, # Bidirectional LSTM
num_heads=8,
dropout=dropout_rate,
batch_first=True
)
# Output heads
self.pivot_classifier = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, num_pivot_classes)
)
self.pivot_price_regressor = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, 1)
)
self.confidence_head = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights with proper scaling"""
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv1d):
torch.nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
def forward(self, x):
"""
Forward pass through CNN pivot predictor
Args:
x: Input tensor [batch_size, input_channels, sequence_length]
Returns:
Dict containing predictions and hidden states
"""
batch_size = x.size(0)
# Convolutional feature extraction
conv_features = self.conv_layers(x) # [batch, 256, sequence_length]
# Prepare for LSTM (transpose to [batch, sequence, features])
lstm_input = conv_features.transpose(1, 2) # [batch, sequence_length, 256]
# LSTM processing
lstm_output, (hidden, cell) = self.lstm(lstm_input) # [batch, sequence_length, hidden_dim*2]
# Attention mechanism
attended_output, attention_weights = self.attention(
lstm_output, lstm_output, lstm_output
)
# Use the last timestep for predictions
final_features = attended_output[:, -1, :] # [batch, hidden_dim*2]
# Generate predictions
pivot_logits = self.pivot_classifier(final_features)
pivot_price = self.pivot_price_regressor(final_features)
confidence = self.confidence_head(final_features)
return {
'pivot_logits': pivot_logits,
'pivot_price': pivot_price,
'confidence': confidence,
'hidden_states': final_features,
'attention_weights': attention_weights,
'conv_features': conv_features,
'lstm_output': lstm_output
}
class CNNTrainingDataset(Dataset):
"""Dataset for CNN training with training episodes"""
def __init__(self, training_episodes: List[TrainingEpisode]):
self.episodes = training_episodes
self.valid_episodes = self._validate_episodes()
def _validate_episodes(self) -> List[TrainingEpisode]:
"""Validate and filter episodes for training"""
valid = []
for episode in self.episodes:
try:
# Check if episode has required data
if (episode.input_package.cnn_features is not None and
episode.actual_outcome.outcome_validated):
valid.append(episode)
except Exception as e:
logger.warning(f"Invalid episode {episode.episode_id}: {e}")
logger.info(f"Validated {len(valid)}/{len(self.episodes)} episodes for training")
return valid
def __len__(self):
return len(self.valid_episodes)
def __getitem__(self, idx):
episode = self.valid_episodes[idx]
# Extract features
features = torch.from_numpy(episode.input_package.cnn_features).float()
# Create labels from actual outcomes
pivot_class = self._determine_pivot_class(episode.actual_outcome)
pivot_price = episode.actual_outcome.optimal_exit_price
confidence_target = episode.actual_outcome.profitability_score
return {
'features': features,
'pivot_class': torch.tensor(pivot_class, dtype=torch.long),
'pivot_price': torch.tensor(pivot_price, dtype=torch.float),
'confidence_target': torch.tensor(confidence_target, dtype=torch.float),
'episode_id': episode.episode_id,
'profitability': episode.actual_outcome.profitability_score
}
def _determine_pivot_class(self, outcome) -> int:
"""Determine pivot class from outcome"""
if outcome.price_change_15m > 0.5: # Significant upward movement
return 0 # High pivot
elif outcome.price_change_15m < -0.5: # Significant downward movement
return 1 # Low pivot
else:
return 2 # No significant pivot
class CNNTrainer:
"""CNN trainer with comprehensive data storage and replay capabilities"""
def __init__(self,
model: CNNPivotPredictor,
device: str = 'cuda',
learning_rate: float = 0.001,
storage_dir: str = "cnn_training_storage"):
self.model = model.to(device)
self.device = device
self.learning_rate = learning_rate
# Storage
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
# Optimizer
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=learning_rate,
weight_decay=1e-5
)
# Learning rate scheduler
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='min', patience=10, factor=0.5
)
# Training data collector
self.data_collector = get_training_data_collector()
# Training sessions storage
self.training_sessions: List[CNNTrainingSession] = []
self.current_session: Optional[CNNTrainingSession] = None
# Training statistics
self.training_stats = {
'total_sessions': 0,
'total_steps': 0,
'best_validation_loss': float('inf'),
'profitable_predictions': 0,
'total_predictions': 0,
'replay_sessions': 0
}
# Background training
self.is_training = False
self.training_thread = None
logger.info(f"CNN Trainer initialized")
logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
logger.info(f"Storage directory: {self.storage_dir}")
def start_real_time_training(self, symbol: str):
"""Start real-time training for a symbol"""
if self.is_training:
logger.warning("CNN training already running")
return
self.is_training = True
self.training_thread = threading.Thread(
target=self._real_time_training_worker,
args=(symbol,),
daemon=True
)
self.training_thread.start()
logger.info(f"Started real-time CNN training for {symbol}")
def stop_training(self):
"""Stop training"""
self.is_training = False
if self.training_thread:
self.training_thread.join(timeout=10)
if self.current_session:
self._finalize_training_session()
logger.info("CNN training stopped")
def _real_time_training_worker(self, symbol: str):
"""Real-time training worker"""
logger.info(f"Real-time CNN training worker started for {symbol}")
while self.is_training:
try:
# Get high-priority episodes for training
episodes = self.data_collector.get_high_priority_episodes(
symbol=symbol,
limit=100,
min_priority=0.3
)
if len(episodes) >= 32: # Minimum batch size
self._train_on_episodes(episodes, training_mode='real_time')
# Wait before next training cycle
threading.Event().wait(300) # Train every 5 minutes
except Exception as e:
logger.error(f"Error in real-time training worker: {e}")
threading.Event().wait(60) # Wait before retrying
logger.info(f"Real-time CNN training worker stopped for {symbol}")
def train_on_profitable_episodes(self,
symbol: str,
min_profitability: float = 0.7,
max_episodes: int = 500) -> Dict[str, Any]:
"""Train specifically on most profitable episodes"""
try:
# Get all episodes for symbol
all_episodes = self.data_collector.training_episodes.get(symbol, [])
# Filter for profitable episodes
profitable_episodes = [
ep for ep in all_episodes
if (ep.actual_outcome.is_profitable and
ep.actual_outcome.profitability_score >= min_profitability)
]
# Sort by profitability and limit
profitable_episodes.sort(
key=lambda x: x.actual_outcome.profitability_score,
reverse=True
)
profitable_episodes = profitable_episodes[:max_episodes]
if len(profitable_episodes) < 10:
logger.warning(f"Insufficient profitable episodes for {symbol}: {len(profitable_episodes)}")
return {'status': 'insufficient_data', 'episodes_found': len(profitable_episodes)}
# Train on profitable episodes
results = self._train_on_episodes(
profitable_episodes,
training_mode='profitable_replay'
)
logger.info(f"Trained on {len(profitable_episodes)} profitable episodes for {symbol}")
return results
except Exception as e:
logger.error(f"Error training on profitable episodes: {e}")
return {'status': 'error', 'error': str(e)}
def _train_on_episodes(self,
episodes: List[TrainingEpisode],
training_mode: str = 'batch') -> Dict[str, Any]:
"""Train on a batch of episodes with comprehensive data storage"""
try:
# Start new training session
session = CNNTrainingSession(
session_id=f"{training_mode}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
start_timestamp=datetime.now(),
training_mode=training_mode,
symbol=episodes[0].input_package.symbol if episodes else 'unknown'
)
self.current_session = session
# Create dataset and dataloader
dataset = CNNTrainingDataset(episodes)
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=2
)
# Training loop
self.model.train()
total_loss = 0.0
num_batches = 0
for batch_idx, batch in enumerate(dataloader):
# Move to device
features = batch['features'].to(self.device)
pivot_class = batch['pivot_class'].to(self.device)
pivot_price = batch['pivot_price'].to(self.device)
confidence_target = batch['confidence_target'].to(self.device)
# Forward pass
self.optimizer.zero_grad()
outputs = self.model(features)
# Calculate losses
classification_loss = F.cross_entropy(outputs['pivot_logits'], pivot_class)
regression_loss = F.mse_loss(outputs['pivot_price'].squeeze(), pivot_price)
confidence_loss = F.binary_cross_entropy(
outputs['confidence'].squeeze(),
confidence_target
)
# Combined loss
total_batch_loss = classification_loss + 0.5 * regression_loss + 0.3 * confidence_loss
# Backward pass
total_batch_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Store gradients before optimizer step
gradients = {}
gradient_norms = {}
for name, param in self.model.named_parameters():
if param.grad is not None:
gradients[name] = param.grad.clone().detach()
gradient_norms[name] = param.grad.norm().item()
# Optimizer step
self.optimizer.step()
# Create training step record
step = CNNTrainingStep(
step_id=f"{session.session_id}_step_{batch_idx}",
timestamp=datetime.now(),
episode_id=f"batch_{batch_idx}",
input_features=features.detach().cpu(),
target_labels=pivot_class.detach().cpu(),
model_outputs={k: v.detach().cpu() for k, v in outputs.items()},
predictions=self._extract_predictions(outputs),
confidence_scores=outputs['confidence'].detach().cpu(),
total_loss=total_batch_loss.item(),
pivot_prediction_loss=classification_loss.item(),
confidence_loss=confidence_loss.item(),
regularization_loss=0.0,
gradients=gradients,
gradient_norms=gradient_norms,
learning_rate=self.optimizer.param_groups[0]['lr'],
batch_size=features.size(0)
)
# Calculate training value for this step
step.training_value = self._calculate_step_training_value(step, batch)
# Add to session
session.training_steps.append(step)
total_loss += total_batch_loss.item()
num_batches += 1
# Log progress
if batch_idx % 10 == 0:
logger.debug(f"Batch {batch_idx}: Loss = {total_batch_loss.item():.4f}")
# Finalize session
session.end_timestamp = datetime.now()
session.total_steps = num_batches
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
session.best_loss = min(step.total_loss for step in session.training_steps)
# Calculate session value
session.session_value = self._calculate_session_value(session)
# Update scheduler
self.scheduler.step(session.average_loss)
# Save session
self._save_training_session(session)
# Update statistics
self.training_stats['total_sessions'] += 1
self.training_stats['total_steps'] += session.total_steps
if training_mode == 'profitable_replay':
self.training_stats['replay_sessions'] += 1
logger.info(f"Training session completed: {session.session_id}")
logger.info(f"Average loss: {session.average_loss:.4f}")
logger.info(f"Session value: {session.session_value:.3f}")
return {
'status': 'success',
'session_id': session.session_id,
'average_loss': session.average_loss,
'total_steps': session.total_steps,
'session_value': session.session_value
}
except Exception as e:
logger.error(f"Error in training session: {e}")
return {'status': 'error', 'error': str(e)}
finally:
self.current_session = None
def _extract_predictions(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
"""Extract human-readable predictions from model outputs"""
try:
pivot_probs = F.softmax(outputs['pivot_logits'], dim=1)
predicted_class = torch.argmax(pivot_probs, dim=1)
return {
'pivot_class': predicted_class.cpu().numpy().tolist(),
'pivot_probabilities': pivot_probs.cpu().numpy().tolist(),
'pivot_price': outputs['pivot_price'].cpu().numpy().tolist(),
'confidence': outputs['confidence'].cpu().numpy().tolist()
}
except Exception as e:
logger.warning(f"Error extracting predictions: {e}")
return {}
def _calculate_step_training_value(self,
step: CNNTrainingStep,
batch: Dict[str, Any]) -> float:
"""Calculate the training value of a step for replay prioritization"""
try:
value = 0.0
# Base value from loss (lower loss = higher value)
if step.total_loss > 0:
value += 1.0 / (1.0 + step.total_loss)
# Bonus for high profitability episodes in batch
avg_profitability = torch.mean(batch['profitability']).item()
value += avg_profitability * 0.3
# Bonus for gradient magnitude (indicates learning)
avg_grad_norm = np.mean(list(step.gradient_norms.values()))
value += min(avg_grad_norm / 10.0, 0.2) # Cap at 0.2
return min(value, 1.0)
except Exception as e:
logger.warning(f"Error calculating step training value: {e}")
return 0.0
def _calculate_session_value(self, session: CNNTrainingSession) -> float:
"""Calculate overall session value for replay prioritization"""
try:
if not session.training_steps:
return 0.0
# Average step values
avg_step_value = np.mean([step.training_value for step in session.training_steps])
# Bonus for convergence
convergence_bonus = 0.0
if len(session.training_steps) > 10:
early_loss = np.mean([s.total_loss for s in session.training_steps[:5]])
late_loss = np.mean([s.total_loss for s in session.training_steps[-5:]])
if early_loss > late_loss:
convergence_bonus = min((early_loss - late_loss) / early_loss, 0.3)
# Bonus for profitable replay sessions
mode_bonus = 0.2 if session.training_mode == 'profitable_replay' else 0.0
return min(avg_step_value + convergence_bonus + mode_bonus, 1.0)
except Exception as e:
logger.warning(f"Error calculating session value: {e}")
return 0.0
def _save_training_session(self, session: CNNTrainingSession):
"""Save training session to disk"""
try:
session_dir = self.storage_dir / session.symbol / 'sessions'
session_dir.mkdir(parents=True, exist_ok=True)
# Save full session data
session_file = session_dir / f"{session.session_id}.pkl"
with open(session_file, 'wb') as f:
pickle.dump(session, f)
# Save session metadata
metadata = {
'session_id': session.session_id,
'start_timestamp': session.start_timestamp.isoformat(),
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
'training_mode': session.training_mode,
'symbol': session.symbol,
'total_steps': session.total_steps,
'average_loss': session.average_loss,
'best_loss': session.best_loss,
'session_value': session.session_value
}
metadata_file = session_dir / f"{session.session_id}_metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
logger.debug(f"Saved training session: {session.session_id}")
except Exception as e:
logger.error(f"Error saving training session: {e}")
def _finalize_training_session(self):
"""Finalize current training session"""
if self.current_session:
self.current_session.end_timestamp = datetime.now()
self._save_training_session(self.current_session)
self.training_sessions.append(self.current_session)
self.current_session = None
def get_training_statistics(self) -> Dict[str, Any]:
"""Get comprehensive training statistics"""
stats = self.training_stats.copy()
# Add recent session information
if self.training_sessions:
recent_sessions = sorted(
self.training_sessions,
key=lambda x: x.start_timestamp,
reverse=True
)[:10]
stats['recent_sessions'] = [
{
'session_id': s.session_id,
'timestamp': s.start_timestamp.isoformat(),
'mode': s.training_mode,
'average_loss': s.average_loss,
'session_value': s.session_value
}
for s in recent_sessions
]
# Calculate profitability rate
if stats['total_predictions'] > 0:
stats['profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
else:
stats['profitability_rate'] = 0.0
return stats
def replay_high_value_sessions(self,
symbol: str,
min_session_value: float = 0.7,
max_sessions: int = 10) -> Dict[str, Any]:
"""Replay high-value training sessions"""
try:
# Find high-value sessions
high_value_sessions = [
s for s in self.training_sessions
if (s.symbol == symbol and
s.session_value >= min_session_value)
]
# Sort by value and limit
high_value_sessions.sort(key=lambda x: x.session_value, reverse=True)
high_value_sessions = high_value_sessions[:max_sessions]
if not high_value_sessions:
return {'status': 'no_high_value_sessions', 'sessions_found': 0}
# Replay sessions
total_replayed = 0
for session in high_value_sessions:
# Extract episodes from session steps
episode_ids = list(set(step.episode_id for step in session.training_steps))
# Get corresponding episodes
episodes = []
for episode_id in episode_ids:
# Find episode in data collector
for ep in self.data_collector.training_episodes.get(symbol, []):
if ep.episode_id == episode_id:
episodes.append(ep)
break
if episodes:
self._train_on_episodes(episodes, training_mode='high_value_replay')
total_replayed += 1
logger.info(f"Replayed {total_replayed} high-value sessions for {symbol}")
return {
'status': 'success',
'sessions_replayed': total_replayed,
'sessions_found': len(high_value_sessions)
}
except Exception as e:
logger.error(f"Error replaying high-value sessions: {e}")
return {'status': 'error', 'error': str(e)}
# Global instance
cnn_trainer = None
def get_cnn_trainer(model: CNNPivotPredictor = None) -> CNNTrainer:
"""Get global CNN trainer instance"""
global cnn_trainer
if cnn_trainer is None:
if model is None:
model = CNNPivotPredictor()
cnn_trainer = CNNTrainer(model)
return cnn_trainer

View File

@ -1,864 +0,0 @@
# """
# Enhanced CNN Adapter for Standardized Input Format
# This module provides an adapter for the EnhancedCNN model to work with the standardized
# BaseDataInput format, enabling seamless integration with the multi-modal trading system.
# """
# import torch
# import numpy as np
# import logging
# import os
# import random
# from datetime import datetime, timedelta
# from typing import Dict, List, Optional, Tuple, Any, Union
# from threading import Lock
# from .data_models import BaseDataInput, ModelOutput, create_model_output
# from NN.models.enhanced_cnn import EnhancedCNN
# from utils.inference_logger import log_model_inference
# logger = logging.getLogger(__name__)
# class EnhancedCNNAdapter:
# """
# Adapter for EnhancedCNN model to work with standardized BaseDataInput format
# This adapter:
# 1. Converts BaseDataInput to the format expected by EnhancedCNN
# 2. Processes model outputs to create standardized ModelOutput
# 3. Manages model training with collected data
# 4. Handles checkpoint management
# """
# def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"):
# """
# Initialize the EnhancedCNN adapter
# Args:
# model_path: Path to load model from, if None a new model is created
# checkpoint_dir: Directory to save checkpoints to
# """
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# self.model = None
# self.model_path = model_path
# self.checkpoint_dir = checkpoint_dir
# self.training_lock = Lock()
# self.training_data = []
# self.max_training_samples = 10000
# self.batch_size = 32
# self.learning_rate = 0.0001
# self.model_name = "enhanced_cnn"
# # Enhanced metrics tracking
# self.last_inference_time = None
# self.last_inference_duration = 0.0
# self.last_prediction_output = None
# self.last_training_time = None
# self.last_training_duration = 0.0
# self.last_training_loss = 0.0
# self.inference_count = 0
# self.training_count = 0
# # Create checkpoint directory if it doesn't exist
# os.makedirs(checkpoint_dir, exist_ok=True)
# # Initialize the model
# self._initialize_model()
# # Load checkpoint if available
# if model_path and os.path.exists(model_path):
# self._load_checkpoint(model_path)
# else:
# self._load_best_checkpoint()
# # Final device check and move
# self._ensure_model_on_device()
# logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
# def _create_realistic_synthetic_features(self, symbol: str) -> torch.Tensor:
# """Create realistic synthetic features instead of random data"""
# try:
# # Create realistic market-like features
# features = torch.zeros(7850, dtype=torch.float32, device=self.device)
# # OHLCV features (6000 features: 300 frames x 4 timeframes x 5 features)
# ohlcv_start = 0
# for timeframe_idx in range(4): # 1s, 1m, 1h, 1d
# base_price = 3500.0 + timeframe_idx * 10 # Slight variation per timeframe
# for frame_idx in range(300):
# # Create realistic price movement
# price_change = torch.sin(torch.tensor(frame_idx * 0.1)) * 0.01 # Cyclical movement
# current_price = base_price * (1 + price_change)
# # Realistic OHLCV values
# open_price = current_price
# high_price = current_price * torch.uniform(1.0, 1.005)
# low_price = current_price * torch.uniform(0.995, 1.0)
# close_price = current_price * torch.uniform(0.998, 1.002)
# volume = torch.uniform(500.0, 2000.0)
# # Set features
# feature_idx = ohlcv_start + frame_idx * 5 + timeframe_idx * 1500
# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
# # BTC OHLCV features (1500 features: 300 frames x 5 features)
# btc_start = 6000
# btc_base_price = 50000.0
# for frame_idx in range(300):
# price_change = torch.sin(torch.tensor(frame_idx * 0.05)) * 0.02
# current_price = btc_base_price * (1 + price_change)
# open_price = current_price
# high_price = current_price * torch.uniform(1.0, 1.01)
# low_price = current_price * torch.uniform(0.99, 1.0)
# close_price = current_price * torch.uniform(0.995, 1.005)
# volume = torch.uniform(100.0, 500.0)
# feature_idx = btc_start + frame_idx * 5
# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
# # COB features (200 features) - realistic order book data
# cob_start = 7500
# for i in range(200):
# features[cob_start + i] = torch.uniform(0.0, 1000.0) # Realistic COB values
# # Technical indicators (100 features)
# indicator_start = 7700
# for i in range(100):
# features[indicator_start + i] = torch.uniform(-1.0, 1.0) # Normalized indicators
# # Last predictions (50 features)
# prediction_start = 7800
# for i in range(50):
# features[prediction_start + i] = torch.uniform(0.0, 1.0) # Probability values
# return features
# except Exception as e:
# logger.error(f"Error creating realistic synthetic features: {e}")
# # Fallback to small random variation
# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
# return base_features + noise
# def _create_realistic_features(self, symbol: str) -> torch.Tensor:
# """Create features from real market data if available"""
# try:
# # This would need to be implemented to use actual market data
# # For now, fall back to synthetic features
# return self._create_realistic_synthetic_features(symbol)
# except Exception as e:
# logger.error(f"Error creating realistic features: {e}")
# return self._create_realistic_synthetic_features(symbol)
# def _initialize_model(self):
# """Initialize the EnhancedCNN model"""
# try:
# # Calculate input shape based on BaseDataInput structure
# # OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
# # BTC OHLCV: 300 frames x 5 features = 1500 features
# # COB: ±20 buckets x 4 metrics = 160 features
# # MA: 4 timeframes x 10 buckets = 40 features
# # Technical indicators: 100 features
# # Last predictions: 50 features
# # Total: 7850 features
# input_shape = 7850
# n_actions = 3 # BUY, SELL, HOLD
# # Create model
# self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
# # Ensure model is moved to the correct device
# self.model.to(self.device)
# logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
# except Exception as e:
# logger.error(f"Error initializing EnhancedCNN model: {e}")
# raise
# def _load_checkpoint(self, checkpoint_path: str) -> bool:
# """Load model from checkpoint path"""
# try:
# if self.model and os.path.exists(checkpoint_path):
# success = self.model.load(checkpoint_path)
# if success:
# # Ensure model is moved to the correct device after loading
# self.model.to(self.device)
# logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
# return True
# else:
# logger.warning(f"Failed to load model from {checkpoint_path}")
# return False
# else:
# logger.warning(f"Checkpoint path does not exist: {checkpoint_path}")
# return False
# except Exception as e:
# logger.error(f"Error loading checkpoint: {e}")
# return False
# def _load_best_checkpoint(self) -> bool:
# """Load the best available checkpoint"""
# try:
# return self.load_best_checkpoint()
# except Exception as e:
# logger.error(f"Error loading best checkpoint: {e}")
# return False
# def load_best_checkpoint(self) -> bool:
# """Load the best checkpoint based on accuracy"""
# try:
# # Import checkpoint manager
# from utils.checkpoint_manager import CheckpointManager
# # Create checkpoint manager
# checkpoint_manager = CheckpointManager(
# checkpoint_dir=self.checkpoint_dir,
# max_checkpoints=10,
# metric_name="accuracy"
# )
# # Load best checkpoint
# best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
# if not best_checkpoint_path:
# logger.info(f"No checkpoints found for {self.model_name} - starting in COLD START mode")
# return False
# # Load model
# success = self.model.load(best_checkpoint_path)
# if success:
# # Ensure model is moved to the correct device after loading
# self.model.to(self.device)
# logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
# # Log metrics
# metrics = best_checkpoint_metadata.get('metrics', {})
# logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
# return True
# else:
# logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
# return False
# except Exception as e:
# logger.error(f"Error loading best checkpoint: {e}")
# return False
# def _ensure_model_on_device(self):
# """Ensure model and all its components are on the correct device"""
# try:
# if self.model:
# self.model.to(self.device)
# # Also ensure the model's internal device is set correctly
# if hasattr(self.model, 'device'):
# self.model.device = self.device
# logger.debug(f"Model ensured on device {self.device}")
# except Exception as e:
# logger.error(f"Error ensuring model on device: {e}")
# def _create_default_output(self, symbol: str) -> ModelOutput:
# """Create default output when prediction fails"""
# return create_model_output(
# model_type='cnn',
# model_name=self.model_name,
# symbol=symbol,
# action='HOLD',
# confidence=0.0,
# metadata={'error': 'Prediction failed, using default output'}
# )
# def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]:
# """Process hidden states for cross-model feeding"""
# processed_states = {}
# for key, value in hidden_states.items():
# if isinstance(value, torch.Tensor):
# # Convert tensor to numpy array
# processed_states[key] = value.cpu().numpy().tolist()
# else:
# processed_states[key] = value
# return processed_states
# def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
# """
# Convert BaseDataInput to feature vector for EnhancedCNN
# Args:
# base_data: Standardized input data
# Returns:
# torch.Tensor: Feature vector for EnhancedCNN
# """
# try:
# # Use the get_feature_vector method from BaseDataInput
# features = base_data.get_feature_vector()
# # Validate feature quality before using
# self._validate_feature_quality(features)
# # Convert to torch tensor
# features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device)
# return features_tensor
# except Exception as e:
# logger.error(f"Error converting BaseDataInput to features: {e}")
# # Return empty tensor with correct shape
# return torch.zeros(7850, dtype=torch.float32, device=self.device)
# def _validate_feature_quality(self, features: np.ndarray):
# """Validate that features are realistic and not synthetic/placeholder data"""
# try:
# if len(features) != 7850:
# logger.warning(f"Feature vector has wrong size: {len(features)} != 7850")
# return
# # Check for all-zero or all-identical features (indicates placeholder data)
# if np.all(features == 0):
# logger.warning("Feature vector contains all zeros - likely placeholder data")
# return
# # Check for repetitive patterns in OHLCV data (first 6000 features)
# ohlcv_features = features[:6000]
# if len(ohlcv_features) >= 20:
# # Check if first 20 values are identical (indicates padding with same bar)
# if np.allclose(ohlcv_features[:20], ohlcv_features[0], atol=1e-6):
# logger.warning("OHLCV features show repetitive pattern - possible synthetic data")
# # Check for unrealistic values
# if np.any(features > 1e6) or np.any(features < -1e6):
# logger.warning("Feature vector contains unrealistic values")
# # Check for NaN or infinite values
# if np.any(np.isnan(features)) or np.any(np.isinf(features)):
# logger.warning("Feature vector contains NaN or infinite values")
# except Exception as e:
# logger.error(f"Error validating feature quality: {e}")
# def predict(self, base_data: BaseDataInput) -> ModelOutput:
# """
# Make a prediction using the EnhancedCNN model
# Args:
# base_data: Standardized input data
# Returns:
# ModelOutput: Standardized model output
# """
# try:
# # Track inference timing
# start_time = datetime.now()
# inference_start = start_time.timestamp()
# # Convert BaseDataInput to features
# features = self._convert_base_data_to_features(base_data)
# # Ensure features has batch dimension
# if features.dim() == 1:
# features = features.unsqueeze(0)
# # Ensure model is on correct device before prediction
# self._ensure_model_on_device()
# # Set model to evaluation mode
# self.model.eval()
# # Make prediction
# with torch.no_grad():
# q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features)
# # Get action and confidence
# action_probs = torch.softmax(q_values, dim=1)
# action_idx = torch.argmax(action_probs, dim=1).item()
# raw_confidence = float(action_probs[0, action_idx].item())
# # Validate confidence - prevent 100% confidence which indicates overfitting
# if raw_confidence >= 0.99:
# logger.warning(f"CNN produced suspiciously high confidence: {raw_confidence:.4f} - possible overfitting")
# # Cap confidence at 0.95 to prevent unrealistic predictions
# confidence = min(raw_confidence, 0.95)
# logger.info(f"Capped confidence from {raw_confidence:.4f} to {confidence:.4f}")
# else:
# confidence = raw_confidence
# # Map action index to action string
# actions = ['BUY', 'SELL', 'HOLD']
# action = actions[action_idx]
# # Extract pivot price prediction (simplified - take first value from price_pred)
# pivot_price = None
# if price_pred is not None and len(price_pred.squeeze()) > 0:
# # Get current price from base_data for context
# current_price = 0.0
# if base_data.ohlcv_1s and len(base_data.ohlcv_1s) > 0:
# current_price = base_data.ohlcv_1s[-1].close
# # Calculate pivot price as current price + predicted change
# price_change_pct = float(price_pred.squeeze()[0].item()) # First prediction value
# pivot_price = current_price * (1 + price_change_pct * 0.01) # Convert percentage to price
# # Create predictions dictionary
# predictions = {
# 'action': action,
# 'buy_probability': float(action_probs[0, 0].item()),
# 'sell_probability': float(action_probs[0, 1].item()),
# 'hold_probability': float(action_probs[0, 2].item()),
# 'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
# 'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist(),
# 'pivot_price': pivot_price
# }
# # Create hidden states dictionary
# hidden_states = {
# 'features': features_refined.squeeze(0).cpu().numpy().tolist()
# }
# # Calculate inference duration
# end_time = datetime.now()
# inference_duration = (end_time.timestamp() - inference_start) * 1000 # Convert to milliseconds
# # Update metrics
# self.last_inference_time = start_time
# self.last_inference_duration = inference_duration
# self.inference_count += 1
# # Store last prediction output for dashboard
# self.last_prediction_output = {
# 'action': action,
# 'confidence': confidence,
# 'pivot_price': pivot_price,
# 'timestamp': start_time,
# 'symbol': base_data.symbol
# }
# # Create metadata dictionary
# metadata = {
# 'model_version': '1.0',
# 'timestamp': start_time.isoformat(),
# 'input_shape': features.shape,
# 'inference_duration_ms': inference_duration,
# 'inference_count': self.inference_count
# }
# # Create ModelOutput
# model_output = ModelOutput(
# model_type='cnn',
# model_name=self.model_name,
# symbol=base_data.symbol,
# timestamp=start_time,
# confidence=confidence,
# predictions=predictions,
# hidden_states=hidden_states,
# metadata=metadata
# )
# # Log inference with full input data for training feedback
# log_model_inference(
# model_name=self.model_name,
# symbol=base_data.symbol,
# action=action,
# confidence=confidence,
# probabilities={
# 'BUY': predictions['buy_probability'],
# 'SELL': predictions['sell_probability'],
# 'HOLD': predictions['hold_probability']
# },
# input_features=features.cpu().numpy(), # Store full feature vector
# processing_time_ms=inference_duration,
# checkpoint_id=None, # Could be enhanced to track checkpoint
# metadata={
# 'base_data_input': {
# 'symbol': base_data.symbol,
# 'timestamp': base_data.timestamp.isoformat(),
# 'ohlcv_1s_count': len(base_data.ohlcv_1s),
# 'ohlcv_1m_count': len(base_data.ohlcv_1m),
# 'ohlcv_1h_count': len(base_data.ohlcv_1h),
# 'ohlcv_1d_count': len(base_data.ohlcv_1d),
# 'btc_ohlcv_1s_count': len(base_data.btc_ohlcv_1s),
# 'has_cob_data': base_data.cob_data is not None,
# 'technical_indicators_count': len(base_data.technical_indicators),
# 'pivot_points_count': len(base_data.pivot_points),
# 'last_predictions_count': len(base_data.last_predictions)
# },
# 'model_predictions': {
# 'pivot_price': pivot_price,
# 'extrema_prediction': predictions['extrema'],
# 'price_prediction': predictions['price_prediction']
# }
# }
# )
# return model_output
# except Exception as e:
# logger.error(f"Error making prediction with EnhancedCNN: {e}")
# # Return default ModelOutput
# return create_model_output(
# model_type='cnn',
# model_name=self.model_name,
# symbol=base_data.symbol,
# action='HOLD',
# confidence=0.0
# )
# def add_training_sample(self, symbol_or_base_data, actual_action: str, reward: float):
# """
# Add a training sample to the training data
# Args:
# symbol_or_base_data: Either a symbol string or BaseDataInput object
# actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
# reward: Reward received for the action
# """
# try:
# # Handle both symbol string and BaseDataInput object
# if isinstance(symbol_or_base_data, str):
# # For cold start mode - create a simple training sample with current features
# # This is a simplified approach for rapid training
# symbol = symbol_or_base_data
# # Create a realistic feature vector instead of random data
# # Use actual market data if available, otherwise create realistic synthetic data
# try:
# # Try to get real market data first
# if hasattr(self, 'data_provider') and self.data_provider:
# # This would need to be implemented in the adapter
# features = self._create_realistic_features(symbol)
# else:
# # Create realistic synthetic features (not random)
# features = self._create_realistic_synthetic_features(symbol)
# except Exception as e:
# logger.warning(f"Could not create realistic features for {symbol}: {e}")
# # Fallback to small random variation instead of pure random
# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
# features = base_features + noise
# logger.debug(f"Added realistic training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
# else:
# # Full BaseDataInput object
# base_data = symbol_or_base_data
# features = self._convert_base_data_to_features(base_data)
# symbol = base_data.symbol
# logger.debug(f"Added full training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
# # Convert action to index
# actions = ['BUY', 'SELL', 'HOLD']
# action_idx = actions.index(actual_action)
# # Add to training data
# with self.training_lock:
# self.training_data.append((features, action_idx, reward))
# # Limit training data size
# if len(self.training_data) > self.max_training_samples:
# # Sort by reward (highest first) and keep top samples
# self.training_data.sort(key=lambda x: x[2], reverse=True)
# self.training_data = self.training_data[:self.max_training_samples]
# except Exception as e:
# logger.error(f"Error adding training sample: {e}")
# def train(self, epochs: int = 1) -> Dict[str, float]:
# """
# Train the model with collected data and inference history
# Args:
# epochs: Number of epochs to train for
# Returns:
# Dict[str, float]: Training metrics
# """
# try:
# # Track training timing
# training_start_time = datetime.now()
# training_start = training_start_time.timestamp()
# with self.training_lock:
# # Get additional training data from inference history
# self._load_training_data_from_inference_history()
# # Check if we have enough data
# if len(self.training_data) < self.batch_size:
# logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
# return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
# # Ensure model is on correct device before training
# self._ensure_model_on_device()
# # Set model to training mode
# self.model.train()
# # Create optimizer
# optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
# # Training metrics
# total_loss = 0.0
# correct_predictions = 0
# total_predictions = 0
# # Train for specified number of epochs
# for epoch in range(epochs):
# # Shuffle training data
# np.random.shuffle(self.training_data)
# # Process in batches
# for i in range(0, len(self.training_data), self.batch_size):
# batch = self.training_data[i:i+self.batch_size]
# # Skip if batch is too small
# if len(batch) < 2:
# continue
# # Prepare batch - ensure all tensors are on the correct device
# features = torch.stack([sample[0].to(self.device) for sample in batch])
# actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
# rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
# # Zero gradients
# optimizer.zero_grad()
# # Forward pass
# q_values, _, _, _, _ = self.model(features)
# # Calculate loss (CrossEntropyLoss with reward weighting)
# # First, apply softmax to get probabilities
# probs = torch.softmax(q_values, dim=1)
# # Get probability of chosen action
# chosen_probs = probs[torch.arange(len(actions)), actions]
# # Calculate negative log likelihood loss
# nll_loss = -torch.log(chosen_probs + 1e-10)
# # Weight by reward (higher reward = higher weight)
# # Normalize rewards to [0, 1] range
# min_reward = rewards.min()
# max_reward = rewards.max()
# if max_reward > min_reward:
# normalized_rewards = (rewards - min_reward) / (max_reward - min_reward)
# else:
# normalized_rewards = torch.ones_like(rewards)
# # Apply reward weighting (higher reward = higher weight)
# weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights
# # Mean loss
# loss = weighted_loss.mean()
# # Backward pass
# loss.backward()
# # Update weights
# optimizer.step()
# # Update metrics
# total_loss += loss.item()
# # Calculate accuracy
# predicted_actions = torch.argmax(q_values, dim=1)
# correct_predictions += (predicted_actions == actions).sum().item()
# total_predictions += len(actions)
# # Validate training - detect overfitting
# if total_predictions > 0:
# current_accuracy = correct_predictions / total_predictions
# if current_accuracy >= 0.99:
# logger.warning(f"CNN training shows suspiciously high accuracy: {current_accuracy:.4f} - possible overfitting")
# # Add regularization to prevent overfitting
# l2_reg = 0.01 * sum(p.pow(2.0).sum() for p in self.model.parameters())
# loss = loss + l2_reg
# logger.info("Added L2 regularization to prevent overfitting")
# # Calculate final metrics
# avg_loss = total_loss / (len(self.training_data) / self.batch_size)
# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
# # Calculate training duration
# training_end_time = datetime.now()
# training_duration = (training_end_time.timestamp() - training_start) * 1000 # Convert to milliseconds
# # Update training metrics
# self.last_training_time = training_start_time
# self.last_training_duration = training_duration
# self.last_training_loss = avg_loss
# self.training_count += 1
# # Save checkpoint
# self._save_checkpoint(avg_loss, accuracy)
# logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}, duration={training_duration:.1f}ms")
# return {
# 'loss': avg_loss,
# 'accuracy': accuracy,
# 'samples': len(self.training_data),
# 'duration_ms': training_duration,
# 'training_count': self.training_count
# }
# except Exception as e:
# logger.error(f"Error training model: {e}")
# return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)}
# def _save_checkpoint(self, loss: float, accuracy: float):
# """
# Save model checkpoint
# Args:
# loss: Training loss
# accuracy: Training accuracy
# """
# try:
# # Import checkpoint manager
# from utils.checkpoint_manager import CheckpointManager
# # Create checkpoint manager
# checkpoint_manager = CheckpointManager(
# checkpoint_dir=self.checkpoint_dir,
# max_checkpoints=10,
# metric_name="accuracy"
# )
# # Create temporary model file
# temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp")
# self.model.save(temp_path)
# # Create metrics
# metrics = {
# 'loss': loss,
# 'accuracy': accuracy,
# 'samples': len(self.training_data)
# }
# # Create metadata
# metadata = {
# 'timestamp': datetime.now().isoformat(),
# 'model_name': self.model_name,
# 'input_shape': self.model.input_shape,
# 'n_actions': self.model.n_actions
# }
# # Save checkpoint
# checkpoint_path = checkpoint_manager.save_checkpoint(
# model_name=self.model_name,
# model_path=f"{temp_path}.pt",
# metrics=metrics,
# metadata=metadata
# )
# # Delete temporary model file
# if os.path.exists(f"{temp_path}.pt"):
# os.remove(f"{temp_path}.pt")
# logger.info(f"Model checkpoint saved to {checkpoint_path}")
# except Exception as e:
# logger.error(f"Error saving checkpoint: {e}")
# def _load_training_data_from_inference_history(self):
# """Load training data from inference history for continuous learning"""
# try:
# from utils.database_manager import get_database_manager
# db_manager = get_database_manager()
# # Get recent inference records with input features
# inference_records = db_manager.get_inference_records_for_training(
# model_name=self.model_name,
# hours_back=24, # Last 24 hours
# limit=1000
# )
# if not inference_records:
# logger.debug("No inference records found for training")
# return
# # Convert inference records to training samples
# # For now, use a simple approach: treat high-confidence predictions as ground truth
# for record in inference_records:
# if record.input_features is not None and record.confidence > 0.7:
# # Convert action to index
# actions = ['BUY', 'SELL', 'HOLD']
# if record.action in actions:
# action_idx = actions.index(record.action)
# # Use confidence as a proxy for reward (high confidence = good prediction)
# reward = record.confidence * 2 - 1 # Scale to [-1, 1]
# # Convert features to tensor
# features_tensor = torch.tensor(record.input_features, dtype=torch.float32, device=self.device)
# # Add to training data if not already present (avoid duplicates)
# sample_exists = any(
# torch.equal(features_tensor, existing[0])
# for existing in self.training_data
# )
# if not sample_exists:
# self.training_data.append((features_tensor, action_idx, reward))
# logger.info(f"Loaded {len(inference_records)} inference records for training, total training samples: {len(self.training_data)}")
# except Exception as e:
# logger.error(f"Error loading training data from inference history: {e}")
# def evaluate_predictions_against_outcomes(self, hours_back: int = 1) -> Dict[str, float]:
# """
# Evaluate past predictions against actual market outcomes
# Args:
# hours_back: How many hours back to evaluate
# Returns:
# Dict with evaluation metrics
# """
# try:
# from utils.database_manager import get_database_manager
# db_manager = get_database_manager()
# # Get inference records from the specified time period
# inference_records = db_manager.get_inference_records_for_training(
# model_name=self.model_name,
# hours_back=hours_back,
# limit=100
# )
# if not inference_records:
# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
# # For now, use a simple evaluation based on confidence
# # In a real implementation, this would compare against actual price movements
# correct_predictions = 0
# total_predictions = len(inference_records)
# # Simple heuristic: high confidence predictions are more likely to be correct
# for record in inference_records:
# if record.confidence > 0.8: # High confidence threshold
# correct_predictions += 1
# elif record.confidence > 0.6: # Medium confidence
# correct_predictions += 0.5
# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
# logger.info(f"Prediction evaluation: {correct_predictions:.1f}/{total_predictions} = {accuracy:.3f} accuracy")
# return {
# 'accuracy': accuracy,
# 'total_predictions': total_predictions,
# 'correct_predictions': correct_predictions
# }
# except Exception as e:
# logger.error(f"Error evaluating predictions: {e}")
# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}

View File

@ -1,464 +0,0 @@
"""
Enhanced Trading Orchestrator
Central coordination hub for the multi-modal trading system that manages:
- Data subscription and management
- Model inference coordination
- Cross-model data feeding
- Training pipeline orchestration
- Decision making using Mixture of Experts
"""
import asyncio
import logging
import numpy as np
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from core.data_provider import DataProvider
from core.trading_action import TradingAction
from utils.tensorboard_logger import TensorBoardLogger
logger = logging.getLogger(__name__)
@dataclass
class ModelOutput:
"""Extensible model output format supporting all model types"""
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
model_name: str # Specific model identifier
symbol: str
timestamp: datetime
confidence: float
predictions: Dict[str, Any] # Model-specific predictions
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
@dataclass
class BaseDataInput:
"""Unified base data input for all models"""
symbol: str
timestamp: datetime
ohlcv_data: Dict[str, Any] = field(default_factory=dict) # Multi-timeframe OHLCV
cob_data: Optional[Dict[str, Any]] = None # COB buckets for 1s timeframe
technical_indicators: Dict[str, float] = field(default_factory=dict)
pivot_points: List[Any] = field(default_factory=list)
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict) # From all models
market_microstructure: Dict[str, Any] = field(default_factory=dict) # Order flow, etc.
@dataclass
class COBData:
"""Cumulative Order Book data for price buckets"""
symbol: str
timestamp: datetime
current_price: float
bucket_size: float # $1 for ETH, $10 for BTC
price_buckets: Dict[float, Dict[str, float]] = field(default_factory=dict) # price -> {bid_volume, ask_volume, etc.}
bid_ask_imbalance: Dict[float, float] = field(default_factory=dict) # price -> imbalance ratio
volume_weighted_prices: Dict[float, float] = field(default_factory=dict) # price -> VWAP within bucket
order_flow_metrics: Dict[str, float] = field(default_factory=dict) # Various order flow indicators
class EnhancedTradingOrchestrator:
"""
Enhanced Trading Orchestrator implementing the design specification
Coordinates data flow, model inference, and decision making for the multi-modal trading system.
"""
def __init__(self, data_provider: DataProvider, symbols: List[str], enhanced_rl_training: bool = False, model_registry: Dict = None):
"""Initialize the enhanced orchestrator"""
self.data_provider = data_provider
self.symbols = symbols
self.enhanced_rl_training = enhanced_rl_training
self.model_registry = model_registry or {}
# Data management
self.data_buffers = {symbol: {} for symbol in symbols}
self.last_update_times = {symbol: {} for symbol in symbols}
# Model output storage
self.model_outputs = {symbol: {} for symbol in symbols}
self.model_output_history = {symbol: {} for symbol in symbols}
# Training pipeline
self.training_data = {symbol: [] for symbol in symbols}
self.tensorboard_logger = TensorBoardLogger("runs", f"orchestrator_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
# COB integration
self.cob_data = {symbol: None for symbol in symbols}
# Performance tracking
self.performance_metrics = {
'inference_count': 0,
'successful_states': 0,
'total_episodes': 0
}
logger.info("Enhanced Trading Orchestrator initialized")
async def start_cob_integration(self):
"""Start COB data integration for real-time market microstructure"""
try:
# Subscribe to COB data updates
self.data_provider.subscribe_to_cob_data(self._on_cob_data_update)
logger.info("COB integration started")
except Exception as e:
logger.error(f"Error starting COB integration: {e}")
async def start_realtime_processing(self):
"""Start real-time data processing"""
try:
# Subscribe to tick data for real-time processing
for symbol in self.symbols:
self.data_provider.subscribe_to_ticks(
callback=self._on_tick_data,
symbols=[symbol],
subscriber_name=f"orchestrator_{symbol}"
)
logger.info("Real-time processing started")
except Exception as e:
logger.error(f"Error starting real-time processing: {e}")
def _on_cob_data_update(self, symbol: str, cob_data: dict):
"""Handle COB data updates"""
try:
# Process and store COB data
self.cob_data[symbol] = self._process_cob_data(symbol, cob_data)
logger.debug(f"COB data updated for {symbol}")
except Exception as e:
logger.error(f"Error processing COB data for {symbol}: {e}")
def _process_cob_data(self, symbol: str, cob_data: dict) -> COBData:
"""Process raw COB data into structured format"""
try:
# Determine bucket size based on symbol
bucket_size = 1.0 if 'ETH' in symbol else 10.0
# Extract current price
stats = cob_data.get('stats', {})
current_price = stats.get('mid_price', 0)
# Create COB data structure
cob = COBData(
symbol=symbol,
timestamp=datetime.now(),
current_price=current_price,
bucket_size=bucket_size
)
# Process order book data into price buckets
bids = cob_data.get('bids', [])
asks = cob_data.get('asks', [])
# Create price buckets around current price
bucket_count = 20 # ±20 buckets
for i in range(-bucket_count, bucket_count + 1):
bucket_price = current_price + (i * bucket_size)
cob.price_buckets[bucket_price] = {
'bid_volume': 0.0,
'ask_volume': 0.0
}
# Aggregate bid volumes into buckets
for price, volume in bids:
bucket_price = round(price / bucket_size) * bucket_size
if bucket_price in cob.price_buckets:
cob.price_buckets[bucket_price]['bid_volume'] += volume
# Aggregate ask volumes into buckets
for price, volume in asks:
bucket_price = round(price / bucket_size) * bucket_size
if bucket_price in cob.price_buckets:
cob.price_buckets[bucket_price]['ask_volume'] += volume
# Calculate bid/ask imbalances
for price, volumes in cob.price_buckets.items():
bid_vol = volumes['bid_volume']
ask_vol = volumes['ask_volume']
total_vol = bid_vol + ask_vol
if total_vol > 0:
cob.bid_ask_imbalance[price] = (bid_vol - ask_vol) / total_vol
else:
cob.bid_ask_imbalance[price] = 0.0
# Calculate volume-weighted prices
for price, volumes in cob.price_buckets.items():
bid_vol = volumes['bid_volume']
ask_vol = volumes['ask_volume']
total_vol = bid_vol + ask_vol
if total_vol > 0:
cob.volume_weighted_prices[price] = (
(price * bid_vol) + (price * ask_vol)
) / total_vol
else:
cob.volume_weighted_prices[price] = price
# Calculate order flow metrics
cob.order_flow_metrics = {
'total_bid_volume': sum(v['bid_volume'] for v in cob.price_buckets.values()),
'total_ask_volume': sum(v['ask_volume'] for v in cob.price_buckets.values()),
'bid_ask_ratio': 0.0 if cob.order_flow_metrics['total_ask_volume'] == 0 else
cob.order_flow_metrics['total_bid_volume'] / cob.order_flow_metrics['total_ask_volume']
}
return cob
except Exception as e:
logger.error(f"Error processing COB data for {symbol}: {e}")
return COBData(symbol=symbol, timestamp=datetime.now(), current_price=0, bucket_size=bucket_size)
def _on_tick_data(self, tick):
"""Handle incoming tick data"""
try:
# Update data buffers
symbol = tick.symbol
if symbol not in self.data_buffers:
self.data_buffers[symbol] = {}
# Store tick data
if 'ticks' not in self.data_buffers[symbol]:
self.data_buffers[symbol]['ticks'] = []
self.data_buffers[symbol]['ticks'].append(tick)
# Keep only last 1000 ticks
if len(self.data_buffers[symbol]['ticks']) > 1000:
self.data_buffers[symbol]['ticks'] = self.data_buffers[symbol]['ticks'][-1000:]
# Update last update time
self.last_update_times[symbol]['tick'] = datetime.now()
logger.debug(f"Tick data updated for {symbol}")
except Exception as e:
logger.error(f"Error processing tick data: {e}")
def build_comprehensive_rl_state(self, symbol: str) -> Optional[np.ndarray]:
"""
Build comprehensive RL state with 13,400 features as specified
Returns:
np.ndarray: State vector with 13,400 features
"""
try:
# Initialize state vector
state_size = 13400
state = np.zeros(state_size, dtype=np.float32)
# Get latest data
ohlcv_data = self.data_provider.get_latest_candles(symbol, '1s', limit=100)
cob_data = self.cob_data.get(symbol)
# Feature index tracking
idx = 0
# 1. OHLCV features (4000 features)
if ohlcv_data is not None and not ohlcv_data.empty:
# Use last 100 1s candles (40 features each: O,H,L,C,V + 36 indicators)
for i in range(min(100, len(ohlcv_data))):
if idx + 40 <= state_size:
row = ohlcv_data.iloc[-(i+1)]
state[idx] = row.get('open', 0) / 100000 # Normalized
state[idx+1] = row.get('high', 0) / 100000
state[idx+2] = row.get('low', 0) / 100000
state[idx+3] = row.get('close', 0) / 100000
state[idx+4] = row.get('volume', 0) / 1000000
# Add technical indicators if available
indicator_idx = 5
for col in ['sma_10', 'sma_20', 'ema_12', 'ema_26', 'rsi_14',
'macd', 'bb_upper', 'bb_lower', 'atr', 'adx']:
if col in row and idx + indicator_idx < state_size:
state[idx + indicator_idx] = row[col] / 100000
indicator_idx += 1
idx += 40
# 2. COB features (8000 features)
if cob_data and idx + 8000 <= state_size:
# Use 200 price buckets (40 features each)
bucket_prices = sorted(cob_data.price_buckets.keys())
for i, price in enumerate(bucket_prices[:200]):
if idx + 40 <= state_size:
bucket = cob_data.price_buckets[price]
state[idx] = bucket.get('bid_volume', 0) / 1000000 # Normalized
state[idx+1] = bucket.get('ask_volume', 0) / 1000000
state[idx+2] = cob_data.bid_ask_imbalance.get(price, 0)
state[idx+3] = cob_data.volume_weighted_prices.get(price, price) / 100000
# Additional COB metrics
state[idx+4] = cob_data.order_flow_metrics.get('total_bid_volume', 0) / 10000000
state[idx+5] = cob_data.order_flow_metrics.get('total_ask_volume', 0) / 10000000
state[idx+6] = cob_data.order_flow_metrics.get('bid_ask_ratio', 0)
idx += 40
# 3. Technical indicator features (1000 features)
# Already included in OHLCV section above
# 4. Market microstructure features (400 features)
if cob_data and idx + 400 <= state_size:
# Add order flow metrics
metrics = list(cob_data.order_flow_metrics.values())
for i, metric in enumerate(metrics[:400]):
if idx + i < state_size:
state[idx + i] = metric
# Log state building success
self.performance_metrics['successful_states'] += 1
logger.debug(f"Comprehensive RL state built for {symbol}: {len(state)} features")
# Log to TensorBoard
self.tensorboard_logger.log_state_metrics(
symbol=symbol,
state_info={
'size': len(state),
'quality': 1.0,
'feature_counts': {
'total': len(state),
'non_zero': np.count_nonzero(state)
}
},
step=self.performance_metrics['successful_states']
)
return state
except Exception as e:
logger.error(f"Error building comprehensive RL state for {symbol}: {e}")
return None
def calculate_enhanced_pivot_reward(self, trade_decision: Dict, market_data: Dict, trade_outcome: Dict) -> float:
"""
Calculate enhanced pivot-based reward
Args:
trade_decision: Trading decision with action and confidence
market_data: Market context data
trade_outcome: Actual trade results
Returns:
float: Enhanced reward value
"""
try:
# Base reward from PnL
pnl_reward = trade_outcome.get('net_pnl', 0) / 100 # Normalize
# Confidence weighting
confidence = trade_decision.get('confidence', 0.5)
confidence_reward = confidence * 0.2
# Volatility adjustment
volatility = market_data.get('volatility', 0.01)
volatility_reward = (1.0 - volatility * 10) * 0.1 # Prefer low volatility
# Order flow alignment
order_flow = market_data.get('order_flow_strength', 0)
order_flow_reward = order_flow * 0.2
# Pivot alignment bonus (if near pivot in favorable direction)
pivot_bonus = 0.0
if market_data.get('near_pivot', False):
action = trade_decision.get('action', '').upper()
pivot_type = market_data.get('pivot_type', '').upper()
# Bonus for buying near support or selling near resistance
if (action == 'BUY' and pivot_type == 'LOW') or \
(action == 'SELL' and pivot_type == 'HIGH'):
pivot_bonus = 0.5
# Calculate final reward
enhanced_reward = pnl_reward + confidence_reward + volatility_reward + order_flow_reward + pivot_bonus
# Log to TensorBoard
self.tensorboard_logger.log_scalars('Rewards/Components', {
'pnl_component': pnl_reward,
'confidence': confidence_reward,
'volatility': volatility_reward,
'order_flow': order_flow_reward,
'pivot_bonus': pivot_bonus
}, self.performance_metrics['total_episodes'])
self.tensorboard_logger.log_scalar('Rewards/Enhanced', enhanced_reward, self.performance_metrics['total_episodes'])
logger.debug(f"Enhanced reward calculated: {enhanced_reward}")
return enhanced_reward
except Exception as e:
logger.error(f"Error calculating enhanced pivot reward: {e}")
return 0.0
async def make_coordinated_decisions(self) -> Dict[str, TradingAction]:
"""
Make coordinated trading decisions using all available models
Returns:
Dict[str, TradingAction]: Trading actions for each symbol
"""
try:
decisions = {}
# For each symbol, coordinate model inference
for symbol in self.symbols:
# Build comprehensive state for RL model
rl_state = self.build_comprehensive_rl_state(symbol)
if rl_state is not None:
# Store state for training
self.performance_metrics['total_episodes'] += 1
# Create mock RL decision (in a real implementation, this would call the RL model)
action = 'BUY' if np.mean(rl_state[:100]) > 0.5 else 'SELL'
confidence = min(1.0, max(0.0, np.std(rl_state) * 10))
# Create trading action
decisions[symbol] = TradingAction(
symbol=symbol,
timestamp=datetime.now(),
action=action,
confidence=confidence,
source='rl_orchestrator'
)
logger.info(f"Coordinated decision for {symbol}: {action} (confidence: {confidence:.3f})")
else:
logger.warning(f"Failed to build state for {symbol}, skipping decision")
self.performance_metrics['inference_count'] += 1
return decisions
except Exception as e:
logger.error(f"Error making coordinated decisions: {e}")
return {}
def _get_symbol_correlation(self, symbol1: str, symbol2: str) -> float:
"""
Calculate correlation between two symbols
Args:
symbol1: First symbol
symbol2: Second symbol
Returns:
float: Correlation coefficient (-1 to 1)
"""
try:
# Get recent price data for both symbols
data1 = self.data_provider.get_latest_candles(symbol1, '1m', limit=50)
data2 = self.data_provider.get_latest_candles(symbol2, '1m', limit=50)
if data1 is None or data2 is None or data1.empty or data2.empty:
return 0.0
# Align data by timestamp
merged = data1[['close']].join(data2[['close']], lsuffix='_1', rsuffix='_2', how='inner')
if len(merged) < 10:
return 0.0
# Calculate correlation
correlation = merged['close_1'].corr(merged['close_2'])
return correlation if not np.isnan(correlation) else 0.0
except Exception as e:
logger.error(f"Error calculating symbol correlation: {e}")
return 0.0
```

View File

@ -1,775 +0,0 @@
"""
Enhanced Training Integration Module
This module provides comprehensive integration between the training data collection system,
CNN training pipeline, RL training pipeline, and your existing infrastructure.
Key Features:
- Real-time integration with existing DataProvider
- Coordinated training across CNN and RL models
- Automatic outcome validation and profitability tracking
- Integration with existing COB RL model
- Performance monitoring and optimization
- Seamless connection to existing orchestrator and trading executor
"""
import asyncio
import logging
import numpy as np
import pandas as pd
import torch
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass
import threading
import time
from pathlib import Path
# Import existing components
from .data_provider import DataProvider
from .orchestrator import Orchestrator
from .trading_executor import TradingExecutor
# Import our training system components
from .training_data_collector import (
TrainingDataCollector,
get_training_data_collector
)
from .cnn_training_pipeline import (
CNNPivotPredictor,
CNNTrainer,
get_cnn_trainer
)
from .rl_training_pipeline import (
RLTradingAgent,
RLTrainer,
get_rl_trainer
)
from .training_integration import TrainingIntegration
# Import existing RL model
try:
from NN.models.cob_rl_model import COBRLModelInterface
except ImportError:
logger.warning("Could not import COBRLModelInterface - using fallback")
COBRLModelInterface = None
logger = logging.getLogger(__name__)
@dataclass
class EnhancedTrainingConfig:
"""Enhanced configuration for comprehensive training integration"""
# Data collection
collection_interval: float = 1.0
min_data_completeness: float = 0.8
# Training triggers
min_episodes_for_cnn_training: int = 100
min_experiences_for_rl_training: int = 200
training_frequency_minutes: int = 30
# Profitability thresholds
min_profitability_for_replay: float = 0.1
high_profitability_threshold: float = 0.5
# Model integration
use_existing_cob_rl_model: bool = True
enable_cross_model_learning: bool = True
# Performance optimization
max_concurrent_training_sessions: int = 2
enable_background_validation: bool = True
class EnhancedTrainingIntegration:
"""Enhanced training integration with existing infrastructure"""
def __init__(self,
data_provider: DataProvider,
orchestrator: Orchestrator = None,
trading_executor: TradingExecutor = None,
config: EnhancedTrainingConfig = None):
self.data_provider = data_provider
self.orchestrator = orchestrator
self.trading_executor = trading_executor
self.config = config or EnhancedTrainingConfig()
# Initialize training components
self.data_collector = get_training_data_collector()
# Initialize CNN components
self.cnn_model = CNNPivotPredictor()
self.cnn_trainer = get_cnn_trainer(self.cnn_model)
# Initialize RL components
if self.config.use_existing_cob_rl_model and COBRLModelInterface:
self.existing_rl_model = COBRLModelInterface()
logger.info("Using existing COB RL model")
else:
self.existing_rl_model = None
self.rl_agent = RLTradingAgent()
self.rl_trainer = get_rl_trainer(self.rl_agent)
# Integration state
self.is_running = False
self.training_threads = {}
self.validation_thread = None
# Performance tracking
self.integration_stats = {
'total_data_packages': 0,
'cnn_training_sessions': 0,
'rl_training_sessions': 0,
'profitable_predictions': 0,
'total_predictions': 0,
'cross_model_improvements': 0,
'last_update': datetime.now()
}
# Model prediction tracking
self.recent_predictions = {}
self.prediction_outcomes = {}
# Cross-model learning
self.model_performance_history = {
'cnn': [],
'rl': [],
'orchestrator': []
}
logger.info("Enhanced Training Integration initialized")
logger.info(f"CNN model parameters: {sum(p.numel() for p in self.cnn_model.parameters()):,}")
logger.info(f"RL agent parameters: {sum(p.numel() for p in self.rl_agent.parameters()):,}")
logger.info(f"Using existing COB RL model: {self.existing_rl_model is not None}")
def start_enhanced_integration(self):
"""Start the enhanced training integration system"""
if self.is_running:
logger.warning("Enhanced training integration already running")
return
self.is_running = True
# Start data collection
self.data_collector.start_collection()
# Start CNN training
if self.config.min_episodes_for_cnn_training > 0:
for symbol in self.data_provider.symbols:
self.cnn_trainer.start_real_time_training(symbol)
# Start coordinated training thread
self.training_threads['coordinator'] = threading.Thread(
target=self._training_coordinator_worker,
daemon=True
)
self.training_threads['coordinator'].start()
# Start data collection and validation
self.training_threads['data_collector'] = threading.Thread(
target=self._enhanced_data_collection_worker,
daemon=True
)
self.training_threads['data_collector'].start()
# Start outcome validation if enabled
if self.config.enable_background_validation:
self.validation_thread = threading.Thread(
target=self._outcome_validation_worker,
daemon=True
)
self.validation_thread.start()
logger.info("Enhanced training integration started")
def stop_enhanced_integration(self):
"""Stop the enhanced training integration system"""
self.is_running = False
# Stop data collection
self.data_collector.stop_collection()
# Stop CNN training
self.cnn_trainer.stop_training()
# Wait for threads to finish
for thread_name, thread in self.training_threads.items():
thread.join(timeout=10)
logger.info(f"Stopped {thread_name} thread")
if self.validation_thread:
self.validation_thread.join(timeout=5)
logger.info("Enhanced training integration stopped")
def _enhanced_data_collection_worker(self):
"""Enhanced data collection with real-time model integration"""
logger.info("Enhanced data collection worker started")
while self.is_running:
try:
for symbol in self.data_provider.symbols:
self._collect_enhanced_training_data(symbol)
time.sleep(self.config.collection_interval)
except Exception as e:
logger.error(f"Error in enhanced data collection: {e}")
time.sleep(5)
logger.info("Enhanced data collection worker stopped")
def _collect_enhanced_training_data(self, symbol: str):
"""Collect enhanced training data with model predictions"""
try:
# Get comprehensive market data
market_data = self._get_comprehensive_market_data(symbol)
if not market_data or not self._validate_market_data(market_data):
return
# Get current model predictions
model_predictions = self._get_all_model_predictions(symbol, market_data)
# Create enhanced features
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
rl_state = self._create_enhanced_rl_state(symbol, market_data, model_predictions)
# Collect training data with predictions
episode_id = self.data_collector.collect_training_data(
symbol=symbol,
ohlcv_data=market_data['ohlcv'],
tick_data=market_data['ticks'],
cob_data=market_data['cob'],
technical_indicators=market_data['indicators'],
pivot_points=market_data['pivots'],
cnn_features=cnn_features,
rl_state=rl_state,
orchestrator_context=market_data['context'],
model_predictions=model_predictions
)
if episode_id:
# Store predictions for outcome validation
self.recent_predictions[episode_id] = {
'timestamp': datetime.now(),
'symbol': symbol,
'predictions': model_predictions,
'market_data': market_data
}
# Add RL experience if we have action
if 'rl_action' in model_predictions:
self._add_rl_experience(symbol, market_data, model_predictions, episode_id)
self.integration_stats['total_data_packages'] += 1
except Exception as e:
logger.error(f"Error collecting enhanced training data for {symbol}: {e}")
def _get_comprehensive_market_data(self, symbol: str) -> Dict[str, Any]:
"""Get comprehensive market data from all sources"""
try:
market_data = {}
# OHLCV data
ohlcv_data = {}
for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']:
df = self.data_provider.get_historical_data(symbol, timeframe, limit=300, refresh=True)
if df is not None and not df.empty:
ohlcv_data[timeframe] = df
market_data['ohlcv'] = ohlcv_data
# Tick data
market_data['ticks'] = self._get_recent_tick_data(symbol)
# COB data
market_data['cob'] = self._get_cob_data(symbol)
# Technical indicators
market_data['indicators'] = self._get_technical_indicators(symbol)
# Pivot points
market_data['pivots'] = self._get_pivot_points(symbol)
# Market context
market_data['context'] = self._get_market_context(symbol)
return market_data
except Exception as e:
logger.error(f"Error getting comprehensive market data: {e}")
return {}
def _get_all_model_predictions(self, symbol: str, market_data: Dict[str, Any]) -> Dict[str, Any]:
"""Get predictions from all available models"""
predictions = {}
try:
# CNN predictions
if self.cnn_model and market_data.get('ohlcv'):
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
if cnn_features is not None:
cnn_input = torch.from_numpy(cnn_features).float().unsqueeze(0)
# Reshape for CNN (add channel dimension)
cnn_input = cnn_input.view(1, 10, -1) # Assuming 10 channels
with torch.no_grad():
cnn_outputs = self.cnn_model(cnn_input)
predictions['cnn'] = {
'pivot_logits': cnn_outputs['pivot_logits'].cpu().numpy(),
'pivot_price': cnn_outputs['pivot_price'].cpu().numpy(),
'confidence': cnn_outputs['confidence'].cpu().numpy(),
'timestamp': datetime.now()
}
# RL predictions
if self.rl_agent and market_data.get('cob'):
rl_state = self._create_enhanced_rl_state(symbol, market_data, predictions)
if rl_state is not None:
action, confidence = self.rl_agent.select_action(rl_state, epsilon=0.1)
predictions['rl'] = {
'action': action,
'confidence': confidence,
'timestamp': datetime.now()
}
predictions['rl_action'] = action
# Existing COB RL model predictions
if self.existing_rl_model and market_data.get('cob'):
cob_features = market_data['cob'].get('cob_features', [])
if cob_features and len(cob_features) >= 2000:
cob_array = np.array(cob_features[:2000], dtype=np.float32)
cob_prediction = self.existing_rl_model.predict(cob_array)
predictions['cob_rl'] = {
'predicted_direction': cob_prediction.get('predicted_direction', 1),
'confidence': cob_prediction.get('confidence', 0.5),
'value': cob_prediction.get('value', 0.0),
'timestamp': datetime.now()
}
# Orchestrator predictions (if available)
if self.orchestrator:
try:
# This would integrate with your orchestrator's prediction method
orchestrator_prediction = self._get_orchestrator_prediction(symbol, market_data, predictions)
if orchestrator_prediction:
predictions['orchestrator'] = orchestrator_prediction
except Exception as e:
logger.debug(f"Could not get orchestrator prediction: {e}")
return predictions
except Exception as e:
logger.error(f"Error getting model predictions: {e}")
return {}
def _add_rl_experience(self, symbol: str, market_data: Dict[str, Any],
predictions: Dict[str, Any], episode_id: str):
"""Add RL experience to the training buffer"""
try:
# Create RL state
state = self._create_enhanced_rl_state(symbol, market_data, predictions)
if state is None:
return
# Get action from predictions
action = predictions.get('rl_action', 1) # Default to HOLD
# Calculate immediate reward (placeholder - would be updated with actual outcome)
reward = 0.0
# Create next state (same as current for now - would be updated)
next_state = state.copy()
# Market context
market_context = {
'symbol': symbol,
'episode_id': episode_id,
'timestamp': datetime.now(),
'market_session': market_data['context'].get('market_session', 'unknown'),
'volatility_regime': market_data['context'].get('volatility_regime', 'unknown')
}
# Add experience
experience_id = self.rl_trainer.add_experience(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=False,
market_context=market_context,
cnn_predictions=predictions.get('cnn'),
confidence_score=predictions.get('rl', {}).get('confidence', 0.0)
)
if experience_id:
logger.debug(f"Added RL experience: {experience_id}")
except Exception as e:
logger.error(f"Error adding RL experience: {e}")
def _training_coordinator_worker(self):
"""Coordinate training across all models"""
logger.info("Training coordinator worker started")
while self.is_running:
try:
# Check if we should trigger training
for symbol in self.data_provider.symbols:
self._check_and_trigger_training(symbol)
# Wait before next check
time.sleep(self.config.training_frequency_minutes * 60)
except Exception as e:
logger.error(f"Error in training coordinator: {e}")
time.sleep(60)
logger.info("Training coordinator worker stopped")
def _check_and_trigger_training(self, symbol: str):
"""Check conditions and trigger training if needed"""
try:
# Get training episodes and experiences
episodes = self.data_collector.get_high_priority_episodes(symbol, limit=1000)
# Check CNN training conditions
if len(episodes) >= self.config.min_episodes_for_cnn_training:
profitable_episodes = [ep for ep in episodes if ep.actual_outcome.is_profitable]
if len(profitable_episodes) >= 20: # Minimum profitable episodes
logger.info(f"Triggering CNN training for {symbol} with {len(profitable_episodes)} profitable episodes")
results = self.cnn_trainer.train_on_profitable_episodes(
symbol=symbol,
min_profitability=self.config.min_profitability_for_replay,
max_episodes=len(profitable_episodes)
)
if results.get('status') == 'success':
self.integration_stats['cnn_training_sessions'] += 1
logger.info(f"CNN training completed for {symbol}")
# Check RL training conditions
buffer_stats = self.rl_trainer.experience_buffer.get_buffer_statistics()
total_experiences = buffer_stats.get('total_experiences', 0)
if total_experiences >= self.config.min_experiences_for_rl_training:
profitable_experiences = buffer_stats.get('profitable_experiences', 0)
if profitable_experiences >= 50: # Minimum profitable experiences
logger.info(f"Triggering RL training with {profitable_experiences} profitable experiences")
results = self.rl_trainer.train_on_profitable_experiences(
min_profitability=self.config.min_profitability_for_replay,
max_experiences=min(profitable_experiences, 500),
batch_size=32
)
if results.get('status') == 'success':
self.integration_stats['rl_training_sessions'] += 1
logger.info("RL training completed")
except Exception as e:
logger.error(f"Error checking training conditions for {symbol}: {e}")
def _outcome_validation_worker(self):
"""Background worker for validating prediction outcomes"""
logger.info("Outcome validation worker started")
while self.is_running:
try:
self._validate_recent_predictions()
time.sleep(300) # Check every 5 minutes
except Exception as e:
logger.error(f"Error in outcome validation: {e}")
time.sleep(60)
logger.info("Outcome validation worker stopped")
def _validate_recent_predictions(self):
"""Validate recent predictions against actual outcomes"""
try:
current_time = datetime.now()
validation_delay = timedelta(hours=1) # Wait 1 hour to validate
validated_predictions = []
for episode_id, prediction_data in self.recent_predictions.items():
prediction_time = prediction_data['timestamp']
if current_time - prediction_time >= validation_delay:
# Validate this prediction
outcome = self._calculate_prediction_outcome(prediction_data)
if outcome:
self.prediction_outcomes[episode_id] = outcome
# Update RL experience if exists
if 'rl_action' in prediction_data['predictions']:
self._update_rl_experience_outcome(episode_id, outcome)
# Update statistics
if outcome['is_profitable']:
self.integration_stats['profitable_predictions'] += 1
self.integration_stats['total_predictions'] += 1
validated_predictions.append(episode_id)
# Remove validated predictions
for episode_id in validated_predictions:
del self.recent_predictions[episode_id]
if validated_predictions:
logger.info(f"Validated {len(validated_predictions)} predictions")
except Exception as e:
logger.error(f"Error validating predictions: {e}")
def _calculate_prediction_outcome(self, prediction_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Calculate actual outcome for a prediction"""
try:
symbol = prediction_data['symbol']
prediction_time = prediction_data['timestamp']
# Get price data after prediction
current_df = self.data_provider.get_historical_data(symbol, '1m', limit=100, refresh=True)
if current_df is None or current_df.empty:
return None
# Find price at prediction time and current price
prediction_price = prediction_data['market_data']['ohlcv'].get('1m', pd.DataFrame())
if prediction_price.empty:
return None
base_price = float(prediction_price['close'].iloc[-1])
current_price = float(current_df['close'].iloc[-1])
# Calculate outcome
price_change = (current_price - base_price) / base_price
is_profitable = abs(price_change) > 0.005 # 0.5% threshold
return {
'episode_id': prediction_data.get('episode_id'),
'base_price': base_price,
'current_price': current_price,
'price_change': price_change,
'is_profitable': is_profitable,
'profitability_score': abs(price_change) * 10, # Scale to 0-1 range
'validation_time': datetime.now()
}
except Exception as e:
logger.error(f"Error calculating prediction outcome: {e}")
return None
def _update_rl_experience_outcome(self, episode_id: str, outcome: Dict[str, Any]):
"""Update RL experience with actual outcome"""
try:
# Find the experience ID associated with this episode
# This is a simplified approach - in practice you'd maintain better mapping
actual_profit = outcome['price_change']
# Determine optimal action based on outcome
if outcome['price_change'] > 0.01:
optimal_action = 2 # BUY
elif outcome['price_change'] < -0.01:
optimal_action = 0 # SELL
else:
optimal_action = 1 # HOLD
# Update experience (this would need proper experience ID mapping)
# For now, we'll update the most recent experience
# In practice, you'd maintain a mapping between episodes and experiences
except Exception as e:
logger.error(f"Error updating RL experience outcome: {e}")
def get_integration_statistics(self) -> Dict[str, Any]:
"""Get comprehensive integration statistics"""
stats = self.integration_stats.copy()
# Add component statistics
stats['data_collector'] = self.data_collector.get_collection_statistics()
stats['cnn_trainer'] = self.cnn_trainer.get_training_statistics()
stats['rl_trainer'] = self.rl_trainer.get_training_statistics()
# Add performance metrics
stats['is_running'] = self.is_running
stats['active_symbols'] = len(self.data_provider.symbols)
stats['recent_predictions_count'] = len(self.recent_predictions)
stats['validated_outcomes_count'] = len(self.prediction_outcomes)
# Calculate profitability rate
if stats['total_predictions'] > 0:
stats['overall_profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
else:
stats['overall_profitability_rate'] = 0.0
return stats
def trigger_manual_training(self, training_type: str = 'all', symbol: str = None) -> Dict[str, Any]:
"""Manually trigger training"""
results = {}
try:
if training_type in ['all', 'cnn']:
symbols = [symbol] if symbol else self.data_provider.symbols
for sym in symbols:
cnn_results = self.cnn_trainer.train_on_profitable_episodes(
symbol=sym,
min_profitability=0.1,
max_episodes=200
)
results[f'cnn_{sym}'] = cnn_results
if training_type in ['all', 'rl']:
rl_results = self.rl_trainer.train_on_profitable_experiences(
min_profitability=0.1,
max_experiences=500,
batch_size=32
)
results['rl'] = rl_results
return {'status': 'success', 'results': results}
except Exception as e:
logger.error(f"Error in manual training trigger: {e}")
return {'status': 'error', 'error': str(e)}
# Helper methods (simplified implementations)
def _get_recent_tick_data(self, symbol: str) -> List[Dict[str, Any]]:
"""Get recent tick data"""
# Implementation would get tick data from data provider
return []
def _get_cob_data(self, symbol: str) -> Dict[str, Any]:
"""Get COB data"""
# Implementation would get COB data from data provider
return {}
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
"""Get technical indicators"""
# Implementation would get indicators from data provider
return {}
def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]:
"""Get pivot points"""
# Implementation would get pivot points from data provider
return []
def _get_market_context(self, symbol: str) -> Dict[str, Any]:
"""Get market context"""
return {
'symbol': symbol,
'timestamp': datetime.now(),
'market_session': 'unknown',
'volatility_regime': 'unknown'
}
def _validate_market_data(self, market_data: Dict[str, Any]) -> bool:
"""Validate market data completeness"""
required_fields = ['ohlcv', 'indicators']
return all(field in market_data for field in required_fields)
def _create_enhanced_cnn_features(self, symbol: str, market_data: Dict[str, Any]) -> Optional[np.ndarray]:
"""Create enhanced CNN features"""
try:
# Simplified feature creation
features = []
# Add OHLCV features
for timeframe in ['1m', '5m', '15m', '1h']:
if timeframe in market_data.get('ohlcv', {}):
df = market_data['ohlcv'][timeframe]
if not df.empty:
ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values
if len(ohlcv_values) > 0:
recent_values = ohlcv_values[-60:].flatten()
features.extend(recent_values)
# Pad to target size
target_size = 3000 # 10 channels * 300 sequence length
if len(features) < target_size:
features.extend([0.0] * (target_size - len(features)))
else:
features = features[:target_size]
return np.array(features, dtype=np.float32)
except Exception as e:
logger.warning(f"Error creating CNN features: {e}")
return None
def _create_enhanced_rl_state(self, symbol: str, market_data: Dict[str, Any],
predictions: Dict[str, Any] = None) -> Optional[np.ndarray]:
"""Create enhanced RL state"""
try:
state_features = []
# Add market features
if '1m' in market_data.get('ohlcv', {}):
df = market_data['ohlcv']['1m']
if not df.empty:
latest = df.iloc[-1]
state_features.extend([
latest['open'], latest['high'],
latest['low'], latest['close'], latest['volume']
])
# Add technical indicators
indicators = market_data.get('indicators', {})
for value in indicators.values():
state_features.append(value)
# Add model predictions as features
if predictions:
if 'cnn' in predictions:
cnn_pred = predictions['cnn']
state_features.extend(cnn_pred.get('pivot_logits', [0, 0, 0]))
state_features.append(cnn_pred.get('confidence', [0.0])[0])
if 'cob_rl' in predictions:
cob_pred = predictions['cob_rl']
state_features.append(cob_pred.get('predicted_direction', 1))
state_features.append(cob_pred.get('confidence', 0.5))
# Pad to target size
target_size = 2000
if len(state_features) < target_size:
state_features.extend([0.0] * (target_size - len(state_features)))
else:
state_features = state_features[:target_size]
return np.array(state_features, dtype=np.float32)
except Exception as e:
logger.warning(f"Error creating RL state: {e}")
return None
def _get_orchestrator_prediction(self, symbol: str, market_data: Dict[str, Any],
predictions: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Get orchestrator prediction"""
# This would integrate with your orchestrator
return None
# Global instance
enhanced_training_integration = None
def get_enhanced_training_integration(data_provider: DataProvider = None,
orchestrator: Orchestrator = None,
trading_executor: TradingExecutor = None) -> EnhancedTrainingIntegration:
"""Get global enhanced training integration instance"""
global enhanced_training_integration
if enhanced_training_integration is None:
if data_provider is None:
raise ValueError("DataProvider required for first initialization")
enhanced_training_integration = EnhancedTrainingIntegration(
data_provider, orchestrator, trading_executor
)
return enhanced_training_integration

View File

@ -1,8 +0,0 @@
# MEXC Web Client Module
#
# This module provides web-based trading capabilities for MEXC futures trading
# which is not supported by their official API.
from .mexc_futures_client import MEXCFuturesWebClient
__all__ = ['MEXCFuturesWebClient']

View File

@ -1,555 +0,0 @@
#!/usr/bin/env python3
"""
MEXC Auto Browser with Request Interception
This script automatically spawns a ChromeDriver instance and captures
all MEXC futures trading requests in real-time, including full request
and response data needed for reverse engineering.
"""
import logging
import time
import json
import sys
import os
from typing import Dict, List, Optional, Any
from datetime import datetime
import threading
import queue
# Selenium imports
try:
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.common.exceptions import TimeoutException, WebDriverException
from webdriver_manager.chrome import ChromeDriverManager
except ImportError:
print("Please install selenium and webdriver-manager:")
print("pip install selenium webdriver-manager")
sys.exit(1)
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class MEXCRequestInterceptor:
"""
Automatically spawns ChromeDriver and intercepts all MEXC API requests
"""
def __init__(self, headless: bool = False, save_to_file: bool = True):
"""
Initialize the request interceptor
Args:
headless: Run browser in headless mode
save_to_file: Save captured requests to JSON file
"""
self.driver = None
self.headless = headless
self.save_to_file = save_to_file
self.captured_requests = []
self.captured_responses = []
self.session_cookies = {}
self.monitoring = False
self.request_queue = queue.Queue()
# File paths for saving data
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.requests_file = f"mexc_requests_{self.timestamp}.json"
self.cookies_file = f"mexc_cookies_{self.timestamp}.json"
def setup_browser(self):
"""Setup Chrome browser with necessary options"""
chrome_options = webdriver.ChromeOptions()
# Enable headless mode if needed
if self.headless:
chrome_options.add_argument('--headless')
chrome_options.add_argument('--disable-gpu')
chrome_options.add_argument('--window-size=1920,1080')
chrome_options.add_argument('--disable-extensions')
# Set up Chrome options with a user data directory to persist session
user_data_base_dir = os.path.join(os.getcwd(), 'chrome_user_data')
os.makedirs(user_data_base_dir, exist_ok=True)
# Check for existing session directories
session_dirs = [d for d in os.listdir(user_data_base_dir) if d.startswith('session_')]
session_dirs.sort(reverse=True) # Sort descending to get the most recent first
user_data_dir = None
if session_dirs:
use_existing = input(f"Found {len(session_dirs)} existing sessions. Use an existing session? (y/n): ").lower().strip() == 'y'
if use_existing:
print("Available sessions:")
for i, session in enumerate(session_dirs[:5], 1): # Show up to 5 most recent
print(f"{i}. {session}")
choice = input("Enter session number (default 1) or any other key for most recent: ")
if choice.isdigit() and 1 <= int(choice) <= len(session_dirs):
selected_session = session_dirs[int(choice) - 1]
else:
selected_session = session_dirs[0]
user_data_dir = os.path.join(user_data_base_dir, selected_session)
print(f"Using session: {selected_session}")
if user_data_dir is None:
user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}')
os.makedirs(user_data_dir, exist_ok=True)
print(f"Creating new session: session_{self.timestamp}")
chrome_options.add_argument(f'--user-data-dir={user_data_dir}')
# Enable logging to capture JS console output and network activity
chrome_options.set_capability('goog:loggingPrefs', {
'browser': 'ALL',
'performance': 'ALL'
})
try:
self.driver = webdriver.Chrome(options=chrome_options)
except Exception as e:
print(f"Failed to start browser with session: {e}")
print("Falling back to a new session...")
user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}_fallback')
os.makedirs(user_data_dir, exist_ok=True)
print(f"Creating fallback session: session_{self.timestamp}_fallback")
chrome_options = webdriver.ChromeOptions()
if self.headless:
chrome_options.add_argument('--headless')
chrome_options.add_argument('--disable-gpu')
chrome_options.add_argument('--window-size=1920,1080')
chrome_options.add_argument('--disable-extensions')
chrome_options.add_argument(f'--user-data-dir={user_data_dir}')
chrome_options.set_capability('goog:loggingPrefs', {
'browser': 'ALL',
'performance': 'ALL'
})
self.driver = webdriver.Chrome(options=chrome_options)
return self.driver
def start_monitoring(self):
"""Start the browser and begin monitoring"""
logger.info("Starting MEXC Request Interceptor...")
try:
# Setup ChromeDriver
self.driver = self.setup_browser()
# Navigate to MEXC futures
mexc_url = "https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap"
logger.info(f"Navigating to: {mexc_url}")
self.driver.get(mexc_url)
# Wait for page load
WebDriverWait(self.driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
logger.info("✅ MEXC page loaded successfully!")
logger.info("📝 Please log in manually in the browser window")
logger.info("🔍 Request monitoring is now active...")
# Start monitoring in background thread
self.monitoring = True
monitor_thread = threading.Thread(target=self._monitor_requests, daemon=True)
monitor_thread.start()
# Wait for manual login
self._wait_for_login()
return True
except Exception as e:
logger.error(f"Failed to start monitoring: {e}")
return False
def _wait_for_login(self):
"""Wait for user to log in and show interactive menu"""
logger.info("\n" + "="*60)
logger.info("MEXC REQUEST INTERCEPTOR - INTERACTIVE MODE")
logger.info("="*60)
while True:
print("\nOptions:")
print("1. Check login status")
print("2. Extract current cookies")
print("3. Show captured requests summary")
print("4. Save captured data to files")
print("5. Perform test trade (manual)")
print("6. Monitor for 60 seconds")
print("0. Stop and exit")
choice = input("\nEnter choice (0-6): ").strip()
if choice == "1":
self._check_login_status()
elif choice == "2":
self._extract_cookies()
elif choice == "3":
self._show_requests_summary()
elif choice == "4":
self._save_all_data()
elif choice == "5":
self._guide_test_trade()
elif choice == "6":
self._monitor_for_duration(60)
elif choice == "0":
break
else:
print("Invalid choice. Please try again.")
self.stop_monitoring()
def _check_login_status(self):
"""Check if user is logged into MEXC"""
try:
cookies = self.driver.get_cookies()
auth_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint']
found_auth = []
for cookie in cookies:
if cookie['name'] in auth_cookies and cookie['value']:
found_auth.append(cookie['name'])
if len(found_auth) >= 2:
print("✅ LOGIN DETECTED - You appear to be logged in!")
print(f" Found auth cookies: {', '.join(found_auth)}")
return True
else:
print("❌ NOT LOGGED IN - Please log in to MEXC in the browser")
print(" Missing required authentication cookies")
return False
except Exception as e:
print(f"❌ Error checking login: {e}")
return False
def _extract_cookies(self):
"""Extract and display current session cookies"""
try:
cookies = self.driver.get_cookies()
cookie_dict = {}
for cookie in cookies:
cookie_dict[cookie['name']] = cookie['value']
self.session_cookies = cookie_dict
print(f"\n📊 Extracted {len(cookie_dict)} cookies:")
# Show important cookies
important = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId']
for name in important:
if name in cookie_dict:
value = cookie_dict[name]
display_value = value[:20] + "..." if len(value) > 20 else value
print(f"{name}: {display_value}")
else:
print(f"{name}: Missing")
# Save cookies to file
if self.save_to_file:
with open(self.cookies_file, 'w') as f:
json.dump(cookie_dict, f, indent=2)
print(f"\n💾 Cookies saved to: {self.cookies_file}")
except Exception as e:
print(f"❌ Error extracting cookies: {e}")
def _monitor_requests(self):
"""Background thread to monitor network requests"""
last_log_count = 0
while self.monitoring:
try:
# Get performance logs
logs = self.driver.get_log('performance')
for log in logs:
try:
message = json.loads(log['message'])
method = message.get('message', {}).get('method', '')
# Capture network requests
if method == 'Network.requestWillBeSent':
self._process_request(message['message']['params'])
elif method == 'Network.responseReceived':
self._process_response(message['message']['params'])
except (json.JSONDecodeError, KeyError) as e:
continue
# Show progress every 10 new requests
if len(self.captured_requests) >= last_log_count + 10:
last_log_count = len(self.captured_requests)
logger.info(f"📈 Captured {len(self.captured_requests)} requests, {len(self.captured_responses)} responses")
except Exception as e:
if self.monitoring: # Only log if we're still supposed to be monitoring
logger.debug(f"Monitor error: {e}")
time.sleep(0.5) # Check every 500ms
def _process_request(self, request_data):
"""Process a captured network request"""
try:
url = request_data.get('request', {}).get('url', '')
# Filter for MEXC API requests
if self._is_mexc_request(url):
request_info = {
'type': 'request',
'timestamp': datetime.now().isoformat(),
'url': url,
'method': request_data.get('request', {}).get('method', ''),
'headers': request_data.get('request', {}).get('headers', {}),
'postData': request_data.get('request', {}).get('postData', ''),
'requestId': request_data.get('requestId', '')
}
self.captured_requests.append(request_info)
# Show important requests immediately
if ('futures.mexc.com' in url or 'captcha' in url):
print(f"\n🚀 CAPTURED REQUEST: {request_info['method']} {url}")
if request_info['postData']:
print(f" 📄 POST Data: {request_info['postData'][:100]}...")
# Enhanced captcha detection and detailed logging
if 'captcha' in url.lower() or 'robot' in url.lower():
logger.info(f"CAPTCHA REQUEST DETECTED: {request_data.get('request', {}).get('method', 'UNKNOWN')} {url}")
logger.info(f" Headers: {request_data.get('request', {}).get('headers', {})}")
if request_data.get('request', {}).get('postData', ''):
logger.info(f" Data: {request_data.get('request', {}).get('postData', '')}")
# Attempt to capture related JavaScript or DOM elements (if possible)
if self.driver is not None:
try:
js_snippet = self.driver.execute_script("return document.querySelector('script[src*=\"captcha\"]') ? document.querySelector('script[src*=\"captcha\"]').outerHTML : 'No captcha script found';")
logger.info(f" Related JS Snippet: {js_snippet}")
except Exception as e:
logger.warning(f" Could not capture JS snippet: {e}")
try:
dom_element = self.driver.execute_script("return document.querySelector('div[id*=\"captcha\"]') ? document.querySelector('div[id*=\"captcha\"]').outerHTML : 'No captcha element found';")
logger.info(f" Related DOM Element: {dom_element}")
except Exception as e:
logger.warning(f" Could not capture DOM element: {e}")
else:
logger.warning(" Driver not initialized, cannot capture JS or DOM elements")
except Exception as e:
logger.debug(f"Error processing request: {e}")
def _process_response(self, response_data):
"""Process a captured network response"""
try:
url = response_data.get('response', {}).get('url', '')
# Filter for MEXC API responses
if self._is_mexc_request(url):
response_info = {
'type': 'response',
'timestamp': datetime.now().isoformat(),
'url': url,
'status': response_data.get('response', {}).get('status', 0),
'headers': response_data.get('response', {}).get('headers', {}),
'requestId': response_data.get('requestId', '')
}
self.captured_responses.append(response_info)
# Show important responses immediately
if ('futures.mexc.com' in url or 'captcha' in url):
status = response_info['status']
status_emoji = "" if status == 200 else ""
print(f" {status_emoji} RESPONSE: {status} for {url}")
except Exception as e:
logger.debug(f"Error processing response: {e}")
def _is_mexc_request(self, url: str) -> bool:
"""Check if URL is a relevant MEXC API request"""
mexc_indicators = [
'futures.mexc.com',
'ucgateway/captcha_api',
'api/v1/private',
'api/v3/order',
'mexc.com/api'
]
return any(indicator in url for indicator in mexc_indicators)
def _show_requests_summary(self):
"""Show summary of captured requests"""
print(f"\n📊 CAPTURE SUMMARY:")
print(f" Total Requests: {len(self.captured_requests)}")
print(f" Total Responses: {len(self.captured_responses)}")
# Group by URL pattern
url_counts = {}
for req in self.captured_requests:
base_url = req['url'].split('?')[0] # Remove query params
url_counts[base_url] = url_counts.get(base_url, 0) + 1
print("\n🔗 Top URLs:")
for url, count in sorted(url_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
print(f" {count}x {url}")
# Show recent futures API calls
futures_requests = [r for r in self.captured_requests if 'futures.mexc.com' in r['url']]
if futures_requests:
print(f"\n🚀 Futures API Calls: {len(futures_requests)}")
for req in futures_requests[-3:]: # Show last 3
print(f" {req['method']} {req['url']}")
def _save_all_data(self):
"""Save all captured data to files"""
if not self.save_to_file:
print("File saving is disabled")
return
try:
# Save requests
with open(self.requests_file, 'w') as f:
json.dump({
'requests': self.captured_requests,
'responses': self.captured_responses,
'summary': {
'total_requests': len(self.captured_requests),
'total_responses': len(self.captured_responses),
'capture_session': self.timestamp
}
}, f, indent=2)
# Save cookies if we have them
if self.session_cookies:
with open(self.cookies_file, 'w') as f:
json.dump(self.session_cookies, f, indent=2)
print(f"\n💾 Data saved to:")
print(f" 📋 Requests: {self.requests_file}")
if self.session_cookies:
print(f" 🍪 Cookies: {self.cookies_file}")
# Extract and save CAPTCHA tokens from captured requests
captcha_tokens = self.extract_captcha_tokens()
if captcha_tokens:
captcha_file = f"mexc_captcha_tokens_{self.timestamp}.json"
with open(captcha_file, 'w') as f:
json.dump(captcha_tokens, f, indent=2)
logger.info(f"Saved CAPTCHA tokens to {captcha_file}")
else:
logger.warning("No CAPTCHA tokens found in captured requests")
except Exception as e:
print(f"❌ Error saving data: {e}")
def _guide_test_trade(self):
"""Guide user through performing a test trade"""
print("\n🧪 TEST TRADE GUIDE:")
print("1. Make sure you're logged into MEXC")
print("2. Go to the trading interface")
print("3. Try to place a SMALL test trade (it may fail, but we'll capture the requests)")
print("4. Watch the console for captured API calls")
print("\n⚠️ IMPORTANT: Use very small amounts for testing!")
input("\nPress Enter when you're ready to start monitoring...")
self._monitor_for_duration(120) # Monitor for 2 minutes
def _monitor_for_duration(self, seconds: int):
"""Monitor requests for a specific duration"""
print(f"\n🔍 Monitoring requests for {seconds} seconds...")
print("Perform your trading actions now!")
start_time = time.time()
initial_count = len(self.captured_requests)
while time.time() - start_time < seconds:
current_count = len(self.captured_requests)
new_requests = current_count - initial_count
remaining = seconds - int(time.time() - start_time)
print(f"\r⏱️ Time remaining: {remaining}s | New requests: {new_requests}", end="", flush=True)
time.sleep(1)
final_count = len(self.captured_requests)
new_total = final_count - initial_count
print(f"\n✅ Monitoring complete! Captured {new_total} new requests")
def stop_monitoring(self):
"""Stop monitoring and close browser"""
logger.info("Stopping request monitoring...")
self.monitoring = False
if self.driver:
self.driver.quit()
logger.info("Browser closed")
# Final save
if self.save_to_file and (self.captured_requests or self.captured_responses):
self._save_all_data()
logger.info("Final data save complete")
def extract_captcha_tokens(self):
"""Extract CAPTCHA tokens from captured requests"""
captcha_tokens = []
for request in self.captured_requests:
if 'captcha-token' in request.get('headers', {}):
token = request['headers']['captcha-token']
captcha_tokens.append({
'token': token,
'url': request.get('url', ''),
'timestamp': request.get('timestamp', '')
})
elif 'captcha' in request.get('url', '').lower():
response = request.get('response', {})
if response and 'captcha-token' in response.get('headers', {}):
token = response['headers']['captcha-token']
captcha_tokens.append({
'token': token,
'url': request.get('url', ''),
'timestamp': request.get('timestamp', '')
})
return captcha_tokens
def main():
"""Main function to run the interceptor"""
print("🚀 MEXC Request Interceptor with ChromeDriver")
print("=" * 50)
print("This will automatically:")
print("✅ Download/setup ChromeDriver")
print("✅ Open MEXC futures page")
print("✅ Capture all API requests/responses")
print("✅ Extract session cookies")
print("✅ Save data to JSON files")
print("\nPress Ctrl+C to stop at any time")
# Ask for preferences
headless = input("\nRun in headless mode? (y/n): ").lower().strip() == 'y'
interceptor = MEXCRequestInterceptor(headless=headless, save_to_file=True)
try:
success = interceptor.start_monitoring()
if not success:
print("❌ Failed to start monitoring")
return
except KeyboardInterrupt:
print("\n\n⏹️ Stopping interceptor...")
except Exception as e:
print(f"\n❌ Error: {e}")
finally:
interceptor.stop_monitoring()
print("\n👋 Goodbye!")
if __name__ == "__main__":
main()

View File

@ -1,358 +0,0 @@
"""
MEXC Browser Automation for Cookie Extraction and Request Monitoring
This module uses Selenium to automate browser interactions and extract
session cookies and request data for MEXC futures trading.
"""
import logging
import time
import json
from typing import Dict, List, Optional, Any
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.common.exceptions import TimeoutException, WebDriverException
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
logger = logging.getLogger(__name__)
class MEXCBrowserAutomation:
"""
Browser automation for MEXC futures trading session management
"""
def __init__(self, headless: bool = False, proxy: Optional[str] = None):
"""
Initialize browser automation
Args:
headless: Run browser in headless mode
proxy: HTTP proxy to use (format: host:port)
"""
self.driver = None
self.headless = headless
self.proxy = proxy
self.logged_in = False
def setup_chrome_driver(self) -> webdriver.Chrome:
"""Setup Chrome driver with appropriate options"""
chrome_options = Options()
if self.headless:
chrome_options.add_argument("--headless")
# Basic Chrome options for automation
chrome_options.add_argument("--no-sandbox")
chrome_options.add_argument("--disable-dev-shm-usage")
chrome_options.add_argument("--disable-blink-features=AutomationControlled")
chrome_options.add_experimental_option("excludeSwitches", ["enable-automation"])
chrome_options.add_experimental_option('useAutomationExtension', False)
# Set user agent to avoid detection
chrome_options.add_argument("--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36")
# Proxy setup if provided
if self.proxy:
chrome_options.add_argument(f"--proxy-server=http://{self.proxy}")
# Enable network logging
chrome_options.add_argument("--enable-logging")
chrome_options.add_argument("--log-level=0")
chrome_options.set_capability("goog:loggingPrefs", {"performance": "ALL"})
# Automatically download and setup ChromeDriver
service = Service(ChromeDriverManager().install())
try:
driver = webdriver.Chrome(service=service, options=chrome_options)
# Execute script to avoid detection
driver.execute_script("Object.defineProperty(navigator, 'webdriver', {get: () => undefined})")
return driver
except WebDriverException as e:
logger.error(f"Failed to setup Chrome driver: {e}")
raise
def start_browser(self):
"""Start the browser session"""
if self.driver is None:
logger.info("Starting Chrome browser for MEXC automation")
self.driver = self.setup_chrome_driver()
logger.info("Browser started successfully")
def stop_browser(self):
"""Stop the browser session"""
if self.driver:
logger.info("Stopping browser")
self.driver.quit()
self.driver = None
def navigate_to_mexc_futures(self, symbol: str = "ETH_USDT"):
"""
Navigate to MEXC futures trading page
Args:
symbol: Trading symbol to navigate to
"""
if not self.driver:
self.start_browser()
url = f"https://www.mexc.com/en-GB/futures/{symbol}?type=linear_swap"
logger.info(f"Navigating to MEXC futures: {url}")
self.driver.get(url)
# Wait for page to load
try:
WebDriverWait(self.driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
logger.info("MEXC futures page loaded")
except TimeoutException:
logger.error("Timeout waiting for MEXC page to load")
def wait_for_login(self, timeout: int = 300) -> bool:
"""
Wait for user to manually log in to MEXC
Args:
timeout: Maximum time to wait for login (seconds)
Returns:
bool: True if login detected, False if timeout
"""
logger.info("Please log in to MEXC manually in the browser window")
logger.info("Waiting for login completion...")
start_time = time.time()
while time.time() - start_time < timeout:
# Check if we can find elements that indicate logged in state
try:
# Look for user-specific elements that appear after login
cookies = self.driver.get_cookies()
# Check for authentication cookies
auth_cookies = ['uc_token', 'u_id']
logged_in_indicators = 0
for cookie in cookies:
if cookie['name'] in auth_cookies and cookie['value']:
logged_in_indicators += 1
if logged_in_indicators >= 2:
logger.info("Login detected!")
self.logged_in = True
return True
except Exception as e:
logger.debug(f"Error checking login status: {e}")
time.sleep(2) # Check every 2 seconds
logger.error(f"Login timeout after {timeout} seconds")
return False
def extract_session_cookies(self) -> Dict[str, str]:
"""
Extract all cookies from current browser session
Returns:
Dictionary of cookie name-value pairs
"""
if not self.driver:
logger.error("Browser not started")
return {}
cookies = {}
try:
browser_cookies = self.driver.get_cookies()
for cookie in browser_cookies:
cookies[cookie['name']] = cookie['value']
logger.info(f"Extracted {len(cookies)} cookies from browser session")
# Log important cookies (without values for security)
important_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId']
for cookie_name in important_cookies:
if cookie_name in cookies:
logger.info(f"Found important cookie: {cookie_name}")
else:
logger.warning(f"Missing important cookie: {cookie_name}")
return cookies
except Exception as e:
logger.error(f"Failed to extract cookies: {e}")
return {}
def monitor_network_requests(self, duration: int = 60) -> List[Dict[str, Any]]:
"""
Monitor network requests for the specified duration
Args:
duration: How long to monitor requests (seconds)
Returns:
List of captured network requests
"""
if not self.driver:
logger.error("Browser not started")
return []
logger.info(f"Starting network monitoring for {duration} seconds")
logger.info("Please perform trading actions in the browser (open/close positions)")
start_time = time.time()
captured_requests = []
while time.time() - start_time < duration:
try:
# Get performance logs (network requests)
logs = self.driver.get_log('performance')
for log in logs:
message = json.loads(log['message'])
# Filter for relevant MEXC API requests
if (message.get('message', {}).get('method') == 'Network.responseReceived'):
response = message['message']['params']['response']
url = response.get('url', '')
# Look for futures API calls
if ('futures.mexc.com' in url or
'ucgateway/captcha_api' in url or
'api/v1/private' in url):
request_data = {
'url': url,
'method': response.get('mimeType', ''),
'status': response.get('status'),
'headers': response.get('headers', {}),
'timestamp': log['timestamp']
}
captured_requests.append(request_data)
logger.info(f"Captured request: {url}")
except Exception as e:
logger.debug(f"Error in network monitoring: {e}")
time.sleep(1)
logger.info(f"Network monitoring complete. Captured {len(captured_requests)} requests")
return captured_requests
def perform_test_trade(self, symbol: str = "ETH_USDT", volume: float = 1.0, leverage: int = 200):
"""
Attempt to perform a test trade to capture the complete request flow
Args:
symbol: Trading symbol
volume: Position size
leverage: Leverage multiplier
"""
if not self.logged_in:
logger.error("Not logged in - cannot perform test trade")
return
logger.info(f"Attempting test trade: {symbol}, Volume: {volume}, Leverage: {leverage}x")
logger.info("This will attempt to click trading interface elements")
try:
# This would need to be implemented based on MEXC's specific UI elements
# For now, just wait and let user perform manual actions
logger.info("Please manually place a small test trade while monitoring is active")
time.sleep(30)
except Exception as e:
logger.error(f"Error during test trade: {e}")
def full_session_capture(self, symbol: str = "ETH_USDT") -> Dict[str, Any]:
"""
Complete session capture workflow
Args:
symbol: Trading symbol to use
Returns:
Dictionary containing cookies and captured requests
"""
logger.info("Starting full MEXC session capture")
try:
# Start browser and navigate to MEXC
self.navigate_to_mexc_futures(symbol)
# Wait for manual login
if not self.wait_for_login():
return {'success': False, 'error': 'Login timeout'}
# Extract session cookies
cookies = self.extract_session_cookies()
if not cookies:
return {'success': False, 'error': 'Failed to extract cookies'}
# Monitor network requests while user performs actions
logger.info("Starting network monitoring - please perform trading actions now")
requests = self.monitor_network_requests(duration=120) # 2 minutes
return {
'success': True,
'cookies': cookies,
'network_requests': requests,
'timestamp': int(time.time())
}
except Exception as e:
logger.error(f"Error in session capture: {e}")
return {'success': False, 'error': str(e)}
finally:
self.stop_browser()
def main():
"""Main function for standalone execution"""
logging.basicConfig(level=logging.INFO)
print("MEXC Browser Automation - Session Capture")
print("This will open a browser window for you to log into MEXC")
print("Make sure you have Chrome browser installed")
automation = MEXCBrowserAutomation(headless=False)
try:
result = automation.full_session_capture()
if result['success']:
print(f"\nSession capture successful!")
print(f"Extracted {len(result['cookies'])} cookies")
print(f"Captured {len(result['network_requests'])} network requests")
# Save results to file
output_file = f"mexc_session_capture_{int(time.time())}.json"
with open(output_file, 'w') as f:
json.dump(result, f, indent=2)
print(f"Results saved to: {output_file}")
else:
print(f"Session capture failed: {result['error']}")
except KeyboardInterrupt:
print("\nSession capture interrupted by user")
except Exception as e:
print(f"Error: {e}")
finally:
automation.stop_browser()
if __name__ == "__main__":
main()

View File

@ -1,525 +0,0 @@
"""
MEXC Futures Web Client
This module implements a web-based client for MEXC futures trading
since their official API doesn't support futures (leverage) trading.
It mimics browser behavior by replicating the exact HTTP requests
that the web interface makes.
"""
import logging
import requests
import time
import json
import hmac
import hashlib
import base64
from typing import Dict, List, Optional, Any
from datetime import datetime
import uuid
from urllib.parse import urlencode
import glob
import os
logger = logging.getLogger(__name__)
class MEXCSessionManager:
def __init__(self):
self.captcha_token = None
def get_captcha_token(self) -> str:
return self.captcha_token if self.captcha_token else ""
def save_captcha_token(self, token: str):
self.captcha_token = token
logger.info("MEXC: Captcha token saved in session manager")
class MEXCFuturesWebClient:
"""
MEXC Futures Web Client that mimics browser behavior for futures trading.
Since MEXC's official API doesn't support futures, this client replicates
the exact HTTP requests made by their web interface.
"""
def __init__(self, api_key: str, api_secret: str, user_id: str, base_url: str = 'https://www.mexc.com', headless: bool = True):
"""
Initialize the MEXC Futures Web Client
Args:
api_key: API key for authentication
api_secret: API secret for authentication
user_id: User ID for authentication
base_url: Base URL for the MEXC website
headless: Whether to run the browser in headless mode
"""
self.api_key = api_key
self.api_secret = api_secret
self.user_id = user_id
self.base_url = base_url
self.is_authenticated = False
self.headless = headless
self.session = requests.Session()
self.session_manager = MEXCSessionManager() # Adding session_manager attribute
self.captcha_url = f'{base_url}/ucgateway/captcha_api'
self.futures_api_url = "https://futures.mexc.com/api/v1"
# Setup default headers that mimic a real browser
self.setup_browser_headers()
def setup_browser_headers(self):
"""Setup default headers that mimic Chrome browser"""
self.session.headers.update({
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36',
'Accept': '*/*',
'Accept-Language': 'en-GB,en-US;q=0.9,en;q=0.8',
'Accept-Encoding': 'gzip, deflate, br',
'sec-ch-ua': '"Chromium";v="136", "Google Chrome";v="136", "Not.A/Brand";v="99"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Windows"',
'sec-fetch-dest': 'empty',
'sec-fetch-mode': 'cors',
'sec-fetch-site': 'same-origin',
'Cache-Control': 'no-cache',
'Pragma': 'no-cache',
'Referer': f'{self.base_url}/en-GB/futures/ETH_USDT?type=linear_swap',
'Language': 'English',
'X-Language': 'en-GB',
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
'trochilus-uid': str(self.user_id) if self.user_id is not None else ''
})
def load_session_cookies(self, cookies: Dict[str, str]):
"""
Load session cookies from browser
Args:
cookies: Dictionary of cookie name-value pairs
"""
for name, value in cookies.items():
self.session.cookies.set(name, value)
# Extract important session info from cookies
self.auth_token = cookies.get('uc_token')
self.user_id = cookies.get('u_id')
self.fingerprint = cookies.get('x-mxc-fingerprint')
self.visitor_id = cookies.get('mexc_fingerprint_visitorId')
if self.auth_token and self.user_id:
self.is_authenticated = True
logger.info("MEXC: Loaded authenticated session")
else:
logger.warning("MEXC: Session cookies incomplete - authentication may fail")
def extract_cookies_from_browser(self, cookie_string: str) -> Dict[str, str]:
"""
Extract cookies from a browser cookie string
Args:
cookie_string: Raw cookie string from browser (copy from Network tab)
Returns:
Dictionary of parsed cookies
"""
cookies = {}
cookie_pairs = cookie_string.split(';')
for pair in cookie_pairs:
if '=' in pair:
name, value = pair.strip().split('=', 1)
cookies[name] = value
return cookies
def verify_captcha(self, symbol: str, side: str, leverage: str) -> bool:
"""
Verify captcha for robot trading protection
Args:
symbol: Trading symbol (e.g., 'ETH_USDT')
side: 'openlong', 'closelong', 'openshort', 'closeshort'
leverage: Leverage string (e.g., '200X')
Returns:
bool: True if captcha verification successful
"""
if not self.is_authenticated:
logger.error("MEXC: Cannot verify captcha - not authenticated")
return False
# Build captcha endpoint URL
endpoint = f"robot.future.{side}.{symbol}.{leverage}"
url = f"{self.captcha_url}/{endpoint}"
# Attempt to get captcha token from session manager
captcha_token = self.session_manager.get_captcha_token()
if not captcha_token:
logger.warning("MEXC: No captcha token available, attempting to fetch from browser")
captcha_token = self._extract_captcha_token_from_browser()
if captcha_token:
self.session_manager.save_captcha_token(captcha_token)
else:
logger.error("MEXC: Failed to extract captcha token from browser")
return False
headers = {
'Content-Type': 'application/json',
'Language': 'en-GB',
'Referer': f'{self.base_url}/en-GB/futures/{symbol}?type=linear_swap',
'trochilus-uid': self.user_id if self.user_id else '',
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
'captcha-token': captcha_token
}
logger.info(f"MEXC: Verifying captcha for {endpoint}")
try:
response = self.session.get(url, headers=headers, timeout=10)
if response.status_code == 200:
data = response.json()
if data.get('success'):
logger.info(f"MEXC: Captcha verified successfully for {endpoint}")
return True
else:
logger.error(f"MEXC: Captcha verification failed for {endpoint}: {data}")
return False
else:
logger.error(f"MEXC: Captcha verification request failed with status {response.status_code}: {response.text}")
return False
except Exception as e:
logger.error(f"MEXC: Captcha verification error for {endpoint}: {str(e)}")
return False
def _extract_captcha_token_from_browser(self) -> str:
"""
Extract captcha token from browser session using stored cookies or requests.
This method looks for the most recent mexc_captcha_tokens JSON file to retrieve a token.
"""
try:
# Look for the most recent mexc_captcha_tokens file
captcha_files = glob.glob("mexc_captcha_tokens_*.json")
if not captcha_files:
logger.error("MEXC: No CAPTCHA token files found")
return ""
# Sort files by timestamp (most recent first)
latest_file = max(captcha_files, key=os.path.getctime)
logger.info(f"MEXC: Using CAPTCHA token file {latest_file}")
with open(latest_file, 'r') as f:
captcha_data = json.load(f)
if captcha_data and isinstance(captcha_data, list) and len(captcha_data) > 0:
# Return the most recent token
return captcha_data[0].get('token', '')
else:
logger.error("MEXC: No valid CAPTCHA tokens found in file")
return ""
except Exception as e:
logger.error(f"MEXC: Error extracting captcha token from browser data: {str(e)}")
return ""
def generate_signature(self, method: str, path: str, params: Dict[str, Any],
timestamp: int, nonce: int) -> str:
"""
Generate signature for MEXC futures API requests
This is reverse-engineered from the browser requests
"""
# This is a placeholder - the actual signature generation would need
# to be reverse-engineered from the browser's JavaScript
# For now, return empty string and rely on cookie authentication
return ""
def open_long_position(self, symbol: str, volume: float, leverage: int = 200,
price: Optional[float] = None) -> Dict[str, Any]:
"""
Open a long futures position
Args:
symbol: Trading symbol (e.g., 'ETH_USDT')
volume: Position size (contracts)
leverage: Leverage multiplier (default 200)
price: Limit price (None for market order)
Returns:
dict: Order response with order ID
"""
if not self.is_authenticated:
logger.error("MEXC: Cannot open position - not authenticated")
return {'success': False, 'error': 'Not authenticated'}
# First verify captcha
if not self.verify_captcha(symbol, 'openlong', f'{leverage}X'):
logger.error("MEXC: Captcha verification failed for opening long position")
return {'success': False, 'error': 'Captcha verification failed'}
# Prepare order parameters based on the request dump
timestamp = int(time.time() * 1000)
nonce = timestamp
order_data = {
'symbol': symbol,
'side': 1, # 1 = long, 2 = short
'openType': 2, # Open position
'type': '5', # Market order (might be '1' for limit)
'vol': volume,
'leverage': leverage,
'marketCeiling': False,
'priceProtect': '0',
'ts': timestamp,
'mhash': self._generate_mhash(), # This needs to be implemented
'mtoken': self.visitor_id
}
# Add price for limit orders
if price is not None:
order_data['price'] = price
order_data['type'] = '1' # Limit order
# Add encrypted parameters (these would need proper implementation)
order_data['p0'] = self._encrypt_p0(order_data) # Placeholder
order_data['k0'] = self._encrypt_k0(order_data) # Placeholder
order_data['chash'] = self._generate_chash(order_data) # Placeholder
# Setup headers for the order request
headers = {
'Authorization': self.auth_token,
'Content-Type': 'application/json',
'Language': 'English',
'x-language': 'en-GB',
'x-mxc-nonce': str(nonce),
'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce),
'trochilus-uid': self.user_id,
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
'Referer': 'https://www.mexc.com/'
}
# Make the order request
url = f"{self.futures_api_url}/private/order/create"
try:
# First make OPTIONS request (preflight)
options_response = self.session.options(url, headers=headers, timeout=10)
if options_response.status_code == 200:
# Now make the actual POST request
response = self.session.post(url, json=order_data, headers=headers, timeout=15)
if response.status_code == 200:
data = response.json()
if data.get('success') and data.get('code') == 0:
order_id = data.get('data', {}).get('orderId')
logger.info(f"MEXC: Long position opened successfully - Order ID: {order_id}")
return {
'success': True,
'order_id': order_id,
'timestamp': data.get('data', {}).get('ts'),
'symbol': symbol,
'side': 'long',
'volume': volume,
'leverage': leverage
}
else:
logger.error(f"MEXC: Order failed: {data}")
return {'success': False, 'error': data.get('msg', 'Unknown error')}
else:
logger.error(f"MEXC: Order request failed with status {response.status_code}")
return {'success': False, 'error': f'HTTP {response.status_code}'}
else:
logger.error(f"MEXC: OPTIONS preflight failed with status {options_response.status_code}")
return {'success': False, 'error': f'Preflight failed: HTTP {options_response.status_code}'}
except Exception as e:
logger.error(f"MEXC: Order execution error: {e}")
return {'success': False, 'error': str(e)}
def close_long_position(self, symbol: str, volume: float, leverage: int = 200,
price: Optional[float] = None) -> Dict[str, Any]:
"""
Close a long futures position
Args:
symbol: Trading symbol (e.g., 'ETH_USDT')
volume: Position size to close (contracts)
leverage: Leverage multiplier
price: Limit price (None for market order)
Returns:
dict: Order response
"""
if not self.is_authenticated:
logger.error("MEXC: Cannot close position - not authenticated")
return {'success': False, 'error': 'Not authenticated'}
# First verify captcha
if not self.verify_captcha(symbol, 'closelong', f'{leverage}X'):
logger.error("MEXC: Captcha verification failed for closing long position")
return {'success': False, 'error': 'Captcha verification failed'}
# Similar to open_long_position but with closeType instead of openType
timestamp = int(time.time() * 1000)
nonce = timestamp
order_data = {
'symbol': symbol,
'side': 2, # Close side is opposite
'closeType': 1, # Close position
'type': '5', # Market order
'vol': volume,
'leverage': leverage,
'marketCeiling': False,
'priceProtect': '0',
'ts': timestamp,
'mhash': self._generate_mhash(),
'mtoken': self.visitor_id
}
if price is not None:
order_data['price'] = price
order_data['type'] = '1'
order_data['p0'] = self._encrypt_p0(order_data)
order_data['k0'] = self._encrypt_k0(order_data)
order_data['chash'] = self._generate_chash(order_data)
return self._execute_order(order_data, 'close_long')
def open_short_position(self, symbol: str, volume: float, leverage: int = 200,
price: Optional[float] = None) -> Dict[str, Any]:
"""Open a short futures position"""
if not self.verify_captcha(symbol, 'openshort', f'{leverage}X'):
return {'success': False, 'error': 'Captcha verification failed'}
order_data = {
'symbol': symbol,
'side': 2, # 2 = short
'openType': 2,
'type': '5',
'vol': volume,
'leverage': leverage,
'marketCeiling': False,
'priceProtect': '0',
'ts': int(time.time() * 1000),
'mhash': self._generate_mhash(),
'mtoken': self.visitor_id
}
if price is not None:
order_data['price'] = price
order_data['type'] = '1'
order_data['p0'] = self._encrypt_p0(order_data)
order_data['k0'] = self._encrypt_k0(order_data)
order_data['chash'] = self._generate_chash(order_data)
return self._execute_order(order_data, 'open_short')
def close_short_position(self, symbol: str, volume: float, leverage: int = 200,
price: Optional[float] = None) -> Dict[str, Any]:
"""Close a short futures position"""
if not self.verify_captcha(symbol, 'closeshort', f'{leverage}X'):
return {'success': False, 'error': 'Captcha verification failed'}
order_data = {
'symbol': symbol,
'side': 1, # Close side is opposite
'closeType': 1,
'type': '5',
'vol': volume,
'leverage': leverage,
'marketCeiling': False,
'priceProtect': '0',
'ts': int(time.time() * 1000),
'mhash': self._generate_mhash(),
'mtoken': self.visitor_id
}
if price is not None:
order_data['price'] = price
order_data['type'] = '1'
order_data['p0'] = self._encrypt_p0(order_data)
order_data['k0'] = self._encrypt_k0(order_data)
order_data['chash'] = self._generate_chash(order_data)
return self._execute_order(order_data, 'close_short')
def _execute_order(self, order_data: Dict[str, Any], action: str) -> Dict[str, Any]:
"""Common order execution logic"""
timestamp = order_data['ts']
nonce = timestamp
headers = {
'Authorization': self.auth_token,
'Content-Type': 'application/json',
'Language': 'English',
'x-language': 'en-GB',
'x-mxc-nonce': str(nonce),
'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce),
'trochilus-uid': self.user_id,
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
'Referer': 'https://www.mexc.com/'
}
url = f"{self.futures_api_url}/private/order/create"
try:
response = self.session.post(url, json=order_data, headers=headers, timeout=15)
if response.status_code == 200:
data = response.json()
if data.get('success') and data.get('code') == 0:
order_id = data.get('data', {}).get('orderId')
logger.info(f"MEXC: {action} executed successfully - Order ID: {order_id}")
return {
'success': True,
'order_id': order_id,
'timestamp': data.get('data', {}).get('ts'),
'action': action
}
else:
logger.error(f"MEXC: {action} failed: {data}")
return {'success': False, 'error': data.get('msg', 'Unknown error')}
else:
logger.error(f"MEXC: {action} request failed with status {response.status_code}")
return {'success': False, 'error': f'HTTP {response.status_code}'}
except Exception as e:
logger.error(f"MEXC: {action} execution error: {e}")
return {'success': False, 'error': str(e)}
# Placeholder methods for encryption/hashing - these need proper implementation
def _generate_mhash(self) -> str:
"""Generate mhash parameter (needs reverse engineering)"""
return "a0015441fd4c3b6ba427b894b76cb7dd" # Placeholder from request dump
def _encrypt_p0(self, order_data: Dict[str, Any]) -> str:
"""Encrypt p0 parameter (needs reverse engineering)"""
return "placeholder_p0_encryption" # This needs proper implementation
def _encrypt_k0(self, order_data: Dict[str, Any]) -> str:
"""Encrypt k0 parameter (needs reverse engineering)"""
return "placeholder_k0_encryption" # This needs proper implementation
def _generate_chash(self, order_data: Dict[str, Any]) -> str:
"""Generate chash parameter (needs reverse engineering)"""
return "d6c64d28e362f314071b3f9d78ff7494d9cd7177ae0465e772d1840e9f7905d8" # Placeholder
def get_account_info(self) -> Dict[str, Any]:
"""Get account information including positions and balances"""
if not self.is_authenticated:
return {'success': False, 'error': 'Not authenticated'}
# This would need to be implemented by reverse engineering the account info endpoints
logger.info("MEXC: Account info endpoint not yet implemented")
return {'success': False, 'error': 'Not implemented'}
def get_open_positions(self) -> List[Dict[str, Any]]:
"""Get list of open futures positions"""
if not self.is_authenticated:
return []
# This would need to be implemented by reverse engineering the positions endpoint
logger.info("MEXC: Open positions endpoint not yet implemented")
return []

View File

@ -1,259 +0,0 @@
"""
MEXC Session Manager
Helper utilities for managing MEXC web sessions and extracting cookies from browser.
"""
import logging
import json
import re
from typing import Dict, Optional, Any
from pathlib import Path
logger = logging.getLogger(__name__)
class MEXCSessionManager:
"""
Helper class for managing MEXC web sessions and extracting browser cookies
"""
def __init__(self):
self.session_file = Path("mexc_session.json")
def extract_cookies_from_network_tab(self, cookie_header: str) -> Dict[str, str]:
"""
Extract cookies from browser Network tab cookie header
Args:
cookie_header: Raw cookie string from browser (copy from Request Headers)
Returns:
Dictionary of parsed cookies
"""
cookies = {}
# Remove 'Cookie: ' prefix if present
if cookie_header.startswith('Cookie: '):
cookie_header = cookie_header[8:]
elif cookie_header.startswith('cookie: '):
cookie_header = cookie_header[8:]
# Split by semicolon and parse each cookie
cookie_pairs = cookie_header.split(';')
for pair in cookie_pairs:
pair = pair.strip()
if '=' in pair:
name, value = pair.split('=', 1)
cookies[name.strip()] = value.strip()
logger.info(f"Extracted {len(cookies)} cookies from browser")
return cookies
def validate_session_cookies(self, cookies: Dict[str, str]) -> bool:
"""
Validate that essential cookies are present for authentication
Args:
cookies: Dictionary of cookie name-value pairs
Returns:
bool: True if cookies appear valid for authentication
"""
required_cookies = [
'uc_token', # User authentication token
'u_id', # User ID
'x-mxc-fingerprint', # Browser fingerprint
'mexc_fingerprint_visitorId' # Visitor ID
]
missing_cookies = []
for cookie_name in required_cookies:
if cookie_name not in cookies or not cookies[cookie_name]:
missing_cookies.append(cookie_name)
if missing_cookies:
logger.warning(f"Missing required cookies: {missing_cookies}")
return False
logger.info("All required cookies are present")
return True
def save_session(self, cookies: Dict[str, str], metadata: Optional[Dict[str, Any]] = None):
"""
Save session cookies to file for reuse
Args:
cookies: Dictionary of cookies to save
metadata: Optional metadata about the session
"""
session_data = {
'cookies': cookies,
'metadata': metadata or {},
'timestamp': int(time.time())
}
try:
with open(self.session_file, 'w') as f:
json.dump(session_data, f, indent=2)
logger.info(f"Session saved to {self.session_file}")
except Exception as e:
logger.error(f"Failed to save session: {e}")
def load_session(self) -> Optional[Dict[str, str]]:
"""
Load session cookies from file
Returns:
Dictionary of cookies if successful, None otherwise
"""
if not self.session_file.exists():
logger.info("No saved session found")
return None
try:
with open(self.session_file, 'r') as f:
session_data = json.load(f)
cookies = session_data.get('cookies', {})
timestamp = session_data.get('timestamp', 0)
# Check if session is too old (24 hours)
import time
if time.time() - timestamp > 24 * 3600:
logger.warning("Saved session is too old (>24h), may be expired")
if self.validate_session_cookies(cookies):
logger.info("Loaded valid session from file")
return cookies
else:
logger.warning("Loaded session has invalid cookies")
return None
except Exception as e:
logger.error(f"Failed to load session: {e}")
return None
def extract_from_curl_command(self, curl_command: str) -> Dict[str, str]:
"""
Extract cookies from a curl command copied from browser
Args:
curl_command: Complete curl command from browser "Copy as cURL"
Returns:
Dictionary of extracted cookies
"""
cookies = {}
# Find cookie header in curl command
cookie_match = re.search(r'-H [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE)
if not cookie_match:
cookie_match = re.search(r'--header [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE)
if cookie_match:
cookie_header = cookie_match.group(1)
cookies = self.extract_cookies_from_network_tab(cookie_header)
logger.info(f"Extracted {len(cookies)} cookies from curl command")
else:
logger.warning("No cookie header found in curl command")
return cookies
def print_cookie_extraction_guide(self):
"""Print instructions for extracting cookies from browser"""
print("\n" + "="*80)
print("MEXC COOKIE EXTRACTION GUIDE")
print("="*80)
print("""
To extract cookies from your browser for MEXC futures trading:
METHOD 1: Browser Network Tab
1. Open MEXC futures page and log in: https://www.mexc.com/en-GB/futures/ETH_USDT
2. Open browser Developer Tools (F12)
3. Go to Network tab
4. Try to place a small futures trade (it will fail, but we need the request)
5. Find the request to 'futures.mexc.com' in the Network tab
6. Right-click on the request -> Copy -> Copy request headers
7. Find the 'Cookie:' line and copy everything after 'Cookie: '
METHOD 2: Copy as cURL
1. Follow steps 1-5 above
2. Right-click on the futures API request -> Copy -> Copy as cURL
3. Paste the entire cURL command
METHOD 3: Manual Cookie Extraction
1. While logged into MEXC, press F12 -> Application/Storage tab
2. On the left, expand 'Cookies' -> click on 'https://www.mexc.com'
3. Copy the values for these important cookies:
- uc_token
- u_id
- x-mxc-fingerprint
- mexc_fingerprint_visitorId
IMPORTANT NOTES:
- Cookies expire after some time (usually 24 hours)
- You must be logged into MEXC futures (not just spot trading)
- Keep your cookies secure - they provide access to your account
- Test with small amounts first
Example usage:
session_manager = MEXCSessionManager()
# Method 1: From cookie header
cookie_header = "uc_token=ABC123; u_id=DEF456; ..."
cookies = session_manager.extract_cookies_from_network_tab(cookie_header)
# Method 2: From cURL command
curl_cmd = "curl 'https://futures.mexc.com/...' -H 'cookie: uc_token=ABC123...'"
cookies = session_manager.extract_from_curl_command(curl_cmd)
# Save session for reuse
session_manager.save_session(cookies)
""")
print("="*80)
if __name__ == "__main__":
# When run directly, show the extraction guide
import time
manager = MEXCSessionManager()
manager.print_cookie_extraction_guide()
print("\nWould you like to:")
print("1. Load saved session")
print("2. Extract cookies from clipboard")
print("3. Exit")
choice = input("\nEnter choice (1-3): ").strip()
if choice == "1":
cookies = manager.load_session()
if cookies:
print(f"\nLoaded {len(cookies)} cookies from saved session")
if manager.validate_session_cookies(cookies):
print("Session appears valid for trading")
else:
print("Warning: Session may be incomplete or expired")
else:
print("No valid saved session found")
elif choice == "2":
print("\nPaste your cookie header or cURL command:")
user_input = input().strip()
if user_input.startswith('curl'):
cookies = manager.extract_from_curl_command(user_input)
else:
cookies = manager.extract_cookies_from_network_tab(user_input)
if cookies and manager.validate_session_cookies(cookies):
print(f"\nSuccessfully extracted {len(cookies)} valid cookies")
save = input("Save session for reuse? (y/n): ").strip().lower()
if save == 'y':
manager.save_session(cookies)
else:
print("Failed to extract valid cookies")
else:
print("Goodbye!")

View File

@ -1,346 +0,0 @@
#!/usr/bin/env python3
"""
Test MEXC Futures Web Client
This script demonstrates how to use the MEXC Futures Web Client
for futures trading that isn't supported by their official API.
IMPORTANT: This requires extracting cookies from your browser session.
"""
import logging
import sys
import os
import time
import json
import uuid
# Add the project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from mexc_futures_client import MEXCFuturesWebClient
from session_manager import MEXCSessionManager
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Constants
SYMBOL = "ETH_USDT"
LEVERAGE = 300
CREDENTIALS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json')
# Read credentials from mexc_credentials.json in JSON format
def load_credentials():
credentials_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json')
cookies = {}
captcha_token_open = ''
captcha_token_close = ''
try:
with open(credentials_file, 'r') as f:
data = json.load(f)
cookies = data.get('credentials', {}).get('cookies', {})
captcha_token_open = data.get('credentials', {}).get('captcha_token_open', '')
captcha_token_close = data.get('credentials', {}).get('captcha_token_close', '')
logger.info(f"Loaded credentials from {credentials_file}")
except Exception as e:
logger.error(f"Error loading credentials: {e}")
return cookies, captcha_token_open, captcha_token_close
def test_basic_connection():
"""Test basic connection and authentication"""
logger.info("Testing MEXC Futures Web Client")
# Initialize session manager
session_manager = MEXCSessionManager()
# Try to load saved session first
cookies = session_manager.load_session()
if not cookies:
# Explicitly load the cookies from the file we have
cookies_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'mexc_cookies_20250703_003625.json')
if os.path.exists(cookies_file):
try:
with open(cookies_file, 'r') as f:
cookies = json.load(f)
logger.info(f"Loaded cookies from {cookies_file}")
except Exception as e:
logger.error(f"Failed to load cookies from {cookies_file}: {e}")
cookies = None
else:
logger.error(f"Cookies file not found at {cookies_file}")
cookies = None
if not cookies:
print("\nNo saved session found. You need to extract cookies from your browser.")
session_manager.print_cookie_extraction_guide()
print("\nPaste your cookie header or cURL command (or press Enter to exit):")
user_input = input().strip()
if not user_input:
print("No input provided. Exiting.")
return False
# Extract cookies from user input
if user_input.startswith('curl'):
cookies = session_manager.extract_from_curl_command(user_input)
else:
cookies = session_manager.extract_cookies_from_network_tab(user_input)
if not cookies:
logger.error("Failed to extract cookies from input")
return False
# Validate and save session
if session_manager.validate_session_cookies(cookies):
session_manager.save_session(cookies)
logger.info("Session saved for future use")
else:
logger.warning("Extracted cookies may be incomplete")
# Initialize the web client
client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='', base_url='https://www.mexc.com', headless=True)
# Load cookies into the client's session
for name, value in cookies.items():
client.session.cookies.set(name, value)
# Update headers to include additional parameters from captured requests
client.session.headers.update({
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
'trochilus-uid': cookies.get('u_id', ''),
'Referer': 'https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap',
'Language': 'English',
'X-Language': 'en-GB'
})
if not client.is_authenticated:
logger.error("Failed to authenticate with extracted cookies")
return False
logger.info("Successfully authenticated with MEXC")
logger.info(f"User ID: {client.user_id}")
logger.info(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "No auth token")
return True
def test_captcha_verification(client: MEXCFuturesWebClient):
"""Test captcha verification system"""
logger.info("Testing captcha verification...")
# Test captcha for ETH_USDT long position with 200x leverage
success = client.verify_captcha('ETH_USDT', 'openlong', '200X')
if success:
logger.info("Captcha verification successful")
else:
logger.warning("Captcha verification failed - this may be normal if no position is being opened")
return success
def test_position_opening(client: MEXCFuturesWebClient, dry_run: bool = True):
"""Test opening a position (dry run by default)"""
if dry_run:
logger.info("DRY RUN: Testing position opening (no actual trade)")
else:
logger.warning("LIVE TRADING: Opening actual position!")
symbol = 'ETH_USDT'
volume = 1 # Small test position
leverage = 200
logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x")
if not dry_run:
result = client.open_long_position(symbol, volume, leverage)
if result['success']:
logger.info(f"Position opened successfully!")
logger.info(f"Order ID: {result['order_id']}")
logger.info(f"Timestamp: {result['timestamp']}")
return True
else:
logger.error(f"Failed to open position: {result['error']}")
return False
else:
logger.info("DRY RUN: Would attempt to open position here")
# Test just the captcha verification part
return client.verify_captcha(symbol, 'openlong', f'{leverage}X')
def test_position_opening_live(client):
symbol = "ETH_USDT"
volume = 1 # Small volume for testing
leverage = 200
logger.info(f"LIVE TRADING: Opening actual position!")
logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x")
result = client.open_long_position(symbol, volume, leverage)
if result.get('success'):
logger.info(f"Successfully opened position: {result}")
else:
logger.error(f"Failed to open position: {result.get('error', 'Unknown error')}")
def interactive_menu(client: MEXCFuturesWebClient):
"""Interactive menu for testing different functions"""
while True:
print("\n" + "="*50)
print("MEXC Futures Web Client Test Menu")
print("="*50)
print("1. Test captcha verification")
print("2. Test position opening (DRY RUN)")
print("3. Test position opening (LIVE - BE CAREFUL!)")
print("4. Test position closing (DRY RUN)")
print("5. Show session info")
print("6. Refresh session")
print("0. Exit")
choice = input("\nEnter choice (0-6): ").strip()
if choice == "1":
test_captcha_verification(client)
elif choice == "2":
test_position_opening(client, dry_run=True)
elif choice == "3":
test_position_opening_live(client)
elif choice == "4":
logger.info("DRY RUN: Position closing test")
success = client.verify_captcha('ETH_USDT', 'closelong', '200X')
if success:
logger.info("DRY RUN: Would close position here")
else:
logger.warning("Captcha verification failed for position closing")
elif choice == "5":
print(f"\nSession Information:")
print(f"Authenticated: {client.is_authenticated}")
print(f"User ID: {client.user_id}")
print(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "None")
print(f"Fingerprint: {client.fingerprint}")
print(f"Visitor ID: {client.visitor_id}")
elif choice == "6":
session_manager = MEXCSessionManager()
session_manager.print_cookie_extraction_guide()
elif choice == "0":
print("Goodbye!")
break
else:
print("Invalid choice. Please try again.")
def main():
"""Main test function"""
print("MEXC Futures Web Client Test")
print("WARNING: This is experimental software for futures trading")
print("Use at your own risk and test with small amounts first!")
# Load cookies and tokens
cookies, captcha_token_open, captcha_token_close = load_credentials()
if not cookies:
logger.error("Failed to load cookies from credentials file")
sys.exit(1)
# Initialize client with loaded cookies and tokens
client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='')
# Load cookies into the client's session
for name, value in cookies.items():
client.session.cookies.set(name, value)
# Set captcha tokens
client.captcha_token_open = captcha_token_open
client.captcha_token_close = captcha_token_close
# Try to load credentials from the new JSON file
try:
with open(CREDENTIALS_FILE, 'r') as f:
credentials_data = json.load(f)
cookies = credentials_data['credentials']['cookies']
captcha_token_open = credentials_data['credentials']['captcha_token_open']
captcha_token_close = credentials_data['credentials']['captcha_token_close']
client.load_session_cookies(cookies)
client.session_manager.save_captcha_token(captcha_token_open) # Assuming this is for opening
except FileNotFoundError:
logger.error(f"Credentials file not found at {CREDENTIALS_FILE}")
return False
except json.JSONDecodeError as e:
logger.error(f"Error loading credentials: {e}")
return False
except KeyError as e:
logger.error(f"Missing key in credentials file: {e}")
return False
if not client.is_authenticated:
logger.error("Client not authenticated. Please ensure valid cookies and tokens are in mexc_credentials.json")
return False
# Test connection and authentication
logger.info("Successfully authenticated with MEXC")
# Set leverage
leverage_response = client.update_leverage(symbol=SYMBOL, leverage=LEVERAGE)
if leverage_response and leverage_response.get('code') == 200:
logger.info(f"Leverage set to {LEVERAGE}x for {SYMBOL}")
else:
logger.error(f"Failed to set leverage: {leverage_response}")
sys.exit(1)
# Get current price
ticker = client.get_ticker_data(symbol=SYMBOL)
if ticker and ticker.get('code') == 200:
current_price = float(ticker['data']['last'])
logger.info(f"Current {SYMBOL} price: {current_price}")
else:
logger.error(f"Failed to get ticker data: {ticker}")
sys.exit(1)
# Calculate order size for a small test trade (e.g., $10 worth)
trade_usdt = 10.0
order_qty = round((trade_usdt / current_price) * LEVERAGE, 3)
logger.info(f"Calculated order quantity: {order_qty} {SYMBOL} for ~${trade_usdt} at {LEVERAGE}x")
# Test 1: Open LONG position
logger.info(f"Opening LONG position for {SYMBOL} at {current_price} with qty {order_qty}")
open_long_order = client.create_order(
symbol=SYMBOL,
side=1, # 1 for BUY
position_side=1, # 1 for LONG
order_type=1, # 1 for LIMIT
price=current_price,
vol=order_qty
)
if open_long_order and open_long_order.get('code') == 200:
logger.info(f"✅ Successfully opened LONG position: {open_long_order['data']}")
else:
logger.error(f"❌ Failed to open LONG position: {open_long_order}")
sys.exit(1)
# Test 2: Close LONG position
logger.info(f"Closing LONG position for {SYMBOL}")
close_long_order = client.create_order(
symbol=SYMBOL,
side=2, # 2 for SELL
position_side=1, # 1 for LONG
order_type=1, # 1 for LIMIT
price=current_price,
vol=order_qty,
reduce_only=True
)
if close_long_order and close_long_order.get('code') == 200:
logger.info(f"✅ Successfully closed LONG position: {close_long_order['data']}")
else:
logger.error(f"❌ Failed to close LONG position: {close_long_order}")
sys.exit(1)
logger.info("All tests completed successfully!")
if __name__ == "__main__":
main()

View File

@ -1,595 +0,0 @@
"""
Negative Case Trainer - Intensive Training on Losing Trades
This module focuses on learning from losses to prevent future mistakes.
Stores negative cases in testcases/negative folder for reuse and retraining.
Supports simultaneous inference and training.
"""
import os
import json
import logging
import pickle
import threading
import time
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
from collections import deque
import numpy as np
import pandas as pd
# Import checkpoint management
import torch
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.training_integration import get_training_integration
logger = logging.getLogger(__name__)
@dataclass
class NegativeCase:
"""Represents a losing trade case for intensive training"""
case_id: str
timestamp: datetime
symbol: str
action: str # 'BUY' or 'SELL'
entry_price: float
exit_price: float
loss_amount: float
loss_percentage: float
confidence_used: float
market_state_before: Dict[str, Any]
market_state_after: Dict[str, Any]
tick_data: List[Dict[str, Any]] # 15 minutes of tick data around the trade
technical_indicators: Dict[str, float]
what_should_have_been_done: str # 'HOLD', 'OPPOSITE', 'WAIT'
lesson_learned: str
training_priority: int # 1-5, 5 being highest priority
retraining_count: int = 0
last_retrained: Optional[datetime] = None
@dataclass
class TrainingSession:
"""Represents an intensive training session on negative cases"""
session_id: str
start_time: datetime
cases_trained: List[str] # case_ids
epochs_completed: int
loss_improvement: float
accuracy_improvement: float
inference_paused: bool = False
training_active: bool = True
class NegativeCaseTrainer:
"""
Intensive trainer focused on learning from losing trades with checkpoint management
Features:
- Stores all losing trades as negative cases
- Intensive retraining on losses
- Simultaneous inference and training
- Persistent storage in testcases/negative
- Priority-based training (bigger losses = higher priority)
- Checkpoint management for training progress
"""
def __init__(self, storage_dir: str = "testcases/negative",
model_name: str = "negative_case_trainer", enable_checkpoints: bool = True):
self.storage_dir = storage_dir
self.stored_cases: List[NegativeCase] = []
self.training_queue = deque(maxlen=1000)
self.training_lock = threading.Lock()
self.inference_lock = threading.Lock()
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = get_training_integration() if enable_checkpoints else None
self.training_session_count = 0
self.best_loss_reduction = 0.0
self.checkpoint_frequency = 25 # Save checkpoint every 25 training sessions
# Training configuration
self.max_concurrent_training = 3 # Max parallel training sessions
self.intensive_training_epochs = 50 # Epochs per negative case
self.priority_multiplier = 2.0 # Training time multiplier for high priority cases
# Simultaneous inference/training control
self.inference_active = True
self.training_active = False
self.current_training_sessions: List[TrainingSession] = []
# Performance tracking
self.total_cases_processed = 0
self.total_training_time = 0.0
self.accuracy_improvements = []
# Initialize storage
self._initialize_storage()
self._load_existing_cases()
# Load best checkpoint if available
if self.enable_checkpoints:
self.load_best_checkpoint()
# Start background training thread
self.training_thread = threading.Thread(target=self._background_training_loop, daemon=True)
self.training_thread.start()
logger.info(f"NegativeCaseTrainer initialized with {len(self.stored_cases)} existing cases")
logger.info(f"Storage directory: {self.storage_dir}")
logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
logger.info("Background training thread started")
def _initialize_storage(self):
"""Initialize storage directories"""
try:
os.makedirs(self.storage_dir, exist_ok=True)
os.makedirs(f"{self.storage_dir}/cases", exist_ok=True)
os.makedirs(f"{self.storage_dir}/sessions", exist_ok=True)
os.makedirs(f"{self.storage_dir}/models", exist_ok=True)
# Create index file if it doesn't exist
index_file = f"{self.storage_dir}/case_index.json"
if not os.path.exists(index_file):
with open(index_file, 'w') as f:
json.dump({"cases": [], "last_updated": datetime.now().isoformat()}, f)
logger.info(f"Storage initialized at {self.storage_dir}")
except Exception as e:
logger.error(f"Error initializing storage: {e}")
def _load_existing_cases(self):
"""Load existing negative cases from storage"""
try:
index_file = f"{self.storage_dir}/case_index.json"
if os.path.exists(index_file):
with open(index_file, 'r') as f:
index_data = json.load(f)
for case_info in index_data.get("cases", []):
case_file = f"{self.storage_dir}/cases/{case_info['case_id']}.pkl"
if os.path.exists(case_file):
try:
with open(case_file, 'rb') as f:
case = pickle.load(f)
self.stored_cases.append(case)
except Exception as e:
logger.warning(f"Error loading case {case_info['case_id']}: {e}")
logger.info(f"Loaded {len(self.stored_cases)} existing negative cases")
except Exception as e:
logger.error(f"Error loading existing cases: {e}")
def add_losing_trade(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
"""
Add a losing trade as a negative case for intensive training
Args:
trade_info: Trade information including P&L
market_data: Market state and tick data around the trade
Returns:
case_id: Unique identifier for the negative case
"""
try:
# Generate unique case ID
case_id = f"loss_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{trade_info['symbol'].replace('/', '')}"
# Calculate loss metrics
loss_amount = abs(trade_info.get('pnl', 0))
loss_percentage = (loss_amount / trade_info.get('value', 1)) * 100
# Determine training priority based on loss size
if loss_percentage > 10:
priority = 5 # Critical loss
elif loss_percentage > 5:
priority = 4 # High loss
elif loss_percentage > 2:
priority = 3 # Medium loss
elif loss_percentage > 1:
priority = 2 # Small loss
else:
priority = 1 # Minimal loss
# Analyze what should have been done
what_should_have_been_done = self._analyze_optimal_action(trade_info, market_data)
lesson_learned = self._generate_lesson(trade_info, market_data, what_should_have_been_done)
# Create negative case
negative_case = NegativeCase(
case_id=case_id,
timestamp=trade_info['timestamp'],
symbol=trade_info['symbol'],
action=trade_info['action'],
entry_price=trade_info['price'],
exit_price=market_data.get('exit_price', trade_info['price']),
loss_amount=loss_amount,
loss_percentage=loss_percentage,
confidence_used=trade_info.get('confidence', 0.5),
market_state_before=market_data.get('state_before', {}),
market_state_after=market_data.get('state_after', {}),
tick_data=market_data.get('tick_data', []),
technical_indicators=market_data.get('technical_indicators', {}),
what_should_have_been_done=what_should_have_been_done,
lesson_learned=lesson_learned,
training_priority=priority
)
# Store the case
self._store_case(negative_case)
# Add to training queue with priority
with self.training_lock:
self.training_queue.append(negative_case)
self.stored_cases.append(negative_case)
logger.error(f"NEGATIVE CASE ADDED: {case_id} | Loss: ${loss_amount:.2f} ({loss_percentage:.1f}%) | Priority: {priority}")
logger.error(f"Lesson: {lesson_learned}")
return case_id
except Exception as e:
logger.error(f"Error adding losing trade: {e}")
return ""
def _analyze_optimal_action(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
"""Analyze what the optimal action should have been"""
try:
# Simple analysis based on price movement
entry_price = trade_info['price']
exit_price = market_data.get('exit_price', entry_price)
action = trade_info['action']
price_change = (exit_price - entry_price) / entry_price
if action == 'BUY' and price_change < 0:
# Bought but price went down
if abs(price_change) > 0.005: # >0.5% move
return 'SELL' # Should have sold instead
else:
return 'HOLD' # Should have waited
elif action == 'SELL' and price_change > 0:
# Sold but price went up
if price_change > 0.005: # >0.5% move
return 'BUY' # Should have bought instead
else:
return 'HOLD' # Should have waited
else:
return 'HOLD' # Should have done nothing
except Exception as e:
logger.error(f"Error analyzing optimal action: {e}")
return 'HOLD'
def _generate_lesson(self, trade_info: Dict[str, Any], market_data: Dict[str, Any], optimal_action: str) -> str:
"""Generate a lesson learned from the losing trade"""
try:
action = trade_info['action']
symbol = trade_info['symbol']
loss_pct = (abs(trade_info.get('pnl', 0)) / trade_info.get('value', 1)) * 100
confidence = trade_info.get('confidence', 0.5)
if optimal_action == 'HOLD':
return f"Should have HELD {symbol} instead of {action}. Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss."
elif optimal_action == 'BUY' and action == 'SELL':
return f"Should have BOUGHT {symbol} instead of SELLING. Market moved opposite to prediction."
elif optimal_action == 'SELL' and action == 'BUY':
return f"Should have SOLD {symbol} instead of BUYING. Market moved opposite to prediction."
else:
return f"Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss on {action} {symbol}."
except Exception as e:
logger.error(f"Error generating lesson: {e}")
return "Learn from this loss to improve future decisions."
def _store_case(self, case: NegativeCase):
"""Store negative case to persistent storage"""
try:
# Store case file
case_file = f"{self.storage_dir}/cases/{case.case_id}.pkl"
with open(case_file, 'wb') as f:
pickle.dump(case, f)
# Update index
index_file = f"{self.storage_dir}/case_index.json"
with open(index_file, 'r') as f:
index_data = json.load(f)
# Add case to index
case_info = {
'case_id': case.case_id,
'timestamp': case.timestamp.isoformat(),
'symbol': case.symbol,
'loss_amount': case.loss_amount,
'loss_percentage': case.loss_percentage,
'training_priority': case.training_priority,
'retraining_count': case.retraining_count
}
index_data['cases'].append(case_info)
index_data['last_updated'] = datetime.now().isoformat()
with open(index_file, 'w') as f:
json.dump(index_data, f, indent=2)
logger.info(f"Stored negative case: {case.case_id}")
except Exception as e:
logger.error(f"Error storing case: {e}")
def _background_training_loop(self):
"""Background loop for intensive training on negative cases"""
logger.info("Background training loop started")
while True:
try:
# Check if we have cases to train on
with self.training_lock:
if not self.training_queue:
time.sleep(5) # Wait for new cases
continue
# Get highest priority case
cases_by_priority = sorted(self.training_queue, key=lambda x: x.training_priority, reverse=True)
case_to_train = cases_by_priority[0]
self.training_queue.remove(case_to_train)
# Start intensive training session
self._start_intensive_training_session(case_to_train)
# Brief pause between training sessions
time.sleep(2)
except Exception as e:
logger.error(f"Error in background training loop: {e}")
time.sleep(10) # Wait longer on error
def _start_intensive_training_session(self, case: NegativeCase):
"""Start an intensive training session for a negative case"""
try:
session_id = f"session_{case.case_id}_{int(time.time())}"
# Create training session
session = TrainingSession(
session_id=session_id,
start_time=datetime.now(),
cases_trained=[case.case_id],
epochs_completed=0,
loss_improvement=0.0,
accuracy_improvement=0.0
)
self.current_training_sessions.append(session)
self.training_active = True
logger.warning(f"INTENSIVE TRAINING STARTED: {session_id}")
logger.warning(f"Training on loss case: {case.case_id} (Priority: {case.training_priority})")
# Calculate training epochs based on priority
epochs = int(self.intensive_training_epochs * case.training_priority * self.priority_multiplier)
# Simulate intensive training (replace with actual model training)
for epoch in range(epochs):
# Pause inference during critical training phases
if case.training_priority >= 4 and epoch % 10 == 0:
with self.inference_lock:
session.inference_paused = True
time.sleep(0.1) # Brief pause for critical training
session.inference_paused = False
# Simulate training step
session.epochs_completed = epoch + 1
# Log progress for high priority cases
if case.training_priority >= 4 and epoch % 10 == 0:
logger.warning(f"Intensive training progress: {epoch}/{epochs} epochs ({case.case_id})")
time.sleep(0.05) # Simulate training time
# Update case retraining info
case.retraining_count += 1
case.last_retrained = datetime.now()
# Calculate improvements (simulated)
session.loss_improvement = np.random.uniform(0.1, 0.5) # 10-50% improvement
session.accuracy_improvement = np.random.uniform(0.05, 0.2) # 5-20% improvement
# Store training session results
self._store_training_session(session)
# Update statistics
self.total_cases_processed += 1
self.total_training_time += (datetime.now() - session.start_time).total_seconds()
self.accuracy_improvements.append(session.accuracy_improvement)
# Remove from active sessions
self.current_training_sessions.remove(session)
if not self.current_training_sessions:
self.training_active = False
logger.warning(f"INTENSIVE TRAINING COMPLETED: {session_id}")
logger.warning(f"Epochs: {session.epochs_completed} | Loss improvement: {session.loss_improvement:.1%} | Accuracy improvement: {session.accuracy_improvement:.1%}")
except Exception as e:
logger.error(f"Error in intensive training session: {e}")
def _store_training_session(self, session: TrainingSession):
"""Store training session results"""
try:
session_file = f"{self.storage_dir}/sessions/{session.session_id}.json"
session_data = {
'session_id': session.session_id,
'start_time': session.start_time.isoformat(),
'end_time': datetime.now().isoformat(),
'cases_trained': session.cases_trained,
'epochs_completed': session.epochs_completed,
'loss_improvement': session.loss_improvement,
'accuracy_improvement': session.accuracy_improvement
}
with open(session_file, 'w') as f:
json.dump(session_data, f, indent=2)
except Exception as e:
logger.error(f"Error storing training session: {e}")
def can_inference_proceed(self) -> bool:
"""Check if inference can proceed (not blocked by critical training)"""
with self.inference_lock:
# Check if any critical training is pausing inference
for session in self.current_training_sessions:
if session.inference_paused:
return False
return True
def get_training_stats(self) -> Dict[str, Any]:
"""Get training statistics"""
try:
avg_accuracy_improvement = np.mean(self.accuracy_improvements) if self.accuracy_improvements else 0.0
return {
'total_negative_cases': len(self.stored_cases),
'cases_in_queue': len(self.training_queue),
'total_cases_processed': self.total_cases_processed,
'total_training_time': self.total_training_time,
'avg_accuracy_improvement': avg_accuracy_improvement,
'active_training_sessions': len(self.current_training_sessions),
'training_active': self.training_active,
'high_priority_cases': len([c for c in self.stored_cases if c.training_priority >= 4]),
'storage_directory': self.storage_dir
}
except Exception as e:
logger.error(f"Error getting training stats: {e}")
return {}
def get_recent_lessons(self, count: int = 5) -> List[str]:
"""Get recent lessons learned from negative cases"""
try:
recent_cases = sorted(self.stored_cases, key=lambda x: x.timestamp, reverse=True)[:count]
return [case.lesson_learned for case in recent_cases]
except Exception as e:
logger.error(f"Error getting recent lessons: {e}")
return []
def retrain_all_cases(self):
"""Retrain all stored negative cases (for periodic retraining)"""
try:
logger.warning("RETRAINING ALL NEGATIVE CASES - This may take a while...")
with self.training_lock:
# Add all stored cases back to training queue
for case in self.stored_cases:
if case not in self.training_queue:
self.training_queue.append(case)
logger.warning(f"Added {len(self.stored_cases)} cases to retraining queue")
except Exception as e:
logger.error(f"Error retraining all cases: {e}")
def load_best_checkpoint(self):
"""Load the best checkpoint for this negative case trainer"""
try:
if not self.enable_checkpoints:
return
result = load_best_checkpoint(self.model_name)
if result:
file_path, metadata = result
checkpoint = torch.load(file_path, map_location='cpu')
# Load training state
if 'training_session_count' in checkpoint:
self.training_session_count = checkpoint['training_session_count']
if 'best_loss_reduction' in checkpoint:
self.best_loss_reduction = checkpoint['best_loss_reduction']
if 'total_cases_processed' in checkpoint:
self.total_cases_processed = checkpoint['total_cases_processed']
if 'total_training_time' in checkpoint:
self.total_training_time = checkpoint['total_training_time']
if 'accuracy_improvements' in checkpoint:
self.accuracy_improvements = checkpoint['accuracy_improvements']
logger.info(f"Loaded NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
logger.info(f"Session: {self.training_session_count}, Best loss reduction: {self.best_loss_reduction:.4f}")
except Exception as e:
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
def save_checkpoint(self, loss_improvement: float = 0.0, force_save: bool = False):
"""Save checkpoint if performance improved or forced"""
try:
if not self.enable_checkpoints:
return False
self.training_session_count += 1
# Update best loss reduction
improved = False
if loss_improvement > self.best_loss_reduction:
self.best_loss_reduction = loss_improvement
improved = True
# Save checkpoint if improved, forced, or at regular intervals
should_save = (
force_save or
improved or
self.training_session_count % self.checkpoint_frequency == 0
)
if should_save:
# Prepare checkpoint data
checkpoint_data = {
'training_session_count': self.training_session_count,
'best_loss_reduction': self.best_loss_reduction,
'total_cases_processed': self.total_cases_processed,
'total_training_time': self.total_training_time,
'accuracy_improvements': self.accuracy_improvements,
'storage_dir': self.storage_dir,
'max_concurrent_training': self.max_concurrent_training,
'intensive_training_epochs': self.intensive_training_epochs
}
# Create performance metrics for checkpoint manager
avg_accuracy_improvement = (
sum(self.accuracy_improvements) / len(self.accuracy_improvements)
if self.accuracy_improvements else 0.0
)
performance_metrics = {
'loss_reduction': self.best_loss_reduction,
'avg_accuracy_improvement': avg_accuracy_improvement,
'total_cases_processed': self.total_cases_processed,
'training_efficiency': (
self.total_cases_processed / self.total_training_time
if self.total_training_time > 0 else 0.0
)
}
# Save using checkpoint manager
metadata = save_checkpoint(
model=checkpoint_data, # We're saving data dict instead of model
model_name=self.model_name,
model_type="negative_case_trainer",
performance_metrics=performance_metrics,
training_metadata={
'session': self.training_session_count,
'cases_processed': self.total_cases_processed,
'training_time_hours': self.total_training_time / 3600
},
force_save=force_save
)
if metadata:
logger.info(f"Saved NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
return True
return False
except Exception as e:
logger.error(f"Error saving NegativeCaseTrainer checkpoint: {e}")
return False

View File

@ -1,277 +0,0 @@
#!/usr/bin/env python3
"""
Neural Network Decision Fusion System
Central NN that merges all model outputs + market data for final trading decisions
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
@dataclass
class ModelPrediction:
"""Standardized prediction from any model"""
model_name: str
prediction_type: str # 'price', 'direction', 'action'
value: float # -1 to 1 for direction, actual price for price predictions
confidence: float # 0 to 1
timestamp: datetime
metadata: Optional[Dict[str, Any]] = None
@dataclass
class MarketContext:
"""Current market context for decision fusion"""
symbol: str
current_price: float
price_change_1m: float
price_change_5m: float
volume_ratio: float
volatility: float
timestamp: datetime
@dataclass
class FusionDecision:
"""Final trading decision from fusion NN"""
action: str # 'BUY', 'SELL', 'HOLD'
confidence: float # 0 to 1
expected_return: float # Expected return percentage
risk_score: float # 0 to 1, higher = riskier
position_size: float # Recommended position size
reasoning: str # Human-readable explanation
model_contributions: Dict[str, float] # How much each model contributed
timestamp: datetime
class DecisionFusionNetwork(nn.Module):
"""Small NN that fuses model predictions with market context"""
def __init__(self, input_dim: int = 32, hidden_dim: int = 64):
super().__init__()
self.fusion_layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 16)
)
# Output heads
self.action_head = nn.Linear(16, 3) # BUY, SELL, HOLD
self.confidence_head = nn.Linear(16, 1)
self.return_head = nn.Linear(16, 1)
def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Forward pass through fusion network"""
fusion_output = self.fusion_layers(features)
action_logits = self.action_head(fusion_output)
action_probs = F.softmax(action_logits, dim=1)
confidence = torch.sigmoid(self.confidence_head(fusion_output))
expected_return = torch.tanh(self.return_head(fusion_output))
return {
'action_probs': action_probs,
'confidence': confidence.squeeze(),
'expected_return': expected_return.squeeze()
}
class NeuralDecisionFusion:
"""Main NN-based decision fusion system"""
def __init__(self, training_mode: bool = True):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.network = DecisionFusionNetwork().to(self.device)
self.training_mode = training_mode
self.registered_models = {}
self.last_predictions = {}
logger.info(f"Neural Decision Fusion initialized on {self.device}")
def register_model(self, model_name: str, model_type: str, prediction_format: str):
"""Register a model that will provide predictions"""
self.registered_models[model_name] = {
'type': model_type,
'format': prediction_format,
'prediction_count': 0
}
logger.info(f"Registered NN model: {model_name} ({model_type})")
def add_prediction(self, prediction: ModelPrediction):
"""Add a prediction from a registered model"""
self.last_predictions[prediction.model_name] = prediction
if prediction.model_name in self.registered_models:
self.registered_models[prediction.model_name]['prediction_count'] += 1
logger.debug(f"🔮 {prediction.model_name}: {prediction.value:.3f} "
f"(confidence: {prediction.confidence:.3f})")
def make_decision(self, symbol: str, market_context: MarketContext,
min_confidence: float = 0.25) -> Optional[FusionDecision]:
"""Make NN-driven trading decision"""
try:
if len(self.last_predictions) < 1:
logger.debug("No NN predictions available")
return None
# Prepare features
features = self._prepare_features(market_context)
if features is None:
return None
# Run NN inference
with torch.no_grad():
self.network.eval()
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.device)
outputs = self.network(features_tensor)
action_probs = outputs['action_probs'][0].cpu().numpy()
confidence = outputs['confidence'].cpu().item()
expected_return = outputs['expected_return'].cpu().item()
# Determine action
action_idx = np.argmax(action_probs)
actions = ['BUY', 'SELL', 'HOLD']
action = actions[action_idx]
# Check confidence threshold
if confidence < min_confidence:
action = 'HOLD'
logger.debug(f"Low NN confidence ({confidence:.3f}), defaulting to HOLD")
# Calculate position size
position_size = self._calculate_position_size(confidence, expected_return)
# Generate reasoning
reasoning = self._generate_reasoning(action, confidence, expected_return, action_probs)
# Calculate risk score and model contributions
risk_score = min(1.0, abs(expected_return) * 5 + (1 - confidence) * 0.5)
model_contributions = self._calculate_model_contributions()
decision = FusionDecision(
action=action,
confidence=confidence,
expected_return=expected_return,
risk_score=risk_score,
position_size=position_size,
reasoning=reasoning,
model_contributions=model_contributions,
timestamp=datetime.now()
)
logger.info(f"🧠 NN DECISION: {action} (conf: {confidence:.3f}, "
f"return: {expected_return:.3f}, size: {position_size:.4f})")
return decision
except Exception as e:
logger.error(f"Error in NN decision making: {e}")
return None
def _prepare_features(self, context: MarketContext) -> Optional[np.ndarray]:
"""Prepare feature vector for NN"""
try:
features = np.zeros(32)
# Model predictions (slots 0-15)
idx = 0
for model_name, prediction in self.last_predictions.items():
if idx < 14: # Leave room for other features
features[idx] = prediction.value
features[idx + 1] = prediction.confidence
idx += 2
# Market context (slots 16-31)
features[16] = np.tanh(context.price_change_1m * 100) # 1m change
features[17] = np.tanh(context.price_change_5m * 100) # 5m change
features[18] = np.tanh(context.volume_ratio - 1) # Volume ratio
features[19] = np.tanh(context.volatility * 100) # Volatility
features[20] = context.current_price / 10000.0 # Normalized price
# Time features
now = context.timestamp
features[21] = now.hour / 24.0
features[22] = now.weekday() / 7.0
# Model agreement features
if len(self.last_predictions) >= 2:
values = [p.value for p in self.last_predictions.values()]
features[23] = np.mean(values) # Average prediction
features[24] = np.std(values) # Prediction variance
features[25] = len(self.last_predictions) # Model count
return features
except Exception as e:
logger.error(f"Error preparing NN features: {e}")
return None
def _calculate_position_size(self, confidence: float, expected_return: float) -> float:
"""Calculate position size based on NN outputs"""
base_size = 0.01 # 0.01 ETH base
# Scale by confidence
confidence_multiplier = max(0.1, min(2.0, confidence * 1.5))
# Scale by expected return
return_multiplier = 1.0 + abs(expected_return) * 0.5
final_size = base_size * confidence_multiplier * return_multiplier
return max(0.001, min(0.05, final_size))
def _generate_reasoning(self, action: str, confidence: float,
expected_return: float, action_probs: np.ndarray) -> str:
"""Generate human-readable reasoning"""
reasons = []
if action == 'BUY':
reasons.append(f"NN suggests BUY ({action_probs[0]:.1%})")
elif action == 'SELL':
reasons.append(f"NN suggests SELL ({action_probs[1]:.1%})")
else:
reasons.append(f"NN suggests HOLD")
if confidence > 0.7:
reasons.append("High confidence")
elif confidence > 0.5:
reasons.append("Moderate confidence")
else:
reasons.append("Low confidence")
if abs(expected_return) > 0.01:
direction = "positive" if expected_return > 0 else "negative"
reasons.append(f"Expected {direction} return: {expected_return:.2%}")
reasons.append(f"Based on {len(self.last_predictions)} NN models")
return " | ".join(reasons)
def _calculate_model_contributions(self) -> Dict[str, float]:
"""Calculate how much each model contributed to the decision"""
contributions = {}
total_confidence = sum(p.confidence for p in self.last_predictions.values()) if self.last_predictions else 1.0
if total_confidence > 0:
for model_name, prediction in self.last_predictions.items():
contributions[model_name] = prediction.confidence / total_confidence
return contributions
def get_status(self) -> Dict[str, Any]:
"""Get NN fusion system status"""
return {
'device': str(self.device),
'training_mode': self.training_mode,
'registered_models': len(self.registered_models),
'recent_predictions': len(self.last_predictions),
'model_parameters': sum(p.numel() for p in self.network.parameters())
}

Binary file not shown.

View File

@ -1,649 +0,0 @@
"""
Real-Time Tick Processing Neural Network Module
This module acts as a Neural Network DPS (Data Processing System) alternative,
processing raw tick data with ultra-low latency and feeding processed features
to trading models in real-time.
Features:
- Real-time tick ingestion with volume processing
- Neural network feature extraction from tick streams
- Ultra-low latency processing (sub-millisecond)
- Volume-weighted price analysis
- Microstructure pattern detection
- Real-time feature streaming to models
"""
import asyncio
import logging
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Deque
from collections import deque
from threading import Thread, Lock
import websockets
import json
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class TickData:
"""Raw tick data structure"""
timestamp: datetime
price: float
volume: float
side: str # 'buy' or 'sell'
trade_id: Optional[str] = None
@dataclass
class ProcessedTickFeatures:
"""Processed tick features for model consumption"""
timestamp: datetime
price_features: np.ndarray # Price-based features
volume_features: np.ndarray # Volume-based features
microstructure_features: np.ndarray # Market microstructure features
neural_features: np.ndarray # Neural network extracted features
confidence: float # Feature quality confidence
class TickProcessingNN(nn.Module):
"""
Neural Network for real-time tick processing
Extracts high-level features from raw tick data
"""
def __init__(self, input_size: int = 9, hidden_size: int = 128, output_size: int = 64):
super(TickProcessingNN, self).__init__()
# Tick sequence processing layers
self.tick_encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1)
)
# LSTM for temporal patterns
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, num_layers=2)
# Attention mechanism for important tick selection
self.attention = nn.MultiheadAttention(hidden_size, num_heads=8, batch_first=True)
# Feature extraction heads
self.price_head = nn.Linear(hidden_size, 16) # Price pattern features
self.volume_head = nn.Linear(hidden_size, 16) # Volume pattern features
self.microstructure_head = nn.Linear(hidden_size, 16) # Microstructure features
# Final feature fusion
self.feature_fusion = nn.Sequential(
nn.Linear(48, output_size), # 16+16+16 = 48
nn.ReLU(),
nn.Linear(output_size, output_size)
)
# Confidence estimation
self.confidence_head = nn.Sequential(
nn.Linear(output_size, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid()
)
def forward(self, tick_sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Process tick sequence and extract features
Args:
tick_sequence: [batch, sequence_length, features]
Returns:
features: [batch, output_size] - extracted features
confidence: [batch, 1] - feature confidence
"""
batch_size, seq_len, _ = tick_sequence.shape
# Encode each tick
encoded = self.tick_encoder(tick_sequence) # [batch, seq_len, hidden_size]
# LSTM processing for temporal patterns
lstm_out, _ = self.lstm(encoded) # [batch, seq_len, hidden_size]
# Attention to focus on important ticks
attended, _ = self.attention(lstm_out, lstm_out, lstm_out) # [batch, seq_len, hidden_size]
# Use the last attended output
final_features = attended[:, -1, :] # [batch, hidden_size]
# Extract specialized features
price_features = self.price_head(final_features)
volume_features = self.volume_head(final_features)
microstructure_features = self.microstructure_head(final_features)
# Fuse all features
combined_features = torch.cat([price_features, volume_features, microstructure_features], dim=1)
final_features = self.feature_fusion(combined_features)
# Estimate confidence
confidence = self.confidence_head(final_features)
return final_features, confidence
class RealTimeTickProcessor:
"""
Real-time tick processing system with neural network feature extraction
Acts as a DPS alternative for ultra-low latency tick processing
"""
def __init__(self, symbols: List[str] = None, tick_buffer_size: int = 1000):
"""Initialize the real-time tick processor"""
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
self.tick_buffer_size = tick_buffer_size
# Tick storage buffers
self.tick_buffers: Dict[str, Deque[TickData]] = {}
self.processed_features: Dict[str, Deque[ProcessedTickFeatures]] = {}
# Initialize buffers for each symbol
for symbol in self.symbols:
self.tick_buffers[symbol] = deque(maxlen=tick_buffer_size)
self.processed_features[symbol] = deque(maxlen=100) # Keep last 100 processed features
# Neural network for feature extraction
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tick_nn = TickProcessingNN(input_size=9).to(self.device)
self.tick_nn.eval() # Start in evaluation mode
# Processing parameters
self.processing_window = 50 # Number of ticks to process at once
self.min_ticks_for_processing = 10 # Minimum ticks before processing
# Real-time streaming
self.streaming = False
self.websocket_tasks = {}
self.processing_threads = {}
# Performance tracking
self.processing_times = deque(maxlen=1000)
self.tick_counts = {symbol: 0 for symbol in self.symbols}
# Thread safety
self.data_lock = Lock()
# Feature subscribers (models that want real-time features)
self.feature_subscribers = []
logger.info(f"RealTimeTickProcessor initialized for symbols: {self.symbols}")
logger.info(f"Neural network device: {self.device}")
logger.info(f"Tick buffer size: {tick_buffer_size}")
def add_feature_subscriber(self, callback):
"""Add a callback function to receive processed features"""
self.feature_subscribers.append(callback)
logger.info(f"Added feature subscriber: {callback.__name__}")
def remove_feature_subscriber(self, callback):
"""Remove a feature subscriber"""
if callback in self.feature_subscribers:
self.feature_subscribers.remove(callback)
logger.info(f"Removed feature subscriber: {callback.__name__}")
async def start_processing(self):
"""Start real-time tick processing"""
logger.info("Starting real-time tick processing...")
self.streaming = True
# Start WebSocket streams for each symbol
for symbol in self.symbols:
task = asyncio.create_task(self._websocket_stream(symbol))
self.websocket_tasks[symbol] = task
# Start processing thread for each symbol
thread = Thread(target=self._processing_loop, args=(symbol,), daemon=True)
thread.start()
self.processing_threads[symbol] = thread
logger.info("Real-time tick processing started")
async def stop_processing(self):
"""Stop real-time tick processing"""
logger.info("Stopping real-time tick processing...")
self.streaming = False
# Cancel WebSocket tasks
for symbol, task in self.websocket_tasks.items():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.websocket_tasks.clear()
logger.info("Real-time tick processing stopped")
async def _websocket_stream(self, symbol: str):
"""WebSocket stream for real-time tick data"""
binance_symbol = symbol.replace('/', '').lower()
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade"
while self.streaming:
try:
async with websockets.connect(url) as websocket:
logger.info(f"Tick WebSocket connected for {symbol}")
async for message in websocket:
if not self.streaming:
break
try:
data = json.loads(message)
await self._process_raw_tick(symbol, data)
except Exception as e:
logger.warning(f"Error processing tick for {symbol}: {e}")
except Exception as e:
logger.error(f"WebSocket error for {symbol}: {e}")
if self.streaming:
logger.info(f"Reconnecting tick WebSocket for {symbol} in 2 seconds...")
await asyncio.sleep(2)
async def _process_raw_tick(self, symbol: str, raw_data: Dict):
"""Process raw tick data from WebSocket"""
try:
# Extract tick information
tick = TickData(
timestamp=datetime.fromtimestamp(int(raw_data['T']) / 1000),
price=float(raw_data['p']),
volume=float(raw_data['q']),
side='buy' if raw_data['m'] == False else 'sell', # m=true means buyer is market maker (sell)
trade_id=raw_data.get('t')
)
# Add to buffer
with self.data_lock:
self.tick_buffers[symbol].append(tick)
self.tick_counts[symbol] += 1
except Exception as e:
logger.error(f"Error processing raw tick for {symbol}: {e}")
def _processing_loop(self, symbol: str):
"""Main processing loop for a symbol"""
logger.info(f"Starting processing loop for {symbol}")
while self.streaming:
try:
# Check if we have enough ticks to process
with self.data_lock:
tick_count = len(self.tick_buffers[symbol])
if tick_count >= self.min_ticks_for_processing:
start_time = time.time()
# Process ticks
features = self._extract_neural_features(symbol)
if features is not None:
# Store processed features
with self.data_lock:
self.processed_features[symbol].append(features)
# Notify subscribers
self._notify_feature_subscribers(symbol, features)
# Track processing time
processing_time = (time.time() - start_time) * 1000 # Convert to ms
self.processing_times.append(processing_time)
if len(self.processing_times) % 100 == 0:
avg_time = np.mean(list(self.processing_times))
logger.debug(f"RTP: Average processing time: {avg_time:.2f}ms")
# Small sleep to prevent CPU overload
time.sleep(0.001) # 1ms sleep for ultra-low latency
except Exception as e:
logger.error(f"Error in processing loop for {symbol}: {e}")
time.sleep(0.01) # Longer sleep on error
def _extract_neural_features(self, symbol: str) -> Optional[ProcessedTickFeatures]:
"""Extract neural network features from recent ticks"""
try:
with self.data_lock:
# Get recent ticks
recent_ticks = list(self.tick_buffers[symbol])[-self.processing_window:]
if len(recent_ticks) < self.min_ticks_for_processing:
return None
# Convert ticks to neural network input
tick_features = self._ticks_to_features(recent_ticks)
# Process with neural network
with torch.no_grad():
tick_tensor = torch.FloatTensor(tick_features).unsqueeze(0).to(self.device)
neural_features, confidence = self.tick_nn(tick_tensor)
neural_features = neural_features.cpu().numpy().flatten()
confidence = confidence.cpu().numpy().item()
# Extract traditional features
price_features = self._extract_price_features(recent_ticks)
volume_features = self._extract_volume_features(recent_ticks)
microstructure_features = self._extract_microstructure_features(recent_ticks)
# Create processed features object
processed = ProcessedTickFeatures(
timestamp=recent_ticks[-1].timestamp,
price_features=price_features,
volume_features=volume_features,
microstructure_features=microstructure_features,
neural_features=neural_features,
confidence=confidence
)
return processed
except Exception as e:
logger.error(f"Error extracting neural features for {symbol}: {e}")
return None
def _ticks_to_features(self, ticks: List[TickData]) -> np.ndarray:
"""Convert tick data to neural network input features"""
features = []
for i, tick in enumerate(ticks):
tick_features = [
tick.price,
tick.volume,
1.0 if tick.side == 'buy' else 0.0, # Buy/sell indicator
tick.timestamp.timestamp(), # Timestamp
]
# Add relative features if we have previous ticks
if i > 0:
prev_tick = ticks[i-1]
price_change = (tick.price - prev_tick.price) / prev_tick.price
volume_ratio = tick.volume / (prev_tick.volume + 1e-8)
time_delta = (tick.timestamp - prev_tick.timestamp).total_seconds()
tick_features.extend([
price_change,
volume_ratio,
time_delta
])
else:
tick_features.extend([0.0, 1.0, 0.0]) # Default values for first tick
# Add moving averages if we have enough data
if i >= 5:
recent_prices = [t.price for t in ticks[max(0, i-4):i+1]]
recent_volumes = [t.volume for t in ticks[max(0, i-4):i+1]]
price_ma = np.mean(recent_prices)
volume_ma = np.mean(recent_volumes)
tick_features.extend([
(tick.price - price_ma) / price_ma, # Price deviation from MA
(tick.volume - volume_ma) / (volume_ma + 1e-8) # Volume deviation from MA
])
else:
tick_features.extend([0.0, 0.0])
features.append(tick_features)
# Pad or truncate to fixed size
target_length = self.processing_window
if len(features) < target_length:
# Pad with zeros
padding = [[0.0] * len(features[0])] * (target_length - len(features))
features = padding + features
elif len(features) > target_length:
# Take the most recent ticks
features = features[-target_length:]
return np.array(features, dtype=np.float32)
def _extract_price_features(self, ticks: List[TickData]) -> np.ndarray:
"""Extract price-based features"""
prices = np.array([tick.price for tick in ticks])
features = [
prices[-1], # Current price
np.mean(prices), # Average price
np.std(prices), # Price volatility
np.max(prices), # High
np.min(prices), # Low
(prices[-1] - prices[0]) / prices[0] if prices[0] != 0 else 0, # Total return
]
# Price momentum features
if len(prices) >= 10:
short_ma = np.mean(prices[-5:])
long_ma = np.mean(prices[-10:])
momentum = (short_ma - long_ma) / long_ma if long_ma != 0 else 0
features.append(momentum)
else:
features.append(0.0)
return np.array(features, dtype=np.float32)
def _extract_volume_features(self, ticks: List[TickData]) -> np.ndarray:
"""Extract volume-based features"""
volumes = np.array([tick.volume for tick in ticks])
buy_volumes = np.array([tick.volume for tick in ticks if tick.side == 'buy'])
sell_volumes = np.array([tick.volume for tick in ticks if tick.side == 'sell'])
features = [
np.sum(volumes), # Total volume
np.mean(volumes), # Average volume
np.std(volumes), # Volume volatility
np.sum(buy_volumes) if len(buy_volumes) > 0 else 0, # Buy volume
np.sum(sell_volumes) if len(sell_volumes) > 0 else 0, # Sell volume
]
# Volume imbalance
total_buy = np.sum(buy_volumes) if len(buy_volumes) > 0 else 0
total_sell = np.sum(sell_volumes) if len(sell_volumes) > 0 else 0
total_volume = total_buy + total_sell
if total_volume > 0:
buy_ratio = total_buy / total_volume
volume_imbalance = buy_ratio - 0.5 # -0.5 to 0.5 range
else:
volume_imbalance = 0.0
features.append(volume_imbalance)
# VWAP (Volume Weighted Average Price)
if np.sum(volumes) > 0:
prices = np.array([tick.price for tick in ticks])
vwap = np.sum(prices * volumes) / np.sum(volumes)
current_price = ticks[-1].price
vwap_deviation = (current_price - vwap) / vwap if vwap != 0 else 0
else:
vwap_deviation = 0.0
features.append(vwap_deviation)
return np.array(features, dtype=np.float32)
def _extract_microstructure_features(self, ticks: List[TickData]) -> np.ndarray:
"""Extract market microstructure features"""
features = []
# Trade frequency
if len(ticks) >= 2:
time_deltas = [(ticks[i].timestamp - ticks[i-1].timestamp).total_seconds()
for i in range(1, len(ticks))]
avg_time_delta = np.mean(time_deltas)
trade_frequency = 1.0 / avg_time_delta if avg_time_delta > 0 else 0
else:
trade_frequency = 0.0
features.append(trade_frequency)
# Price impact features
prices = [tick.price for tick in ticks]
volumes = [tick.volume for tick in ticks]
if len(prices) >= 3:
# Calculate price changes and corresponding volumes
price_changes = [(prices[i] - prices[i-1]) / prices[i-1]
for i in range(1, len(prices)) if prices[i-1] != 0]
corresponding_volumes = volumes[1:len(price_changes)+1]
if len(price_changes) > 0 and len(corresponding_volumes) > 0:
# Simple price impact measure
price_impact = np.corrcoef(np.abs(price_changes), corresponding_volumes)[0, 1]
if np.isnan(price_impact):
price_impact = 0.0
else:
price_impact = 0.0
else:
price_impact = 0.0
features.append(price_impact)
# Bid-ask spread proxy (using price volatility)
if len(prices) >= 5:
recent_prices = prices[-5:]
spread_proxy = (np.max(recent_prices) - np.min(recent_prices)) / np.mean(recent_prices)
else:
spread_proxy = 0.0
features.append(spread_proxy)
# Order flow imbalance (already calculated in volume features, but different perspective)
buy_count = sum(1 for tick in ticks if tick.side == 'buy')
sell_count = len(ticks) - buy_count
total_trades = len(ticks)
if total_trades > 0:
order_flow_imbalance = (buy_count - sell_count) / total_trades
else:
order_flow_imbalance = 0.0
features.append(order_flow_imbalance)
return np.array(features, dtype=np.float32)
def _notify_feature_subscribers(self, symbol: str, features: ProcessedTickFeatures):
"""Notify all feature subscribers of new processed features"""
for callback in self.feature_subscribers:
try:
callback(symbol, features)
except Exception as e:
logger.error(f"Error notifying feature subscriber {callback.__name__}: {e}")
def get_latest_features(self, symbol: str) -> Optional[ProcessedTickFeatures]:
"""Get the latest processed features for a symbol"""
with self.data_lock:
if symbol in self.processed_features and self.processed_features[symbol]:
return self.processed_features[symbol][-1]
return None
def get_processing_stats(self) -> Dict[str, Any]:
"""Get processing performance statistics"""
stats = {
'symbols': self.symbols,
'streaming': self.streaming,
'tick_counts': dict(self.tick_counts),
'buffer_sizes': {symbol: len(self.tick_buffers[symbol]) for symbol in self.symbols},
'feature_counts': {symbol: len(self.processed_features[symbol]) for symbol in self.symbols},
'subscribers': len(self.feature_subscribers)
}
if self.processing_times:
stats['processing_performance'] = {
'avg_time_ms': np.mean(list(self.processing_times)),
'min_time_ms': np.min(list(self.processing_times)),
'max_time_ms': np.max(list(self.processing_times)),
'std_time_ms': np.std(list(self.processing_times))
}
return stats
def train_neural_network(self, training_data: List[Tuple[np.ndarray, np.ndarray]], epochs: int = 100):
"""Train the tick processing neural network"""
logger.info("Training tick processing neural network...")
self.tick_nn.train()
optimizer = torch.optim.Adam(self.tick_nn.parameters(), lr=0.001)
criterion = nn.MSELoss()
for epoch in range(epochs):
total_loss = 0.0
for batch_features, batch_targets in training_data:
optimizer.zero_grad()
# Convert to tensors
features_tensor = torch.FloatTensor(batch_features).to(self.device)
targets_tensor = torch.FloatTensor(batch_targets).to(self.device)
# Forward pass
outputs, confidence = self.tick_nn(features_tensor)
# Calculate loss
loss = criterion(outputs, targets_tensor)
# Backward pass
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0:
avg_loss = total_loss / len(training_data)
logger.info(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.6f}")
self.tick_nn.eval()
logger.info("Neural network training completed")
# Integration with existing orchestrator
def integrate_with_orchestrator(orchestrator, tick_processor: RealTimeTickProcessor):
"""Integrate tick processor with enhanced orchestrator"""
def feature_callback(symbol: str, features: ProcessedTickFeatures):
"""Callback to feed processed features to orchestrator"""
try:
# Convert processed features to format expected by orchestrator
feature_dict = {
'symbol': symbol,
'timestamp': features.timestamp,
'neural_features': features.neural_features,
'price_features': features.price_features,
'volume_features': features.volume_features,
'microstructure_features': features.microstructure_features,
'confidence': features.confidence
}
# Feed to orchestrator's real-time feature processing
if hasattr(orchestrator, 'process_realtime_features'):
orchestrator.process_realtime_features(feature_dict)
except Exception as e:
logger.error(f"Error integrating features with orchestrator: {e}")
# Add the callback to tick processor
tick_processor.add_feature_subscriber(feature_callback)
logger.info("Tick processor integrated with orchestrator")
# Factory function for easy creation
def create_realtime_tick_processor(symbols: List[str] = None) -> RealTimeTickProcessor:
"""Create and configure a real-time tick processor"""
if symbols is None:
symbols = ['ETH/USDT', 'BTC/USDT']
processor = RealTimeTickProcessor(symbols=symbols)
logger.info(f"Created RealTimeTickProcessor for symbols: {symbols}")
return processor

View File

@ -1,453 +0,0 @@
"""
Retrospective Training System
This module implements a retrospective training system that:
1. Triggers training when trades close with known P&L outcomes
2. Uses captured model inputs from trade entry to train models
3. Optimizes for profit by learning from profitable vs unprofitable patterns
4. Supports simultaneous inference and training without weight reloading
5. Implements reinforcement learning with immediate reward feedback
"""
import logging
import threading
import time
import queue
from datetime import datetime
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass
import numpy as np
from collections import deque
logger = logging.getLogger(__name__)
@dataclass
class TrainingCase:
"""Represents a completed trade case for retrospective training"""
case_id: str
symbol: str
action: str # 'BUY' or 'SELL'
entry_price: float
exit_price: float
entry_time: datetime
exit_time: datetime
pnl: float
fees: float
confidence: float
model_inputs: Dict[str, Any]
market_state: Dict[str, Any]
outcome_label: int # 1 for profit, 0 for loss, 2 for breakeven
reward_signal: float # Scaled reward for RL training
leverage: float = 1.0
class RetrospectiveTrainer:
"""Retrospective training system for real-time model optimization"""
def __init__(self, orchestrator=None, config: Optional[Dict[str, Any]] = None):
"""Initialize the retrospective trainer"""
self.orchestrator = orchestrator
self.config = config or {}
# Training configuration
self.batch_size = self.config.get('batch_size', 32)
self.min_cases_for_training = self.config.get('min_cases_for_training', 5)
self.profit_threshold = self.config.get('profit_threshold', 0.0)
self.training_frequency = self.config.get('training_frequency_seconds', 120) # 2 minutes
self.max_training_cases = self.config.get('max_training_cases', 1000)
# Training state
self.training_queue = queue.Queue()
self.completed_cases = deque(maxlen=self.max_training_cases)
self.training_stats = {
'total_cases': 0,
'profitable_cases': 0,
'loss_cases': 0,
'breakeven_cases': 0,
'avg_profit': 0.0,
'last_training_time': datetime.now(),
'training_sessions': 0,
'model_updates': 0
}
# Threading
self.training_thread = None
self.is_training_active = False
self.training_lock = threading.Lock()
logger.info("RetrospectiveTrainer initialized")
logger.info(f"Configuration: batch_size={self.batch_size}, "
f"min_cases={self.min_cases_for_training}, "
f"training_freq={self.training_frequency}s")
def add_completed_trade(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> bool:
"""Add a completed trade for retrospective training"""
try:
# Create training case from trade record
case = self._create_training_case(trade_record, model_inputs)
if case is None:
return False
# Add to completed cases
self.completed_cases.append(case)
self.training_queue.put(case)
# Update statistics
self.training_stats['total_cases'] += 1
if case.outcome_label == 1: # Profit
self.training_stats['profitable_cases'] += 1
elif case.outcome_label == 0: # Loss
self.training_stats['loss_cases'] += 1
else: # Breakeven
self.training_stats['breakeven_cases'] += 1
# Calculate running average profit
total_pnl = sum(c.pnl for c in self.completed_cases)
self.training_stats['avg_profit'] = total_pnl / len(self.completed_cases)
logger.info(f"RETROSPECTIVE: Added training case {case.case_id} "
f"(P&L: ${case.pnl:.3f}, Label: {case.outcome_label})")
# Trigger training if we have enough cases
self._maybe_trigger_training()
return True
except Exception as e:
logger.error(f"Error adding completed trade for retrospective training: {e}")
return False
def _create_training_case(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> Optional[TrainingCase]:
"""Create a training case from trade record and model inputs"""
try:
# Extract trade information
symbol = trade_record.get('symbol', 'UNKNOWN')
side = trade_record.get('side', 'UNKNOWN')
pnl = trade_record.get('pnl', 0.0)
fees = trade_record.get('fees', 0.0)
confidence = trade_record.get('confidence', 0.0)
# Calculate net P&L after fees
net_pnl = pnl - fees
# Determine outcome label and reward signal
if net_pnl > self.profit_threshold:
outcome_label = 1 # Profitable
# Scale reward by profit magnitude and confidence
reward_signal = min(10.0, net_pnl * confidence * 10) # Amplify for training
elif net_pnl < -self.profit_threshold:
outcome_label = 0 # Loss
# Negative reward scaled by loss magnitude
reward_signal = max(-10.0, net_pnl * confidence * 10) # Negative reward
else:
outcome_label = 2 # Breakeven
reward_signal = 0.0
# Create case ID
timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
case_id = f"retro_{timestamp_str}_{symbol.replace('/', '')}_{side}_pnl_{abs(net_pnl):.3f}".replace('.', 'p')
# Create training case
case = TrainingCase(
case_id=case_id,
symbol=symbol,
action=side,
entry_price=trade_record.get('entry_price', 0.0),
exit_price=trade_record.get('exit_price', 0.0),
entry_time=trade_record.get('entry_time', datetime.now()),
exit_time=trade_record.get('exit_time', datetime.now()),
pnl=net_pnl,
fees=fees,
confidence=confidence,
model_inputs=model_inputs,
market_state=model_inputs.get('market_state', {}),
outcome_label=outcome_label,
reward_signal=reward_signal,
leverage=trade_record.get('leverage', 1.0)
)
return case
except Exception as e:
logger.error(f"Error creating training case: {e}")
return None
def _maybe_trigger_training(self):
"""Check if we should trigger a training session"""
try:
# Check if we have enough cases
if len(self.completed_cases) < self.min_cases_for_training:
return
# Check if enough time has passed since last training
time_since_last = (datetime.now() - self.training_stats['last_training_time']).total_seconds()
if time_since_last < self.training_frequency:
return
# Check if training thread is not already running
if self.is_training_active:
logger.debug("Training already in progress, skipping trigger")
return
# Start training in background thread
self._start_training_session()
except Exception as e:
logger.error(f"Error checking training trigger: {e}")
def _start_training_session(self):
"""Start a training session in background thread"""
try:
if self.training_thread and self.training_thread.is_alive():
logger.debug("Training thread already running")
return
self.training_thread = threading.Thread(
target=self._run_training_session,
daemon=True,
name="RetrospectiveTrainer"
)
self.training_thread.start()
logger.info("RETROSPECTIVE: Started training session")
except Exception as e:
logger.error(f"Error starting training session: {e}")
def _run_training_session(self):
"""Run a complete training session"""
try:
with self.training_lock:
self.is_training_active = True
start_time = time.time()
logger.info(f"RETROSPECTIVE: Training with {len(self.completed_cases)} cases")
# Train models if orchestrator available
training_results = {}
if self.orchestrator:
training_results = self._train_models()
# Update statistics
self.training_stats['last_training_time'] = datetime.now()
self.training_stats['training_sessions'] += 1
self.training_stats['model_updates'] += len(training_results)
elapsed_time = time.time() - start_time
logger.info(f"RETROSPECTIVE: Training completed in {elapsed_time:.2f}s - {training_results}")
except Exception as e:
logger.error(f"Error in retrospective training session: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
self.is_training_active = False
def _train_models(self) -> Dict[str, Any]:
"""Train available models using retrospective data"""
results = {}
try:
# Prepare training data
profitable_cases = [c for c in self.completed_cases if c.outcome_label == 1]
loss_cases = [c for c in self.completed_cases if c.outcome_label == 0]
if len(profitable_cases) == 0 and len(loss_cases) == 0:
return {'error': 'No labeled cases for training'}
logger.info(f"RETROSPECTIVE: Training data - Profitable: {len(profitable_cases)}, Loss: {len(loss_cases)}")
# Train DQN agent if available
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
try:
dqn_result = self._train_dqn_retrospective()
results['dqn'] = dqn_result
logger.info(f"RETROSPECTIVE: DQN training result: {dqn_result}")
except Exception as e:
logger.warning(f"DQN retrospective training failed: {e}")
results['dqn'] = {'error': str(e)}
# Train other models
if self.orchestrator and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
try:
# Update extrema trainer with retrospective feedback
extrema_feedback = self._create_extrema_feedback()
if extrema_feedback:
results['extrema'] = {'feedback_cases': len(extrema_feedback)}
logger.info(f"RETROSPECTIVE: Extrema feedback provided for {len(extrema_feedback)} cases")
except Exception as e:
logger.warning(f"Extrema retrospective training failed: {e}")
return results
except Exception as e:
logger.error(f"Error training models retrospectively: {e}")
return {'error': str(e)}
def _train_dqn_retrospective(self) -> Dict[str, Any]:
"""Train DQN agent using retrospective experience replay"""
try:
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
return {'error': 'DQN agent not available'}
dqn_agent = self.orchestrator.rl_agent
experiences_added = 0
# Add retrospective experiences to DQN replay buffer
for case in self.completed_cases:
try:
# Extract state from model inputs
state = self._extract_state_vector(case.model_inputs)
if state is None:
continue
# Action mapping: BUY=0, SELL=1
action = 0 if case.action == 'BUY' else 1
# Use reward signal as immediate reward
reward = case.reward_signal
# For retrospective training, next_state is None (terminal)
next_state = np.zeros_like(state) # Terminal state
done = True
# Add experience to DQN replay buffer
if hasattr(dqn_agent, 'add_experience'):
dqn_agent.add_experience(state, action, reward, next_state, done)
experiences_added += 1
except Exception as e:
logger.debug(f"Error adding DQN experience: {e}")
continue
# Train DQN if we have enough experiences
if experiences_added > 0 and hasattr(dqn_agent, 'train'):
try:
# Perform multiple training steps on retrospective data
training_steps = min(10, experiences_added // 4) # Conservative training
for _ in range(training_steps):
loss = dqn_agent.train()
if loss is None:
break
return {
'experiences_added': experiences_added,
'training_steps': training_steps,
'method': 'retrospective_experience_replay'
}
except Exception as e:
logger.warning(f"DQN training step failed: {e}")
return {'experiences_added': experiences_added, 'training_error': str(e)}
return {'experiences_added': experiences_added, 'training_steps': 0}
except Exception as e:
logger.error(f"Error in DQN retrospective training: {e}")
return {'error': str(e)}
def _extract_state_vector(self, model_inputs: Dict[str, Any]) -> Optional[np.ndarray]:
"""Extract state vector for DQN training from model inputs"""
try:
# Try to get pre-built RL state
if 'dqn_state' in model_inputs:
state = model_inputs['dqn_state']
if isinstance(state, dict) and 'state_vector' in state:
return np.array(state['state_vector'])
# Build state from market features
market_state = model_inputs.get('market_state', {})
features = []
# Price features
for key in ['current_price', 'price_sma_5', 'price_sma_20', 'price_std_20', 'price_rsi']:
features.append(market_state.get(key, 0.0))
# Volume features
for key in ['volume_current', 'volume_sma_20', 'volume_ratio']:
features.append(market_state.get(key, 0.0))
# Technical indicators
indicators = model_inputs.get('technical_indicators', {})
for key in ['sma_10', 'sma_20', 'bb_upper', 'bb_lower', 'bb_position', 'macd', 'volatility']:
features.append(indicators.get(key, 0.0))
if len(features) < 5: # Minimum required features
return None
return np.array(features, dtype=np.float32)
except Exception as e:
logger.debug(f"Error extracting state vector: {e}")
return None
def _create_extrema_feedback(self) -> List[Dict[str, Any]]:
"""Create feedback data for extrema trainer"""
feedback = []
try:
for case in self.completed_cases:
if case.outcome_label in [0, 1]: # Only profit/loss cases
feedback_item = {
'symbol': case.symbol,
'action': case.action,
'entry_price': case.entry_price,
'exit_price': case.exit_price,
'was_profitable': case.outcome_label == 1,
'reward_signal': case.reward_signal,
'market_state': case.market_state
}
feedback.append(feedback_item)
return feedback
except Exception as e:
logger.error(f"Error creating extrema feedback: {e}")
return []
def get_training_stats(self) -> Dict[str, Any]:
"""Get current training statistics"""
stats = self.training_stats.copy()
stats['total_cases_in_memory'] = len(self.completed_cases)
stats['training_queue_size'] = self.training_queue.qsize()
stats['is_training_active'] = self.is_training_active
# Calculate profit metrics
if len(self.completed_cases) > 0:
profitable_count = sum(1 for c in self.completed_cases if c.pnl > 0)
stats['profit_rate'] = profitable_count / len(self.completed_cases)
stats['total_pnl'] = sum(c.pnl for c in self.completed_cases)
stats['avg_reward'] = sum(c.reward_signal for c in self.completed_cases) / len(self.completed_cases)
return stats
def force_training_session(self) -> bool:
"""Force a training session regardless of timing constraints"""
try:
if self.is_training_active:
logger.warning("Training already in progress")
return False
if len(self.completed_cases) < 1:
logger.warning("No completed cases available for training")
return False
logger.info("RETROSPECTIVE: Forcing training session")
self._start_training_session()
return True
except Exception as e:
logger.error(f"Error forcing training session: {e}")
return False
def stop(self):
"""Stop the retrospective trainer"""
try:
self.is_training_active = False
if self.training_thread and self.training_thread.is_alive():
self.training_thread.join(timeout=10)
logger.info("RetrospectiveTrainer stopped")
except Exception as e:
logger.error(f"Error stopping RetrospectiveTrainer: {e}")
def create_retrospective_trainer(orchestrator=None, config: Optional[Dict[str, Any]] = None) -> RetrospectiveTrainer:
"""Factory function to create a RetrospectiveTrainer instance"""
return RetrospectiveTrainer(orchestrator=orchestrator, config=config)

View File

@ -1,529 +0,0 @@
"""
RL Training Pipeline with Comprehensive Experience Storage and Replay
This module implements a robust RL training pipeline that:
1. Stores all training experiences with profitability metrics
2. Implements profit-weighted experience replay
3. Tracks gradient information for each training step
4. Enables retraining on most profitable trading sequences
5. Maintains comprehensive trading episode analysis
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field
import json
import pickle
from collections import deque
import threading
import random
from .training_data_collector import get_training_data_collector
logger = logging.getLogger(__name__)
@dataclass
class RLExperience:
"""Single RL experience with complete state-action-reward information"""
experience_id: str
timestamp: datetime
episode_id: str
# Core RL components
state: np.ndarray
action: int # 0=SELL, 1=HOLD, 2=BUY
reward: float
next_state: np.ndarray
done: bool
# Extended state information
market_context: Dict[str, Any]
cnn_predictions: Optional[Dict[str, Any]] = None
confidence_score: float = 0.0
# Actual trading outcome
actual_profit: Optional[float] = None
actual_holding_time: Optional[timedelta] = None
optimal_action: Optional[int] = None
# Experience value for replay
experience_value: float = 0.0
profitability_score: float = 0.0
learning_priority: float = 0.0
# Training metadata
times_trained: int = 0
last_trained: Optional[datetime] = None
class ProfitWeightedExperienceBuffer:
"""Experience buffer with profit-weighted sampling for replay"""
def __init__(self, max_size: int = 100000):
self.max_size = max_size
self.experiences: Dict[str, RLExperience] = {}
self.experience_order: deque = deque(maxlen=max_size)
self.profitable_experiences: List[str] = []
self.total_experiences = 0
self.total_profitable = 0
def add_experience(self, experience: RLExperience):
"""Add experience to buffer"""
try:
self.experiences[experience.experience_id] = experience
self.experience_order.append(experience.experience_id)
if experience.actual_profit is not None and experience.actual_profit > 0:
self.profitable_experiences.append(experience.experience_id)
self.total_profitable += 1
# Remove oldest if buffer is full
if len(self.experiences) > self.max_size:
oldest_id = self.experience_order[0]
if oldest_id in self.experiences:
del self.experiences[oldest_id]
if oldest_id in self.profitable_experiences:
self.profitable_experiences.remove(oldest_id)
self.total_experiences += 1
except Exception as e:
logger.error(f"Error adding experience to buffer: {e}")
def sample_batch(self, batch_size: int, prioritize_profitable: bool = True) -> List[RLExperience]:
"""Sample batch with profit-weighted prioritization"""
try:
if len(self.experiences) < batch_size:
return list(self.experiences.values())
if prioritize_profitable and len(self.profitable_experiences) > batch_size // 2:
# Sample mix of profitable and all experiences
profitable_sample_size = min(batch_size // 2, len(self.profitable_experiences))
remaining_sample_size = batch_size - profitable_sample_size
profitable_ids = random.sample(self.profitable_experiences, profitable_sample_size)
all_ids = list(self.experiences.keys())
remaining_ids = random.sample(all_ids, remaining_sample_size)
sampled_ids = profitable_ids + remaining_ids
else:
# Random sampling from all experiences
all_ids = list(self.experiences.keys())
sampled_ids = random.sample(all_ids, batch_size)
sampled_experiences = [self.experiences[exp_id] for exp_id in sampled_ids]
# Update training counts
for experience in sampled_experiences:
experience.times_trained += 1
experience.last_trained = datetime.now()
return sampled_experiences
except Exception as e:
logger.error(f"Error sampling batch: {e}")
return list(self.experiences.values())[:batch_size]
def get_most_profitable_experiences(self, limit: int = 100) -> List[RLExperience]:
"""Get most profitable experiences for targeted training"""
try:
profitable_experiences = [
self.experiences[exp_id] for exp_id in self.profitable_experiences
if exp_id in self.experiences
]
profitable_experiences.sort(
key=lambda x: x.actual_profit if x.actual_profit else 0,
reverse=True
)
return profitable_experiences[:limit]
except Exception as e:
logger.error(f"Error getting profitable experiences: {e}")
return []
class RLTradingAgent(nn.Module):
"""RL Trading Agent with comprehensive state processing"""
def __init__(self, state_dim: int = 2000, action_dim: int = 3, hidden_dim: int = 512):
super(RLTradingAgent, self).__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.hidden_dim = hidden_dim
# State processing network
self.state_processor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.ReLU()
)
# Q-value network
self.q_network = nn.Sequential(
nn.Linear(hidden_dim // 2, hidden_dim // 4),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 4, action_dim)
)
# Policy network
self.policy_network = nn.Sequential(
nn.Linear(hidden_dim // 2, hidden_dim // 4),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 4, action_dim),
nn.Softmax(dim=-1)
)
# Value network
self.value_network = nn.Sequential(
nn.Linear(hidden_dim // 2, hidden_dim // 4),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 4, 1)
)
def forward(self, state):
"""Forward pass through the agent"""
processed_state = self.state_processor(state)
q_values = self.q_network(processed_state)
policy_probs = self.policy_network(processed_state)
state_value = self.value_network(processed_state)
return {
'q_values': q_values,
'policy_probs': policy_probs,
'state_value': state_value,
'processed_state': processed_state
}
def select_action(self, state, epsilon: float = 0.1) -> Tuple[int, float]:
"""Select action using epsilon-greedy policy"""
self.eval()
with torch.no_grad():
if isinstance(state, np.ndarray):
state = torch.from_numpy(state).float().unsqueeze(0)
outputs = self.forward(state)
if random.random() < epsilon:
action = random.randint(0, self.action_dim - 1)
confidence = 0.33
else:
q_values = outputs['q_values']
action = torch.argmax(q_values, dim=1).item()
q_softmax = F.softmax(q_values, dim=1)
confidence = torch.max(q_softmax).item()
return action, confidence
@dataclass
class RLTrainingStep:
"""Single RL training step with backpropagation data"""
step_id: str
timestamp: datetime
batch_experiences: List[str]
# Training data
total_loss: float
q_loss: float
policy_loss: float
# Gradients
gradients: Dict[str, torch.Tensor]
gradient_norms: Dict[str, float]
# Metadata
learning_rate: float = 0.001
batch_size: int = 32
# Performance
batch_profitability: float = 0.0
correct_actions: int = 0
total_actions: int = 0
step_value: float = 0.0
@dataclass
class RLTrainingSession:
"""Complete RL training session"""
session_id: str
start_timestamp: datetime
end_timestamp: Optional[datetime] = None
training_mode: str = 'experience_replay'
symbol: str = ''
training_steps: List[RLTrainingStep] = field(default_factory=list)
total_steps: int = 0
average_loss: float = 0.0
best_loss: float = float('inf')
profitable_actions: int = 0
total_actions: int = 0
profitability_rate: float = 0.0
session_value: float = 0.0
class RLTrainer:
"""RL trainer with comprehensive experience storage and replay"""
def __init__(self, agent: RLTradingAgent, device: str = 'cuda', storage_dir: str = "rl_training_storage"):
self.agent = agent.to(device)
self.device = device
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
self.optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001)
self.experience_buffer = ProfitWeightedExperienceBuffer()
self.data_collector = get_training_data_collector()
self.training_sessions: List[RLTrainingSession] = []
self.current_session: Optional[RLTrainingSession] = None
self.gamma = 0.99
self.training_stats = {
'total_sessions': 0,
'total_steps': 0,
'total_experiences': 0,
'profitable_actions': 0,
'total_actions': 0,
'average_reward': 0.0
}
logger.info(f"RL Trainer initialized with {sum(p.numel() for p in agent.parameters()):,} parameters")
def add_experience(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool, market_context: Dict[str, Any],
cnn_predictions: Dict[str, Any] = None, confidence_score: float = 0.0) -> str:
"""Add experience to the buffer"""
try:
experience_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
experience = RLExperience(
experience_id=experience_id,
timestamp=datetime.now(),
episode_id=market_context.get('episode_id', 'unknown'),
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
market_context=market_context,
cnn_predictions=cnn_predictions,
confidence_score=confidence_score
)
self.experience_buffer.add_experience(experience)
self.training_stats['total_experiences'] += 1
return experience_id
except Exception as e:
logger.error(f"Error adding experience: {e}")
return None
def train_on_experiences(self, batch_size: int = 32, num_batches: int = 10) -> Dict[str, Any]:
"""Train on experiences with comprehensive data storage"""
try:
session = RLTrainingSession(
session_id=f"rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
start_timestamp=datetime.now(),
training_mode='experience_replay'
)
self.current_session = session
self.agent.train()
total_loss = 0.0
for batch_idx in range(num_batches):
experiences = self.experience_buffer.sample_batch(batch_size, True)
if len(experiences) < batch_size:
continue
# Prepare batch tensors
states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
dones = torch.BoolTensor([exp.done for exp in experiences]).to(self.device)
# Forward pass
self.optimizer.zero_grad()
current_outputs = self.agent(states)
current_q_values = current_outputs['q_values']
# Calculate target Q-values
with torch.no_grad():
next_outputs = self.agent(next_states)
next_q_values = next_outputs['q_values']
max_next_q_values = torch.max(next_q_values, dim=1)[0]
target_q_values = rewards + (self.gamma * max_next_q_values * ~dones)
# Calculate loss
current_q_values_for_actions = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
q_loss = F.mse_loss(current_q_values_for_actions, target_q_values)
# Backward pass
q_loss.backward()
# Store gradients
gradients = {}
gradient_norms = {}
for name, param in self.agent.named_parameters():
if param.grad is not None:
gradients[name] = param.grad.clone().detach()
gradient_norms[name] = param.grad.norm().item()
torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=1.0)
self.optimizer.step()
# Create training step record
step = RLTrainingStep(
step_id=f"{session.session_id}_step_{batch_idx}",
timestamp=datetime.now(),
batch_experiences=[exp.experience_id for exp in experiences],
total_loss=q_loss.item(),
q_loss=q_loss.item(),
policy_loss=0.0,
gradients=gradients,
gradient_norms=gradient_norms,
batch_size=len(experiences)
)
session.training_steps.append(step)
total_loss += q_loss.item()
# Finalize session
session.end_timestamp = datetime.now()
session.total_steps = num_batches
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
self._save_training_session(session)
self.training_stats['total_sessions'] += 1
self.training_stats['total_steps'] += session.total_steps
logger.info(f"RL training session completed: {session.session_id}")
logger.info(f"Average loss: {session.average_loss:.4f}")
return {
'status': 'success',
'session_id': session.session_id,
'average_loss': session.average_loss,
'total_steps': session.total_steps
}
except Exception as e:
logger.error(f"Error in RL training session: {e}")
return {'status': 'error', 'error': str(e)}
finally:
self.current_session = None
def train_on_profitable_experiences(self, min_profitability: float = 0.1,
max_experiences: int = 1000, batch_size: int = 32) -> Dict[str, Any]:
"""Train specifically on most profitable experiences"""
try:
profitable_experiences = self.experience_buffer.get_most_profitable_experiences(max_experiences)
filtered_experiences = [
exp for exp in profitable_experiences
if exp.actual_profit is not None and exp.actual_profit >= min_profitability
]
if len(filtered_experiences) < batch_size:
return {'status': 'insufficient_data', 'experiences_found': len(filtered_experiences)}
logger.info(f"Training on {len(filtered_experiences)} profitable experiences")
num_batches = len(filtered_experiences) // batch_size
# Temporarily replace buffer sampling
original_sample_method = self.experience_buffer.sample_batch
def profitable_sample_batch(batch_size, prioritize_profitable=True):
return random.sample(filtered_experiences, min(batch_size, len(filtered_experiences)))
self.experience_buffer.sample_batch = profitable_sample_batch
try:
results = self.train_on_experiences(batch_size=batch_size, num_batches=num_batches)
results['training_mode'] = 'profitable_replay'
results['experiences_used'] = len(filtered_experiences)
return results
finally:
self.experience_buffer.sample_batch = original_sample_method
except Exception as e:
logger.error(f"Error training on profitable experiences: {e}")
return {'status': 'error', 'error': str(e)}
def _save_training_session(self, session: RLTrainingSession):
"""Save training session to disk"""
try:
session_dir = self.storage_dir / 'sessions'
session_dir.mkdir(parents=True, exist_ok=True)
session_file = session_dir / f"{session.session_id}.pkl"
with open(session_file, 'wb') as f:
pickle.dump(session, f)
metadata = {
'session_id': session.session_id,
'start_timestamp': session.start_timestamp.isoformat(),
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
'training_mode': session.training_mode,
'total_steps': session.total_steps,
'average_loss': session.average_loss
}
metadata_file = session_dir / f"{session.session_id}_metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
except Exception as e:
logger.error(f"Error saving training session: {e}")
def get_training_statistics(self) -> Dict[str, Any]:
"""Get comprehensive training statistics"""
stats = self.training_stats.copy()
if self.training_sessions:
recent_sessions = sorted(self.training_sessions, key=lambda x: x.start_timestamp, reverse=True)[:10]
stats['recent_sessions'] = [
{
'session_id': s.session_id,
'timestamp': s.start_timestamp.isoformat(),
'mode': s.training_mode,
'average_loss': s.average_loss
}
for s in recent_sessions
]
return stats
# Global instance
rl_trainer = None
def get_rl_trainer(agent: RLTradingAgent = None) -> RLTrainer:
"""Get global RL trainer instance"""
global rl_trainer
if rl_trainer is None:
if agent is None:
agent = RLTradingAgent()
rl_trainer = RLTrainer(agent)
return rl_trainer

View File

@ -1,460 +0,0 @@
"""
Robust COB (Consolidated Order Book) Provider
This module provides a robust COB data provider that handles:
- HTTP 418 errors from Binance (rate limiting)
- Thread safety issues
- API rate limiting and backoff
- Fallback data sources
- Error recovery strategies
Features:
- Automatic rate limiting and backoff
- Multiple exchange support with fallbacks
- Thread-safe operations
- Comprehensive error handling
- Data validation and integrity checking
"""
import asyncio
import logging
import time
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass, field
from collections import deque
import json
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
from .api_rate_limiter import get_rate_limiter, RateLimitConfig
logger = logging.getLogger(__name__)
@dataclass
class COBData:
"""Consolidated Order Book data structure"""
symbol: str
timestamp: datetime
bids: List[Tuple[float, float]] # [(price, quantity), ...]
asks: List[Tuple[float, float]] # [(price, quantity), ...]
# Derived metrics
spread: float = 0.0
mid_price: float = 0.0
total_bid_volume: float = 0.0
total_ask_volume: float = 0.0
# Data quality
data_source: str = 'unknown'
quality_score: float = 1.0
def __post_init__(self):
"""Calculate derived metrics"""
if self.bids and self.asks:
self.spread = self.asks[0][0] - self.bids[0][0]
self.mid_price = (self.asks[0][0] + self.bids[0][0]) / 2
self.total_bid_volume = sum(qty for _, qty in self.bids)
self.total_ask_volume = sum(qty for _, qty in self.asks)
# Calculate quality score based on data completeness
self.quality_score = min(
len(self.bids) / 20, # Expect at least 20 bid levels
len(self.asks) / 20, # Expect at least 20 ask levels
1.0
)
class RobustCOBProvider:
"""Robust COB provider with error handling and rate limiting"""
def __init__(self, symbols: List[str] = None):
self.symbols = symbols or ['ETHUSDT', 'BTCUSDT']
# Rate limiter
self.rate_limiter = get_rate_limiter()
# Thread safety
self.lock = threading.RLock()
# Data cache
self.cob_cache: Dict[str, COBData] = {}
self.cache_timestamps: Dict[str, datetime] = {}
self.cache_ttl = timedelta(seconds=5) # 5 second cache TTL
# Error tracking
self.error_counts: Dict[str, int] = {}
self.last_successful_fetch: Dict[str, datetime] = {}
# Background fetching
self.is_running = False
self.fetch_threads: Dict[str, threading.Thread] = {}
self.executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="COB-Fetcher")
# Fallback data
self.fallback_data: Dict[str, COBData] = {}
# Performance tracking
self.fetch_stats = {
'total_requests': 0,
'successful_requests': 0,
'failed_requests': 0,
'rate_limited_requests': 0,
'cache_hits': 0,
'fallback_uses': 0
}
logger.info(f"Robust COB Provider initialized for symbols: {self.symbols}")
def start_background_fetching(self):
"""Start background COB data fetching"""
if self.is_running:
logger.warning("Background fetching already running")
return
self.is_running = True
# Start fetching thread for each symbol
for symbol in self.symbols:
thread = threading.Thread(
target=self._background_fetch_worker,
args=(symbol,),
name=f"COB-{symbol}",
daemon=True
)
self.fetch_threads[symbol] = thread
thread.start()
logger.info(f"Started background COB fetching for {len(self.symbols)} symbols")
def stop_background_fetching(self):
"""Stop background COB data fetching"""
self.is_running = False
# Wait for threads to finish
for symbol, thread in self.fetch_threads.items():
thread.join(timeout=5)
logger.debug(f"Stopped COB fetching for {symbol}")
# Shutdown executor
self.executor.shutdown(wait=True, timeout=10)
logger.info("Stopped background COB fetching")
def _background_fetch_worker(self, symbol: str):
"""Background worker for fetching COB data"""
logger.info(f"Started COB fetching worker for {symbol}")
while self.is_running:
try:
# Fetch COB data
cob_data = self._fetch_cob_data_safe(symbol)
if cob_data:
with self.lock:
self.cob_cache[symbol] = cob_data
self.cache_timestamps[symbol] = datetime.now()
self.last_successful_fetch[symbol] = datetime.now()
self.error_counts[symbol] = 0 # Reset error count on success
logger.debug(f"Updated COB cache for {symbol}")
else:
with self.lock:
self.error_counts[symbol] = self.error_counts.get(symbol, 0) + 1
logger.debug(f"Failed to fetch COB for {symbol}, error count: {self.error_counts.get(symbol, 0)}")
# Wait before next fetch (adaptive based on errors)
error_count = self.error_counts.get(symbol, 0)
base_interval = 2.0 # Base 2 second interval
backoff_interval = min(base_interval * (2 ** min(error_count, 5)), 60.0) # Max 60s
time.sleep(backoff_interval)
except Exception as e:
logger.error(f"Error in COB fetching worker for {symbol}: {e}")
time.sleep(10) # Wait 10s on unexpected errors
logger.info(f"Stopped COB fetching worker for {symbol}")
def _fetch_cob_data_safe(self, symbol: str) -> Optional[COBData]:
"""Safely fetch COB data with error handling"""
try:
self.fetch_stats['total_requests'] += 1
# Try Binance first
cob_data = self._fetch_binance_cob(symbol)
if cob_data:
self.fetch_stats['successful_requests'] += 1
return cob_data
# Try MEXC as fallback
cob_data = self._fetch_mexc_cob(symbol)
if cob_data:
self.fetch_stats['successful_requests'] += 1
cob_data.data_source = 'mexc_fallback'
return cob_data
# Use cached fallback data if available
if symbol in self.fallback_data:
self.fetch_stats['fallback_uses'] += 1
fallback = self.fallback_data[symbol]
fallback.timestamp = datetime.now()
fallback.data_source = 'fallback_cache'
fallback.quality_score *= 0.5 # Reduce quality score for old data
return fallback
self.fetch_stats['failed_requests'] += 1
return None
except Exception as e:
logger.error(f"Error fetching COB data for {symbol}: {e}")
self.fetch_stats['failed_requests'] += 1
return None
def _fetch_binance_cob(self, symbol: str) -> Optional[COBData]:
"""Fetch COB data from Binance with rate limiting"""
try:
url = f"https://api.binance.com/api/v3/depth"
params = {
'symbol': symbol,
'limit': 100 # Get 100 levels
}
# Use rate limiter
response = self.rate_limiter.make_request(
'binance_api',
url,
method='GET',
params=params
)
if not response:
self.fetch_stats['rate_limited_requests'] += 1
return None
if response.status_code != 200:
logger.warning(f"Binance COB API returned {response.status_code} for {symbol}")
return None
data = response.json()
# Parse order book data
bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])]
asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])]
if not bids or not asks:
logger.warning(f"Empty order book data from Binance for {symbol}")
return None
cob_data = COBData(
symbol=symbol,
timestamp=datetime.now(),
bids=bids,
asks=asks,
data_source='binance'
)
# Store as fallback for future use
self.fallback_data[symbol] = cob_data
return cob_data
except Exception as e:
logger.error(f"Error fetching Binance COB for {symbol}: {e}")
return None
def _fetch_mexc_cob(self, symbol: str) -> Optional[COBData]:
"""Fetch COB data from MEXC as fallback"""
try:
url = f"https://api.mexc.com/api/v3/depth"
params = {
'symbol': symbol,
'limit': 100
}
response = self.rate_limiter.make_request(
'mexc_api',
url,
method='GET',
params=params
)
if not response or response.status_code != 200:
return None
data = response.json()
# Parse order book data
bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])]
asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])]
if not bids or not asks:
return None
return COBData(
symbol=symbol,
timestamp=datetime.now(),
bids=bids,
asks=asks,
data_source='mexc'
)
except Exception as e:
logger.debug(f"Error fetching MEXC COB for {symbol}: {e}")
return None
def get_cob_data(self, symbol: str) -> Optional[COBData]:
"""Get COB data for a symbol (from cache or fresh fetch)"""
with self.lock:
# Check cache first
if symbol in self.cob_cache:
cached_data = self.cob_cache[symbol]
cache_time = self.cache_timestamps.get(symbol, datetime.min)
# Return cached data if still fresh
if datetime.now() - cache_time < self.cache_ttl:
self.fetch_stats['cache_hits'] += 1
return cached_data
# If background fetching is running, return cached data even if stale
if self.is_running and symbol in self.cob_cache:
return self.cob_cache[symbol]
# Fetch fresh data if not running background fetching
if not self.is_running:
return self._fetch_cob_data_safe(symbol)
return None
def get_cob_features(self, symbol: str, feature_count: int = 120) -> Optional[np.ndarray]:
"""
Get COB features for ML models
Args:
symbol: Trading symbol
feature_count: Number of features to return
Returns:
Numpy array of COB features or None if no data
"""
cob_data = self.get_cob_data(symbol)
if not cob_data:
return None
try:
features = []
# Basic market metrics
features.extend([
cob_data.mid_price,
cob_data.spread,
cob_data.total_bid_volume,
cob_data.total_ask_volume,
cob_data.quality_score
])
# Bid levels (price and volume)
max_levels = min(len(cob_data.bids), 20)
for i in range(max_levels):
price, volume = cob_data.bids[i]
features.extend([price, volume])
# Pad bid levels if needed
for i in range(max_levels, 20):
features.extend([0.0, 0.0])
# Ask levels (price and volume)
max_levels = min(len(cob_data.asks), 20)
for i in range(max_levels):
price, volume = cob_data.asks[i]
features.extend([price, volume])
# Pad ask levels if needed
for i in range(max_levels, 20):
features.extend([0.0, 0.0])
# Calculate additional features
if len(cob_data.bids) > 0 and len(cob_data.asks) > 0:
# Volume imbalance
bid_volume_5 = sum(vol for _, vol in cob_data.bids[:5])
ask_volume_5 = sum(vol for _, vol in cob_data.asks[:5])
volume_imbalance = (bid_volume_5 - ask_volume_5) / (bid_volume_5 + ask_volume_5) if (bid_volume_5 + ask_volume_5) > 0 else 0
features.append(volume_imbalance)
# Price levels
bid_price_levels = [price for price, _ in cob_data.bids[:10]]
ask_price_levels = [price for price, _ in cob_data.asks[:10]]
features.extend(bid_price_levels + ask_price_levels)
# Pad or truncate to desired feature count
if len(features) < feature_count:
features.extend([0.0] * (feature_count - len(features)))
else:
features = features[:feature_count]
return np.array(features, dtype=np.float32)
except Exception as e:
logger.error(f"Error creating COB features for {symbol}: {e}")
return None
def get_provider_status(self) -> Dict[str, Any]:
"""Get provider status and statistics"""
with self.lock:
status = {
'is_running': self.is_running,
'symbols': self.symbols,
'cache_status': {},
'error_counts': self.error_counts.copy(),
'last_successful_fetch': {
symbol: timestamp.isoformat()
for symbol, timestamp in self.last_successful_fetch.items()
},
'fetch_stats': self.fetch_stats.copy(),
'rate_limiter_status': self.rate_limiter.get_all_endpoint_status()
}
# Cache status for each symbol
for symbol in self.symbols:
cache_time = self.cache_timestamps.get(symbol)
status['cache_status'][symbol] = {
'has_data': symbol in self.cob_cache,
'cache_time': cache_time.isoformat() if cache_time else None,
'cache_age_seconds': (datetime.now() - cache_time).total_seconds() if cache_time else None,
'data_quality': self.cob_cache[symbol].quality_score if symbol in self.cob_cache else 0.0
}
return status
def reset_errors(self):
"""Reset error counts and rate limiter"""
with self.lock:
self.error_counts.clear()
self.rate_limiter.reset_all_endpoints()
logger.info("Reset all error counts and rate limiter")
def force_refresh(self, symbol: str = None):
"""Force refresh COB data for symbol(s)"""
symbols_to_refresh = [symbol] if symbol else self.symbols
for sym in symbols_to_refresh:
# Clear cache to force refresh
with self.lock:
if sym in self.cob_cache:
del self.cob_cache[sym]
if sym in self.cache_timestamps:
del self.cache_timestamps[sym]
logger.info(f"Forced refresh for {sym}")
# Global COB provider instance
_global_cob_provider = None
def get_cob_provider(symbols: List[str] = None) -> RobustCOBProvider:
"""Get global COB provider instance"""
global _global_cob_provider
if _global_cob_provider is None:
_global_cob_provider = RobustCOBProvider(symbols)
return _global_cob_provider

View File

@ -1,350 +0,0 @@
#!/usr/bin/env python3
"""
Shared COB Service - Eliminates Redundant COB Implementations
This service provides a singleton COB integration that can be shared across:
- Dashboard components
- RL trading systems
- Enhanced orchestrators
- Training pipelines
Instead of each component creating its own COBIntegration instance,
they all share this single service, eliminating redundant connections.
"""
import asyncio
import logging
import weakref
from typing import Dict, List, Optional, Any, Callable, Set
from datetime import datetime
from threading import Lock
from dataclasses import dataclass
from .cob_integration import COBIntegration
from .multi_exchange_cob_provider import COBSnapshot
from .data_provider import DataProvider
logger = logging.getLogger(__name__)
@dataclass
class COBSubscription:
"""Represents a subscription to COB updates"""
subscriber_id: str
callback: Callable
symbol_filter: Optional[List[str]] = None
callback_type: str = "general" # general, cnn, dqn, dashboard
class SharedCOBService:
"""
Shared COB Service - Singleton pattern for unified COB data access
This service eliminates redundant COB integrations by providing a single
shared instance that all components can subscribe to.
"""
_instance: Optional['SharedCOBService'] = None
_lock = Lock()
def __new__(cls, *args, **kwargs):
"""Singleton pattern implementation"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(SharedCOBService, cls).__new__(cls)
return cls._instance
def __init__(self, symbols: Optional[List[str]] = None, data_provider: Optional[DataProvider] = None):
"""Initialize shared COB service (only called once due to singleton)"""
if hasattr(self, '_initialized'):
return
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
self.data_provider = data_provider
# Single COB integration instance
self.cob_integration: Optional[COBIntegration] = None
self.is_running = False
# Subscriber management
self.subscribers: Dict[str, COBSubscription] = {}
self.subscriber_counter = 0
self.subscription_lock = Lock()
# Cached data for immediate access
self.latest_snapshots: Dict[str, COBSnapshot] = {}
self.latest_cnn_features: Dict[str, Any] = {}
self.latest_dqn_states: Dict[str, Any] = {}
# Performance tracking
self.total_subscribers = 0
self.update_count = 0
self.start_time = None
self._initialized = True
logger.info(f"SharedCOBService initialized for symbols: {self.symbols}")
async def start(self) -> None:
"""Start the shared COB service"""
if self.is_running:
logger.warning("SharedCOBService already running")
return
logger.info("Starting SharedCOBService...")
try:
# Initialize COB integration if not already done
if self.cob_integration is None:
self.cob_integration = COBIntegration(
data_provider=self.data_provider,
symbols=self.symbols
)
# Register internal callbacks
self.cob_integration.add_cnn_callback(self._on_cob_cnn_update)
self.cob_integration.add_dqn_callback(self._on_cob_dqn_update)
self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_update)
# Start COB integration
await self.cob_integration.start()
self.is_running = True
self.start_time = datetime.now()
logger.info("SharedCOBService started successfully")
logger.info(f"Active subscribers: {len(self.subscribers)}")
except Exception as e:
logger.error(f"Error starting SharedCOBService: {e}")
raise
async def stop(self) -> None:
"""Stop the shared COB service"""
if not self.is_running:
return
logger.info("Stopping SharedCOBService...")
try:
if self.cob_integration:
await self.cob_integration.stop()
self.is_running = False
# Notify all subscribers of shutdown
for subscription in self.subscribers.values():
try:
if hasattr(subscription.callback, '__call__'):
subscription.callback("SHUTDOWN", None)
except Exception as e:
logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}")
logger.info("SharedCOBService stopped")
except Exception as e:
logger.error(f"Error stopping SharedCOBService: {e}")
def subscribe(self,
callback: Callable,
callback_type: str = "general",
symbol_filter: Optional[List[str]] = None,
subscriber_name: str = None) -> str:
"""
Subscribe to COB updates
Args:
callback: Function to call on updates
callback_type: Type of callback ('general', 'cnn', 'dqn', 'dashboard')
symbol_filter: Only receive updates for these symbols (None = all)
subscriber_name: Optional name for the subscriber
Returns:
Subscription ID for unsubscribing
"""
with self.subscription_lock:
self.subscriber_counter += 1
subscriber_id = f"{callback_type}_{self.subscriber_counter}"
if subscriber_name:
subscriber_id = f"{subscriber_name}_{subscriber_id}"
subscription = COBSubscription(
subscriber_id=subscriber_id,
callback=callback,
symbol_filter=symbol_filter,
callback_type=callback_type
)
self.subscribers[subscriber_id] = subscription
self.total_subscribers += 1
logger.info(f"New subscriber: {subscriber_id} ({callback_type})")
logger.info(f"Total active subscribers: {len(self.subscribers)}")
return subscriber_id
def unsubscribe(self, subscriber_id: str) -> bool:
"""
Unsubscribe from COB updates
Args:
subscriber_id: ID returned from subscribe()
Returns:
True if successfully unsubscribed
"""
with self.subscription_lock:
if subscriber_id in self.subscribers:
del self.subscribers[subscriber_id]
logger.info(f"Unsubscribed: {subscriber_id}")
logger.info(f"Remaining subscribers: {len(self.subscribers)}")
return True
else:
logger.warning(f"Subscriber not found: {subscriber_id}")
return False
# Internal callback handlers
async def _on_cob_cnn_update(self, symbol: str, data: Dict):
"""Handle CNN feature updates from COB integration"""
try:
self.latest_cnn_features[symbol] = data
await self._notify_subscribers("cnn", symbol, data)
except Exception as e:
logger.error(f"Error in CNN update handler: {e}")
async def _on_cob_dqn_update(self, symbol: str, data: Dict):
"""Handle DQN state updates from COB integration"""
try:
self.latest_dqn_states[symbol] = data
await self._notify_subscribers("dqn", symbol, data)
except Exception as e:
logger.error(f"Error in DQN update handler: {e}")
async def _on_cob_dashboard_update(self, symbol: str, data: Dict):
"""Handle dashboard updates from COB integration"""
try:
# Store snapshot if it's a COBSnapshot
if hasattr(data, 'volume_weighted_mid'): # Duck typing for COBSnapshot
self.latest_snapshots[symbol] = data
await self._notify_subscribers("dashboard", symbol, data)
await self._notify_subscribers("general", symbol, data)
self.update_count += 1
except Exception as e:
logger.error(f"Error in dashboard update handler: {e}")
async def _notify_subscribers(self, callback_type: str, symbol: str, data: Any):
"""Notify all relevant subscribers of an update"""
try:
relevant_subscribers = [
sub for sub in self.subscribers.values()
if (sub.callback_type == callback_type or sub.callback_type == "general") and
(sub.symbol_filter is None or symbol in sub.symbol_filter)
]
for subscription in relevant_subscribers:
try:
if asyncio.iscoroutinefunction(subscription.callback):
asyncio.create_task(subscription.callback(symbol, data))
else:
subscription.callback(symbol, data)
except Exception as e:
logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}")
except Exception as e:
logger.error(f"Error notifying subscribers: {e}")
# Public data access methods
def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]:
"""Get latest COB snapshot for a symbol"""
if self.cob_integration:
return self.cob_integration.get_cob_snapshot(symbol)
return self.latest_snapshots.get(symbol)
def get_cnn_features(self, symbol: str) -> Optional[Any]:
"""Get latest CNN features for a symbol"""
return self.latest_cnn_features.get(symbol)
def get_dqn_state(self, symbol: str) -> Optional[Any]:
"""Get latest DQN state for a symbol"""
return self.latest_dqn_states.get(symbol)
def get_market_depth_analysis(self, symbol: str) -> Optional[Dict]:
"""Get detailed market depth analysis"""
if self.cob_integration:
return self.cob_integration.get_market_depth_analysis(symbol)
return None
def get_exchange_breakdown(self, symbol: str) -> Optional[Dict]:
"""Get liquidity breakdown by exchange"""
if self.cob_integration:
return self.cob_integration.get_exchange_breakdown(symbol)
return None
def get_price_buckets(self, symbol: str) -> Optional[Dict]:
"""Get fine-grain price buckets"""
if self.cob_integration:
return self.cob_integration.get_price_buckets(symbol)
return None
def get_session_volume_profile(self, symbol: str) -> Optional[Dict]:
"""Get session volume profile"""
if self.cob_integration and hasattr(self.cob_integration.cob_provider, 'get_session_volume_profile'):
return self.cob_integration.cob_provider.get_session_volume_profile(symbol)
return None
def get_realtime_stats_for_nn(self, symbol: str) -> Dict:
"""Get real-time statistics formatted for NN models"""
if self.cob_integration:
return self.cob_integration.get_realtime_stats_for_nn(symbol)
return {}
def get_service_statistics(self) -> Dict[str, Any]:
"""Get service statistics"""
uptime = None
if self.start_time:
uptime = (datetime.now() - self.start_time).total_seconds()
base_stats = {
'service_name': 'SharedCOBService',
'is_running': self.is_running,
'symbols': self.symbols,
'total_subscribers': len(self.subscribers),
'lifetime_subscribers': self.total_subscribers,
'update_count': self.update_count,
'uptime_seconds': uptime,
'subscribers_by_type': {}
}
# Count subscribers by type
for subscription in self.subscribers.values():
callback_type = subscription.callback_type
if callback_type not in base_stats['subscribers_by_type']:
base_stats['subscribers_by_type'][callback_type] = 0
base_stats['subscribers_by_type'][callback_type] += 1
# Get COB integration stats if available
if self.cob_integration:
cob_stats = self.cob_integration.get_statistics()
base_stats.update(cob_stats)
return base_stats
# Global service instance access functions
def get_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService:
"""Get the shared COB service instance"""
return SharedCOBService(symbols=symbols, data_provider=data_provider)
async def start_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService:
"""Start the shared COB service"""
service = get_shared_cob_service(symbols=symbols, data_provider=data_provider)
await service.start()
return service
async def stop_shared_cob_service():
"""Stop the shared COB service"""
service = get_shared_cob_service()
await service.stop()

View File

@ -1,425 +0,0 @@
"""
Shared Data Manager for UI Stability Fix
Manages data sharing between processes through files with proper locking
and atomic operations to prevent corruption and conflicts.
"""
import json
import os
import time
import tempfile
import platform
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional, Union
from pathlib import Path
import logging
# Windows-compatible file locking
if platform.system() == "Windows":
import msvcrt
else:
import fcntl
logger = logging.getLogger(__name__)
@dataclass
class ProcessStatus:
"""Model for process status information"""
name: str
pid: int
status: str # 'running', 'stopped', 'error'
start_time: datetime
last_heartbeat: datetime
memory_usage: float
cpu_usage: float
error_message: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary with datetime serialization"""
data = asdict(self)
data['start_time'] = self.start_time.isoformat()
data['last_heartbeat'] = self.last_heartbeat.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ProcessStatus':
"""Create from dictionary with datetime deserialization"""
data['start_time'] = datetime.fromisoformat(data['start_time'])
data['last_heartbeat'] = datetime.fromisoformat(data['last_heartbeat'])
return cls(**data)
@dataclass
class TrainingStatus:
"""Model for training status information"""
is_running: bool
current_epoch: int
total_epochs: int
loss: float
accuracy: float
last_update: datetime
model_path: str
error_message: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary with datetime serialization"""
data = asdict(self)
data['last_update'] = self.last_update.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'TrainingStatus':
"""Create from dictionary with datetime deserialization"""
data['last_update'] = datetime.fromisoformat(data['last_update'])
return cls(**data)
@dataclass
class DashboardState:
"""Model for dashboard state information"""
is_connected: bool
last_data_update: datetime
active_connections: int
error_count: int
performance_metrics: Dict[str, float]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary with datetime serialization"""
data = asdict(self)
data['last_data_update'] = self.last_data_update.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'DashboardState':
"""Create from dictionary with datetime deserialization"""
data['last_data_update'] = datetime.fromisoformat(data['last_data_update'])
return cls(**data)
class SharedDataManager:
"""
Manages data sharing between processes through files with proper locking
and atomic operations to prevent corruption and conflicts.
"""
def __init__(self, data_dir: str = "shared_data"):
"""
Initialize the shared data manager
Args:
data_dir: Directory to store shared data files
"""
self.data_dir = Path(data_dir)
self.data_dir.mkdir(exist_ok=True)
# Define file paths for different data types
self.training_status_file = self.data_dir / "training_status.json"
self.dashboard_state_file = self.data_dir / "dashboard_state.json"
self.process_status_file = self.data_dir / "process_status.json"
self.market_data_file = self.data_dir / "market_data.json"
self.model_metrics_file = self.data_dir / "model_metrics.json"
logger.info(f"SharedDataManager initialized with data directory: {self.data_dir}")
def _lock_file(self, file_handle, exclusive=True):
"""Cross-platform file locking"""
if platform.system() == "Windows":
# Windows file locking
try:
if exclusive:
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
else:
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
except IOError:
pass # File locking may not be available in all scenarios
else:
# Unix file locking
lock_type = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH
fcntl.flock(file_handle.fileno(), lock_type)
def _unlock_file(self, file_handle):
"""Cross-platform file unlocking"""
if platform.system() == "Windows":
try:
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
except IOError:
pass
else:
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
def _write_json_atomic(self, file_path: Path, data: Dict[str, Any]) -> None:
"""
Write JSON data atomically with file locking
Args:
file_path: Path to the file to write
data: Data to write as JSON
"""
temp_path = None
try:
# Create temporary file in the same directory
temp_fd, temp_path = tempfile.mkstemp(
dir=file_path.parent,
prefix=f".{file_path.name}.",
suffix=".tmp"
)
with os.fdopen(temp_fd, 'w') as temp_file:
# Lock the temporary file
self._lock_file(temp_file, exclusive=True)
# Write data with proper formatting
json.dump(data, temp_file, indent=2, default=str)
temp_file.flush()
os.fsync(temp_file.fileno())
# Unlock before closing
self._unlock_file(temp_file)
# Atomically replace the original file
os.replace(temp_path, file_path)
logger.debug(f"Successfully wrote data to {file_path}")
except Exception as e:
# Clean up temporary file if it exists
if temp_path:
try:
os.unlink(temp_path)
except:
pass
logger.error(f"Failed to write data to {file_path}: {e}")
raise
def _read_json_safe(self, file_path: Path) -> Dict[str, Any]:
"""
Read JSON data safely with file locking
Args:
file_path: Path to the file to read
Returns:
Dictionary containing the JSON data
"""
if not file_path.exists():
logger.debug(f"File {file_path} does not exist, returning empty dict")
return {}
try:
with open(file_path, 'r') as file:
# Lock the file for reading
self._lock_file(file, exclusive=False)
data = json.load(file)
self._unlock_file(file)
logger.debug(f"Successfully read data from {file_path}")
return data
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in {file_path}: {e}")
return {}
except Exception as e:
logger.error(f"Failed to read data from {file_path}: {e}")
return {}
def write_training_status(self, status: TrainingStatus) -> None:
"""
Write training status to shared file
Args:
status: TrainingStatus object to write
"""
try:
data = status.to_dict()
self._write_json_atomic(self.training_status_file, data)
logger.debug("Training status written successfully")
except Exception as e:
logger.error(f"Failed to write training status: {e}")
raise
def read_training_status(self) -> Optional[TrainingStatus]:
"""
Read training status from shared file
Returns:
TrainingStatus object or None if not available
"""
try:
data = self._read_json_safe(self.training_status_file)
if not data:
return None
return TrainingStatus.from_dict(data)
except Exception as e:
logger.error(f"Failed to read training status: {e}")
return None
def write_dashboard_state(self, state: DashboardState) -> None:
"""
Write dashboard state to shared file
Args:
state: DashboardState object to write
"""
try:
data = state.to_dict()
self._write_json_atomic(self.dashboard_state_file, data)
logger.debug("Dashboard state written successfully")
except Exception as e:
logger.error(f"Failed to write dashboard state: {e}")
raise
def read_dashboard_state(self) -> Optional[DashboardState]:
"""
Read dashboard state from shared file
Returns:
DashboardState object or None if not available
"""
try:
data = self._read_json_safe(self.dashboard_state_file)
if not data:
return None
return DashboardState.from_dict(data)
except Exception as e:
logger.error(f"Failed to read dashboard state: {e}")
return None
def write_process_status(self, status: ProcessStatus) -> None:
"""
Write process status to shared file
Args:
status: ProcessStatus object to write
"""
try:
data = status.to_dict()
self._write_json_atomic(self.process_status_file, data)
logger.debug("Process status written successfully")
except Exception as e:
logger.error(f"Failed to write process status: {e}")
raise
def read_process_status(self) -> Optional[ProcessStatus]:
"""
Read process status from shared file
Returns:
ProcessStatus object or None if not available
"""
try:
data = self._read_json_safe(self.process_status_file)
if not data:
return None
return ProcessStatus.from_dict(data)
except Exception as e:
logger.error(f"Failed to read process status: {e}")
return None
def write_market_data(self, data: Dict[str, Any]) -> None:
"""
Write market data to shared file
Args:
data: Market data dictionary to write
"""
try:
# Add timestamp to market data
data['timestamp'] = datetime.now().isoformat()
self._write_json_atomic(self.market_data_file, data)
logger.debug("Market data written successfully")
except Exception as e:
logger.error(f"Failed to write market data: {e}")
raise
def read_market_data(self) -> Dict[str, Any]:
"""
Read market data from shared file
Returns:
Dictionary containing market data
"""
try:
return self._read_json_safe(self.market_data_file)
except Exception as e:
logger.error(f"Failed to read market data: {e}")
return {}
def write_model_metrics(self, metrics: Dict[str, Any]) -> None:
"""
Write model metrics to shared file
Args:
metrics: Model metrics dictionary to write
"""
try:
# Add timestamp to metrics
metrics['timestamp'] = datetime.now().isoformat()
self._write_json_atomic(self.model_metrics_file, metrics)
logger.debug("Model metrics written successfully")
except Exception as e:
logger.error(f"Failed to write model metrics: {e}")
raise
def read_model_metrics(self) -> Dict[str, Any]:
"""
Read model metrics from shared file
Returns:
Dictionary containing model metrics
"""
try:
return self._read_json_safe(self.model_metrics_file)
except Exception as e:
logger.error(f"Failed to read model metrics: {e}")
return {}
def cleanup(self) -> None:
"""
Clean up shared data files
"""
try:
for file_path in [
self.training_status_file,
self.dashboard_state_file,
self.process_status_file,
self.market_data_file,
self.model_metrics_file
]:
if file_path.exists():
file_path.unlink()
logger.debug(f"Removed {file_path}")
# Remove directory if empty
if self.data_dir.exists() and not any(self.data_dir.iterdir()):
self.data_dir.rmdir()
logger.debug(f"Removed empty directory {self.data_dir}")
except Exception as e:
logger.error(f"Failed to cleanup shared data: {e}")
def get_data_age(self, data_type: str) -> Optional[float]:
"""
Get the age of data in seconds
Args:
data_type: Type of data ('training', 'dashboard', 'process', 'market', 'metrics')
Returns:
Age in seconds or None if file doesn't exist
"""
file_map = {
'training': self.training_status_file,
'dashboard': self.dashboard_state_file,
'process': self.process_status_file,
'market': self.market_data_file,
'metrics': self.model_metrics_file
}
file_path = file_map.get(data_type)
if not file_path or not file_path.exists():
return None
try:
mtime = file_path.stat().st_mtime
return time.time() - mtime
except Exception as e:
logger.error(f"Failed to get data age for {data_type}: {e}")
return None

View File

@ -1,59 +0,0 @@
"""
Trading Action Module
Defines the TradingAction class used throughout the trading system.
"""
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, Any, List
@dataclass
class TradingAction:
"""Represents a trading action with full context"""
symbol: str
action: str # 'BUY', 'SELL', 'HOLD'
quantity: float
confidence: float
price: float
timestamp: datetime
reasoning: Dict[str, Any]
def __post_init__(self):
"""Validate the trading action after initialization"""
if self.action not in ['BUY', 'SELL', 'HOLD']:
raise ValueError(f"Invalid action: {self.action}. Must be 'BUY', 'SELL', or 'HOLD'")
if self.confidence < 0.0 or self.confidence > 1.0:
raise ValueError(f"Invalid confidence: {self.confidence}. Must be between 0.0 and 1.0")
if self.quantity < 0:
raise ValueError(f"Invalid quantity: {self.quantity}. Must be non-negative")
if self.price <= 0:
raise ValueError(f"Invalid price: {self.price}. Must be positive")
def to_dict(self) -> Dict[str, Any]:
"""Convert trading action to dictionary"""
return {
'symbol': self.symbol,
'action': self.action,
'quantity': self.quantity,
'confidence': self.confidence,
'price': self.price,
'timestamp': self.timestamp.isoformat(),
'reasoning': self.reasoning
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'TradingAction':
"""Create trading action from dictionary"""
return cls(
symbol=data['symbol'],
action=data['action'],
quantity=data['quantity'],
confidence=data['confidence'],
price=data['price'],
timestamp=datetime.fromisoformat(data['timestamp']),
reasoning=data['reasoning']
)

View File

@ -1,401 +0,0 @@
"""
Trading Executor Fix - Addresses issues with entry/exit prices and P&L calculations
This module provides fixes for:
1. Identical entry prices issue
2. Price caching problems
3. Position tracking reset logic
4. Trade cooldown implementation
5. P&L calculation verification
Apply these fixes to the TradingExecutor class to improve trade execution reliability.
"""
import logging
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
logger = logging.getLogger(__name__)
class TradingExecutorFix:
"""
Fixes for the TradingExecutor class to address entry/exit price issues
and improve P&L calculation accuracy.
"""
def __init__(self, trading_executor):
"""
Initialize the fix with a reference to the trading executor
Args:
trading_executor: The TradingExecutor instance to fix
"""
self.trading_executor = trading_executor
# Add cooldown tracking
self.last_trade_time = {} # {symbol: timestamp}
self.min_trade_cooldown = 30 # 30 seconds minimum between trades
# Add price history for validation
self.recent_entry_prices = {} # {symbol: [recent_prices]}
self.max_price_history = 10 # Keep last 10 entry prices
# Add position reset tracking
self.position_reset_flags = {} # {symbol: bool}
# Add price update tracking
self.last_price_update = {} # {symbol: timestamp}
self.price_update_threshold = 5 # 5 seconds max since last price update
# Add P&L verification
self.trade_history = {} # {symbol: [trade_records]}
logger.info("TradingExecutorFix initialized - addressing entry/exit price issues")
def apply_fixes(self):
"""Apply all fixes to the trading executor"""
self._patch_execute_action()
self._patch_close_position()
self._patch_calculate_pnl()
self._patch_update_prices()
logger.info("All trading executor fixes applied successfully")
def _patch_execute_action(self):
"""Patch the execute_action method to add price validation and cooldown"""
original_execute_action = self.trading_executor.execute_action
def execute_action_with_fixes(decision):
"""Enhanced execute_action with price validation and cooldown"""
try:
symbol = decision.symbol
action = decision.action
current_time = datetime.now()
# 1. Check cooldown period
if symbol in self.last_trade_time:
time_since_last_trade = (current_time - self.last_trade_time[symbol]).total_seconds()
if time_since_last_trade < self.min_trade_cooldown:
logger.warning(f"Trade rejected: Cooldown period ({time_since_last_trade:.1f}s < {self.min_trade_cooldown}s) for {symbol}")
return False
# 2. Validate price freshness
if symbol in self.last_price_update:
time_since_update = (current_time - self.last_price_update[symbol]).total_seconds()
if time_since_update > self.price_update_threshold:
logger.warning(f"Trade rejected: Price data stale ({time_since_update:.1f}s > {self.price_update_threshold}s) for {symbol}")
# Force price refresh
self._refresh_price(symbol)
return False
# 3. Validate entry price against recent history
current_price = self._get_current_price(symbol)
if symbol in self.recent_entry_prices and len(self.recent_entry_prices[symbol]) > 0:
# Check if price is identical to any recent entry
if current_price in self.recent_entry_prices[symbol]:
logger.warning(f"Trade rejected: Duplicate entry price ${current_price} for {symbol}")
return False
# 4. Ensure position is properly reset before new entry
if not self._ensure_position_reset(symbol):
logger.warning(f"Trade rejected: Position not properly reset for {symbol}")
return False
# Execute the original action
result = original_execute_action(decision)
# If successful, update tracking
if result:
# Update cooldown timestamp
self.last_trade_time[symbol] = current_time
# Update price history
if symbol not in self.recent_entry_prices:
self.recent_entry_prices[symbol] = []
self.recent_entry_prices[symbol].append(current_price)
# Keep only the most recent prices
if len(self.recent_entry_prices[symbol]) > self.max_price_history:
self.recent_entry_prices[symbol] = self.recent_entry_prices[symbol][-self.max_price_history:]
# Mark position as active
self.position_reset_flags[symbol] = False
logger.info(f"Trade executed: {action} {symbol} at ${current_price} with validation")
return result
except Exception as e:
logger.error(f"Error in execute_action_with_fixes: {e}")
return original_execute_action(decision)
# Replace the original method
self.trading_executor.execute_action = execute_action_with_fixes
logger.info("Patched execute_action with price validation and cooldown")
def _patch_close_position(self):
"""Patch the close_position method to ensure proper position reset"""
original_close_position = self.trading_executor.close_position
def close_position_with_fixes(symbol, **kwargs):
"""Enhanced close_position with proper reset logic"""
try:
# Get current price for P&L verification
exit_price = self._get_current_price(symbol)
# Call original close position
result = original_close_position(symbol, **kwargs)
if result:
# Mark position as reset
self.position_reset_flags[symbol] = True
# Record trade for verification
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
position = self.trading_executor.positions[symbol]
# Create trade record
trade_record = {
'symbol': symbol,
'entry_time': getattr(position, 'entry_time', datetime.now()),
'exit_time': datetime.now(),
'entry_price': getattr(position, 'entry_price', 0),
'exit_price': exit_price,
'size': getattr(position, 'size', 0),
'side': getattr(position, 'side', 'UNKNOWN'),
'pnl': self._calculate_verified_pnl(position, exit_price),
'fees': getattr(position, 'fees', 0),
'hold_time_seconds': (datetime.now() - getattr(position, 'entry_time', datetime.now())).total_seconds()
}
# Store trade record
if symbol not in self.trade_history:
self.trade_history[symbol] = []
self.trade_history[symbol].append(trade_record)
logger.info(f"Position closed: {symbol} at ${exit_price} with verified P&L: ${trade_record['pnl']:.2f}")
return result
except Exception as e:
logger.error(f"Error in close_position_with_fixes: {e}")
return original_close_position(symbol, **kwargs)
# Replace the original method
self.trading_executor.close_position = close_position_with_fixes
logger.info("Patched close_position with proper reset logic")
def _patch_calculate_pnl(self):
"""Patch the calculate_pnl method to ensure accurate P&L calculation"""
original_calculate_pnl = getattr(self.trading_executor, 'calculate_pnl', None)
def calculate_pnl_with_fixes(position, current_price=None):
"""Enhanced calculate_pnl with verification"""
try:
# If no original method, implement our own
if original_calculate_pnl is None:
return self._calculate_verified_pnl(position, current_price)
# Call original method
original_pnl = original_calculate_pnl(position, current_price)
# Calculate our verified P&L
verified_pnl = self._calculate_verified_pnl(position, current_price)
# If there's a significant difference, log it
if abs(original_pnl - verified_pnl) > 0.01:
logger.warning(f"P&L calculation discrepancy: original=${original_pnl:.2f}, verified=${verified_pnl:.2f}")
# Use the verified P&L
return verified_pnl
return original_pnl
except Exception as e:
logger.error(f"Error in calculate_pnl_with_fixes: {e}")
if original_calculate_pnl:
return original_calculate_pnl(position, current_price)
return 0.0
# Replace the original method if it exists
if original_calculate_pnl:
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
logger.info("Patched calculate_pnl with verification")
else:
# Add the method if it doesn't exist
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
logger.info("Added calculate_pnl method with verification")
def _patch_update_prices(self):
"""Patch the update_prices method to track price updates"""
original_update_prices = getattr(self.trading_executor, 'update_prices', None)
def update_prices_with_tracking(prices):
"""Enhanced update_prices with timestamp tracking"""
try:
# Call original method if it exists
if original_update_prices:
result = original_update_prices(prices)
else:
# If no original method, update prices directly
if hasattr(self.trading_executor, 'current_prices'):
self.trading_executor.current_prices.update(prices)
result = True
# Track update timestamps
current_time = datetime.now()
for symbol in prices:
self.last_price_update[symbol] = current_time
return result
except Exception as e:
logger.error(f"Error in update_prices_with_tracking: {e}")
if original_update_prices:
return original_update_prices(prices)
return False
# Replace the original method if it exists
if original_update_prices:
self.trading_executor.update_prices = update_prices_with_tracking
logger.info("Patched update_prices with timestamp tracking")
else:
# Add the method if it doesn't exist
self.trading_executor.update_prices = update_prices_with_tracking
logger.info("Added update_prices method with timestamp tracking")
def _calculate_verified_pnl(self, position, current_price=None):
"""Calculate verified P&L for a position"""
try:
# Get position details
entry_price = getattr(position, 'entry_price', 0)
size = getattr(position, 'size', 0)
side = getattr(position, 'side', 'UNKNOWN')
leverage = getattr(position, 'leverage', 1.0)
fees = getattr(position, 'fees', 0.0)
# If current_price is not provided, try to get it
if current_price is None:
symbol = getattr(position, 'symbol', None)
if symbol:
current_price = self._get_current_price(symbol)
else:
return 0.0
# Calculate P&L based on position side
if side == 'LONG':
pnl = (current_price - entry_price) * size * leverage
elif side == 'SHORT':
pnl = (entry_price - current_price) * size * leverage
else:
pnl = 0.0
# Subtract fees for net P&L
net_pnl = pnl - fees
return net_pnl
except Exception as e:
logger.error(f"Error calculating verified P&L: {e}")
return 0.0
def _get_current_price(self, symbol):
"""Get current price for a symbol with fallbacks"""
try:
# Try to get from trading executor
if hasattr(self.trading_executor, 'current_prices') and symbol in self.trading_executor.current_prices:
return self.trading_executor.current_prices[symbol]
# Try to get from data provider
if hasattr(self.trading_executor, 'data_provider'):
data_provider = self.trading_executor.data_provider
if hasattr(data_provider, 'get_current_price'):
price = data_provider.get_current_price(symbol)
if price and price > 0:
return price
# Try to get from COB data
if hasattr(self.trading_executor, 'latest_cob_data') and symbol in self.trading_executor.latest_cob_data:
cob_data = self.trading_executor.latest_cob_data[symbol]
if hasattr(cob_data, 'stats') and 'mid_price' in cob_data.stats:
return cob_data.stats['mid_price']
# Default fallback
return 0.0
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
return 0.0
def _refresh_price(self, symbol):
"""Force a price refresh for a symbol"""
try:
# Try to refresh from data provider
if hasattr(self.trading_executor, 'data_provider'):
data_provider = self.trading_executor.data_provider
if hasattr(data_provider, 'fetch_current_price'):
price = data_provider.fetch_current_price(symbol)
if price and price > 0:
# Update trading executor price
if hasattr(self.trading_executor, 'current_prices'):
self.trading_executor.current_prices[symbol] = price
# Update timestamp
self.last_price_update[symbol] = datetime.now()
logger.info(f"Refreshed price for {symbol}: ${price:.2f}")
return True
logger.warning(f"Failed to refresh price for {symbol}")
return False
except Exception as e:
logger.error(f"Error refreshing price for {symbol}: {e}")
return False
def _ensure_position_reset(self, symbol):
"""Ensure position is properly reset before new entry"""
try:
# Check if we have an active position
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
# Position exists, check if it's valid
position = self.trading_executor.positions[symbol]
if position and getattr(position, 'active', False):
logger.warning(f"Position already active for {symbol}, cannot enter new position")
return False
# Check reset flag
if symbol in self.position_reset_flags and not self.position_reset_flags[symbol]:
# Force position cleanup
if hasattr(self.trading_executor, 'positions'):
self.trading_executor.positions.pop(symbol, None)
logger.info(f"Forced position reset for {symbol}")
self.position_reset_flags[symbol] = True
return True
except Exception as e:
logger.error(f"Error ensuring position reset for {symbol}: {e}")
return False
def get_trade_history(self, symbol=None):
"""Get verified trade history"""
if symbol:
return self.trade_history.get(symbol, [])
return self.trade_history
def get_price_update_status(self):
"""Get price update status for all symbols"""
status = {}
current_time = datetime.now()
for symbol, timestamp in self.last_price_update.items():
time_since_update = (current_time - timestamp).total_seconds()
status[symbol] = {
'last_update': timestamp,
'seconds_ago': time_since_update,
'is_fresh': time_since_update <= self.price_update_threshold
}
return status

View File

@ -1,795 +0,0 @@
"""
Comprehensive Training Data Collection System
This module implements a robust training data collection system that:
1. Captures all model inputs with validation and completeness checks
2. Stores training data packages with future outcome validation
3. Detects rapid price changes for high-value training examples
4. Enables replay and retraining on most profitable setups
5. Maintains data integrity and traceability
Key Features:
- Real-time data package creation with all model inputs
- Future outcome validation (profitable vs unprofitable predictions)
- Rapid price change detection for premium training examples
- Comprehensive data validation and completeness verification
- Backpropagation data storage for gradient replay
- Training episode profitability tracking and ranking
"""
import asyncio
import json
import logging
import numpy as np
import pandas as pd
import pickle
import torch
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass, field, asdict
from collections import deque
import hashlib
import threading
from concurrent.futures import ThreadPoolExecutor
logger = logging.getLogger(__name__)
@dataclass
class ModelInputPackage:
"""Complete package of all model inputs at a specific timestamp"""
timestamp: datetime
symbol: str
# Market data inputs
ohlcv_data: Dict[str, pd.DataFrame] # {timeframe: DataFrame}
tick_data: List[Dict[str, Any]] # Raw tick data
cob_data: Dict[str, Any] # Consolidated Order Book data
technical_indicators: Dict[str, float] # All technical indicators
pivot_points: List[Dict[str, Any]] # Detected pivot points
# Model-specific inputs
cnn_features: np.ndarray # CNN input features
rl_state: np.ndarray # RL state representation
orchestrator_context: Dict[str, Any] # Orchestrator context
# Cross-model inputs (outputs from other models)
cnn_predictions: Optional[Dict[str, Any]] = None
rl_predictions: Optional[Dict[str, Any]] = None
orchestrator_decision: Optional[Dict[str, Any]] = None
# Data validation
data_hash: str = ""
completeness_score: float = 0.0
validation_flags: Dict[str, bool] = field(default_factory=dict)
def __post_init__(self):
"""Calculate data hash and completeness after initialization"""
self.data_hash = self._calculate_hash()
self.completeness_score = self._calculate_completeness()
self.validation_flags = self._validate_data()
def _calculate_hash(self) -> str:
"""Calculate hash for data integrity verification"""
try:
# Create a string representation of all data
data_str = f"{self.timestamp}_{self.symbol}"
data_str += f"_{len(self.ohlcv_data)}_{len(self.tick_data)}"
data_str += f"_{self.cnn_features.shape if self.cnn_features is not None else 'None'}"
data_str += f"_{self.rl_state.shape if self.rl_state is not None else 'None'}"
return hashlib.md5(data_str.encode()).hexdigest()
except Exception as e:
logger.warning(f"Error calculating data hash: {e}")
return "invalid_hash"
def _calculate_completeness(self) -> float:
"""Calculate completeness score (0.0 to 1.0)"""
try:
total_fields = 10 # Total expected data fields
complete_fields = 0
# Check each required field
if self.ohlcv_data and len(self.ohlcv_data) > 0:
complete_fields += 1
if self.tick_data and len(self.tick_data) > 0:
complete_fields += 1
if self.cob_data and len(self.cob_data) > 0:
complete_fields += 1
if self.technical_indicators and len(self.technical_indicators) > 0:
complete_fields += 1
if self.pivot_points and len(self.pivot_points) > 0:
complete_fields += 1
if self.cnn_features is not None and self.cnn_features.size > 0:
complete_fields += 1
if self.rl_state is not None and self.rl_state.size > 0:
complete_fields += 1
if self.orchestrator_context and len(self.orchestrator_context) > 0:
complete_fields += 1
if self.cnn_predictions is not None:
complete_fields += 1
if self.rl_predictions is not None:
complete_fields += 1
return complete_fields / total_fields
except Exception as e:
logger.warning(f"Error calculating completeness: {e}")
return 0.0
def _validate_data(self) -> Dict[str, bool]:
"""Validate data integrity and consistency"""
flags = {}
try:
# Validate timestamp
flags['valid_timestamp'] = isinstance(self.timestamp, datetime)
# Validate OHLCV data
flags['valid_ohlcv'] = (
self.ohlcv_data is not None and
len(self.ohlcv_data) > 0 and
all(isinstance(df, pd.DataFrame) for df in self.ohlcv_data.values())
)
# Validate feature arrays
flags['valid_cnn_features'] = (
self.cnn_features is not None and
isinstance(self.cnn_features, np.ndarray) and
self.cnn_features.size > 0
)
flags['valid_rl_state'] = (
self.rl_state is not None and
isinstance(self.rl_state, np.ndarray) and
self.rl_state.size > 0
)
# Validate data consistency
flags['data_consistent'] = self.completeness_score > 0.7
except Exception as e:
logger.warning(f"Error validating data: {e}")
flags['validation_error'] = True
return flags
@dataclass
class TrainingOutcome:
"""Future outcome validation for training data"""
input_package_hash: str
timestamp: datetime
symbol: str
# Price movement outcomes
price_change_1m: float
price_change_5m: float
price_change_15m: float
price_change_1h: float
# Profitability metrics
max_profit_potential: float
max_loss_potential: float
optimal_entry_price: float
optimal_exit_price: float
optimal_holding_time: timedelta
# Classification labels
is_profitable: bool
profitability_score: float # 0.0 to 1.0
risk_reward_ratio: float
# Rapid price change detection
is_rapid_change: bool
change_velocity: float # Price change per minute
volatility_spike: bool
# Validation
outcome_validated: bool = False
validation_timestamp: datetime = field(default_factory=datetime.now)
@dataclass
class TrainingEpisode:
"""Complete training episode with inputs, predictions, and outcomes"""
episode_id: str
input_package: ModelInputPackage
model_predictions: Dict[str, Any] # Predictions from all models
actual_outcome: TrainingOutcome
# Training metadata
episode_type: str # 'normal', 'rapid_change', 'high_profit'
profitability_rank: float # Ranking among all episodes
training_priority: float # Priority for replay training
# Backpropagation data storage
gradient_data: Optional[Dict[str, torch.Tensor]] = None
loss_components: Optional[Dict[str, float]] = None
model_states: Optional[Dict[str, Any]] = None
# Episode statistics
created_timestamp: datetime = field(default_factory=datetime.now)
last_trained_timestamp: Optional[datetime] = None
training_count: int = 0
def calculate_training_priority(self) -> float:
"""Calculate training priority based on profitability and characteristics"""
try:
priority = 0.0
# Base priority from profitability
if self.actual_outcome.is_profitable:
priority += self.actual_outcome.profitability_score * 0.4
# Bonus for rapid changes (high learning value)
if self.actual_outcome.is_rapid_change:
priority += 0.3
# Bonus for high risk-reward ratio
if self.actual_outcome.risk_reward_ratio > 2.0:
priority += 0.2
# Bonus for data completeness
priority += self.input_package.completeness_score * 0.1
# Penalty for frequent training (avoid overfitting)
if self.training_count > 5:
priority *= 0.8
return min(priority, 1.0)
except Exception as e:
logger.warning(f"Error calculating training priority: {e}")
return 0.0
class RapidChangeDetector:
"""Detects rapid price changes for high-value training examples"""
def __init__(self,
velocity_threshold: float = 0.5, # % per minute
volatility_multiplier: float = 3.0,
lookback_minutes: int = 5):
self.velocity_threshold = velocity_threshold
self.volatility_multiplier = volatility_multiplier
self.lookback_minutes = lookback_minutes
# Price history for change detection
self.price_history: Dict[str, deque] = {}
self.volatility_baseline: Dict[str, float] = {}
def add_price_point(self, symbol: str, timestamp: datetime, price: float):
"""Add new price point for change detection"""
if symbol not in self.price_history:
self.price_history[symbol] = deque(maxlen=self.lookback_minutes * 60) # 1 second resolution
self.volatility_baseline[symbol] = 0.0
self.price_history[symbol].append((timestamp, price))
self._update_volatility_baseline(symbol)
def detect_rapid_change(self, symbol: str) -> Tuple[bool, float, bool]:
"""
Detect rapid price changes
Returns:
(is_rapid_change, change_velocity, volatility_spike)
"""
if symbol not in self.price_history or len(self.price_history[symbol]) < 60:
return False, 0.0, False
try:
prices = list(self.price_history[symbol])
# Calculate recent velocity (last minute)
recent_prices = prices[-60:] # Last 60 seconds
if len(recent_prices) < 2:
return False, 0.0, False
start_price = recent_prices[0][1]
end_price = recent_prices[-1][1]
time_diff = (recent_prices[-1][0] - recent_prices[0][0]).total_seconds() / 60.0 # minutes
if time_diff <= 0:
return False, 0.0, False
# Calculate velocity (% change per minute)
velocity = abs((end_price - start_price) / start_price * 100) / time_diff
# Check for rapid change
is_rapid = velocity > self.velocity_threshold
# Check for volatility spike
current_volatility = self._calculate_current_volatility(symbol)
baseline_volatility = self.volatility_baseline.get(symbol, 0.0)
volatility_spike = (
baseline_volatility > 0 and
current_volatility > baseline_volatility * self.volatility_multiplier
)
return is_rapid, velocity, volatility_spike
except Exception as e:
logger.warning(f"Error detecting rapid change for {symbol}: {e}")
return False, 0.0, False
def _update_volatility_baseline(self, symbol: str):
"""Update volatility baseline for the symbol"""
try:
if len(self.price_history[symbol]) < 120: # Need at least 2 minutes of data
return
# Calculate rolling volatility over longer period
prices = [p[1] for p in list(self.price_history[symbol])[-300:]] # Last 5 minutes
if len(prices) < 2:
return
# Calculate standard deviation of price changes
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
volatility = np.std(price_changes) * 100 # Convert to percentage
# Update baseline with exponential moving average
alpha = 0.1
if self.volatility_baseline[symbol] == 0:
self.volatility_baseline[symbol] = volatility
else:
self.volatility_baseline[symbol] = (
alpha * volatility + (1 - alpha) * self.volatility_baseline[symbol]
)
except Exception as e:
logger.warning(f"Error updating volatility baseline for {symbol}: {e}")
def _calculate_current_volatility(self, symbol: str) -> float:
"""Calculate current volatility for the symbol"""
try:
if len(self.price_history[symbol]) < 60:
return 0.0
# Use last minute of data
recent_prices = [p[1] for p in list(self.price_history[symbol])[-60:]]
if len(recent_prices) < 2:
return 0.0
price_changes = [abs(recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
for i in range(1, len(recent_prices))]
return np.std(price_changes) * 100
except Exception as e:
logger.warning(f"Error calculating current volatility for {symbol}: {e}")
return 0.0
class TrainingDataCollector:
"""Main training data collection system"""
def __init__(self,
storage_dir: str = "training_data",
max_episodes_per_symbol: int = 10000,
outcome_validation_delay: timedelta = timedelta(hours=1)):
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
self.max_episodes_per_symbol = max_episodes_per_symbol
self.outcome_validation_delay = outcome_validation_delay
# Data storage
self.training_episodes: Dict[str, List[TrainingEpisode]] = {} # {symbol: episodes}
self.pending_outcomes: Dict[str, List[ModelInputPackage]] = {} # Awaiting outcome validation
# Rapid change detection
self.rapid_change_detector = RapidChangeDetector()
# Data validation and statistics
self.collection_stats = {
'total_episodes': 0,
'profitable_episodes': 0,
'rapid_change_episodes': 0,
'validation_errors': 0,
'data_completeness_avg': 0.0
}
# Background processing
self.is_collecting = False
self.collection_thread = None
self.outcome_validation_thread = None
# Thread safety
self.data_lock = threading.Lock()
logger.info(f"Training Data Collector initialized")
logger.info(f"Storage directory: {self.storage_dir}")
logger.info(f"Max episodes per symbol: {self.max_episodes_per_symbol}")
def start_collection(self):
"""Start the training data collection system"""
if self.is_collecting:
logger.warning("Training data collection already running")
return
self.is_collecting = True
# Start outcome validation thread
self.outcome_validation_thread = threading.Thread(
target=self._outcome_validation_worker,
daemon=True
)
self.outcome_validation_thread.start()
logger.info("Training data collection started")
def stop_collection(self):
"""Stop the training data collection system"""
self.is_collecting = False
if self.outcome_validation_thread:
self.outcome_validation_thread.join(timeout=5)
logger.info("Training data collection stopped")
def collect_training_data(self,
symbol: str,
ohlcv_data: Dict[str, pd.DataFrame],
tick_data: List[Dict[str, Any]],
cob_data: Dict[str, Any],
technical_indicators: Dict[str, float],
pivot_points: List[Dict[str, Any]],
cnn_features: np.ndarray,
rl_state: np.ndarray,
orchestrator_context: Dict[str, Any],
model_predictions: Dict[str, Any] = None) -> str:
"""
Collect comprehensive training data package
Returns:
episode_id for tracking
"""
try:
# Create input package
input_package = ModelInputPackage(
timestamp=datetime.now(),
symbol=symbol,
ohlcv_data=ohlcv_data,
tick_data=tick_data,
cob_data=cob_data,
technical_indicators=technical_indicators,
pivot_points=pivot_points,
cnn_features=cnn_features,
rl_state=rl_state,
orchestrator_context=orchestrator_context
)
# Validate data completeness
if input_package.completeness_score < 0.5:
logger.warning(f"Low data completeness for {symbol}: {input_package.completeness_score:.2f}")
self.collection_stats['validation_errors'] += 1
return None
# Check for rapid price changes
current_price = self._extract_current_price(ohlcv_data)
if current_price:
self.rapid_change_detector.add_price_point(symbol, input_package.timestamp, current_price)
# Add to pending outcomes for future validation
with self.data_lock:
if symbol not in self.pending_outcomes:
self.pending_outcomes[symbol] = []
self.pending_outcomes[symbol].append(input_package)
# Limit pending outcomes to prevent memory issues
if len(self.pending_outcomes[symbol]) > 1000:
self.pending_outcomes[symbol] = self.pending_outcomes[symbol][-500:]
# Generate episode ID
episode_id = f"{symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
# Update statistics
self.collection_stats['total_episodes'] += 1
self.collection_stats['data_completeness_avg'] = (
(self.collection_stats['data_completeness_avg'] * (self.collection_stats['total_episodes'] - 1) +
input_package.completeness_score) / self.collection_stats['total_episodes']
)
logger.debug(f"Collected training data for {symbol}: {episode_id}")
logger.debug(f"Data completeness: {input_package.completeness_score:.2f}")
return episode_id
except Exception as e:
logger.error(f"Error collecting training data for {symbol}: {e}")
self.collection_stats['validation_errors'] += 1
return None
def _extract_current_price(self, ohlcv_data: Dict[str, pd.DataFrame]) -> Optional[float]:
"""Extract current price from OHLCV data"""
try:
# Try to get price from shortest timeframe first
for timeframe in ['1s', '1m', '5m', '15m', '1h']:
if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty:
return float(ohlcv_data[timeframe]['close'].iloc[-1])
return None
except Exception as e:
logger.warning(f"Error extracting current price: {e}")
return None
def _outcome_validation_worker(self):
"""Background worker for validating training outcomes"""
logger.info("Outcome validation worker started")
while self.is_collecting:
try:
self._validate_pending_outcomes()
threading.Event().wait(60) # Check every minute
except Exception as e:
logger.error(f"Error in outcome validation worker: {e}")
threading.Event().wait(30) # Wait before retrying
logger.info("Outcome validation worker stopped")
def _validate_pending_outcomes(self):
"""Validate outcomes for pending training data"""
current_time = datetime.now()
with self.data_lock:
for symbol in list(self.pending_outcomes.keys()):
if symbol not in self.pending_outcomes:
continue
validated_packages = []
remaining_packages = []
for package in self.pending_outcomes[symbol]:
# Check if enough time has passed for outcome validation
if current_time - package.timestamp >= self.outcome_validation_delay:
outcome = self._calculate_training_outcome(package)
if outcome:
self._create_training_episode(package, outcome)
validated_packages.append(package)
else:
remaining_packages.append(package)
else:
remaining_packages.append(package)
# Update pending outcomes
self.pending_outcomes[symbol] = remaining_packages
if validated_packages:
logger.info(f"Validated {len(validated_packages)} outcomes for {symbol}")
def _calculate_training_outcome(self, input_package: ModelInputPackage) -> Optional[TrainingOutcome]:
"""Calculate training outcome based on future price movements"""
try:
# This would typically fetch recent price data to calculate outcomes
# For now, we'll create a placeholder implementation
# Extract base price from input package
base_price = self._extract_current_price(input_package.ohlcv_data)
if not base_price:
return None
# Simulate outcome calculation (in real implementation, fetch actual future prices)
# This is where you would integrate with your data provider to get actual outcomes
# Check for rapid change
is_rapid, velocity, volatility_spike = self.rapid_change_detector.detect_rapid_change(
input_package.symbol
)
# Create outcome (placeholder values - replace with actual calculation)
outcome = TrainingOutcome(
input_package_hash=input_package.data_hash,
timestamp=input_package.timestamp,
symbol=input_package.symbol,
price_change_1m=0.0, # Calculate from actual future data
price_change_5m=0.0,
price_change_15m=0.0,
price_change_1h=0.0,
max_profit_potential=0.0,
max_loss_potential=0.0,
optimal_entry_price=base_price,
optimal_exit_price=base_price,
optimal_holding_time=timedelta(minutes=5),
is_profitable=False, # Determine from actual outcomes
profitability_score=0.0,
risk_reward_ratio=1.0,
is_rapid_change=is_rapid,
change_velocity=velocity,
volatility_spike=volatility_spike,
outcome_validated=True
)
return outcome
except Exception as e:
logger.error(f"Error calculating training outcome: {e}")
return None
def _create_training_episode(self, input_package: ModelInputPackage, outcome: TrainingOutcome):
"""Create complete training episode"""
try:
episode_id = f"{input_package.symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
# Determine episode type
episode_type = 'normal'
if outcome.is_rapid_change:
episode_type = 'rapid_change'
self.collection_stats['rapid_change_episodes'] += 1
elif outcome.profitability_score > 0.8:
episode_type = 'high_profit'
if outcome.is_profitable:
self.collection_stats['profitable_episodes'] += 1
# Create training episode
episode = TrainingEpisode(
episode_id=episode_id,
input_package=input_package,
model_predictions={}, # Will be filled when models make predictions
actual_outcome=outcome,
episode_type=episode_type,
profitability_rank=0.0, # Will be calculated later
training_priority=0.0
)
# Calculate training priority
episode.training_priority = episode.calculate_training_priority()
# Store episode
symbol = input_package.symbol
if symbol not in self.training_episodes:
self.training_episodes[symbol] = []
self.training_episodes[symbol].append(episode)
# Limit episodes per symbol
if len(self.training_episodes[symbol]) > self.max_episodes_per_symbol:
# Keep highest priority episodes
self.training_episodes[symbol].sort(key=lambda x: x.training_priority, reverse=True)
self.training_episodes[symbol] = self.training_episodes[symbol][:self.max_episodes_per_symbol]
# Save episode to disk
self._save_episode_to_disk(episode)
logger.debug(f"Created training episode: {episode_id}")
logger.debug(f"Episode type: {episode_type}, Priority: {episode.training_priority:.3f}")
except Exception as e:
logger.error(f"Error creating training episode: {e}")
def _save_episode_to_disk(self, episode: TrainingEpisode):
"""Save training episode to disk for persistence"""
try:
symbol_dir = self.storage_dir / episode.input_package.symbol
symbol_dir.mkdir(parents=True, exist_ok=True)
# Save episode data
episode_file = symbol_dir / f"{episode.episode_id}.pkl"
with open(episode_file, 'wb') as f:
pickle.dump(episode, f)
# Save episode metadata for quick access
metadata = {
'episode_id': episode.episode_id,
'timestamp': episode.input_package.timestamp.isoformat(),
'episode_type': episode.episode_type,
'training_priority': episode.training_priority,
'profitability_score': episode.actual_outcome.profitability_score,
'is_profitable': episode.actual_outcome.is_profitable,
'is_rapid_change': episode.actual_outcome.is_rapid_change,
'data_completeness': episode.input_package.completeness_score
}
metadata_file = symbol_dir / f"{episode.episode_id}_metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
except Exception as e:
logger.error(f"Error saving episode to disk: {e}")
def get_high_priority_episodes(self,
symbol: str,
limit: int = 100,
min_priority: float = 0.5) -> List[TrainingEpisode]:
"""Get high-priority training episodes for replay training"""
try:
if symbol not in self.training_episodes:
return []
# Filter and sort by priority
high_priority = [
ep for ep in self.training_episodes[symbol]
if ep.training_priority >= min_priority
]
high_priority.sort(key=lambda x: x.training_priority, reverse=True)
return high_priority[:limit]
except Exception as e:
logger.error(f"Error getting high priority episodes for {symbol}: {e}")
return []
def get_collection_statistics(self) -> Dict[str, Any]:
"""Get comprehensive collection statistics"""
stats = self.collection_stats.copy()
# Add per-symbol statistics
stats['episodes_per_symbol'] = {
symbol: len(episodes)
for symbol, episodes in self.training_episodes.items()
}
# Add pending outcomes count
stats['pending_outcomes'] = {
symbol: len(packages)
for symbol, packages in self.pending_outcomes.items()
}
# Calculate profitability rate
if stats['total_episodes'] > 0:
stats['profitability_rate'] = stats['profitable_episodes'] / stats['total_episodes']
stats['rapid_change_rate'] = stats['rapid_change_episodes'] / stats['total_episodes']
else:
stats['profitability_rate'] = 0.0
stats['rapid_change_rate'] = 0.0
return stats
def validate_data_integrity(self) -> Dict[str, Any]:
"""Comprehensive data integrity validation"""
validation_results = {
'total_episodes_checked': 0,
'hash_mismatches': 0,
'completeness_issues': 0,
'validation_flag_failures': 0,
'corrupted_episodes': [],
'integrity_score': 1.0
}
try:
for symbol, episodes in self.training_episodes.items():
for episode in episodes:
validation_results['total_episodes_checked'] += 1
# Check data hash
expected_hash = episode.input_package._calculate_hash()
if expected_hash != episode.input_package.data_hash:
validation_results['hash_mismatches'] += 1
validation_results['corrupted_episodes'].append(episode.episode_id)
# Check completeness
if episode.input_package.completeness_score < 0.7:
validation_results['completeness_issues'] += 1
# Check validation flags
if not episode.input_package.validation_flags.get('data_consistent', False):
validation_results['validation_flag_failures'] += 1
# Calculate integrity score
total_issues = (
validation_results['hash_mismatches'] +
validation_results['completeness_issues'] +
validation_results['validation_flag_failures']
)
if validation_results['total_episodes_checked'] > 0:
validation_results['integrity_score'] = 1.0 - (
total_issues / validation_results['total_episodes_checked']
)
logger.info(f"Data integrity validation completed")
logger.info(f"Integrity score: {validation_results['integrity_score']:.3f}")
except Exception as e:
logger.error(f"Error during data integrity validation: {e}")
validation_results['validation_error'] = str(e)
return validation_results
# Global instance for easy access
training_data_collector = None
def get_training_data_collector() -> TrainingDataCollector:
"""Get global training data collector instance"""
global training_data_collector
if training_data_collector is None:
training_data_collector = TrainingDataCollector()
return training_data_collector

File diff suppressed because it is too large Load Diff

View File

@ -1,164 +0,0 @@
#!/usr/bin/env python3
"""
Test script to verify that both model prediction and trading statistics issues are fixed
"""
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
from core.trading_executor import TradingExecutor
import asyncio
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def test_model_predictions():
"""Test that model predictions are working correctly"""
logger.info("=" * 60)
logger.info("TESTING MODEL PREDICTIONS")
logger.info("=" * 60)
# Initialize components
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider)
# Check model registration
logger.info("1. Checking model registration...")
models = orchestrator.model_registry.get_all_models()
logger.info(f" Registered models: {list(models.keys()) if models else 'None'}")
# Test making a decision
logger.info("2. Testing trading decision generation...")
decision = await orchestrator.make_trading_decision('ETH/USDT')
if decision:
logger.info(f" ✅ Decision generated: {decision.action} (confidence: {decision.confidence:.3f})")
logger.info(f" ✅ Reasoning: {decision.reasoning}")
return True
else:
logger.error(" ❌ No decision generated")
return False
def test_trading_statistics():
"""Test that trading statistics calculations are working correctly"""
logger.info("=" * 60)
logger.info("TESTING TRADING STATISTICS")
logger.info("=" * 60)
# Initialize trading executor
trading_executor = TradingExecutor()
# Check if we have any trades
trade_history = trading_executor.get_trade_history()
logger.info(f"1. Current trade history: {len(trade_history)} trades")
# Get daily stats
daily_stats = trading_executor.get_daily_stats()
logger.info("2. Daily statistics from trading executor:")
logger.info(f" Total trades: {daily_stats.get('total_trades', 0)}")
logger.info(f" Winning trades: {daily_stats.get('winning_trades', 0)}")
logger.info(f" Losing trades: {daily_stats.get('losing_trades', 0)}")
logger.info(f" Win rate: {daily_stats.get('win_rate', 0.0) * 100:.1f}%")
logger.info(f" Avg winning trade: ${daily_stats.get('avg_winning_trade', 0.0):.2f}")
logger.info(f" Avg losing trade: ${daily_stats.get('avg_losing_trade', 0.0):.2f}")
logger.info(f" Total P&L: ${daily_stats.get('total_pnl', 0.0):.2f}")
# Simulate some trades if we don't have any
if daily_stats.get('total_trades', 0) == 0:
logger.info("3. No trades found - simulating some test trades...")
# Add some mock trades to the trade history
from core.trading_executor import TradeRecord
from datetime import datetime
# Add a winning trade
winning_trade = TradeRecord(
symbol='ETH/USDT',
side='LONG',
quantity=0.01,
entry_price=2500.0,
exit_price=2550.0,
entry_time=datetime.now(),
exit_time=datetime.now(),
pnl=0.50, # $0.50 profit
fees=0.01,
confidence=0.8
)
trading_executor.trade_history.append(winning_trade)
# Add a losing trade
losing_trade = TradeRecord(
symbol='ETH/USDT',
side='LONG',
quantity=0.01,
entry_price=2500.0,
exit_price=2480.0,
entry_time=datetime.now(),
exit_time=datetime.now(),
pnl=-0.20, # $0.20 loss
fees=0.01,
confidence=0.7
)
trading_executor.trade_history.append(losing_trade)
# Get updated stats
daily_stats = trading_executor.get_daily_stats()
logger.info(" Updated statistics after adding test trades:")
logger.info(f" Total trades: {daily_stats.get('total_trades', 0)}")
logger.info(f" Winning trades: {daily_stats.get('winning_trades', 0)}")
logger.info(f" Losing trades: {daily_stats.get('losing_trades', 0)}")
logger.info(f" Win rate: {daily_stats.get('win_rate', 0.0) * 100:.1f}%")
logger.info(f" Avg winning trade: ${daily_stats.get('avg_winning_trade', 0.0):.2f}")
logger.info(f" Avg losing trade: ${daily_stats.get('avg_losing_trade', 0.0):.2f}")
logger.info(f" Total P&L: ${daily_stats.get('total_pnl', 0.0):.2f}")
# Verify calculations
expected_win_rate = 1/2 # 1 win out of 2 trades = 50%
expected_avg_win = 0.50
expected_avg_loss = -0.20
actual_win_rate = daily_stats.get('win_rate', 0.0)
actual_avg_win = daily_stats.get('avg_winning_trade', 0.0)
actual_avg_loss = daily_stats.get('avg_losing_trade', 0.0)
logger.info("4. Verifying calculations:")
logger.info(f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {actual_win_rate*100:.1f}% ✅" if abs(actual_win_rate - expected_win_rate) < 0.01 else f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {actual_win_rate*100:.1f}% ❌")
logger.info(f" Avg win: Expected ${expected_avg_win:.2f}, Got ${actual_avg_win:.2f}" if abs(actual_avg_win - expected_avg_win) < 0.01 else f" Avg win: Expected ${expected_avg_win:.2f}, Got ${actual_avg_win:.2f}")
logger.info(f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${actual_avg_loss:.2f}" if abs(actual_avg_loss - expected_avg_loss) < 0.01 else f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${actual_avg_loss:.2f}")
return True
return True
async def main():
"""Run all tests"""
logger.info("🚀 STARTING COMPREHENSIVE FIXES TEST")
logger.info("Testing both model prediction fixes and trading statistics fixes")
# Test model predictions
prediction_success = await test_model_predictions()
# Test trading statistics
stats_success = test_trading_statistics()
logger.info("=" * 60)
logger.info("TEST SUMMARY")
logger.info("=" * 60)
logger.info(f"Model Predictions: {'✅ FIXED' if prediction_success else '❌ STILL BROKEN'}")
logger.info(f"Trading Statistics: {'✅ FIXED' if stats_success else '❌ STILL BROKEN'}")
if prediction_success and stats_success:
logger.info("🎉 ALL ISSUES FIXED! The system should now work correctly.")
else:
logger.error("❌ Some issues remain. Check the logs above for details.")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,250 +0,0 @@
#!/usr/bin/env python3
"""
Test script to verify trading fixes:
1. Position sizes with leverage
2. ETH-only trading
3. Correct win rate calculations
4. Meaningful P&L values
"""
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from core.trading_executor import TradingExecutor
from core.trading_executor import TradeRecord
from datetime import datetime
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_position_sizing():
"""Test that position sizing now includes leverage and meaningful amounts"""
logger.info("=" * 60)
logger.info("TESTING POSITION SIZING WITH LEVERAGE")
logger.info("=" * 60)
# Initialize trading executor
trading_executor = TradingExecutor()
# Test position calculation
confidence = 0.8
current_price = 2500.0 # ETH price
position_value = trading_executor._calculate_position_size(confidence, current_price)
quantity = position_value / current_price
logger.info(f"1. Position calculation test:")
logger.info(f" Confidence: {confidence}")
logger.info(f" ETH Price: ${current_price}")
logger.info(f" Position Value: ${position_value:.2f}")
logger.info(f" Quantity: {quantity:.6f} ETH")
# Check if position is meaningful
if position_value > 1000: # Should be >$1000 with 10x leverage
logger.info(" ✅ Position size is meaningful (>$1000)")
else:
logger.error(f" ❌ Position size too small: ${position_value:.2f}")
# Test different confidence levels
logger.info("2. Testing different confidence levels:")
for conf in [0.2, 0.5, 0.8, 1.0]:
pos_val = trading_executor._calculate_position_size(conf, current_price)
qty = pos_val / current_price
logger.info(f" Confidence {conf}: ${pos_val:.2f} ({qty:.6f} ETH)")
def test_eth_only_restriction():
"""Test that only ETH trades are allowed"""
logger.info("=" * 60)
logger.info("TESTING ETH-ONLY TRADING RESTRICTION")
logger.info("=" * 60)
trading_executor = TradingExecutor()
# Test ETH trade (should be allowed)
logger.info("1. Testing ETH/USDT trade (should be allowed):")
eth_allowed = trading_executor._check_safety_conditions('ETH/USDT', 'BUY')
logger.info(f" ETH/USDT allowed: {'✅ YES' if eth_allowed else '❌ NO'}")
# Test BTC trade (should be blocked)
logger.info("2. Testing BTC/USDT trade (should be blocked):")
btc_allowed = trading_executor._check_safety_conditions('BTC/USDT', 'BUY')
logger.info(f" BTC/USDT allowed: {'❌ YES (ERROR!)' if btc_allowed else '✅ NO (CORRECT)'}")
def test_win_rate_calculation():
"""Test that win rate calculations are correct"""
logger.info("=" * 60)
logger.info("TESTING WIN RATE CALCULATIONS")
logger.info("=" * 60)
trading_executor = TradingExecutor()
# Clear existing trades
trading_executor.trade_history = []
# Add test trades with meaningful P&L
logger.info("1. Adding test trades with meaningful P&L:")
# Add 3 winning trades
for i in range(3):
winning_trade = TradeRecord(
symbol='ETH/USDT',
side='LONG',
quantity=1.0,
entry_price=2500.0,
exit_price=2550.0,
entry_time=datetime.now(),
exit_time=datetime.now(),
pnl=50.0, # $50 profit with leverage
fees=1.0,
confidence=0.8,
hold_time_seconds=30.0 # 30 second hold
)
trading_executor.trade_history.append(winning_trade)
logger.info(f" Added winning trade #{i+1}: +$50.00 (30s hold)")
# Add 2 losing trades
for i in range(2):
losing_trade = TradeRecord(
symbol='ETH/USDT',
side='LONG',
quantity=1.0,
entry_price=2500.0,
exit_price=2475.0,
entry_time=datetime.now(),
exit_time=datetime.now(),
pnl=-25.0, # $25 loss with leverage
fees=1.0,
confidence=0.7,
hold_time_seconds=15.0 # 15 second hold
)
trading_executor.trade_history.append(losing_trade)
logger.info(f" Added losing trade #{i+1}: -$25.00 (15s hold)")
# Get statistics
stats = trading_executor.get_daily_stats()
logger.info("2. Calculated statistics:")
logger.info(f" Total trades: {stats['total_trades']}")
logger.info(f" Winning trades: {stats['winning_trades']}")
logger.info(f" Losing trades: {stats['losing_trades']}")
logger.info(f" Win rate: {stats['win_rate']*100:.1f}%")
logger.info(f" Avg winning trade: ${stats['avg_winning_trade']:.2f}")
logger.info(f" Avg losing trade: ${stats['avg_losing_trade']:.2f}")
logger.info(f" Total P&L: ${stats['total_pnl']:.2f}")
# Verify calculations
expected_win_rate = 3/5 # 3 wins out of 5 trades = 60%
expected_avg_win = 50.0
expected_avg_loss = -25.0
logger.info("3. Verification:")
win_rate_ok = abs(stats['win_rate'] - expected_win_rate) < 0.01
avg_win_ok = abs(stats['avg_winning_trade'] - expected_avg_win) < 0.01
avg_loss_ok = abs(stats['avg_losing_trade'] - expected_avg_loss) < 0.01
logger.info(f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {stats['win_rate']*100:.1f}% {'' if win_rate_ok else ''}")
logger.info(f" Avg win: Expected ${expected_avg_win:.2f}, Got ${stats['avg_winning_trade']:.2f} {'' if avg_win_ok else ''}")
logger.info(f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${stats['avg_losing_trade']:.2f} {'' if avg_loss_ok else ''}")
return win_rate_ok and avg_win_ok and avg_loss_ok
def test_new_features():
"""Test new features: hold time, leverage, percentage-based sizing"""
logger.info("=" * 60)
logger.info("TESTING NEW FEATURES")
logger.info("=" * 60)
trading_executor = TradingExecutor()
# Test account info
account_info = trading_executor.get_account_info()
logger.info(f"1. Account Information:")
logger.info(f" Account Balance: ${account_info['account_balance']:.2f}")
logger.info(f" Leverage: {account_info['leverage']:.0f}x")
logger.info(f" Trading Mode: {account_info['trading_mode']}")
logger.info(f" Position Sizing: {account_info['position_sizing']['base_percent']:.1f}% base")
# Test leverage setting
logger.info("2. Testing leverage control:")
old_leverage = trading_executor.get_leverage()
logger.info(f" Current leverage: {old_leverage:.0f}x")
success = trading_executor.set_leverage(100.0)
new_leverage = trading_executor.get_leverage()
logger.info(f" Set to 100x: {'✅ SUCCESS' if success and new_leverage == 100.0 else '❌ FAILED'}")
# Reset leverage
trading_executor.set_leverage(old_leverage)
# Test percentage-based position sizing
logger.info("3. Testing percentage-based position sizing:")
confidence = 0.8
eth_price = 2500.0
position_value = trading_executor._calculate_position_size(confidence, eth_price)
account_balance = trading_executor._get_account_balance_for_sizing()
base_percent = trading_executor.mexc_config.get('base_position_percent', 5.0)
leverage = trading_executor.get_leverage()
expected_base = account_balance * (base_percent / 100.0) * confidence
expected_leveraged = expected_base * leverage
logger.info(f" Account: ${account_balance:.2f}")
logger.info(f" Base %: {base_percent:.1f}%")
logger.info(f" Confidence: {confidence:.1f}")
logger.info(f" Leverage: {leverage:.0f}x")
logger.info(f" Expected base: ${expected_base:.2f}")
logger.info(f" Expected leveraged: ${expected_leveraged:.2f}")
logger.info(f" Actual: ${position_value:.2f}")
sizing_ok = abs(position_value - expected_leveraged) < 0.01
logger.info(f" Percentage sizing: {'✅ CORRECT' if sizing_ok else '❌ INCORRECT'}")
return sizing_ok
def main():
"""Run all tests"""
logger.info("🚀 TESTING TRADING FIXES AND NEW FEATURES")
logger.info("Testing position sizing, ETH-only trading, win rate calculations, and new features")
# Test position sizing
test_position_sizing()
# Test ETH-only restriction
test_eth_only_restriction()
# Test win rate calculation
calculation_success = test_win_rate_calculation()
# Test new features
features_success = test_new_features()
logger.info("=" * 60)
logger.info("TEST SUMMARY")
logger.info("=" * 60)
logger.info(f"Position Sizing: ✅ Updated with percentage-based leverage")
logger.info(f"ETH-Only Trading: ✅ Configured in config")
logger.info(f"Win Rate Calculation: {'✅ FIXED' if calculation_success else '❌ STILL BROKEN'}")
logger.info(f"New Features: {'✅ WORKING' if features_success else '❌ ISSUES FOUND'}")
if calculation_success and features_success:
logger.info("🎉 ALL FEATURES WORKING! Now you should see:")
logger.info(" - Percentage-based position sizing (2-20% of account)")
logger.info(" - 50x leverage (adjustable in UI)")
logger.info(" - Hold time in seconds for each trade")
logger.info(" - Total fees in trading statistics")
logger.info(" - Only ETH/USDT trades")
logger.info(" - Correct win rate calculations")
else:
logger.error("❌ Some issues remain. Check the logs above for details.")
if __name__ == "__main__":
main()

View File

@ -1,344 +0,0 @@
#!/usr/bin/env python3
"""
Trade Audit Tool
This tool analyzes trade data to identify potential issues with:
- Duplicate entry prices
- Rapid consecutive trades
- P&L calculation accuracy
- Position tracking problems
Usage:
python debug/trade_audit.py [--trades-file path/to/trades.json]
"""
import argparse
import json
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import os
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
def parse_trade_time(time_str):
"""Parse trade time string to datetime object"""
try:
# Try HH:MM:SS format
return datetime.strptime(time_str, "%H:%M:%S")
except ValueError:
try:
# Try full datetime format
return datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
except ValueError:
# Return as is if parsing fails
return time_str
def load_trades_from_file(file_path):
"""Load trades from JSON file"""
try:
with open(file_path, 'r') as f:
return json.load(f)
except FileNotFoundError:
print(f"Error: File {file_path} not found")
return []
except json.JSONDecodeError:
print(f"Error: File {file_path} is not valid JSON")
return []
def load_trades_from_dashboard_cache():
"""Load trades from dashboard cache file if available"""
cache_paths = [
"cache/dashboard_trades.json",
"cache/closed_trades.json",
"data/trades_history.json"
]
for path in cache_paths:
if os.path.exists(path):
print(f"Loading trades from cache: {path}")
return load_trades_from_file(path)
print("No trade cache files found")
return []
def parse_trade_data(trades_data):
"""Parse trade data into a pandas DataFrame for analysis"""
parsed_trades = []
for trade in trades_data:
# Handle different trade data formats
parsed_trade = {}
# Time field might be named entry_time or time
if 'entry_time' in trade:
parsed_trade['time'] = parse_trade_time(trade['entry_time'])
elif 'time' in trade:
parsed_trade['time'] = parse_trade_time(trade['time'])
else:
parsed_trade['time'] = None
# Side might be named side or action
parsed_trade['side'] = trade.get('side', trade.get('action', 'UNKNOWN'))
# Size might be named size or quantity
parsed_trade['size'] = float(trade.get('size', trade.get('quantity', 0)))
# Entry and exit prices
parsed_trade['entry_price'] = float(trade.get('entry_price', trade.get('entry', 0)))
parsed_trade['exit_price'] = float(trade.get('exit_price', trade.get('exit', 0)))
# Hold time in seconds
parsed_trade['hold_time'] = float(trade.get('hold_time_seconds', trade.get('hold', 0)))
# P&L and fees
parsed_trade['pnl'] = float(trade.get('pnl', 0))
parsed_trade['fees'] = float(trade.get('fees', 0))
# Calculate expected P&L for verification
if parsed_trade['side'] == 'LONG' or parsed_trade['side'] == 'BUY':
expected_pnl = (parsed_trade['exit_price'] - parsed_trade['entry_price']) * parsed_trade['size']
else: # SHORT or SELL
expected_pnl = (parsed_trade['entry_price'] - parsed_trade['exit_price']) * parsed_trade['size']
parsed_trade['expected_pnl'] = expected_pnl
parsed_trade['pnl_difference'] = parsed_trade['pnl'] - expected_pnl
parsed_trades.append(parsed_trade)
# Convert to DataFrame
if parsed_trades:
df = pd.DataFrame(parsed_trades)
return df
else:
return pd.DataFrame()
def analyze_trades(df):
"""Analyze trades for potential issues"""
if df.empty:
print("No trades to analyze")
return
print(f"\n{'='*50}")
print("TRADE AUDIT RESULTS")
print(f"{'='*50}")
print(f"Total trades analyzed: {len(df)}")
# Check for duplicate entry prices
entry_price_counts = df['entry_price'].value_counts()
duplicate_entries = entry_price_counts[entry_price_counts > 1]
print(f"\n{'='*20} DUPLICATE ENTRY PRICES {'='*20}")
if not duplicate_entries.empty:
print(f"Found {len(duplicate_entries)} prices with multiple entries:")
for price, count in duplicate_entries.items():
print(f" ${price:.2f}: {count} trades")
# Analyze the duplicate entry trades in more detail
for price in duplicate_entries.index:
duplicate_df = df[df['entry_price'] == price].copy()
duplicate_df['time_diff'] = duplicate_df['time'].diff().dt.total_seconds()
print(f"\nDetailed analysis for entry price ${price:.2f}:")
print(f" Time gaps between consecutive trades:")
for i, (_, row) in enumerate(duplicate_df.iterrows()):
if i > 0: # Skip first row as it has no previous trade
time_diff = row['time_diff']
if pd.notna(time_diff):
print(f" {row['time'].strftime('%H:%M:%S')}: {time_diff:.0f} seconds after previous trade")
else:
print("No duplicate entry prices found")
# Check for rapid consecutive trades
df = df.sort_values('time')
df['time_since_last'] = df['time'].diff().dt.total_seconds()
rapid_trades = df[df['time_since_last'] < 30].copy()
print(f"\n{'='*20} RAPID CONSECUTIVE TRADES {'='*20}")
if not rapid_trades.empty:
print(f"Found {len(rapid_trades)} trades executed within 30 seconds of previous trade:")
for _, row in rapid_trades.iterrows():
if pd.notna(row['time_since_last']):
print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} ${row['size']:.2f} @ ${row['entry_price']:.2f} - {row['time_since_last']:.0f}s after previous")
else:
print("No rapid consecutive trades found")
# Check for P&L calculation accuracy
pnl_diff = df[abs(df['pnl_difference']) > 0.01].copy()
print(f"\n{'='*20} P&L CALCULATION ISSUES {'='*20}")
if not pnl_diff.empty:
print(f"Found {len(pnl_diff)} trades with P&L calculation discrepancies:")
for _, row in pnl_diff.iterrows():
print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} - Reported: ${row['pnl']:.2f}, Expected: ${row['expected_pnl']:.2f}, Diff: ${row['pnl_difference']:.2f}")
else:
print("No P&L calculation issues found")
# Check for side distribution
side_counts = df['side'].value_counts()
print(f"\n{'='*20} TRADE SIDE DISTRIBUTION {'='*20}")
for side, count in side_counts.items():
print(f" {side}: {count} trades ({count/len(df)*100:.1f}%)")
# Check for hold time distribution
print(f"\n{'='*20} HOLD TIME DISTRIBUTION {'='*20}")
print(f" Min hold time: {df['hold_time'].min():.0f} seconds")
print(f" Max hold time: {df['hold_time'].max():.0f} seconds")
print(f" Avg hold time: {df['hold_time'].mean():.0f} seconds")
print(f" Median hold time: {df['hold_time'].median():.0f} seconds")
# Hold time buckets
hold_buckets = [0, 30, 60, 120, 300, 600, 1800, 3600, float('inf')]
hold_labels = ['0-30s', '30-60s', '1-2m', '2-5m', '5-10m', '10-30m', '30-60m', '60m+']
df['hold_bucket'] = pd.cut(df['hold_time'], bins=hold_buckets, labels=hold_labels)
hold_dist = df['hold_bucket'].value_counts().sort_index()
for bucket, count in hold_dist.items():
print(f" {bucket}: {count} trades ({count/len(df)*100:.1f}%)")
# Generate summary statistics
print(f"\n{'='*20} TRADE PERFORMANCE SUMMARY {'='*20}")
winning_trades = df[df['pnl'] > 0]
losing_trades = df[df['pnl'] < 0]
print(f" Win rate: {len(winning_trades)/len(df)*100:.1f}% ({len(winning_trades)}W/{len(losing_trades)}L)")
print(f" Avg win: ${winning_trades['pnl'].mean():.2f}")
print(f" Avg loss: ${abs(losing_trades['pnl'].mean()):.2f}")
print(f" Total P&L: ${df['pnl'].sum():.2f}")
print(f" Total fees: ${df['fees'].sum():.2f}")
print(f" Net P&L: ${(df['pnl'].sum() - df['fees'].sum()):.2f}")
# Plot entry price distribution
plt.figure(figsize=(10, 6))
plt.hist(df['entry_price'], bins=20, alpha=0.7)
plt.title('Entry Price Distribution')
plt.xlabel('Entry Price ($)')
plt.ylabel('Number of Trades')
plt.grid(True, alpha=0.3)
plt.savefig('debug/entry_price_distribution.png')
# Plot P&L distribution
plt.figure(figsize=(10, 6))
plt.hist(df['pnl'], bins=20, alpha=0.7)
plt.title('P&L Distribution')
plt.xlabel('P&L ($)')
plt.ylabel('Number of Trades')
plt.grid(True, alpha=0.3)
plt.savefig('debug/pnl_distribution.png')
print(f"\n{'='*20} AUDIT COMPLETE {'='*20}")
print("Plots saved to debug/entry_price_distribution.png and debug/pnl_distribution.png")
def analyze_manual_trades(trades_data):
"""Analyze manually provided trade data"""
# Parse the trade data into a structured format
parsed_trades = []
for line in trades_data.strip().split('\n'):
if not line or line.startswith('from last session') or line.startswith('Recent Closed Trades') or line.startswith('Trading Performance'):
continue
if line.startswith('Win Rate:'):
# This is the summary line, skip it
continue
try:
# Parse trade line format: Time Side Size Entry Exit Hold P&L Fees
parts = line.split('$')
time_side = parts[0].strip().split()
time = time_side[0]
side = time_side[1]
size = float(parts[1].split()[0])
entry = float(parts[2].split()[0])
exit = float(parts[3].split()[0])
# The hold time and P&L are in the last parts
remaining = parts[3].split()
hold = int(remaining[1])
pnl = float(parts[4].split()[0])
# Fees might be in a different format
if len(parts) > 5:
fees = float(parts[5].strip())
else:
fees = 0.0
parsed_trade = {
'time': parse_trade_time(time),
'side': side,
'size': size,
'entry_price': entry,
'exit_price': exit,
'hold_time': hold,
'pnl': pnl,
'fees': fees
}
# Calculate expected P&L
if side == 'LONG' or side == 'BUY':
expected_pnl = (exit - entry) * size
else: # SHORT or SELL
expected_pnl = (entry - exit) * size
parsed_trade['expected_pnl'] = expected_pnl
parsed_trade['pnl_difference'] = pnl - expected_pnl
parsed_trades.append(parsed_trade)
except Exception as e:
print(f"Error parsing trade line: {line}")
print(f"Error details: {e}")
# Convert to DataFrame
if parsed_trades:
df = pd.DataFrame(parsed_trades)
return df
else:
return pd.DataFrame()
def main():
parser = argparse.ArgumentParser(description='Trade Audit Tool')
parser.add_argument('--trades-file', type=str, help='Path to trades JSON file')
parser.add_argument('--manual-trades', type=str, help='Path to text file with manually entered trades')
args = parser.parse_args()
# Create debug directory if it doesn't exist
os.makedirs('debug', exist_ok=True)
if args.trades_file:
trades_data = load_trades_from_file(args.trades_file)
df = parse_trade_data(trades_data)
elif args.manual_trades:
try:
with open(args.manual_trades, 'r') as f:
manual_trades = f.read()
df = analyze_manual_trades(manual_trades)
except Exception as e:
print(f"Error reading manual trades file: {e}")
df = pd.DataFrame()
else:
# Try to load from dashboard cache
trades_data = load_trades_from_dashboard_cache()
if trades_data:
df = parse_trade_data(trades_data)
else:
print("No trade data provided. Use --trades-file or --manual-trades")
return
if not df.empty:
analyze_trades(df)
else:
print("No valid trade data to analyze")
if __name__ == "__main__":
main()

View File

@ -1,84 +0,0 @@
#!/usr/bin/env python3
"""
Debug Training Methods
This script checks what training methods are available on each model.
"""
import asyncio
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
async def debug_training_methods():
"""Debug the available training methods on each model"""
print("=== Debugging Training Methods ===")
# Initialize orchestrator
print("1. Initializing orchestrator...")
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider=data_provider)
# Wait for initialization
await asyncio.sleep(2)
print("\n2. Checking available training methods on each model:")
for model_name, model_interface in orchestrator.model_registry.models.items():
print(f"\n--- {model_name} ---")
print(f"Interface type: {type(model_interface).__name__}")
# Get underlying model
underlying_model = getattr(model_interface, 'model', None)
if underlying_model:
print(f"Underlying model type: {type(underlying_model).__name__}")
else:
print("No underlying model found")
continue
# Check for training methods
training_methods = []
for method in ['train_on_outcome', 'add_experience', 'remember', 'replay', 'add_training_sample', 'train', 'train_with_reward', 'update_loss']:
if hasattr(underlying_model, method):
training_methods.append(method)
print(f"Available training methods: {training_methods}")
# Check for specific attributes
attributes = []
for attr in ['memory', 'batch_size', 'training_data']:
if hasattr(underlying_model, attr):
attr_value = getattr(underlying_model, attr)
if attr == 'memory' and hasattr(attr_value, '__len__'):
attributes.append(f"{attr}(len={len(attr_value)})")
elif attr == 'training_data' and hasattr(attr_value, '__len__'):
attributes.append(f"{attr}(len={len(attr_value)})")
else:
attributes.append(f"{attr}={attr_value}")
print(f"Relevant attributes: {attributes}")
# Check if it's an RL agent
if hasattr(underlying_model, 'act') and hasattr(underlying_model, 'remember'):
print("✅ Detected as RL Agent")
elif hasattr(underlying_model, 'predict') and hasattr(underlying_model, 'add_training_sample'):
print("✅ Detected as CNN Model")
else:
print("❓ Unknown model type")
print("\n3. Testing a simple training attempt:")
# Get a prediction first
predictions = await orchestrator._get_all_predictions('ETH/USDT')
print(f"Got {len(predictions)} predictions")
# Try to trigger training for each model
for model_name in orchestrator.model_registry.models.keys():
print(f"\nTesting training for {model_name}...")
try:
await orchestrator._trigger_immediate_training_for_model(model_name, 'ETH/USDT')
print(f"✅ Training attempt completed for {model_name}")
except Exception as e:
print(f"❌ Training failed for {model_name}: {e}")
if __name__ == "__main__":
asyncio.run(debug_training_methods())

View File

@ -1,233 +0,0 @@
"""
Bybit Integration Examples
Based on official pybit library documentation and examples
"""
import os
from pybit.unified_trading import HTTP
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def create_bybit_session(testnet=True):
"""Create a Bybit HTTP session.
Args:
testnet (bool): Use testnet if True, live if False
Returns:
HTTP: Bybit session object
"""
api_key = os.getenv('BYBIT_API_KEY')
api_secret = os.getenv('BYBIT_API_SECRET')
if not api_key or not api_secret:
raise ValueError("BYBIT_API_KEY and BYBIT_API_SECRET must be set in environment")
session = HTTP(
testnet=testnet,
api_key=api_key,
api_secret=api_secret,
)
logger.info(f"Created Bybit session (testnet: {testnet})")
return session
def get_account_info(session):
"""Get account information and balances."""
try:
# Get account info
account_info = session.get_wallet_balance(accountType="UNIFIED")
logger.info(f"Account info: {account_info}")
return account_info
except Exception as e:
logger.error(f"Error getting account info: {e}")
return None
def get_ticker_info(session, symbol="BTCUSDT"):
"""Get ticker information for a symbol.
Args:
session: Bybit HTTP session
symbol: Trading symbol (default: BTCUSDT)
"""
try:
ticker = session.get_tickers(category="linear", symbol=symbol)
logger.info(f"Ticker for {symbol}: {ticker}")
return ticker
except Exception as e:
logger.error(f"Error getting ticker for {symbol}: {e}")
return None
def get_orderbook(session, symbol="BTCUSDT", limit=25):
"""Get orderbook for a symbol.
Args:
session: Bybit HTTP session
symbol: Trading symbol
limit: Number of price levels to return
"""
try:
orderbook = session.get_orderbook(
category="linear",
symbol=symbol,
limit=limit
)
logger.info(f"Orderbook for {symbol}: {orderbook}")
return orderbook
except Exception as e:
logger.error(f"Error getting orderbook for {symbol}: {e}")
return None
def place_limit_order(session, symbol="BTCUSDT", side="Buy", qty="0.001", price="50000"):
"""Place a limit order.
Args:
session: Bybit HTTP session
symbol: Trading symbol
side: "Buy" or "Sell"
qty: Order quantity as string
price: Order price as string
"""
try:
order = session.place_order(
category="linear",
symbol=symbol,
side=side,
orderType="Limit",
qty=qty,
price=price,
timeInForce="GTC" # Good Till Cancelled
)
logger.info(f"Placed order: {order}")
return order
except Exception as e:
logger.error(f"Error placing order: {e}")
return None
def place_market_order(session, symbol="BTCUSDT", side="Buy", qty="0.001"):
"""Place a market order.
Args:
session: Bybit HTTP session
symbol: Trading symbol
side: "Buy" or "Sell"
qty: Order quantity as string
"""
try:
order = session.place_order(
category="linear",
symbol=symbol,
side=side,
orderType="Market",
qty=qty
)
logger.info(f"Placed market order: {order}")
return order
except Exception as e:
logger.error(f"Error placing market order: {e}")
return None
def get_open_orders(session, symbol=None):
"""Get open orders.
Args:
session: Bybit HTTP session
symbol: Trading symbol (optional, gets all if None)
"""
try:
params = {"category": "linear", "openOnly": True}
if symbol:
params["symbol"] = symbol
orders = session.get_open_orders(**params)
logger.info(f"Open orders: {orders}")
return orders
except Exception as e:
logger.error(f"Error getting open orders: {e}")
return None
def cancel_order(session, symbol, order_id):
"""Cancel an order.
Args:
session: Bybit HTTP session
symbol: Trading symbol
order_id: Order ID to cancel
"""
try:
result = session.cancel_order(
category="linear",
symbol=symbol,
orderId=order_id
)
logger.info(f"Cancelled order {order_id}: {result}")
return result
except Exception as e:
logger.error(f"Error cancelling order {order_id}: {e}")
return None
def get_position(session, symbol="BTCUSDT"):
"""Get position information.
Args:
session: Bybit HTTP session
symbol: Trading symbol
"""
try:
positions = session.get_positions(
category="linear",
symbol=symbol
)
logger.info(f"Position for {symbol}: {positions}")
return positions
except Exception as e:
logger.error(f"Error getting position for {symbol}: {e}")
return None
def get_trade_history(session, symbol="BTCUSDT", limit=50):
"""Get trade history.
Args:
session: Bybit HTTP session
symbol: Trading symbol
limit: Number of trades to return
"""
try:
trades = session.get_executions(
category="linear",
symbol=symbol,
limit=limit
)
logger.info(f"Trade history for {symbol}: {trades}")
return trades
except Exception as e:
logger.error(f"Error getting trade history for {symbol}: {e}")
return None
# Example usage
if __name__ == "__main__":
# Create session (testnet by default)
session = create_bybit_session(testnet=True)
# Get account info
account_info = get_account_info(session)
# Get ticker
ticker = get_ticker_info(session, "BTCUSDT")
# Get orderbook
orderbook = get_orderbook(session, "BTCUSDT")
# Get open orders
open_orders = get_open_orders(session)
# Get position
position = get_position(session, "BTCUSDT")
# Note: Uncomment below to actually place orders (use with caution)
# order = place_limit_order(session, "BTCUSDT", "Buy", "0.001", "30000")
# market_order = place_market_order(session, "BTCUSDT", "Buy", "0.001")

View File

@ -1,89 +0,0 @@
#!/usr/bin/env python3
"""
Example usage of the simplified data provider
"""
import time
import logging
from core.data_provider import DataProvider
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
"""Demonstrate the simplified data provider usage"""
# Initialize data provider (starts automatic maintenance)
logger.info("Initializing DataProvider...")
dp = DataProvider()
# Wait for initial data load (happens automatically in background)
logger.info("Waiting for initial data load...")
time.sleep(15) # Give it time to load data
# Example 1: Get cached historical data (no API calls)
logger.info("\n=== Example 1: Getting Historical Data ===")
eth_1m_data = dp.get_historical_data('ETH/USDT', '1m', limit=50)
if eth_1m_data is not None:
logger.info(f"ETH/USDT 1m data: {len(eth_1m_data)} candles")
logger.info(f"Latest candle: {eth_1m_data.iloc[-1]['close']}")
# Example 2: Get current prices
logger.info("\n=== Example 2: Current Prices ===")
eth_price = dp.get_current_price('ETH/USDT')
btc_price = dp.get_current_price('BTC/USDT')
logger.info(f"ETH current price: ${eth_price}")
logger.info(f"BTC current price: ${btc_price}")
# Example 3: Check cache status
logger.info("\n=== Example 3: Cache Status ===")
cache_summary = dp.get_cached_data_summary()
for symbol in cache_summary['cached_data']:
logger.info(f"\n{symbol}:")
for timeframe, info in cache_summary['cached_data'][symbol].items():
if 'candle_count' in info and info['candle_count'] > 0:
logger.info(f" {timeframe}: {info['candle_count']} candles, latest: ${info['latest_price']}")
else:
logger.info(f" {timeframe}: {info.get('status', 'no data')}")
# Example 4: Multiple timeframe data
logger.info("\n=== Example 4: Multiple Timeframes ===")
for tf in ['1s', '1m', '1h', '1d']:
data = dp.get_historical_data('ETH/USDT', tf, limit=5)
if data is not None and not data.empty:
logger.info(f"ETH {tf}: {len(data)} candles, range: ${data['close'].min():.2f} - ${data['close'].max():.2f}")
# Example 5: Health check
logger.info("\n=== Example 5: Health Check ===")
health = dp.health_check()
logger.info(f"Data maintenance active: {health['data_maintenance_active']}")
logger.info(f"Symbols: {health['symbols']}")
logger.info(f"Timeframes: {health['timeframes']}")
# Example 6: Wait and show automatic updates
logger.info("\n=== Example 6: Automatic Updates ===")
logger.info("Waiting 30 seconds to show automatic data updates...")
# Get initial timestamp
initial_data = dp.get_historical_data('ETH/USDT', '1s', limit=1)
initial_time = initial_data.index[-1] if initial_data is not None else None
time.sleep(30)
# Check if data was updated
updated_data = dp.get_historical_data('ETH/USDT', '1s', limit=1)
updated_time = updated_data.index[-1] if updated_data is not None else None
if initial_time and updated_time and updated_time > initial_time:
logger.info(f"✅ Data automatically updated! New timestamp: {updated_time}")
else:
logger.info("⏳ Data update in progress...")
# Clean shutdown
logger.info("\n=== Shutting Down ===")
dp.stop_automatic_data_maintenance()
logger.info("DataProvider stopped successfully")
if __name__ == "__main__":
main()

View File

@ -1,331 +0,0 @@
"""
Kill Stale Processes
This script identifies and kills stale Python processes that might be causing
the dashboard startup freeze. It looks for:
1. Hanging dashboard processes
2. Stale COB data collection threads
3. Matplotlib GUI processes
4. Blocked network connections
Usage:
python kill_stale_processes.py
"""
import os
import sys
import psutil
import signal
import time
from datetime import datetime
def find_python_processes():
"""Find all Python processes"""
python_processes = []
try:
for proc in psutil.process_iter(['pid', 'name', 'cmdline', 'create_time', 'status']):
try:
if proc.info['name'] and 'python' in proc.info['name'].lower():
# Get command line to identify dashboard processes
cmdline = ' '.join(proc.info['cmdline']) if proc.info['cmdline'] else ''
python_processes.append({
'pid': proc.info['pid'],
'name': proc.info['name'],
'cmdline': cmdline,
'create_time': proc.info['create_time'],
'status': proc.info['status'],
'process': proc
})
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
except Exception as e:
print(f"Error finding Python processes: {e}")
return python_processes
def identify_dashboard_processes(python_processes):
"""Identify processes related to the dashboard"""
dashboard_processes = []
dashboard_keywords = [
'clean_dashboard',
'run_clean_dashboard',
'dashboard',
'trading',
'cob_data',
'orchestrator',
'data_provider'
]
for proc_info in python_processes:
cmdline = proc_info['cmdline'].lower()
# Check if this is a dashboard-related process
is_dashboard = any(keyword in cmdline for keyword in dashboard_keywords)
if is_dashboard:
dashboard_processes.append(proc_info)
return dashboard_processes
def identify_stale_processes(python_processes):
"""Identify potentially stale processes"""
stale_processes = []
current_time = time.time()
for proc_info in python_processes:
try:
proc = proc_info['process']
# Check if process is in a problematic state
if proc_info['status'] in ['zombie', 'stopped']:
stale_processes.append({
**proc_info,
'reason': f"Process status: {proc_info['status']}"
})
continue
# Check if process has been running for a very long time without activity
age_hours = (current_time - proc_info['create_time']) / 3600
if age_hours > 24: # Running for more than 24 hours
try:
# Check CPU usage
cpu_percent = proc.cpu_percent(interval=1)
if cpu_percent < 0.1: # Very low CPU usage
stale_processes.append({
**proc_info,
'reason': f"Old process ({age_hours:.1f}h) with low CPU usage ({cpu_percent:.1f}%)"
})
except:
pass
# Check for processes with high memory usage but no activity
try:
memory_info = proc.memory_info()
memory_mb = memory_info.rss / 1024 / 1024
if memory_mb > 500: # More than 500MB
cpu_percent = proc.cpu_percent(interval=1)
if cpu_percent < 0.1:
stale_processes.append({
**proc_info,
'reason': f"High memory usage ({memory_mb:.1f}MB) with low CPU usage ({cpu_percent:.1f}%)"
})
except:
pass
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
return stale_processes
def kill_process_safely(proc_info, force=False):
"""Kill a process safely"""
try:
proc = proc_info['process']
pid = proc_info['pid']
print(f"Attempting to {'force kill' if force else 'terminate'} PID {pid}: {proc_info['name']}")
if force:
# Force kill
if os.name == 'nt': # Windows
os.system(f"taskkill /F /PID {pid}")
else: # Unix/Linux
os.kill(pid, signal.SIGKILL)
else:
# Graceful termination
proc.terminate()
# Wait for termination
try:
proc.wait(timeout=5)
print(f"✅ Process {pid} terminated gracefully")
return True
except psutil.TimeoutExpired:
print(f"⚠️ Process {pid} didn't terminate gracefully, will force kill")
return False
print(f"✅ Process {pid} killed")
return True
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
print(f"⚠️ Could not kill process {proc_info['pid']}: {e}")
return False
except Exception as e:
print(f"❌ Error killing process {proc_info['pid']}: {e}")
return False
def check_port_usage():
"""Check if dashboard port is in use"""
try:
import socket
# Check if port 8050 is in use
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('localhost', 8050))
sock.close()
if result == 0:
print("⚠️ Port 8050 is in use")
# Find process using the port
for conn in psutil.net_connections():
if conn.laddr.port == 8050:
try:
proc = psutil.Process(conn.pid)
print(f" Port 8050 used by PID {conn.pid}: {proc.name()}")
return conn.pid
except:
pass
else:
print("✅ Port 8050 is available")
return None
except Exception as e:
print(f"Error checking port usage: {e}")
return None
def main():
"""Main function"""
print("🔍 Stale Process Killer")
print("=" * 50)
try:
# Step 1: Find all Python processes
print("🔍 Finding Python processes...")
python_processes = find_python_processes()
print(f"Found {len(python_processes)} Python processes")
# Step 2: Identify dashboard processes
print("\n🎯 Identifying dashboard processes...")
dashboard_processes = identify_dashboard_processes(python_processes)
if dashboard_processes:
print(f"Found {len(dashboard_processes)} dashboard-related processes:")
for proc in dashboard_processes:
age_hours = (time.time() - proc['create_time']) / 3600
print(f" PID {proc['pid']}: {proc['name']} (age: {age_hours:.1f}h, status: {proc['status']})")
print(f" Command: {proc['cmdline'][:100]}...")
else:
print("No dashboard processes found")
# Step 3: Check port usage
print("\n🌐 Checking port usage...")
port_pid = check_port_usage()
# Step 4: Identify stale processes
print("\n🕵️ Identifying stale processes...")
stale_processes = identify_stale_processes(python_processes)
if stale_processes:
print(f"Found {len(stale_processes)} potentially stale processes:")
for proc in stale_processes:
print(f" PID {proc['pid']}: {proc['name']} - {proc['reason']}")
else:
print("No stale processes identified")
# Step 5: Ask user what to do
if dashboard_processes or stale_processes or port_pid:
print("\n🤔 What would you like to do?")
print("1. Kill all dashboard processes")
print("2. Kill only stale processes")
print("3. Kill process using port 8050")
print("4. Kill all identified processes")
print("5. Show process details and exit")
print("6. Exit without killing anything")
try:
choice = input("\nEnter your choice (1-6): ").strip()
if choice == '1':
# Kill dashboard processes
print("\n🔫 Killing dashboard processes...")
for proc in dashboard_processes:
if not kill_process_safely(proc):
kill_process_safely(proc, force=True)
elif choice == '2':
# Kill stale processes
print("\n🔫 Killing stale processes...")
for proc in stale_processes:
if not kill_process_safely(proc):
kill_process_safely(proc, force=True)
elif choice == '3':
# Kill process using port 8050
if port_pid:
print(f"\n🔫 Killing process using port 8050 (PID {port_pid})...")
try:
proc = psutil.Process(port_pid)
proc_info = {
'pid': port_pid,
'name': proc.name(),
'process': proc
}
if not kill_process_safely(proc_info):
kill_process_safely(proc_info, force=True)
except:
print(f"❌ Could not kill process {port_pid}")
else:
print("No process found using port 8050")
elif choice == '4':
# Kill all identified processes
print("\n🔫 Killing all identified processes...")
all_processes = dashboard_processes + stale_processes
if port_pid:
try:
proc = psutil.Process(port_pid)
all_processes.append({
'pid': port_pid,
'name': proc.name(),
'process': proc
})
except:
pass
for proc in all_processes:
if not kill_process_safely(proc):
kill_process_safely(proc, force=True)
elif choice == '5':
# Show details
print("\n📋 Process Details:")
all_processes = dashboard_processes + stale_processes
for proc in all_processes:
print(f"\nPID {proc['pid']}: {proc['name']}")
print(f" Status: {proc['status']}")
print(f" Command: {proc['cmdline']}")
print(f" Created: {datetime.fromtimestamp(proc['create_time'])}")
elif choice == '6':
print("👋 Exiting without killing processes")
else:
print("❌ Invalid choice")
except KeyboardInterrupt:
print("\n👋 Cancelled by user")
else:
print("\n✅ No problematic processes found")
print("\n" + "=" * 50)
print("💡 After killing processes, you can try:")
print(" python run_lightweight_dashboard.py")
print(" or")
print(" python fix_startup_freeze.py")
return True
except Exception as e:
print(f"❌ Error in main function: {e}")
return False
if __name__ == "__main__":
success = main()
if not success:
sys.exit(1)

View File

@ -1,41 +0,0 @@
"""
Launch training with optimized short-term models only
"""
import os
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import load_config
from core.training import TrainingManager
from core.models import OptimizedShortTermModel
def main():
"""Main training function using only optimized models"""
config = load_config()
# Initialize model
model = OptimizedShortTermModel()
# Load best model if exists
best_model_path = config.model_paths.get('ticks_model')
if os.path.exists(best_model_path):
model.load_state_dict(torch.load(best_model_path))
# Initialize training
trainer = TrainingManager(
model=model,
config=config,
use_ticks=True,
use_realtime=True
)
# Start training
trainer.train()
if __name__ == "__main__":
main()

458
main.py
View File

@ -1,458 +0,0 @@
#!/usr/bin/env python3
"""
Streamlined Trading System - Web Dashboard + Training
Integrated system with both training loop and web dashboard:
- Training Pipeline: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution
- Web Dashboard: Real-time monitoring and control interface
- 2-Action System: BUY/SELL with intelligent position management
- Always invested approach with smart risk/reward setup detection
Usage:
python main.py [--symbol ETH/USDT] [--port 8050]
"""
import os
# Fix OpenMP library conflicts before importing other modules
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
os.environ['OMP_NUM_THREADS'] = '4'
import asyncio
import argparse
import logging
import sys
from pathlib import Path
from threading import Thread
import time
from safe_logging import setup_safe_logging
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config, setup_logging, Config
from core.data_provider import DataProvider
# Import checkpoint management
from utils.checkpoint_manager import get_checkpoint_manager
from utils.training_integration import get_training_integration
logger = logging.getLogger(__name__)
async def run_web_dashboard():
"""Run the streamlined web dashboard with 2-action system and always-invested approach"""
try:
logger.info("Starting Streamlined Trading Dashboard...")
logger.info("2-Action System: BUY/SELL with intelligent position management")
logger.info("Always Invested Approach: Smart risk/reward setup detection")
logger.info("Integrated Training Pipeline: Live data -> Models -> Trading")
# Get configuration
config = get_config()
# Initialize core components for streamlined pipeline
from core.standardized_data_provider import StandardizedDataProvider
from core.orchestrator import TradingOrchestrator
from core.trading_executor import TradingExecutor
# Create standardized data provider (validated BaseDataInput, pivots, COB)
data_provider = StandardizedDataProvider()
# Start real-time streaming for BOM caching
try:
await data_provider.start_real_time_streaming()
logger.info("[SUCCESS] Real-time data streaming started for BOM caching")
except Exception as e:
logger.warning(f"[WARNING] Real-time streaming failed: {e}")
# Verify data connection with retry mechanism
logger.info("[DATA] Verifying live data connection...")
symbol = config.get('symbols', ['ETH/USDT'])[0]
# Wait for data provider to initialize and fetch initial data
max_retries = 10
retry_delay = 2
for attempt in range(max_retries):
test_df = data_provider.get_historical_data(symbol, '1m', limit=10)
if test_df is not None and len(test_df) > 0:
logger.info("[SUCCESS] Data connection verified")
logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation")
break
else:
if attempt < max_retries - 1:
logger.info(f"[DATA] Waiting for data provider to initialize... (attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(retry_delay)
else:
logger.warning("[WARNING] Data connection verification failed, but continuing with system startup")
logger.warning("The system will attempt to fetch data as needed during operation")
# Load model registry for integrated pipeline
try:
from models import get_model_registry
model_registry = {} # Use simple dict for now
logger.info("[MODELS] Model registry initialized for training")
except ImportError:
model_registry = {}
logger.warning("Model registry not available, using empty registry")
# Initialize checkpoint management
checkpoint_manager = get_checkpoint_manager()
training_integration = get_training_integration()
logger.info("Checkpoint management initialized for training pipeline")
# Create unified orchestrator with full ML pipeline
orchestrator = TradingOrchestrator(
data_provider=data_provider,
enhanced_rl_training=True,
model_registry={}
)
logger.info("Unified Trading Orchestrator initialized with full ML pipeline")
logger.info("Data Bus -> Models (DQN + CNN + COB) -> Decision Model -> Trading Signals")
# Checkpoint management will be handled in the training loop
logger.info("Checkpoint management will be initialized in training loop")
# Unified orchestrator includes COB integration as part of data bus
logger.info("COB Integration available - feeds into unified data bus")
# Create trading executor for live execution
trading_executor = TradingExecutor()
# Start the training and monitoring loop
logger.info(f"Starting Enhanced Training Pipeline")
logger.info("Live Data Processing: ENABLED")
logger.info("COB Integration: ENABLED (Real-time market microstructure)")
logger.info("Integrated CNN Training: ENABLED")
logger.info("Integrated RL Training: ENABLED")
logger.info("Real-time Indicators & Pivots: ENABLED")
logger.info("Live Trading Execution: ENABLED")
logger.info("2-Action System: BUY/SELL with position intelligence")
logger.info("Always Invested: Different thresholds for entry/exit")
logger.info("Pipeline: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
logger.info("Starting training loop...")
# Start the training loop
logger.info("About to start training loop...")
await start_training_loop(orchestrator, trading_executor)
except Exception as e:
logger.error(f"Error in streamlined dashboard: {e}")
logger.error("Training stopped")
import traceback
logger.error(traceback.format_exc())
def start_web_ui(port=8051):
"""Start the main TradingDashboard UI in a separate thread"""
try:
logger.info("=" * 50)
logger.info("Starting Main Trading Dashboard UI...")
logger.info(f"Trading Dashboard: http://127.0.0.1:{port}")
logger.info("COB Integration: ENABLED (Real-time order book visualization)")
logger.info("=" * 50)
# Import and create the Clean Trading Dashboard
from web.clean_dashboard import CleanTradingDashboard
from core.standardized_data_provider import StandardizedDataProvider
from core.orchestrator import TradingOrchestrator
from core.trading_executor import TradingExecutor
# Initialize components for the dashboard
config = get_config()
data_provider = StandardizedDataProvider()
# Start real-time streaming for BOM caching (non-blocking)
try:
import threading
def start_streaming():
import asyncio
asyncio.run(data_provider.start_real_time_streaming())
streaming_thread = threading.Thread(target=start_streaming, daemon=True)
streaming_thread.start()
logger.info("[SUCCESS] Real-time streaming thread started for dashboard")
except Exception as e:
logger.warning(f"[WARNING] Dashboard streaming setup failed: {e}")
# Load model registry for enhanced features
try:
from models import get_model_registry
model_registry = {} # Use simple dict for now
except ImportError:
model_registry = {}
# Initialize checkpoint management for dashboard
dashboard_checkpoint_manager = get_checkpoint_manager()
dashboard_training_integration = get_training_integration()
# Create unified orchestrator for the dashboard
dashboard_orchestrator = TradingOrchestrator(
data_provider=data_provider,
enhanced_rl_training=True,
model_registry={}
)
trading_executor = TradingExecutor("config.yaml")
# Create the clean trading dashboard with enhanced features
dashboard = CleanTradingDashboard(
data_provider=data_provider,
orchestrator=dashboard_orchestrator,
trading_executor=trading_executor
)
logger.info("Clean Trading Dashboard created successfully")
logger.info("Features: Live trading, COB visualization, ML pipeline monitoring, Position management")
logger.info("✅ Unified orchestrator with decision-making model and checkpoint management")
# Run the dashboard server (COB integration will start automatically)
dashboard.run_server(host='127.0.0.1', port=port, debug=False)
except Exception as e:
logger.error(f"Error starting main trading dashboard UI: {e}")
import traceback
logger.error(traceback.format_exc())
async def start_training_loop(orchestrator, trading_executor):
"""Start the main training and monitoring loop with checkpoint management"""
logger.info("=" * 70)
logger.info("STARTING ENHANCED TRAINING LOOP WITH COB INTEGRATION")
logger.info("=" * 70)
logger.info("Training loop function entered successfully")
# Initialize checkpoint management for training loop
checkpoint_manager = get_checkpoint_manager()
training_integration = get_training_integration()
# Training statistics for checkpoint management
training_stats = {
'iteration_count': 0,
'total_decisions': 0,
'successful_trades': 0,
'best_performance': 0.0,
'last_checkpoint_iteration': 0
}
try:
# Start real-time processing (Basic orchestrator doesn't have this method)
logger.info("Checking for real-time processing capabilities...")
try:
if hasattr(orchestrator, 'start_realtime_processing'):
logger.info("Starting real-time processing...")
await orchestrator.start_realtime_processing()
logger.info("Real-time processing started")
else:
logger.info("Basic orchestrator - no real-time processing method available")
except Exception as e:
logger.warning(f"Real-time processing not available: {e}")
logger.info("About to enter main training loop...")
# Main training loop
iteration = 0
while True:
iteration += 1
training_stats['iteration_count'] = iteration
logger.info(f"Training iteration {iteration}")
# Make trading decisions using Basic orchestrator (single symbol method)
decisions = {}
symbols = ['ETH/USDT'] # Focus on ETH only for training
for symbol in symbols:
try:
decision = await orchestrator.make_trading_decision(symbol)
decisions[symbol] = decision
except Exception as e:
logger.warning(f"Error making decision for {symbol}: {e}")
decisions[symbol] = None
# Process decisions and collect training metrics
iteration_decisions = 0
iteration_performance = 0.0
# Log decisions and performance
for symbol, decision in decisions.items():
if decision:
iteration_decisions += 1
logger.info(f"{symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
# Track performance for checkpoint management
iteration_performance += decision.confidence
# Execute if confidence is high enough
if decision.confidence > 0.7:
logger.info(f"Executing {symbol}: {decision.action}")
training_stats['successful_trades'] += 1
# trading_executor.execute_action(decision)
# Update training statistics
training_stats['total_decisions'] += iteration_decisions
if iteration_performance > training_stats['best_performance']:
training_stats['best_performance'] = iteration_performance
# Save checkpoint every 50 iterations or when performance improves significantly
should_save_checkpoint = (
iteration % 50 == 0 or # Regular interval
iteration_performance > training_stats['best_performance'] * 1.1 or # 10% improvement
iteration - training_stats['last_checkpoint_iteration'] >= 100 # Force save every 100 iterations
)
if should_save_checkpoint:
try:
# Create performance metrics for checkpoint
performance_metrics = {
'avg_confidence': iteration_performance / max(iteration_decisions, 1),
'success_rate': training_stats['successful_trades'] / max(training_stats['total_decisions'], 1),
'total_decisions': training_stats['total_decisions'],
'iteration': iteration
}
# Save orchestrator state (if it has models)
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
saved = orchestrator.rl_agent.save_checkpoint(iteration_performance)
if saved:
logger.info(f"✅ RL Agent checkpoint saved at iteration {iteration}")
if hasattr(orchestrator, 'cnn_model') and orchestrator.cnn_model:
# Simulate CNN checkpoint save
logger.info(f"✅ CNN Model training state saved at iteration {iteration}")
if hasattr(orchestrator, 'extrema_trainer') and orchestrator.extrema_trainer:
saved = orchestrator.extrema_trainer.save_checkpoint()
if saved:
logger.info(f"✅ ExtremaTrainer checkpoint saved at iteration {iteration}")
training_stats['last_checkpoint_iteration'] = iteration
logger.info(f"📊 Checkpoint management completed for iteration {iteration}")
except Exception as e:
logger.warning(f"Checkpoint saving failed at iteration {iteration}: {e}")
# Log performance metrics every 10 iterations
if iteration % 10 == 0:
metrics = orchestrator.get_performance_metrics()
logger.info(f"Performance metrics: {metrics}")
# Log training statistics
logger.info(f"Training stats: {training_stats}")
# Log checkpoint statistics
checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
logger.info(f"Checkpoints: {checkpoint_stats['total_checkpoints']} total, "
f"{checkpoint_stats['total_size_mb']:.2f} MB")
# Log COB integration status (Basic orchestrator doesn't have COB features)
symbols = getattr(orchestrator, 'symbols', ['ETH/USDT'])
if hasattr(orchestrator, 'latest_cob_features'):
for symbol in symbols:
cob_features = orchestrator.latest_cob_features.get(symbol)
cob_state = orchestrator.latest_cob_state.get(symbol)
if cob_features is not None:
logger.info(f"{symbol} COB: CNN features {cob_features.shape}, DQN state {cob_state.shape if cob_state is not None else 'None'}")
else:
logger.debug("Basic orchestrator - no COB integration features available")
# Sleep between iterations
await asyncio.sleep(5) # 5 second intervals
except KeyboardInterrupt:
logger.info("Training interrupted by user")
except Exception as e:
logger.error(f"Error in training loop: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
# Save final checkpoints before shutdown
try:
logger.info("Saving final checkpoints before shutdown...")
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
orchestrator.rl_agent.save_checkpoint(0.0, force_save=True)
logger.info("✅ Final RL Agent checkpoint saved")
if hasattr(orchestrator, 'extrema_trainer') and orchestrator.extrema_trainer:
orchestrator.extrema_trainer.save_checkpoint(force_save=True)
logger.info("✅ Final ExtremaTrainer checkpoint saved")
# Log final checkpoint statistics
final_stats = checkpoint_manager.get_checkpoint_stats()
logger.info(f"📊 Final checkpoint stats: {final_stats['total_checkpoints']} checkpoints, "
f"{final_stats['total_size_mb']:.2f} MB total")
except Exception as e:
logger.warning(f"Error saving final checkpoints: {e}")
# Stop real-time processing (Basic orchestrator doesn't have these methods)
try:
if hasattr(orchestrator, 'stop_realtime_processing'):
await orchestrator.stop_realtime_processing()
except Exception as e:
logger.warning(f"Error stopping real-time processing: {e}")
try:
if hasattr(orchestrator, 'stop_cob_integration'):
await orchestrator.stop_cob_integration()
except Exception as e:
logger.warning(f"Error stopping COB integration: {e}")
logger.info("Training loop stopped with checkpoint management")
async def main():
"""Main entry point with both training loop and web dashboard"""
parser = argparse.ArgumentParser(description='Streamlined Trading System - Training + Web Dashboard')
parser.add_argument('--symbol', type=str, default='ETH/USDT',
help='Primary trading symbol (default: ETH/USDT)')
parser.add_argument('--port', type=int, default=8050,
help='Web dashboard port (default: 8050)')
parser.add_argument('--debug', action='store_true',
help='Enable debug mode')
args = parser.parse_args()
# Setup logging and ensure directories exist
Path("logs").mkdir(exist_ok=True)
Path("NN/models/saved").mkdir(parents=True, exist_ok=True)
setup_safe_logging()
try:
logger.info("=" * 70)
logger.info("STREAMLINED TRADING SYSTEM - TRAINING + MAIN DASHBOARD")
logger.info(f"Primary Symbol: {args.symbol}")
logger.info(f"Training Port: {args.port}")
logger.info(f"Main Trading Dashboard: http://127.0.0.1:{args.port}")
logger.info("2-Action System: BUY/SELL with intelligent position management")
logger.info("Always Invested: Learning to spot high risk/reward setups")
logger.info("Flow: Data -> COB -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
logger.info("Main Dashboard: Live trading, RL monitoring, Position management")
logger.info("🔄 Checkpoint Management: Automatic training state persistence")
# logger.info("📊 W&B Integration: Optional experiment tracking")
logger.info("💾 Model Rotation: Keep best 5 checkpoints per model")
logger.info("=" * 70)
# Start main trading dashboard UI in a separate thread
web_thread = Thread(target=lambda: start_web_ui(args.port), daemon=True)
web_thread.start()
logger.info("Main trading dashboard UI thread started")
# Give web UI time to start
await asyncio.sleep(2)
# Run the training loop (this will run indefinitely)
await run_web_dashboard()
logger.info("[SUCCESS] Operation completed successfully!")
except KeyboardInterrupt:
logger.info("System shutdown requested by user")
except Exception as e:
logger.error(f"Fatal error: {e}")
import traceback
logger.error(traceback.format_exc())
return 1
return 0
if __name__ == "__main__":
sys.exit(asyncio.run(main()))

View File

@ -1,133 +0,0 @@
#!/usr/bin/env python3
"""
Clean Main Entry Point for Enhanced Trading Dashboard
This is the main entry point that safely launches the clean dashboard
with proper error handling and optimized settings.
"""
import os
import sys
import logging
import argparse
from typing import Optional
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Import core components
try:
from core.config import setup_logging
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
from core.trading_executor import TradingExecutor
from web.clean_dashboard import create_clean_dashboard
except ImportError as e:
print(f"Error importing core modules: {e}")
sys.exit(1)
logger = logging.getLogger(__name__)
def create_safe_orchestrator() -> Optional[TradingOrchestrator]:
"""Create orchestrator with safe CNN model handling"""
try:
# Create orchestrator with basic configuration (uses correct constructor parameters)
orchestrator = TradingOrchestrator(
enhanced_rl_training=False # Disable problematic training initially
)
logger.info("Trading orchestrator created successfully")
return orchestrator
except Exception as e:
logger.error(f"Error creating orchestrator: {e}")
logger.info("Continuing without orchestrator - dashboard will run in view-only mode")
return None
def create_safe_trading_executor() -> Optional[TradingExecutor]:
"""Create trading executor with safe configuration"""
try:
# TradingExecutor only accepts config_path parameter
trading_executor = TradingExecutor(config_path="config.yaml")
logger.info("Trading executor created successfully")
return trading_executor
except Exception as e:
logger.error(f"Error creating trading executor: {e}")
logger.info("Continuing without trading executor - dashboard will be view-only")
return None
def main():
"""Main entry point for clean dashboard"""
parser = argparse.ArgumentParser(description='Enhanced Trading Dashboard')
parser.add_argument('--port', type=int, default=8050, help='Dashboard port (default: 8050)')
parser.add_argument('--host', type=str, default='127.0.0.1', help='Dashboard host (default: 127.0.0.1)')
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
parser.add_argument('--no-training', action='store_true', help='Disable ML training for stability')
args = parser.parse_args()
# Setup logging
try:
setup_logging()
logger.info("================================================================================")
logger.info("CLEAN ENHANCED TRADING DASHBOARD")
logger.info("================================================================================")
logger.info(f"Starting on http://{args.host}:{args.port}")
logger.info("Features: Real-time Charts, Trading Interface, Model Monitoring")
logger.info("================================================================================")
except Exception as e:
print(f"Error setting up logging: {e}")
# Continue without logging setup
# Set environment variables for optimization
os.environ['ENABLE_REALTIME_CHARTS'] = '1'
if not args.no_training:
os.environ['ENABLE_NN_MODELS'] = '1'
try:
# Create data provider
logger.info("Initializing data provider...")
data_provider = DataProvider(symbols=['ETH/USDT', 'BTC/USDT'])
# Create orchestrator (with safe CNN handling)
logger.info("Initializing trading orchestrator...")
orchestrator = create_safe_orchestrator()
# Create trading executor
logger.info("Initializing trading executor...")
trading_executor = create_safe_trading_executor()
# Create and run dashboard
logger.info("Creating clean dashboard...")
dashboard = create_clean_dashboard(
data_provider=data_provider,
orchestrator=orchestrator,
trading_executor=trading_executor
)
# Start the dashboard server
logger.info(f"Starting dashboard server on http://{args.host}:{args.port}")
dashboard.run_server(
host=args.host,
port=args.port,
debug=args.debug
)
except KeyboardInterrupt:
logger.info("Dashboard stopped by user")
except Exception as e:
logger.error(f"Error running dashboard: {e}")
# Try to provide helpful error message
if "model.fit" in str(e) or "CNN" in str(e):
logger.error("CNN model training error detected. Try running with --no-training flag")
logger.error("Command: python main_clean.py --no-training")
sys.exit(1)
finally:
logger.info("Clean dashboard shutdown complete")
if __name__ == '__main__':
main()

View File

@ -1,204 +0,0 @@
#!/usr/bin/env python3
"""
Migrate Existing Models to Checkpoint System
This script migrates existing model files to the new checkpoint system
and creates proper database metadata entries.
"""
import os
import shutil
import logging
from datetime import datetime
from pathlib import Path
from utils.database_manager import get_database_manager, CheckpointMetadata
from utils.checkpoint_manager import save_checkpoint
from utils.text_logger import get_text_logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def migrate_existing_models():
"""Migrate existing models to checkpoint system"""
print("=== Migrating Existing Models to Checkpoint System ===")
db_manager = get_database_manager()
text_logger = get_text_logger()
# Define model migrations
migrations = [
{
'model_name': 'enhanced_cnn',
'model_type': 'cnn',
'source_file': 'models/enhanced_cnn/ETH_USDT_cnn.pth',
'performance_metrics': {'loss': 0.0187, 'accuracy': 0.92},
'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True}
},
{
'model_name': 'dqn_agent',
'model_type': 'rl',
'source_file': 'models/enhanced_rl/ETH_USDT_dqn_policy.pth',
'performance_metrics': {'loss': 0.0234, 'reward': 145.2},
'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True, 'type': 'policy'}
},
{
'model_name': 'dqn_agent_target',
'model_type': 'rl',
'source_file': 'models/enhanced_rl/ETH_USDT_dqn_target.pth',
'performance_metrics': {'loss': 0.0234, 'reward': 145.2},
'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True, 'type': 'target'}
}
]
migrated_count = 0
for migration in migrations:
source_path = Path(migration['source_file'])
if not source_path.exists():
logger.warning(f"Source file not found: {source_path}")
continue
try:
# Create checkpoint directory
checkpoint_dir = Path("models/checkpoints") / migration['model_name']
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Create checkpoint filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_id = f"{migration['model_name']}_{timestamp}"
checkpoint_file = checkpoint_dir / f"{checkpoint_id}.pt"
# Copy model file to checkpoint location
shutil.copy2(source_path, checkpoint_file)
logger.info(f"Copied {source_path} -> {checkpoint_file}")
# Calculate file size
file_size_mb = checkpoint_file.stat().st_size / (1024 * 1024)
# Create checkpoint metadata
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=migration['model_name'],
model_type=migration['model_type'],
timestamp=datetime.now(),
performance_metrics=migration['performance_metrics'],
training_metadata=migration['training_metadata'],
file_path=str(checkpoint_file),
file_size_mb=file_size_mb,
is_active=True
)
# Save to database
if db_manager.save_checkpoint_metadata(metadata):
logger.info(f"Saved checkpoint metadata: {checkpoint_id}")
# Log to text file
text_logger.log_checkpoint_event(
model_name=migration['model_name'],
event_type="MIGRATED",
checkpoint_id=checkpoint_id,
details=f"from {source_path}, size={file_size_mb:.1f}MB"
)
migrated_count += 1
else:
logger.error(f"Failed to save checkpoint metadata: {checkpoint_id}")
except Exception as e:
logger.error(f"Failed to migrate {migration['model_name']}: {e}")
print(f"\nMigration completed: {migrated_count} models migrated")
# Show current checkpoint status
print("\n=== Current Checkpoint Status ===")
for model_name in ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target']:
checkpoints = db_manager.list_checkpoints(model_name)
if checkpoints:
print(f"{model_name}: {len(checkpoints)} checkpoints")
for checkpoint in checkpoints[:2]: # Show first 2
print(f" - {checkpoint.checkpoint_id} ({checkpoint.file_size_mb:.1f}MB)")
else:
print(f"{model_name}: No checkpoints")
def verify_checkpoint_system():
"""Verify the checkpoint system is working"""
print("\n=== Verifying Checkpoint System ===")
db_manager = get_database_manager()
# Test loading checkpoints
for model_name in ['dqn_agent', 'enhanced_cnn']:
metadata = db_manager.get_best_checkpoint_metadata(model_name)
if metadata:
file_exists = Path(metadata.file_path).exists()
print(f"{model_name}: ✅ Metadata found, File exists: {file_exists}")
if file_exists:
print(f" -> {metadata.checkpoint_id} ({metadata.file_size_mb:.1f}MB)")
else:
print(f" -> ERROR: File missing: {metadata.file_path}")
else:
print(f"{model_name}: ❌ No checkpoint metadata found")
def create_test_checkpoint():
"""Create a test checkpoint to verify saving works"""
print("\n=== Testing Checkpoint Saving ===")
try:
import torch
import torch.nn as nn
# Create a simple test model
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
test_model = TestModel()
# Save using the checkpoint system
from utils.checkpoint_manager import save_checkpoint
result = save_checkpoint(
model=test_model,
model_name="test_model",
model_type="test",
performance_metrics={"loss": 0.1, "accuracy": 0.95},
training_metadata={"test": True, "created": datetime.now().isoformat()}
)
if result:
print(f"✅ Test checkpoint saved successfully: {result.checkpoint_id}")
# Verify it exists
db_manager = get_database_manager()
metadata = db_manager.get_best_checkpoint_metadata("test_model")
if metadata and Path(metadata.file_path).exists():
print(f"✅ Test checkpoint verified: {metadata.file_path}")
# Clean up test checkpoint
Path(metadata.file_path).unlink()
print("🧹 Test checkpoint cleaned up")
else:
print("❌ Test checkpoint verification failed")
else:
print("❌ Test checkpoint saving failed")
except Exception as e:
print(f"❌ Test checkpoint creation failed: {e}")
def main():
"""Main migration process"""
migrate_existing_models()
verify_checkpoint_system()
create_test_checkpoint()
print("\n=== Migration Complete ===")
print("The checkpoint system should now work properly!")
print("Existing models have been migrated and the system is ready for use.")
if __name__ == "__main__":
main()

View File

@ -1,558 +0,0 @@
"""
Enhanced Model Management System for Trading Dashboard
This system provides:
- Automatic cleanup of old model checkpoints
- Best model tracking with performance metrics
- Configurable retention policies
- Startup model loading
- Performance-based model selection
"""
import os
import json
import shutil
import logging
import torch
import glob
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from pathlib import Path
import numpy as np
logger = logging.getLogger(__name__)
@dataclass
class ModelMetrics:
"""Performance metrics for model evaluation"""
accuracy: float = 0.0
profit_factor: float = 0.0
win_rate: float = 0.0
sharpe_ratio: float = 0.0
max_drawdown: float = 0.0
total_trades: int = 0
avg_trade_duration: float = 0.0
confidence_score: float = 0.0
def get_composite_score(self) -> float:
"""Calculate composite performance score"""
# Weighted composite score
weights = {
'profit_factor': 0.3,
'sharpe_ratio': 0.25,
'win_rate': 0.2,
'accuracy': 0.15,
'confidence_score': 0.1
}
# Normalize values to 0-1 range
normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0
normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1
normalized_win_rate = self.win_rate
normalized_accuracy = self.accuracy
normalized_confidence = self.confidence_score
# Apply penalties for poor performance
drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown
score = (
weights['profit_factor'] * normalized_pf +
weights['sharpe_ratio'] * normalized_sharpe +
weights['win_rate'] * normalized_win_rate +
weights['accuracy'] * normalized_accuracy +
weights['confidence_score'] * normalized_confidence
) * drawdown_penalty
return min(max(score, 0), 1)
@dataclass
class ModelInfo:
"""Complete model information and metadata"""
model_type: str # 'cnn', 'rl', 'transformer'
model_name: str
file_path: str
creation_time: datetime
last_updated: datetime
file_size_mb: float
metrics: ModelMetrics
training_episodes: int = 0
model_version: str = "1.0"
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization"""
data = asdict(self)
data['creation_time'] = self.creation_time.isoformat()
data['last_updated'] = self.last_updated.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
"""Create from dictionary"""
data['creation_time'] = datetime.fromisoformat(data['creation_time'])
data['last_updated'] = datetime.fromisoformat(data['last_updated'])
data['metrics'] = ModelMetrics(**data['metrics'])
return cls(**data)
class ModelManager:
"""Enhanced model management system"""
def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None):
self.base_dir = Path(base_dir)
self.config = config or self._get_default_config()
# Model directories
self.models_dir = self.base_dir / "models"
self.nn_models_dir = self.base_dir / "NN" / "models"
self.registry_file = self.models_dir / "model_registry.json"
self.best_models_dir = self.models_dir / "best_models"
# Create directories
self.best_models_dir.mkdir(parents=True, exist_ok=True)
# Model registry
self.model_registry: Dict[str, ModelInfo] = {}
self._load_registry()
logger.info(f"Model Manager initialized - Base: {self.base_dir}")
logger.info(f"Retention policy: Keep {self.config['max_models_per_type']} best models per type")
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration"""
return {
'max_models_per_type': 3, # Keep top 3 models per type
'max_total_models': 10, # Maximum total models to keep
'cleanup_frequency_hours': 24, # Cleanup every 24 hours
'min_performance_threshold': 0.3, # Minimum composite score
'max_checkpoint_age_days': 7, # Delete checkpoints older than 7 days
'auto_cleanup_enabled': True,
'backup_before_cleanup': True,
'model_size_limit_mb': 100, # Individual model size limit
'total_storage_limit_gb': 5.0 # Total storage limit
}
def _load_registry(self):
"""Load model registry from file"""
try:
if self.registry_file.exists():
with open(self.registry_file, 'r') as f:
data = json.load(f)
self.model_registry = {
k: ModelInfo.from_dict(v) for k, v in data.items()
}
logger.info(f"Loaded {len(self.model_registry)} models from registry")
else:
logger.info("No existing model registry found")
except Exception as e:
logger.error(f"Error loading model registry: {e}")
self.model_registry = {}
def _save_registry(self):
"""Save model registry to file"""
try:
self.models_dir.mkdir(parents=True, exist_ok=True)
with open(self.registry_file, 'w') as f:
data = {k: v.to_dict() for k, v in self.model_registry.items()}
json.dump(data, f, indent=2, default=str)
logger.info(f"Saved registry with {len(self.model_registry)} models")
except Exception as e:
logger.error(f"Error saving model registry: {e}")
def cleanup_all_existing_models(self, confirm: bool = False) -> Dict[str, Any]:
"""
Clean up all existing model files and prepare for 2-action system training
Args:
confirm: If True, perform the cleanup. If False, return what would be cleaned
Returns:
Dict with cleanup statistics
"""
cleanup_stats = {
'files_found': 0,
'files_deleted': 0,
'directories_cleaned': 0,
'space_freed_mb': 0.0,
'errors': []
}
# Model file patterns for both 2-action and legacy 3-action systems
model_patterns = [
"**/*.pt", "**/*.pth", "**/*.h5", "**/*.pkl", "**/*.joblib", "**/*.model",
"**/checkpoint_*", "**/model_*", "**/cnn_*", "**/dqn_*", "**/rl_*"
]
# Directories to clean
model_directories = [
"models/saved",
"NN/models/saved",
"NN/models/saved/checkpoints",
"NN/models/saved/realtime_checkpoints",
"NN/models/saved/realtime_ticks_checkpoints",
"model_backups"
]
try:
# Scan for files to be cleaned
for directory in model_directories:
dir_path = Path(self.base_dir) / directory
if dir_path.exists():
for pattern in model_patterns:
for file_path in dir_path.glob(pattern):
if file_path.is_file():
cleanup_stats['files_found'] += 1
file_size = file_path.stat().st_size / (1024 * 1024) # MB
cleanup_stats['space_freed_mb'] += file_size
if confirm:
try:
file_path.unlink()
cleanup_stats['files_deleted'] += 1
logger.info(f"Deleted model file: {file_path}")
except Exception as e:
cleanup_stats['errors'].append(f"Failed to delete {file_path}: {e}")
# Clean up empty checkpoint directories
for directory in model_directories:
dir_path = Path(self.base_dir) / directory
if dir_path.exists():
for subdir in dir_path.rglob("*"):
if subdir.is_dir() and not any(subdir.iterdir()):
if confirm:
try:
subdir.rmdir()
cleanup_stats['directories_cleaned'] += 1
logger.info(f"Removed empty directory: {subdir}")
except Exception as e:
cleanup_stats['errors'].append(f"Failed to remove directory {subdir}: {e}")
if confirm:
# Clear the registry for fresh start with 2-action system
self.model_registry = {
'models': {},
'metadata': {
'last_updated': datetime.now().isoformat(),
'total_models': 0,
'system_type': '2_action', # Mark as 2-action system
'action_space': ['SELL', 'BUY'],
'version': '2.0'
}
}
self._save_registry()
logger.info("=" * 60)
logger.info("MODEL CLEANUP COMPLETED - 2-ACTION SYSTEM READY")
logger.info(f"Files deleted: {cleanup_stats['files_deleted']}")
logger.info(f"Space freed: {cleanup_stats['space_freed_mb']:.2f} MB")
logger.info(f"Directories cleaned: {cleanup_stats['directories_cleaned']}")
logger.info("Registry reset for 2-action system (BUY/SELL)")
logger.info("Ready for fresh training with intelligent position management")
logger.info("=" * 60)
else:
logger.info("=" * 60)
logger.info("MODEL CLEANUP PREVIEW - 2-ACTION SYSTEM MIGRATION")
logger.info(f"Files to delete: {cleanup_stats['files_found']}")
logger.info(f"Space to free: {cleanup_stats['space_freed_mb']:.2f} MB")
logger.info("Run with confirm=True to perform cleanup")
logger.info("=" * 60)
except Exception as e:
cleanup_stats['errors'].append(f"Cleanup error: {e}")
logger.error(f"Error during model cleanup: {e}")
return cleanup_stats
def register_model(self, model_path: str, model_type: str, metrics: Optional[ModelMetrics] = None) -> str:
"""
Register a new model in the 2-action system
Args:
model_path: Path to the model file
model_type: Type of model ('cnn', 'rl', 'transformer')
metrics: Performance metrics
Returns:
str: Unique model name/ID
"""
if not Path(model_path).exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
# Generate unique model name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"{model_type}_2action_{timestamp}"
# Get file info
file_path = Path(model_path)
file_size_mb = file_path.stat().st_size / (1024 * 1024)
# Default metrics for 2-action system
if metrics is None:
metrics = ModelMetrics(
accuracy=0.0,
profit_factor=1.0,
win_rate=0.5,
sharpe_ratio=0.0,
max_drawdown=0.0,
confidence_score=0.5
)
# Create model info
model_info = ModelInfo(
model_type=model_type,
model_name=model_name,
file_path=str(file_path.absolute()),
creation_time=datetime.now(),
last_updated=datetime.now(),
file_size_mb=file_size_mb,
metrics=metrics,
model_version="2.0" # 2-action system version
)
# Add to registry
self.model_registry['models'][model_name] = model_info.to_dict()
self.model_registry['metadata']['total_models'] = len(self.model_registry['models'])
self.model_registry['metadata']['last_updated'] = datetime.now().isoformat()
self.model_registry['metadata']['system_type'] = '2_action'
self.model_registry['metadata']['action_space'] = ['SELL', 'BUY']
self._save_registry()
# Cleanup old models if necessary
self._cleanup_models_by_type(model_type)
logger.info(f"Registered 2-action model: {model_name}")
logger.info(f"Model type: {model_type}, Size: {file_size_mb:.2f} MB")
logger.info(f"Performance score: {metrics.get_composite_score():.4f}")
return model_name
def _should_keep_model(self, model_info: ModelInfo) -> bool:
"""Determine if model should be kept based on performance"""
score = model_info.metrics.get_composite_score()
# Check minimum threshold
if score < self.config['min_performance_threshold']:
return False
# Check size limit
if model_info.file_size_mb > self.config['model_size_limit_mb']:
logger.warning(f"Model too large: {model_info.file_size_mb:.1f}MB > {self.config['model_size_limit_mb']}MB")
return False
# Check if better than existing models of same type
existing_models = self.get_models_by_type(model_info.model_type)
if len(existing_models) >= self.config['max_models_per_type']:
# Find worst performing model
worst_model = min(existing_models.values(), key=lambda m: m.metrics.get_composite_score())
if score <= worst_model.metrics.get_composite_score():
return False
return True
def _cleanup_models_by_type(self, model_type: str):
"""Cleanup old models of specific type, keeping only the best ones"""
models_of_type = self.get_models_by_type(model_type)
max_keep = self.config['max_models_per_type']
if len(models_of_type) <= max_keep:
return
# Sort by performance score
sorted_models = sorted(
models_of_type.items(),
key=lambda x: x[1].metrics.get_composite_score(),
reverse=True
)
# Keep only the best models
models_to_keep = sorted_models[:max_keep]
models_to_remove = sorted_models[max_keep:]
for model_name, model_info in models_to_remove:
try:
# Remove file
model_path = Path(model_info.file_path)
if model_path.exists():
model_path.unlink()
# Remove from registry
del self.model_registry[model_name]
logger.info(f"Removed old model: {model_name} (Score: {model_info.metrics.get_composite_score():.3f})")
except Exception as e:
logger.error(f"Error removing model {model_name}: {e}")
def get_models_by_type(self, model_type: str) -> Dict[str, ModelInfo]:
"""Get all models of a specific type"""
return {
name: info for name, info in self.model_registry.items()
if info.model_type == model_type
}
def get_best_model(self, model_type: str) -> Optional[ModelInfo]:
"""Get the best performing model of a specific type"""
models_of_type = self.get_models_by_type(model_type)
if not models_of_type:
return None
return max(models_of_type.values(), key=lambda m: m.metrics.get_composite_score())
def load_best_models(self) -> Dict[str, Any]:
"""Load the best models for each type"""
loaded_models = {}
for model_type in ['cnn', 'rl', 'transformer']:
best_model = self.get_best_model(model_type)
if best_model:
try:
model_path = Path(best_model.file_path)
if model_path.exists():
# Load the model
model_data = torch.load(model_path, map_location='cpu')
loaded_models[model_type] = {
'model': model_data,
'info': best_model,
'path': str(model_path)
}
logger.info(f"Loaded best {model_type} model: {best_model.model_name} "
f"(Score: {best_model.metrics.get_composite_score():.3f})")
else:
logger.warning(f"Best {model_type} model file not found: {model_path}")
except Exception as e:
logger.error(f"Error loading {model_type} model: {e}")
else:
logger.info(f"No {model_type} model available")
return loaded_models
def update_model_performance(self, model_name: str, metrics: ModelMetrics):
"""Update performance metrics for a model"""
if model_name in self.model_registry:
self.model_registry[model_name].metrics = metrics
self.model_registry[model_name].last_updated = datetime.now()
self._save_registry()
logger.info(f"Updated metrics for {model_name}: Score {metrics.get_composite_score():.3f}")
else:
logger.warning(f"Model {model_name} not found in registry")
def get_storage_stats(self) -> Dict[str, Any]:
"""Get storage usage statistics"""
total_size_mb = 0
model_count = 0
for model_info in self.model_registry.values():
total_size_mb += model_info.file_size_mb
model_count += 1
# Check actual storage usage
actual_size_mb = 0
if self.best_models_dir.exists():
actual_size_mb = sum(
f.stat().st_size for f in self.best_models_dir.rglob('*') if f.is_file()
) / 1024 / 1024
return {
'total_models': model_count,
'registered_size_mb': total_size_mb,
'actual_size_mb': actual_size_mb,
'storage_limit_gb': self.config['total_storage_limit_gb'],
'utilization_percent': (actual_size_mb / 1024) / self.config['total_storage_limit_gb'] * 100,
'models_by_type': {
model_type: len(self.get_models_by_type(model_type))
for model_type in ['cnn', 'rl', 'transformer']
}
}
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
"""Get model performance leaderboard"""
leaderboard = []
for model_name, model_info in self.model_registry.items():
leaderboard.append({
'name': model_name,
'type': model_info.model_type,
'score': model_info.metrics.get_composite_score(),
'profit_factor': model_info.metrics.profit_factor,
'win_rate': model_info.metrics.win_rate,
'sharpe_ratio': model_info.metrics.sharpe_ratio,
'size_mb': model_info.file_size_mb,
'age_days': (datetime.now() - model_info.creation_time).days,
'last_updated': model_info.last_updated.strftime('%Y-%m-%d %H:%M')
})
# Sort by score
leaderboard.sort(key=lambda x: x['score'], reverse=True)
return leaderboard
def cleanup_checkpoints(self) -> Dict[str, Any]:
"""Clean up old checkpoint files"""
cleanup_summary = {
'deleted_files': 0,
'freed_space_mb': 0,
'errors': []
}
cutoff_date = datetime.now() - timedelta(days=self.config['max_checkpoint_age_days'])
# Search for checkpoint files
checkpoint_patterns = [
"**/checkpoint_*.pt",
"**/model_*.pt",
"**/*checkpoint*",
"**/epoch_*.pt"
]
for pattern in checkpoint_patterns:
for file_path in self.base_dir.rglob(pattern):
if "best_models" not in str(file_path) and file_path.is_file():
try:
file_time = datetime.fromtimestamp(file_path.stat().st_mtime)
if file_time < cutoff_date:
size_mb = file_path.stat().st_size / 1024 / 1024
file_path.unlink()
cleanup_summary['deleted_files'] += 1
cleanup_summary['freed_space_mb'] += size_mb
except Exception as e:
error_msg = f"Error deleting checkpoint {file_path}: {e}"
logger.error(error_msg)
cleanup_summary['errors'].append(error_msg)
if cleanup_summary['deleted_files'] > 0:
logger.info(f"Checkpoint cleanup: Deleted {cleanup_summary['deleted_files']} files, "
f"freed {cleanup_summary['freed_space_mb']:.1f}MB")
return cleanup_summary
def create_model_manager() -> ModelManager:
"""Create and initialize the global model manager"""
return ModelManager()
# Example usage
if __name__ == "__main__":
# Configure logging
logging.basicConfig(level=logging.INFO)
# Create model manager
manager = ModelManager()
# Clean up all existing models (with confirmation)
print("WARNING: This will delete ALL existing models!")
print("Type 'CONFIRM' to proceed:")
user_input = input().strip()
if user_input == "CONFIRM":
cleanup_result = manager.cleanup_all_existing_models(confirm=True)
print(f"\nCleanup complete:")
print(f"- Deleted {cleanup_result['files_deleted']} files")
print(f"- Freed {cleanup_result['space_freed_mb']:.1f}MB of space")
print(f"- Cleaned {cleanup_result['directories_cleaned']} directories")
if cleanup_result['errors']:
print(f"- {len(cleanup_result['errors'])} errors occurred")
else:
print("Cleanup cancelled")

View File

@ -1,193 +0,0 @@
#!/usr/bin/env python3
"""
Position Sync Enhancement - Fix P&L and Win Rate Calculation
This script enhances the position synchronization and P&L calculation
to properly account for leverage in the trading system.
"""
import os
import sys
import logging
from pathlib import Path
from datetime import datetime
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config, setup_logging
from core.trading_executor import TradingExecutor, TradeRecord
# Setup logging
setup_logging()
logger = logging.getLogger(__name__)
def analyze_trade_records():
"""Analyze trade records for P&L calculation issues"""
logger.info("Analyzing trade records for P&L calculation issues...")
# Initialize trading executor
trading_executor = TradingExecutor()
# Get trade records
trade_records = trading_executor.trade_records
if not trade_records:
logger.warning("No trade records found.")
return
logger.info(f"Found {len(trade_records)} trade records.")
# Analyze P&L calculation
total_pnl = 0.0
total_gross_pnl = 0.0
total_fees = 0.0
winning_trades = 0
losing_trades = 0
breakeven_trades = 0
for trade in trade_records:
# Calculate correct P&L with leverage
entry_value = trade.entry_price * trade.quantity
exit_value = trade.exit_price * trade.quantity
if trade.side == 'LONG':
gross_pnl = (exit_value - entry_value) * trade.leverage
else: # SHORT
gross_pnl = (entry_value - exit_value) * trade.leverage
# Calculate fees
fees = (entry_value + exit_value) * 0.001 # 0.1% fee on both entry and exit
# Calculate net P&L
net_pnl = gross_pnl - fees
# Compare with stored values
pnl_diff = abs(net_pnl - trade.pnl)
if pnl_diff > 0.01: # More than 1 cent difference
logger.warning(f"P&L calculation issue detected for trade {trade.entry_time}:")
logger.warning(f" Stored P&L: ${trade.pnl:.2f}")
logger.warning(f" Calculated P&L: ${net_pnl:.2f}")
logger.warning(f" Difference: ${pnl_diff:.2f}")
logger.warning(f" Leverage used: {trade.leverage}x")
# Update statistics
total_pnl += net_pnl
total_gross_pnl += gross_pnl
total_fees += fees
if net_pnl > 0.01: # More than 1 cent profit
winning_trades += 1
elif net_pnl < -0.01: # More than 1 cent loss
losing_trades += 1
else:
breakeven_trades += 1
# Calculate win rate
total_trades = winning_trades + losing_trades + breakeven_trades
win_rate = (winning_trades / total_trades * 100) if total_trades > 0 else 0.0
logger.info("\nTrade Analysis Results:")
logger.info(f" Total trades: {total_trades}")
logger.info(f" Winning trades: {winning_trades}")
logger.info(f" Losing trades: {losing_trades}")
logger.info(f" Breakeven trades: {breakeven_trades}")
logger.info(f" Win rate: {win_rate:.1f}%")
logger.info(f" Total P&L: ${total_pnl:.2f}")
logger.info(f" Total gross P&L: ${total_gross_pnl:.2f}")
logger.info(f" Total fees: ${total_fees:.2f}")
# Check for leverage issues
leverage_issues = False
for trade in trade_records:
if trade.leverage <= 1.0:
leverage_issues = True
logger.warning(f"Low leverage detected: {trade.leverage}x for trade at {trade.entry_time}")
if leverage_issues:
logger.warning("\nLeverage issues detected. Consider fixing the leverage calculation.")
logger.info("Recommended fix: Ensure leverage is properly set in the trading executor.")
else:
logger.info("\nNo leverage issues detected.")
def fix_leverage_calculation():
"""Fix leverage calculation in the trading executor"""
logger.info("Fixing leverage calculation in the trading executor...")
# Initialize trading executor
trading_executor = TradingExecutor()
# Get current leverage
current_leverage = trading_executor.current_leverage
logger.info(f"Current leverage setting: {current_leverage}x")
# Check if leverage is properly set
if current_leverage <= 1:
logger.warning("Leverage is set too low. Updating to 20x...")
trading_executor.current_leverage = 20
logger.info(f"Updated leverage to {trading_executor.current_leverage}x")
else:
logger.info("Leverage is already set correctly.")
# Update trade records with correct leverage
updated_count = 0
for i, trade in enumerate(trading_executor.trade_records):
if trade.leverage <= 1.0:
# Create updated trade record
updated_trade = TradeRecord(
symbol=trade.symbol,
side=trade.side,
quantity=trade.quantity,
entry_price=trade.entry_price,
exit_price=trade.exit_price,
entry_time=trade.entry_time,
exit_time=trade.exit_time,
pnl=trade.pnl,
fees=trade.fees,
confidence=trade.confidence,
hold_time_seconds=trade.hold_time_seconds,
leverage=trading_executor.current_leverage, # Use current leverage setting
position_size_usd=trade.position_size_usd,
gross_pnl=trade.gross_pnl,
net_pnl=trade.net_pnl
)
# Recalculate P&L with correct leverage
entry_value = updated_trade.entry_price * updated_trade.quantity
exit_value = updated_trade.exit_price * updated_trade.quantity
if updated_trade.side == 'LONG':
updated_trade.gross_pnl = (exit_value - entry_value) * updated_trade.leverage
else: # SHORT
updated_trade.gross_pnl = (entry_value - exit_value) * updated_trade.leverage
# Recalculate fees
updated_trade.fees = (entry_value + exit_value) * 0.001 # 0.1% fee on both entry and exit
# Recalculate net P&L
updated_trade.net_pnl = updated_trade.gross_pnl - updated_trade.fees
updated_trade.pnl = updated_trade.net_pnl
# Update trade record
trading_executor.trade_records[i] = updated_trade
updated_count += 1
logger.info(f"Updated {updated_count} trade records with correct leverage.")
# Save updated trade records
# Note: This is a placeholder. In a real implementation, you would need to
# persist the updated trade records to storage.
logger.info("Changes will take effect on next dashboard restart.")
return updated_count > 0
if __name__ == "__main__":
logger.info("=" * 70)
logger.info("POSITION SYNC ENHANCEMENT")
logger.info("=" * 70)
if len(sys.argv) > 1 and sys.argv[1] == 'fix':
fix_leverage_calculation()
else:
analyze_trade_records()

View File

@ -1,124 +0,0 @@
#!/usr/bin/env python
"""
Log Reader Utility
This script provides a convenient way to read and filter log files during
development.
"""
import os
import sys
import time
import argparse
from datetime import datetime
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Read and filter log files')
parser.add_argument('--file', type=str, help='Log file to read (defaults to most recent .log file)')
parser.add_argument('--tail', type=int, default=50, help='Number of lines to show from the end')
parser.add_argument('--follow', '-f', action='store_true', help='Follow the file as it grows')
parser.add_argument('--filter', type=str, help='Only show lines containing this string')
parser.add_argument('--list', action='store_true', help='List all log files sorted by modification time')
return parser.parse_args()
def get_most_recent_log():
"""Find the most recently modified log file"""
log_files = [f for f in os.listdir('.') if f.endswith('.log')]
if not log_files:
print("No log files found in current directory.")
sys.exit(1)
# Sort by modification time (newest first)
log_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
return log_files[0]
def list_log_files():
"""List all log files sorted by modification time"""
log_files = [f for f in os.listdir('.') if f.endswith('.log')]
if not log_files:
print("No log files found in current directory.")
sys.exit(1)
# Sort by modification time (newest first)
log_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
print(f"{'LAST MODIFIED':<20} {'SIZE':<10} FILENAME")
print("-" * 60)
for log_file in log_files:
mtime = datetime.fromtimestamp(os.path.getmtime(log_file))
size = os.path.getsize(log_file)
size_str = f"{size / 1024:.1f} KB" if size > 1024 else f"{size} B"
print(f"{mtime.strftime('%Y-%m-%d %H:%M:%S'):<20} {size_str:<10} {log_file}")
def read_log_tail(file_path, num_lines, filter_text=None):
"""Read the last N lines of a file"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
# Read all lines (inefficient but simple)
lines = f.readlines()
# Filter if needed
if filter_text:
lines = [line for line in lines if filter_text in line]
# Get the last N lines
last_lines = lines[-num_lines:] if len(lines) > num_lines else lines
return last_lines
except Exception as e:
print(f"Error reading file: {str(e)}")
sys.exit(1)
def follow_log(file_path, filter_text=None):
"""Follow the log file as it grows (like tail -f)"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
# Go to the end of the file
f.seek(0, 2)
while True:
line = f.readline()
if line:
if not filter_text or filter_text in line:
# Remove newlines at the end to avoid double spacing
print(line.rstrip())
else:
time.sleep(0.1) # Sleep briefly to avoid consuming CPU
except KeyboardInterrupt:
print("\nLog reading stopped.")
except Exception as e:
print(f"Error following file: {str(e)}")
sys.exit(1)
def main():
"""Main function"""
args = parse_args()
# List all log files if requested
if args.list:
list_log_files()
return
# Determine which file to read
file_path = args.file
if not file_path:
file_path = get_most_recent_log()
print(f"Reading most recent log file: {file_path}")
# Follow mode (like tail -f)
if args.follow:
print(f"Following {file_path} (Press Ctrl+C to stop)...")
# First print the tail
for line in read_log_tail(file_path, args.tail, args.filter):
print(line.rstrip())
print("-" * 80)
print("Waiting for new content...")
# Then follow
follow_log(file_path, args.filter)
else:
# Just print the tail
for line in read_log_tail(file_path, args.tail, args.filter):
print(line.rstrip())
if __name__ == "__main__":
main()

View File

@ -1,31 +0,0 @@
#!/usr/bin/env python3
"""
Script to reset the database manager instance to trigger migration in running system
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from utils.database_manager import reset_database_manager
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
"""Reset the database manager to trigger migration"""
try:
logger.info("Resetting database manager to trigger migration...")
reset_database_manager()
logger.info("✅ Database manager reset successfully!")
logger.info("The migration will run automatically on the next database access.")
return True
except Exception as e:
logger.error(f"❌ Failed to reset database manager: {e}")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -1,204 +0,0 @@
#!/usr/bin/env python3
"""
Reset Models and Fix Action Mapping
This script:
1. Deletes existing model files
2. Creates new model files with consistent action mapping
3. Updates action mapping in key files
"""
import os
import shutil
import logging
import sys
import torch
import numpy as np
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def ensure_directory(directory):
"""Ensure directory exists"""
if not os.path.exists(directory):
os.makedirs(directory)
logger.info(f"Created directory: {directory}")
def delete_directory_contents(directory):
"""Delete all files in a directory"""
if os.path.exists(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
logger.info(f"Deleted: {file_path}")
except Exception as e:
logger.error(f"Failed to delete {file_path}. Reason: {e}")
def create_backup_directory():
"""Create a backup directory with timestamp"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_dir = f"models/backup_{timestamp}"
ensure_directory(backup_dir)
return backup_dir
def backup_models():
"""Backup existing models"""
backup_dir = create_backup_directory()
# List of model directories to backup
model_dirs = [
"models/enhanced_rl",
"models/enhanced_cnn",
"models/realtime_rl_cob",
"models/rl",
"models/cnn"
]
for model_dir in model_dirs:
if os.path.exists(model_dir):
dest_dir = os.path.join(backup_dir, os.path.basename(model_dir))
ensure_directory(dest_dir)
# Copy files
for filename in os.listdir(model_dir):
file_path = os.path.join(model_dir, filename)
if os.path.isfile(file_path):
shutil.copy2(file_path, dest_dir)
logger.info(f"Backed up: {file_path} to {dest_dir}")
return backup_dir
def initialize_dqn_model():
"""Initialize a new DQN model with consistent action mapping"""
try:
# Import necessary modules
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from NN.models.dqn_agent import DQNAgent
# Define state shape for BTC and ETH
state_shape = (100,) # Default feature dimension
# Create models directory
ensure_directory("models/enhanced_rl")
# Initialize DQN with 3 actions (BUY=0, SELL=1, HOLD=2)
dqn_btc = DQNAgent(
state_shape=state_shape,
n_actions=3, # BUY=0, SELL=1, HOLD=2
learning_rate=0.001,
epsilon=0.5, # Start with moderate exploration
epsilon_min=0.01,
epsilon_decay=0.995,
model_name="BTC_USDT_dqn"
)
dqn_eth = DQNAgent(
state_shape=state_shape,
n_actions=3, # BUY=0, SELL=1, HOLD=2
learning_rate=0.001,
epsilon=0.5, # Start with moderate exploration
epsilon_min=0.01,
epsilon_decay=0.995,
model_name="ETH_USDT_dqn"
)
# Save initial models
torch.save(dqn_btc.policy_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_policy.pth")
torch.save(dqn_btc.target_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_target.pth")
torch.save(dqn_eth.policy_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_policy.pth")
torch.save(dqn_eth.target_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_target.pth")
logger.info("Initialized new DQN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)")
return True
except Exception as e:
logger.error(f"Failed to initialize DQN models: {e}")
return False
def initialize_cnn_model():
"""Initialize a new CNN model with consistent action mapping"""
try:
# Import necessary modules
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from NN.models.enhanced_cnn import EnhancedCNN
# Define input dimension and number of actions
input_dim = 100 # Default feature dimension
n_actions = 3 # BUY=0, SELL=1, HOLD=2
# Create models directory
ensure_directory("models/enhanced_cnn")
# Initialize CNN models for BTC and ETH
cnn_btc = EnhancedCNN(input_dim, n_actions)
cnn_eth = EnhancedCNN(input_dim, n_actions)
# Save initial models
torch.save(cnn_btc.state_dict(), "models/enhanced_cnn/BTC_USDT_cnn.pth")
torch.save(cnn_eth.state_dict(), "models/enhanced_cnn/ETH_USDT_cnn.pth")
logger.info("Initialized new CNN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)")
return True
except Exception as e:
logger.error(f"Failed to initialize CNN models: {e}")
return False
def initialize_realtime_rl_model():
"""Initialize a new realtime RL model with consistent action mapping"""
try:
# Create models directory
ensure_directory("models/realtime_rl_cob")
# Create empty model files to ensure directory is not empty
with open("models/realtime_rl_cob/README.txt", "w") as f:
f.write("Realtime RL COB models will be saved here.\n")
f.write("Action mapping: BUY=0, SELL=1, HOLD=2\n")
logger.info("Initialized realtime RL model directory")
return True
except Exception as e:
logger.error(f"Failed to initialize realtime RL models: {e}")
return False
def main():
"""Main function to reset models and fix action mapping"""
logger.info("Starting model reset and action mapping fix")
# Backup existing models
backup_dir = backup_models()
logger.info(f"Backed up existing models to {backup_dir}")
# Delete existing model files
model_dirs = [
"models/enhanced_rl",
"models/enhanced_cnn",
"models/realtime_rl_cob"
]
for model_dir in model_dirs:
delete_directory_contents(model_dir)
logger.info(f"Deleted contents of {model_dir}")
# Initialize new models with consistent action mapping
dqn_success = initialize_dqn_model()
cnn_success = initialize_cnn_model()
rl_success = initialize_realtime_rl_model()
if dqn_success and cnn_success and rl_success:
logger.info("Successfully reset models and fixed action mapping")
logger.info("New action mapping: BUY=0, SELL=1, HOLD=2")
else:
logger.error("Failed to reset models and fix action mapping")
logger.info("Model reset complete")
if __name__ == "__main__":
main()

View File

@ -1,325 +1,26 @@
#!/usr/bin/env python3
"""
Run Clean Trading Dashboard with Full Training Pipeline
Integrated system with both training loop and clean web dashboard
"""
import os
# Fix OpenMP library conflicts before importing other modules
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
os.environ['OMP_NUM_THREADS'] = '4'
# Fix matplotlib backend issue - set non-interactive backend before any imports
import matplotlib
matplotlib.use('Agg') # Use non-interactive Agg backend
import asyncio
import logging
import sys
import platform
from safe_logging import setup_safe_logging
import threading
import time
from pathlib import Path
# Windows-specific async event loop configuration
if platform.system() == "Windows":
# Use ProactorEventLoop on Windows for better I/O handling
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config, setup_logging
from core.data_provider import DataProvider
# Import checkpoint management
from utils.checkpoint_manager import get_checkpoint_manager
from utils.training_integration import get_training_integration
# Setup logging
setup_safe_logging()
logger = logging.getLogger(__name__)
async def start_training_pipeline(orchestrator, trading_executor):
"""Start the training pipeline in the background with comprehensive error handling"""
logger.info("=" * 70)
logger.info("STARTING TRAINING PIPELINE WITH CLEAN DASHBOARD")
logger.info("=" * 70)
# Set up async exception handler
def handle_async_exception(loop, context):
"""Handle uncaught async exceptions"""
exception = context.get('exception')
if exception:
logger.error(f"Uncaught async exception: {exception}")
logger.error(f"Context: {context}")
else:
logger.error(f"Async error: {context.get('message', 'Unknown error')}")
# Get current event loop and set exception handler
loop = asyncio.get_running_loop()
loop.set_exception_handler(handle_async_exception)
# Initialize checkpoint management
checkpoint_manager = get_checkpoint_manager()
training_integration = get_training_integration()
# Training statistics
training_stats = {
'iteration_count': 0,
'total_decisions': 0,
'successful_trades': 0,
'best_performance': 0.0,
'last_checkpoint_iteration': 0
}
try:
# Start real-time processing with error handling
try:
if hasattr(orchestrator, 'start_realtime_processing'):
await orchestrator.start_realtime_processing()
logger.info("Real-time processing started")
except Exception as e:
logger.error(f"Error starting real-time processing: {e}")
# Start COB integration with error handling
try:
if hasattr(orchestrator, 'start_cob_integration'):
await orchestrator.start_cob_integration()
logger.info("COB integration started - 5-minute data matrix active")
else:
logger.info("COB integration not available")
except Exception as e:
logger.error(f"Error starting COB integration: {e}")
# Main training loop
iteration = 0
last_checkpoint_time = time.time()
while True:
try:
iteration += 1
training_stats['iteration_count'] = iteration
# Get symbols to process
symbols = orchestrator.symbols if hasattr(orchestrator, 'symbols') else ['ETH/USDT']
# Process each symbol
for symbol in symbols:
try:
# Make trading decision (this triggers model training)
decision = await orchestrator.make_trading_decision(symbol)
if decision:
training_stats['total_decisions'] += 1
logger.debug(f"[{symbol}] Decision: {decision.action} @ {decision.confidence:.1%}")
except Exception as e:
logger.warning(f"Error processing {symbol}: {e}")
# Status logging every 100 iterations
if iteration % 100 == 0:
current_time = time.time()
elapsed = current_time - last_checkpoint_time
logger.info(f"[TRAINING] Iteration {iteration}, Decisions: {training_stats['total_decisions']}, Time: {elapsed:.1f}s")
# Models will save their own checkpoints when performance improves
training_stats['last_checkpoint_iteration'] = iteration
last_checkpoint_time = current_time
# Brief pause to prevent overwhelming the system
await asyncio.sleep(0.1) # 100ms between iterations
except Exception as e:
logger.error(f"Training loop error: {e}")
await asyncio.sleep(5) # Wait longer on error
except Exception as e:
logger.error(f"Training pipeline error: {e}")
import traceback
logger.error(traceback.format_exc())
def start_clean_dashboard_with_training():
"""Start clean dashboard with full training pipeline"""
try:
logger.info("=" * 80)
logger.info("CLEAN TRADING DASHBOARD + FULL TRAINING PIPELINE")
logger.info("=" * 80)
logger.info("Features: Real-time Training, COB Integration, Clean UI")
logger.info("Universal Data Stream: ENABLED")
logger.info("Neural Decision Fusion: ENABLED")
logger.info("COB Integration: ENABLED")
logger.info("GPU Training: ENABLED")
logger.info("TensorBoard Integration: ENABLED")
logger.info("Multi-symbol: ETH/USDT, BTC/USDT")
# Get port from environment or use default
dashboard_port = int(os.environ.get('DASHBOARD_PORT', '8051'))
tensorboard_port = int(os.environ.get('TENSORBOARD_PORT', '6006'))
logger.info(f"Dashboard: http://127.0.0.1:{dashboard_port}")
logger.info(f"TensorBoard: http://127.0.0.1:{tensorboard_port}")
logger.info("=" * 80)
# Check environment variables
enable_universal_stream = os.environ.get('ENABLE_UNIVERSAL_DATA_STREAM', '1') == '1'
enable_nn_fusion = os.environ.get('ENABLE_NN_DECISION_FUSION', '1') == '1'
enable_cob = os.environ.get('ENABLE_COB_INTEGRATION', '1') == '1'
logger.info(f"Universal Data Stream: {'ENABLED' if enable_universal_stream else 'DISABLED'}")
logger.info(f"Neural Decision Fusion: {'ENABLED' if enable_nn_fusion else 'DISABLED'}")
logger.info(f"COB Integration: {'ENABLED' if enable_cob else 'DISABLED'}")
# Get configuration
config = get_config()
# Initialize core components with standardized versions
from core.standardized_data_provider import StandardizedDataProvider
from core.orchestrator import TradingOrchestrator
from core.config import setup_logging, get_config
from core.trading_executor import TradingExecutor
# Create standardized data provider
data_provider = StandardizedDataProvider()
logger.info("StandardizedDataProvider created with BaseDataInput support")
# Create enhanced orchestrator with standardized data provider
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
logger.info("Enhanced Trading Orchestrator created with COB integration")
# Create trading executor
trading_executor = TradingExecutor(config_path="config.yaml")
logger.info(f"Creating trading executor with {trading_executor.primary_name} configuration...")
# Connect trading executor to orchestrator
orchestrator.trading_executor = trading_executor
logger.info("Trading Executor connected to Orchestrator")
# Initialize system resource monitoring
from utils.system_monitor import start_system_monitoring
system_monitor = start_system_monitoring()
# Set up cleanup callback for memory management
def cleanup_callback():
"""Custom cleanup for memory management"""
try:
# Clear orchestrator caches
if hasattr(orchestrator, 'recent_decisions'):
for symbol in orchestrator.recent_decisions:
if len(orchestrator.recent_decisions[symbol]) > 50:
orchestrator.recent_decisions[symbol] = orchestrator.recent_decisions[symbol][-25:]
# Clear data provider caches
if hasattr(data_provider, 'clear_old_data'):
data_provider.clear_old_data()
logger.info("Custom memory cleanup completed")
except Exception as e:
logger.error(f"Error in custom cleanup: {e}")
system_monitor.set_callbacks(cleanup=cleanup_callback)
logger.info("System resource monitoring started with memory cleanup")
# Import clean dashboard
from web.clean_dashboard import create_clean_dashboard
# Create clean dashboard
logger.info("Creating clean dashboard...")
dashboard = create_clean_dashboard(data_provider, orchestrator, trading_executor)
logger.info("Clean Trading Dashboard created")
# Add memory cleanup method to dashboard
def cleanup_dashboard_memory():
"""Clean up dashboard memory caches"""
try:
if hasattr(dashboard, 'recent_decisions'):
dashboard.recent_decisions = dashboard.recent_decisions[-50:] # Keep last 50
if hasattr(dashboard, 'closed_trades'):
dashboard.closed_trades = dashboard.closed_trades[-100:] # Keep last 100
if hasattr(dashboard, 'tick_cache'):
dashboard.tick_cache = dashboard.tick_cache[-1000:] # Keep last 1000
logger.debug("Dashboard memory cleanup completed")
except Exception as e:
logger.error(f"Error in dashboard memory cleanup: {e}")
# Set cleanup method on dashboard
dashboard.cleanup_memory = cleanup_dashboard_memory
# Start training pipeline in background thread with enhanced error handling
def training_worker():
"""Run training pipeline in background with comprehensive error handling"""
try:
asyncio.run(start_training_pipeline(orchestrator, trading_executor))
except KeyboardInterrupt:
logger.info("Training worker stopped by user")
except Exception as e:
logger.error(f"Training worker error: {e}")
import traceback
logger.error(f"Training worker traceback: {traceback.format_exc()}")
# Don't exit - let main thread handle restart
training_thread = threading.Thread(target=training_worker, daemon=True)
training_thread.start()
logger.info("Training pipeline started in background with error handling")
# Wait a moment for training to initialize
time.sleep(3)
# Start TensorBoard in background
from web.tensorboard_integration import get_tensorboard_integration
tensorboard_port = int(os.environ.get('TENSORBOARD_PORT', '6006'))
tensorboard_integration = get_tensorboard_integration(log_dir="runs", port=tensorboard_port)
# Start TensorBoard server
tensorboard_started = tensorboard_integration.start_tensorboard(open_browser=False)
if tensorboard_started:
logger.info(f"TensorBoard started at {tensorboard_integration.get_tensorboard_url()}")
else:
logger.warning("Failed to start TensorBoard - training metrics will not be visualized")
# Start dashboard server with error handling (this blocks)
logger.info("Starting Clean Dashboard Server with error handling...")
try:
dashboard.run_server(host='127.0.0.1', port=dashboard_port, debug=False)
except Exception as e:
logger.error(f"Dashboard server error: {e}")
import traceback
logger.error(f"Dashboard server traceback: {traceback.format_exc()}")
raise # Re-raise to trigger main error handling
except KeyboardInterrupt:
logger.info("System stopped by user")
# Stop TensorBoard
try:
tensorboard_integration = get_tensorboard_integration()
tensorboard_integration.stop_tensorboard()
except:
pass
except Exception as e:
logger.error(f"Error running clean dashboard with training: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
from core.orchestrator import TradingOrchestrator
from core.standardized_data_provider import StandardizedDataProvider
from web.clean_dashboard import CleanTradingDashboard
def main():
"""Main function with comprehensive error handling"""
try:
start_clean_dashboard_with_training()
except KeyboardInterrupt:
logger.info("Dashboard stopped by user (Ctrl+C)")
sys.exit(0)
except Exception as e:
logger.error(f"Critical error in main: {e}")
import traceback
logger.error(traceback.format_exc())
sys.exit(1)
setup_logging()
cfg = get_config()
if __name__ == "__main__":
# Ensure logging is flushed on exit
import atexit
def flush_logs():
logging.shutdown()
atexit.register(flush_logs)
data_provider = StandardizedDataProvider()
trading_executor = TradingExecutor()
orchestrator = TradingOrchestrator(data_provider=data_provider)
dashboard = CleanTradingDashboard(
data_provider=data_provider,
orchestrator=orchestrator,
trading_executor=trading_executor
)
logging.getLogger(__name__).info("Starting Clean Trading Dashboard at http://127.0.0.1:8050")
dashboard.run_server(host='127.0.0.1', port=8050, debug=False)
if __name__ == '__main__':
main()

View File

@ -1,501 +0,0 @@
#!/usr/bin/env python3
"""
Continuous Full Training System (RL + CNN)
This system runs continuous training for both RL and CNN models using the enhanced
DataProvider for consistent data streaming to both models and the dashboard.
Features:
- Single DataProvider instance for all data needs
- Continuous RL training with real-time market data
- CNN training with perfect move detection
- Real-time performance monitoring
- Automatic model checkpointing
- Integration with live trading dashboard
"""
import asyncio
import logging
import time
import signal
import sys
from datetime import datetime, timedelta
from threading import Thread, Event
from typing import Dict, Any
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('logs/continuous_training.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Import our components
from core.config import get_config
from core.data_provider import DataProvider, MarketTick
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard
# Import checkpoint management
from utils.checkpoint_manager import get_checkpoint_manager
from utils.training_integration import get_training_integration
class ContinuousTrainingSystem:
"""Comprehensive continuous training system for RL + CNN models"""
def __init__(self):
"""Initialize the continuous training system"""
self.config = get_config()
# Single DataProvider instance for all data needs
self.data_provider = DataProvider(
symbols=['ETH/USDT', 'BTC/USDT'],
timeframes=['1s', '1m', '1h', '1d']
)
# Enhanced orchestrator for AI trading
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
# Dashboard for monitoring
self.dashboard = None
# Training control
self.running = False
self.shutdown_event = Event()
# Checkpoint management
self.checkpoint_manager = get_checkpoint_manager()
self.training_integration = get_training_integration()
# Performance tracking
self.training_stats = {
'start_time': None,
'rl_training_cycles': 0,
'cnn_training_cycles': 0,
'perfect_moves_detected': 0,
'total_ticks_processed': 0,
'models_saved': 0,
'last_checkpoint': None,
'best_rl_reward': float('-inf'),
'best_cnn_accuracy': 0.0
}
# Training intervals
self.rl_training_interval = 300 # 5 minutes
self.cnn_training_interval = 600 # 10 minutes
self.checkpoint_interval = 1800 # 30 minutes
logger.info("Continuous Training System initialized with checkpoint management")
logger.info(f"RL training interval: {self.rl_training_interval}s")
logger.info(f"CNN training interval: {self.cnn_training_interval}s")
logger.info(f"Checkpoint interval: {self.checkpoint_interval}s")
async def start(self, run_dashboard: bool = True):
"""Start the continuous training system"""
logger.info("Starting Continuous Training System...")
self.running = True
self.training_stats['start_time'] = datetime.now()
try:
# Start DataProvider streaming
logger.info("Starting DataProvider real-time streaming...")
await self.data_provider.start_real_time_streaming()
# Subscribe to tick data for training
subscriber_id = self.data_provider.subscribe_to_ticks(
callback=self._handle_training_tick,
symbols=['ETH/USDT', 'BTC/USDT'],
subscriber_name="ContinuousTraining"
)
logger.info(f"Subscribed to training tick stream: {subscriber_id}")
# Start training threads
training_tasks = [
asyncio.create_task(self._rl_training_loop()),
asyncio.create_task(self._cnn_training_loop()),
asyncio.create_task(self._checkpoint_loop()),
asyncio.create_task(self._monitoring_loop())
]
# Start dashboard if requested
if run_dashboard:
dashboard_task = asyncio.create_task(self._run_dashboard())
training_tasks.append(dashboard_task)
logger.info("All training components started successfully")
# Wait for shutdown signal
await self._wait_for_shutdown()
except Exception as e:
logger.error(f"Error in continuous training system: {e}")
raise
finally:
await self.stop()
def _handle_training_tick(self, tick: MarketTick):
"""Handle incoming tick data for training"""
try:
self.training_stats['total_ticks_processed'] += 1
# Process tick through orchestrator for RL training
if self.orchestrator and hasattr(self.orchestrator, 'process_tick'):
self.orchestrator.process_tick(tick)
# Log every 1000 ticks
if self.training_stats['total_ticks_processed'] % 1000 == 0:
logger.info(f"Processed {self.training_stats['total_ticks_processed']} training ticks")
except Exception as e:
logger.warning(f"Error processing training tick: {e}")
async def _rl_training_loop(self):
"""Continuous RL training loop"""
logger.info("Starting RL training loop...")
while self.running:
try:
start_time = time.time()
# Perform RL training cycle
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
logger.info("Starting RL training cycle...")
# Get recent market data for training
training_data = self._prepare_rl_training_data()
if training_data is not None:
# Train RL agent
training_results = await self._train_rl_agent(training_data)
if training_results:
self.training_stats['rl_training_cycles'] += 1
logger.info(f"RL training cycle {self.training_stats['rl_training_cycles']} completed")
logger.info(f"Training results: {training_results}")
else:
logger.warning("No training data available for RL agent")
# Wait for next training cycle
elapsed = time.time() - start_time
sleep_time = max(0, self.rl_training_interval - elapsed)
await asyncio.sleep(sleep_time)
except Exception as e:
logger.error(f"Error in RL training loop: {e}")
await asyncio.sleep(60) # Wait before retrying
async def _cnn_training_loop(self):
"""Continuous CNN training loop"""
logger.info("Starting CNN training loop...")
while self.running:
try:
start_time = time.time()
# Perform CNN training cycle
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
logger.info("Starting CNN training cycle...")
# Detect perfect moves for CNN training
perfect_moves = self._detect_perfect_moves()
if perfect_moves:
self.training_stats['perfect_moves_detected'] += len(perfect_moves)
# Train CNN with perfect moves
training_results = await self._train_cnn_model(perfect_moves)
if training_results:
self.training_stats['cnn_training_cycles'] += 1
logger.info(f"CNN training cycle {self.training_stats['cnn_training_cycles']} completed")
logger.info(f"Perfect moves processed: {len(perfect_moves)}")
else:
logger.info("No perfect moves detected for CNN training")
# Wait for next training cycle
elapsed = time.time() - start_time
sleep_time = max(0, self.cnn_training_interval - elapsed)
await asyncio.sleep(sleep_time)
except Exception as e:
logger.error(f"Error in CNN training loop: {e}")
await asyncio.sleep(60) # Wait before retrying
async def _checkpoint_loop(self):
"""Automatic model checkpointing loop"""
logger.info("Starting checkpoint loop...")
while self.running:
try:
await asyncio.sleep(self.checkpoint_interval)
logger.info("Creating model checkpoints...")
# Save RL model
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
rl_checkpoint = await self._save_rl_checkpoint()
if rl_checkpoint:
logger.info(f"RL checkpoint saved: {rl_checkpoint}")
# Save CNN model
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
cnn_checkpoint = await self._save_cnn_checkpoint()
if cnn_checkpoint:
logger.info(f"CNN checkpoint saved: {cnn_checkpoint}")
self.training_stats['models_saved'] += 1
self.training_stats['last_checkpoint'] = datetime.now()
except Exception as e:
logger.error(f"Error in checkpoint loop: {e}")
async def _monitoring_loop(self):
"""System monitoring and performance tracking loop"""
logger.info("Starting monitoring loop...")
while self.running:
try:
await asyncio.sleep(300) # Monitor every 5 minutes
# Log system statistics
uptime = datetime.now() - self.training_stats['start_time']
logger.info("=== CONTINUOUS TRAINING SYSTEM STATUS ===")
logger.info(f"Uptime: {uptime}")
logger.info(f"RL training cycles: {self.training_stats['rl_training_cycles']}")
logger.info(f"CNN training cycles: {self.training_stats['cnn_training_cycles']}")
logger.info(f"Perfect moves detected: {self.training_stats['perfect_moves_detected']}")
logger.info(f"Total ticks processed: {self.training_stats['total_ticks_processed']}")
logger.info(f"Models saved: {self.training_stats['models_saved']}")
# DataProvider statistics
if hasattr(self.data_provider, 'get_subscriber_stats'):
subscriber_stats = self.data_provider.get_subscriber_stats()
logger.info(f"Active subscribers: {subscriber_stats.get('active_subscribers', 0)}")
logger.info(f"Total ticks distributed: {subscriber_stats.get('distribution_stats', {}).get('total_ticks_distributed', 0)}")
# Orchestrator performance
if hasattr(self.orchestrator, 'get_performance_metrics'):
perf_metrics = self.orchestrator.get_performance_metrics()
logger.info(f"Orchestrator performance: {perf_metrics}")
logger.info("==========================================")
except Exception as e:
logger.error(f"Error in monitoring loop: {e}")
async def _run_dashboard(self):
"""Run the dashboard in a separate thread"""
try:
logger.info("Starting live trading dashboard...")
def run_dashboard():
self.dashboard = RealTimeScalpingDashboard(
data_provider=self.data_provider,
orchestrator=self.orchestrator
)
self.dashboard.run(host='127.0.0.1', port=8051, debug=False)
dashboard_thread = Thread(target=run_dashboard, daemon=True)
dashboard_thread.start()
logger.info("Dashboard started at http://127.0.0.1:8051")
# Keep dashboard thread alive
while self.running:
await asyncio.sleep(10)
except Exception as e:
logger.error(f"Error running dashboard: {e}")
def _prepare_rl_training_data(self) -> Dict[str, Any]:
"""Prepare training data for RL agent"""
try:
# Get recent market data from DataProvider
eth_data = self.data_provider.get_latest_candles('ETH/USDT', '1m', limit=1000)
btc_data = self.data_provider.get_latest_candles('BTC/USDT', '1m', limit=1000)
if eth_data is not None and not eth_data.empty:
return {
'eth_data': eth_data,
'btc_data': btc_data,
'timestamp': datetime.now()
}
return None
except Exception as e:
logger.error(f"Error preparing RL training data: {e}")
return None
def _detect_perfect_moves(self) -> list:
"""Detect perfect trading moves for CNN training"""
try:
# Get recent tick data
recent_ticks = self.data_provider.get_recent_ticks('ETHUSDT', count=500)
if not recent_ticks:
return []
# Simple perfect move detection (can be enhanced)
perfect_moves = []
for i in range(1, len(recent_ticks) - 1):
prev_tick = recent_ticks[i-1]
curr_tick = recent_ticks[i]
next_tick = recent_ticks[i+1]
# Detect significant price movements
price_change = (next_tick.price - curr_tick.price) / curr_tick.price
if abs(price_change) > 0.001: # 0.1% movement
perfect_moves.append({
'timestamp': curr_tick.timestamp,
'price': curr_tick.price,
'action': 'BUY' if price_change > 0 else 'SELL',
'confidence': min(abs(price_change) * 100, 1.0)
})
return perfect_moves[-10:] # Return last 10 perfect moves
except Exception as e:
logger.error(f"Error detecting perfect moves: {e}")
return []
async def _train_rl_agent(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
"""Train the RL agent with market data"""
try:
# Placeholder for RL training logic
# This would integrate with the actual RL agent
logger.info("Training RL agent with market data...")
# Simulate training time
await asyncio.sleep(1)
return {
'loss': 0.05,
'reward': 0.75,
'episodes': 100
}
except Exception as e:
logger.error(f"Error training RL agent: {e}")
return None
async def _train_cnn_model(self, perfect_moves: list) -> Dict[str, Any]:
"""Train the CNN model with perfect moves"""
try:
# Placeholder for CNN training logic
# This would integrate with the actual CNN model
logger.info(f"Training CNN model with {len(perfect_moves)} perfect moves...")
# Simulate training time
await asyncio.sleep(2)
return {
'accuracy': 0.92,
'loss': 0.08,
'perfect_moves_processed': len(perfect_moves)
}
except Exception as e:
logger.error(f"Error training CNN model: {e}")
return None
async def _save_rl_checkpoint(self) -> str:
"""Save RL model checkpoint"""
try:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_path = f"models/rl/checkpoint_rl_{timestamp}.pt"
# Placeholder for actual model saving
logger.info(f"Saving RL checkpoint to {checkpoint_path}")
return checkpoint_path
except Exception as e:
logger.error(f"Error saving RL checkpoint: {e}")
return None
async def _save_cnn_checkpoint(self) -> str:
"""Save CNN model checkpoint"""
try:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_path = f"models/cnn/checkpoint_cnn_{timestamp}.pt"
# Placeholder for actual model saving
logger.info(f"Saving CNN checkpoint to {checkpoint_path}")
return checkpoint_path
except Exception as e:
logger.error(f"Error saving CNN checkpoint: {e}")
return None
async def _wait_for_shutdown(self):
"""Wait for shutdown signal"""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, shutting down...")
self.shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Wait for shutdown event
while not self.shutdown_event.is_set():
await asyncio.sleep(1)
async def stop(self):
"""Stop the continuous training system"""
logger.info("Stopping Continuous Training System...")
self.running = False
try:
# Stop DataProvider streaming
if self.data_provider:
await self.data_provider.stop_real_time_streaming()
# Final checkpoint
logger.info("Creating final checkpoints...")
await self._save_rl_checkpoint()
await self._save_cnn_checkpoint()
# Log final statistics
uptime = datetime.now() - self.training_stats['start_time']
logger.info("=== FINAL TRAINING STATISTICS ===")
logger.info(f"Total uptime: {uptime}")
logger.info(f"RL training cycles: {self.training_stats['rl_training_cycles']}")
logger.info(f"CNN training cycles: {self.training_stats['cnn_training_cycles']}")
logger.info(f"Perfect moves detected: {self.training_stats['perfect_moves_detected']}")
logger.info(f"Total ticks processed: {self.training_stats['total_ticks_processed']}")
logger.info(f"Models saved: {self.training_stats['models_saved']}")
logger.info("=================================")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
logger.info("Continuous Training System stopped")
async def main():
"""Main entry point"""
logger.info("Starting Continuous Full Training System (RL + CNN)")
# Create and start the training system
training_system = ContinuousTrainingSystem()
try:
await training_system.start(run_dashboard=True)
except KeyboardInterrupt:
logger.info("Interrupted by user")
except Exception as e:
logger.error(f"Fatal error: {e}")
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,269 +0,0 @@
#!/usr/bin/env python3
"""
Crash-Safe Dashboard Runner
This runner is designed to prevent crashes by:
1. Isolating imports with try/except blocks
2. Minimal initialization
3. Graceful error handling
4. No complex training loops
5. Safe component loading
"""
import os
import sys
import logging
import traceback
from pathlib import Path
# Fix environment issues before any imports
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
os.environ['OMP_NUM_THREADS'] = '1' # Minimal threads
os.environ['MPLBACKEND'] = 'Agg'
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
# Setup basic logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Reduce noise from other loggers
logging.getLogger('werkzeug').setLevel(logging.ERROR)
logging.getLogger('dash').setLevel(logging.ERROR)
logging.getLogger('matplotlib').setLevel(logging.ERROR)
class CrashSafeDashboard:
"""Crash-safe dashboard with minimal dependencies"""
def __init__(self):
"""Initialize with safe error handling"""
self.components = {}
self.dashboard_app = None
self.initialization_errors = []
logger.info("Initializing crash-safe dashboard...")
def safe_import(self, module_name, class_name=None):
"""Safely import modules with error handling"""
try:
if class_name:
module = __import__(module_name, fromlist=[class_name])
return getattr(module, class_name)
else:
return __import__(module_name)
except Exception as e:
error_msg = f"Failed to import {module_name}.{class_name if class_name else ''}: {e}"
logger.error(error_msg)
self.initialization_errors.append(error_msg)
return None
def initialize_core_components(self):
"""Initialize core components safely"""
logger.info("Initializing core components...")
# Try to import and initialize config
try:
from core.config import get_config, setup_logging
setup_logging()
self.components['config'] = get_config()
logger.info("✓ Config loaded")
except Exception as e:
logger.error(f"✗ Config failed: {e}")
self.initialization_errors.append(f"Config: {e}")
# Try to initialize data provider
try:
from core.data_provider import DataProvider
self.components['data_provider'] = DataProvider()
logger.info("✓ Data provider initialized")
except Exception as e:
logger.error(f"✗ Data provider failed: {e}")
self.initialization_errors.append(f"Data provider: {e}")
# Try to initialize trading executor
try:
from core.trading_executor import TradingExecutor
self.components['trading_executor'] = TradingExecutor()
logger.info("✓ Trading executor initialized")
except Exception as e:
logger.error(f"✗ Trading executor failed: {e}")
self.initialization_errors.append(f"Trading executor: {e}")
# Try to initialize orchestrator (WITHOUT training to avoid crashes)
try:
from core.orchestrator import TradingOrchestrator
self.components['orchestrator'] = TradingOrchestrator(
data_provider=self.components.get('data_provider'),
enhanced_rl_training=False # DISABLED to prevent crashes
)
logger.info("✓ Orchestrator initialized (training disabled)")
except Exception as e:
logger.error(f"✗ Orchestrator failed: {e}")
self.initialization_errors.append(f"Orchestrator: {e}")
def create_minimal_dashboard(self):
"""Create minimal dashboard without complex features"""
try:
import dash
from dash import html, dcc
# Create minimal Dash app
self.dashboard_app = dash.Dash(__name__)
# Create simple layout
self.dashboard_app.layout = html.Div([
html.H1("Trading Dashboard - Safe Mode", style={'textAlign': 'center'}),
html.Hr(),
# Status section
html.Div([
html.H3("System Status"),
html.Div(id="system-status", children=self._get_system_status()),
], style={'margin': '20px'}),
# Error section
html.Div([
html.H3("Initialization Status"),
html.Div(id="init-status", children=self._get_init_status()),
], style={'margin': '20px'}),
# Simple refresh interval
dcc.Interval(
id='interval-component',
interval=5000, # Update every 5 seconds
n_intervals=0
)
])
# Add simple callback
@self.dashboard_app.callback(
[dash.dependencies.Output('system-status', 'children'),
dash.dependencies.Output('init-status', 'children')],
[dash.dependencies.Input('interval-component', 'n_intervals')]
)
def update_status(n):
try:
return self._get_system_status(), self._get_init_status()
except Exception as e:
logger.error(f"Callback error: {e}")
return f"Callback error: {e}", "Error in callback"
logger.info("✓ Minimal dashboard created")
return True
except Exception as e:
logger.error(f"✗ Dashboard creation failed: {e}")
logger.error(traceback.format_exc())
return False
def _get_system_status(self):
"""Get system status for display"""
try:
status_items = []
# Check components
for name, component in self.components.items():
if component is not None:
status_items.append(html.P(f"{name.replace('_', ' ').title()}: OK",
style={'color': 'green'}))
else:
status_items.append(html.P(f"{name.replace('_', ' ').title()}: Failed",
style={'color': 'red'}))
# Add timestamp
status_items.append(html.P(f"Last update: {datetime.now().strftime('%H:%M:%S')}",
style={'color': 'gray', 'fontSize': '12px'}))
return status_items
except Exception as e:
return [html.P(f"Status error: {e}", style={'color': 'red'})]
def _get_init_status(self):
"""Get initialization status for display"""
try:
if not self.initialization_errors:
return [html.P("✓ All components initialized successfully", style={'color': 'green'})]
error_items = [html.P("⚠️ Some components failed to initialize:", style={'color': 'orange'})]
for error in self.initialization_errors:
error_items.append(html.P(f"{error}", style={'color': 'red', 'fontSize': '12px'}))
return error_items
except Exception as e:
return [html.P(f"Init status error: {e}", style={'color': 'red'})]
def run(self, port=8051):
"""Run the crash-safe dashboard"""
try:
logger.info("=" * 60)
logger.info("CRASH-SAFE DASHBOARD")
logger.info("=" * 60)
logger.info("Mode: Safe mode with minimal features")
logger.info("Training: Completely disabled")
logger.info("Focus: System stability and basic monitoring")
logger.info("=" * 60)
# Initialize components
self.initialize_core_components()
# Create dashboard
if not self.create_minimal_dashboard():
logger.error("Failed to create dashboard")
return False
# Report initialization status
if self.initialization_errors:
logger.warning(f"Dashboard starting with {len(self.initialization_errors)} component failures")
for error in self.initialization_errors:
logger.warning(f" - {error}")
else:
logger.info("All components initialized successfully")
# Start dashboard
logger.info(f"Starting dashboard on http://127.0.0.1:{port}")
logger.info("Press Ctrl+C to stop")
self.dashboard_app.run_server(
host='127.0.0.1',
port=port,
debug=False,
use_reloader=False,
threaded=True
)
return True
except KeyboardInterrupt:
logger.info("Dashboard stopped by user")
return True
except Exception as e:
logger.error(f"Dashboard failed: {e}")
logger.error(traceback.format_exc())
return False
def main():
"""Main function with comprehensive error handling"""
try:
dashboard = CrashSafeDashboard()
success = dashboard.run()
if success:
logger.info("Dashboard completed successfully")
else:
logger.error("Dashboard failed")
except Exception as e:
logger.error(f"Fatal error: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -1,525 +0,0 @@
#!/usr/bin/env python3
"""
Enhanced RL Training Launcher with Real Data Integration
This script launches the comprehensive RL training system that uses:
- Real-time tick data (300s window for momentum detection)
- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d)
- BTC reference data for correlation
- CNN hidden features and predictions
- Williams Market Structure pivot points
- Market microstructure analysis
The RL model will receive ~13,400 features instead of the previous ~100 basic features.
Training metrics are automatically logged to TensorBoard for visualization.
"""
import asyncio
import logging
import time
import signal
import sys
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('enhanced_rl_training.log'),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Import our enhanced components
from core.config import get_config
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from training.enhanced_rl_trainer import EnhancedRLTrainer
from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
from training.williams_market_structure import WilliamsMarketStructure
from training.cnn_rl_bridge import CNNRLBridge
from utils.tensorboard_logger import TensorBoardLogger
class EnhancedRLTrainingSystem:
"""Comprehensive RL training system with real data integration"""
def __init__(self):
"""Initialize the enhanced RL training system"""
self.config = get_config()
self.running = False
self.data_provider = None
self.orchestrator = None
self.rl_trainer = None
# Performance tracking
self.training_stats = {
'training_sessions': 0,
'total_experiences': 0,
'avg_state_size': 0,
'data_quality_score': 0.0,
'last_training_time': None
}
# Initialize TensorBoard logger
experiment_name = f"enhanced_rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.tb_logger = TensorBoardLogger(
log_dir="runs",
experiment_name=experiment_name,
enabled=True
)
logger.info("Enhanced RL Training System initialized")
logger.info(f"TensorBoard logging enabled for experiment: {experiment_name}")
logger.info("Features:")
logger.info("- Real-time tick data processing (300s window)")
logger.info("- Multi-timeframe OHLCV analysis (1s, 1m, 1h, 1d)")
logger.info("- BTC correlation analysis")
logger.info("- CNN feature integration")
logger.info("- Williams Market Structure pivot points")
logger.info("- ~13,400 feature state vector (vs previous ~100)")
# async def initialize(self):
# """Initialize all components"""
# try:
# logger.info("Initializing enhanced RL training components...")
# # Initialize data provider with real-time streaming
# logger.info("Setting up data provider with real-time streaming...")
# self.data_provider = DataProvider(
# symbols=self.config.symbols,
# timeframes=self.config.timeframes
# )
# # Start real-time data streaming
# await self.data_provider.start_real_time_streaming()
# logger.info("Real-time data streaming started")
# # Wait for initial data collection
# logger.info("Collecting initial market data...")
# await asyncio.sleep(30) # Allow 30 seconds for data collection
# # Initialize enhanced orchestrator
# logger.info("Initializing enhanced orchestrator...")
# self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
# # Initialize enhanced RL trainer with comprehensive state building
# logger.info("Initializing enhanced RL trainer...")
# self.rl_trainer = EnhancedRLTrainer(
# config=self.config,
# orchestrator=self.orchestrator
# )
# # Verify data availability
# data_status = await self._verify_data_availability()
# if not data_status['has_sufficient_data']:
# logger.warning("Insufficient data detected. Continuing with limited training.")
# logger.warning(f"Data status: {data_status}")
# else:
# logger.info("Sufficient data available for comprehensive RL training")
# logger.info(f"Tick data: {data_status['tick_count']} ticks")
# logger.info(f"OHLCV data: {data_status['ohlcv_bars']} bars")
# self.running = True
# logger.info("Enhanced RL training system initialized successfully")
# except Exception as e:
# logger.error(f"Error during initialization: {e}")
# raise
# async def _verify_data_availability(self) -> Dict[str, any]:
# """Verify that we have sufficient data for training"""
# try:
# data_status = {
# 'has_sufficient_data': False,
# 'tick_count': 0,
# 'ohlcv_bars': 0,
# 'symbols_with_data': [],
# 'missing_data': []
# }
# for symbol in self.config.symbols:
# # Check tick data
# recent_ticks = self.data_provider.get_recent_ticks(symbol, count=100)
# tick_count = len(recent_ticks)
# # Check OHLCV data
# ohlcv_bars = 0
# for timeframe in ['1s', '1m', '1h', '1d']:
# try:
# df = self.data_provider.get_historical_data(
# symbol=symbol,
# timeframe=timeframe,
# limit=50,
# refresh=True
# )
# if df is not None and not df.empty:
# ohlcv_bars += len(df)
# except Exception as e:
# logger.warning(f"Error checking {timeframe} data for {symbol}: {e}")
# data_status['tick_count'] += tick_count
# data_status['ohlcv_bars'] += ohlcv_bars
# if tick_count >= 50 and ohlcv_bars >= 100:
# data_status['symbols_with_data'].append(symbol)
# else:
# data_status['missing_data'].append(f"{symbol}: {tick_count} ticks, {ohlcv_bars} bars")
# # Consider data sufficient if we have at least one symbol with good data
# data_status['has_sufficient_data'] = len(data_status['symbols_with_data']) > 0
# return data_status
# except Exception as e:
# logger.error(f"Error verifying data availability: {e}")
# return {'has_sufficient_data': False, 'error': str(e)}
# async def run_training_loop(self):
# """Run the main training loop with real data"""
# logger.info("Starting enhanced RL training loop...")
# training_cycle = 0
# last_state_size_log = time.time()
# try:
# while self.running:
# training_cycle += 1
# cycle_start_time = time.time()
# logger.info(f"Training cycle {training_cycle} started")
# # Get comprehensive market states with real data
# market_states = await self._get_comprehensive_market_states()
# if not market_states:
# logger.warning("No market states available. Waiting for data...")
# await asyncio.sleep(60)
# continue
# # Train RL agents with comprehensive states
# training_results = await self._train_rl_agents(market_states)
# # Update performance tracking
# self._update_training_stats(training_results, market_states)
# # Log training progress
# cycle_duration = time.time() - cycle_start_time
# logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
# # Log state size periodically
# if time.time() - last_state_size_log > 300: # Every 5 minutes
# self._log_state_size_info(market_states)
# last_state_size_log = time.time()
# # Save models periodically
# if training_cycle % 10 == 0:
# await self._save_training_progress()
# # Wait before next training cycle
# await asyncio.sleep(300) # Train every 5 minutes
# except Exception as e:
# logger.error(f"Error in training loop: {e}")
# raise
# async def _get_comprehensive_market_states(self) -> Dict[str, any]:
# """Get comprehensive market states with all required data"""
# try:
# # Get market states from orchestrator
# universal_stream = self.orchestrator.universal_adapter.get_universal_stream()
# market_states = await self.orchestrator._get_all_market_states_universal(universal_stream)
# # Verify data quality
# quality_score = self._calculate_data_quality(market_states)
# self.training_stats['data_quality_score'] = quality_score
# if quality_score < 0.5:
# logger.warning(f"Low data quality detected: {quality_score:.2f}")
# return market_states
# except Exception as e:
# logger.error(f"Error getting comprehensive market states: {e}")
# return {}
# def _calculate_data_quality(self, market_states: Dict[str, any]) -> float:
# """Calculate data quality score based on available data"""
# try:
# if not market_states:
# return 0.0
# total_score = 0.0
# total_symbols = len(market_states)
# for symbol, state in market_states.items():
# symbol_score = 0.0
# # Score based on tick data availability
# if hasattr(state, 'raw_ticks') and state.raw_ticks:
# tick_score = min(len(state.raw_ticks) / 100, 1.0) # Max score for 100+ ticks
# symbol_score += tick_score * 0.3
# # Score based on OHLCV data availability
# if hasattr(state, 'ohlcv_data') and state.ohlcv_data:
# ohlcv_score = len(state.ohlcv_data) / 4.0 # Max score for all 4 timeframes
# symbol_score += min(ohlcv_score, 1.0) * 0.4
# # Score based on CNN features
# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
# symbol_score += 0.15
# # Score based on pivot points
# if hasattr(state, 'pivot_points') and state.pivot_points:
# symbol_score += 0.15
# total_score += symbol_score
# return total_score / total_symbols if total_symbols > 0 else 0.0
# except Exception as e:
# logger.warning(f"Error calculating data quality: {e}")
# return 0.5 # Default to medium quality
async def _train_rl_agents(self, market_states: Dict[str, any]) -> Dict[str, any]:
"""Train RL agents with comprehensive market states"""
try:
training_results = {
'symbols_trained': [],
'total_experiences': 0,
'avg_state_size': 0,
'training_errors': [],
'losses': {},
'rewards': {}
}
for symbol, market_state in market_states.items():
try:
# Convert market state to comprehensive RL state
rl_state = self.rl_trainer._market_state_to_rl_state(market_state)
if rl_state is not None and len(rl_state) > 0:
# Record state size
state_size = len(rl_state)
training_results['avg_state_size'] += state_size
# Log state size to TensorBoard
self.tb_logger.log_scalar(
f'State/{symbol}/Size',
state_size,
self.training_stats['training_sessions']
)
# Simulate trading action for experience generation
# In real implementation, this would be actual trading decisions
action = self._simulate_trading_action(symbol, rl_state)
# Generate reward based on market outcome
reward = self._calculate_training_reward(symbol, market_state, action)
# Store reward for TensorBoard logging
training_results['rewards'][symbol] = reward
# Log action and reward to TensorBoard
self.tb_logger.log_scalars(f'Actions/{symbol}', {
'action': action,
'reward': reward
}, self.training_stats['training_sessions'])
# Add experience to RL agent
agent = self.rl_trainer.agents.get(symbol)
if agent:
# Create next state (would be actual next market state in real scenario)
next_state = rl_state # Simplified for now
agent.remember(
state=rl_state,
action=action,
reward=reward,
next_state=next_state,
done=False
)
# Train agent if enough experiences
if len(agent.replay_buffer) >= agent.batch_size:
loss = agent.replay()
if loss is not None:
logger.debug(f"Agent {symbol} training loss: {loss:.4f}")
# Store loss for TensorBoard logging
training_results['losses'][symbol] = loss
# Log loss to TensorBoard
self.tb_logger.log_scalar(
f'Training/{symbol}/Loss',
loss,
self.training_stats['training_sessions']
)
training_results['symbols_trained'].append(symbol)
training_results['total_experiences'] += 1
except Exception as e:
error_msg = f"Error training {symbol}: {e}"
logger.warning(error_msg)
training_results['training_errors'].append(error_msg)
# Calculate average state size
if len(training_results['symbols_trained']) > 0:
training_results['avg_state_size'] /= len(training_results['symbols_trained'])
# Log overall training metrics to TensorBoard
self.tb_logger.log_scalars('Training/Overall', {
'symbols_trained': len(training_results['symbols_trained']),
'experiences': training_results['total_experiences'],
'avg_state_size': training_results['avg_state_size'],
'errors': len(training_results['training_errors'])
}, self.training_stats['training_sessions'])
return training_results
except Exception as e:
logger.error(f"Error training RL agents: {e}")
return {'error': str(e)}
# def _simulate_trading_action(self, symbol: str, rl_state) -> int:
# """Simulate trading action for training (would be real decision in production)"""
# # Simple simulation based on state features
# if len(rl_state) > 100:
# # Use momentum features to decide action
# momentum_features = rl_state[:100] # First 100 features assumed to be momentum
# avg_momentum = sum(momentum_features) / len(momentum_features)
# if avg_momentum > 0.6:
# return 1 # BUY
# elif avg_momentum < 0.4:
# return 2 # SELL
# else:
# return 0 # HOLD
# else:
# return 0 # HOLD as default
# def _calculate_training_reward(self, symbol: str, market_state, action: int) -> float:
# """Calculate training reward based on market state and action"""
# try:
# # Simple reward calculation based on market conditions
# base_reward = 0.0
# # Reward based on volatility alignment
# if hasattr(market_state, 'volatility'):
# if action == 0 and market_state.volatility > 0.02: # HOLD in high volatility
# base_reward += 0.1
# elif action != 0 and market_state.volatility < 0.01: # Trade in low volatility
# base_reward += 0.1
# # Reward based on trend alignment
# if hasattr(market_state, 'trend_strength'):
# if action == 1 and market_state.trend_strength > 0.6: # BUY in uptrend
# base_reward += 0.2
# elif action == 2 and market_state.trend_strength < 0.4: # SELL in downtrend
# base_reward += 0.2
# return base_reward
# except Exception as e:
# logger.warning(f"Error calculating reward for {symbol}: {e}")
# return 0.0
# def _update_training_stats(self, training_results: Dict[str, any], market_states: Dict[str, any]):
# """Update training statistics"""
# self.training_stats['training_sessions'] += 1
# self.training_stats['total_experiences'] += training_results.get('total_experiences', 0)
# self.training_stats['avg_state_size'] = training_results.get('avg_state_size', 0)
# self.training_stats['last_training_time'] = datetime.now()
# # Log statistics periodically
# if self.training_stats['training_sessions'] % 10 == 0:
# logger.info("Training Statistics:")
# logger.info(f" Sessions: {self.training_stats['training_sessions']}")
# logger.info(f" Total Experiences: {self.training_stats['total_experiences']}")
# logger.info(f" Avg State Size: {self.training_stats['avg_state_size']:.0f}")
# logger.info(f" Data Quality: {self.training_stats['data_quality_score']:.2f}")
# def _log_state_size_info(self, market_states: Dict[str, any]):
# """Log information about state sizes for debugging"""
# for symbol, state in market_states.items():
# info = []
# if hasattr(state, 'raw_ticks'):
# info.append(f"ticks: {len(state.raw_ticks)}")
# if hasattr(state, 'ohlcv_data'):
# total_bars = sum(len(bars) for bars in state.ohlcv_data.values())
# info.append(f"OHLCV bars: {total_bars}")
# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
# info.append("CNN features: available")
# if hasattr(state, 'pivot_points') and state.pivot_points:
# info.append("pivot points: available")
# logger.info(f"{symbol} state data: {', '.join(info)}")
# async def _save_training_progress(self):
# """Save training progress and models"""
# try:
# if self.rl_trainer:
# self.rl_trainer._save_all_models()
# logger.info("Training progress saved")
# except Exception as e:
# logger.error(f"Error saving training progress: {e}")
# async def shutdown(self):
# """Graceful shutdown"""
# logger.info("Shutting down enhanced RL training system...")
# self.running = False
# # Save final state
# await self._save_training_progress()
# # Stop data provider
# if self.data_provider:
# await self.data_provider.stop_real_time_streaming()
# logger.info("Enhanced RL training system shutdown complete")
# async def main():
# """Main function to run enhanced RL training"""
# system = None
# def signal_handler(signum, frame):
# logger.info("Received shutdown signal")
# if system:
# asyncio.create_task(system.shutdown())
# # Set up signal handlers
# signal.signal(signal.SIGINT, signal_handler)
# signal.signal(signal.SIGTERM, signal_handler)
# try:
# # Create and initialize the training system
# system = EnhancedRLTrainingSystem()
# await system.initialize()
# logger.info("Enhanced RL Training System is now running...")
# logger.info("The RL model now receives ~13,400 features instead of ~100!")
# logger.info("Press Ctrl+C to stop")
# # Run the training loop
# await system.run_training_loop()
# except KeyboardInterrupt:
# logger.info("Training interrupted by user")
# except Exception as e:
# logger.error(f"Error in main training loop: {e}")
# raise
# finally:
# if system:
# await system.shutdown()
# if __name__ == "__main__":
# asyncio.run(main())

View File

@ -1,95 +0,0 @@
#!/usr/bin/env python3
"""
Run Dashboard with Enhanced Training System Enabled
This script starts the trading dashboard with the enhanced real-time
training system automatically enabled and running.
"""
import sys
import os
import asyncio
import logging
from datetime import datetime
# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
from core.trading_executor import TradingExecutor
from web.clean_dashboard import create_clean_dashboard
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
async def main():
"""Start dashboard with enhanced training enabled"""
try:
logger.info("=" * 70)
logger.info("STARTING DASHBOARD WITH ENHANCED TRAINING SYSTEM")
logger.info("=" * 70)
# 1. Initialize components with enhanced training
logger.info("1. Initializing components...")
data_provider = DataProvider()
trading_executor = TradingExecutor()
# 2. Create orchestrator with enhanced training ENABLED
logger.info("2. Creating orchestrator with enhanced training...")
orchestrator = TradingOrchestrator(
data_provider=data_provider,
enhanced_rl_training=True # 🔥 THIS ENABLES ENHANCED TRAINING
)
# 3. Verify enhanced training is available
logger.info("3. Verifying enhanced training system...")
if orchestrator.enhanced_training_system:
logger.info("✅ Enhanced training system available")
logger.info(f" - Training enabled: {orchestrator.training_enabled}")
# 4. Start enhanced training
logger.info("4. Starting enhanced training system...")
start_result = orchestrator.start_enhanced_training()
if start_result:
logger.info("✅ Enhanced training started successfully")
else:
logger.warning("⚠️ Enhanced training start failed")
else:
logger.warning("⚠️ Enhanced training system not available")
# 5. Create dashboard
logger.info("5. Creating dashboard...")
dashboard = create_clean_dashboard(
data_provider=data_provider,
orchestrator=orchestrator,
trading_executor=trading_executor
)
# 6. Connect training system to dashboard
logger.info("6. Connecting training system to dashboard...")
orchestrator.set_training_dashboard(dashboard)
# 7. Start dashboard
logger.info("7. Starting dashboard...")
logger.info("🎉 Dashboard with enhanced training is now running!")
logger.info(" - Enhanced training: ENABLED")
logger.info(" - Real-time learning: ACTIVE")
logger.info(" - Dashboard URL: http://127.0.0.1:8051")
# Keep running
await asyncio.sleep(3600) # Run for 1 hour
except KeyboardInterrupt:
logger.info("Dashboard stopped by user")
except Exception as e:
logger.error(f"Error starting dashboard: {e}")
import traceback
logger.error(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,510 +0,0 @@
#!/usr/bin/env python3
"""
Integrated Real-time RL COB Trading System with Dashboard
This script starts both:
1. RealtimeRLCOBTrader - 1B parameter RL model with real-time training
2. COB Dashboard - Real-time visualization with RL predictions
The RL predictions are integrated into the dashboard for live visualization.
"""
import asyncio
import logging
import signal
import sys
import json
import os
from datetime import datetime
from typing import Dict, Any
from aiohttp import web
# Local imports
from core.realtime_rl_cob_trader import RealtimeRLCOBTrader, PredictionResult
from core.trading_executor import TradingExecutor
from core.config import load_config
from web.cob_realtime_dashboard import COBDashboardServer
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('logs/integrated_rl_cob_system.log'),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
class IntegratedRLCOBSystem:
"""
Integrated Real-time RL COB Trading System with Dashboard
"""
def __init__(self, config_path: str = "config.yaml"):
"""Initialize integrated system with configuration"""
self.config = load_config(config_path)
self.trader = None
self.dashboard = None
self.trading_executor = None
self.running = False
# RL prediction storage for dashboard
self.rl_predictions: Dict[str, list] = {}
self.prediction_history: Dict[str, list] = {}
# Setup signal handlers for graceful shutdown
signal.signal(signal.SIGINT, self._signal_handler)
signal.signal(signal.SIGTERM, self._signal_handler)
logger.info("IntegratedRLCOBSystem initialized")
def _signal_handler(self, signum, frame):
"""Handle shutdown signals"""
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
self.running = False
async def start(self):
"""Start the integrated RL COB trading system with dashboard"""
try:
logger.info("=" * 60)
logger.info("INTEGRATED RL COB SYSTEM STARTING")
logger.info("Real-time RL Trading + Dashboard")
logger.info("=" * 60)
# Initialize trading executor
await self._initialize_trading_executor()
# Initialize RL trader with prediction callback
await self._initialize_rl_trader()
# Initialize dashboard with RL integration
await self._initialize_dashboard()
# Start the integrated system
await self._start_integrated_system()
# Run main loop
await self._run_main_loop()
except Exception as e:
logger.error(f"Critical error in integrated system: {e}")
raise
finally:
await self.stop()
async def _initialize_trading_executor(self):
"""Initialize the trading executor"""
logger.info("Initializing Trading Executor...")
# Get trading configuration
trading_config = self.config.get('trading', {})
mexc_config = self.config.get('mexc', {})
# Determine if we should run in simulation mode
simulation_mode = mexc_config.get('simulation_mode', True)
if simulation_mode:
logger.info("Running in SIMULATION mode - no real trades will be executed")
else:
logger.warning("Running in LIVE TRADING mode - real money at risk!")
# Add safety confirmation for live trading
confirmation = input("Type 'CONFIRM_LIVE_TRADING' to proceed with live trading: ")
if confirmation != 'CONFIRM_LIVE_TRADING':
logger.info("Live trading not confirmed, switching to simulation mode")
simulation_mode = True
# Initialize trading executor with config path
self.trading_executor = TradingExecutor("config.yaml")
logger.info(f"Trading Executor initialized in {'SIMULATION' if simulation_mode else 'LIVE'} mode")
async def _initialize_rl_trader(self):
"""Initialize the RL trader with prediction callbacks"""
logger.info("Initializing Real-time RL COB Trader...")
# Get RL configuration
rl_config = self.config.get('realtime_rl', {})
# Trading symbols
symbols = rl_config.get('symbols', ['BTC/USDT', 'ETH/USDT'])
# Initialize prediction storage
for symbol in symbols:
self.rl_predictions[symbol] = []
self.prediction_history[symbol] = []
# RL parameters
inference_interval_ms = rl_config.get('inference_interval_ms', 200)
min_confidence_threshold = rl_config.get('min_confidence_threshold', 0.7)
required_confident_predictions = rl_config.get('required_confident_predictions', 3)
model_checkpoint_dir = rl_config.get('model_checkpoint_dir', 'models/realtime_rl_cob')
# Initialize RL trader
self.trader = RealtimeRLCOBTrader(
symbols=symbols,
trading_executor=self.trading_executor,
model_checkpoint_dir=model_checkpoint_dir,
inference_interval_ms=inference_interval_ms,
min_confidence_threshold=min_confidence_threshold,
required_confident_predictions=required_confident_predictions
)
# Monkey-patch the trader to capture predictions
original_add_signal = self.trader._add_signal
def enhanced_add_signal(symbol: str, prediction: PredictionResult):
# Call original method
original_add_signal(symbol, prediction)
# Capture prediction for dashboard
self._on_rl_prediction(symbol, prediction)
self.trader._add_signal = enhanced_add_signal
logger.info(f"RL Trader initialized for symbols: {symbols}")
logger.info(f"Inference interval: {inference_interval_ms}ms")
logger.info(f"Confidence threshold: {min_confidence_threshold}")
logger.info(f"Required predictions: {required_confident_predictions}")
def _on_rl_prediction(self, symbol: str, prediction: PredictionResult):
"""Handle RL predictions for dashboard integration"""
try:
# Convert prediction to dashboard format
prediction_data = {
'timestamp': prediction.timestamp.isoformat(),
'direction': prediction.predicted_direction, # 0=DOWN, 1=SIDEWAYS, 2=UP
'confidence': prediction.confidence,
'predicted_change': prediction.predicted_change,
'direction_text': ['DOWN', 'SIDEWAYS', 'UP'][prediction.predicted_direction],
'color': ['red', 'gray', 'green'][prediction.predicted_direction]
}
# Add to current predictions (for live display)
self.rl_predictions[symbol].append(prediction_data)
if len(self.rl_predictions[symbol]) > 100: # Keep last 100
self.rl_predictions[symbol] = self.rl_predictions[symbol][-100:]
# Add to history (for chart overlay)
self.prediction_history[symbol].append(prediction_data)
if len(self.prediction_history[symbol]) > 1000: # Keep last 1000
self.prediction_history[symbol] = self.prediction_history[symbol][-1000:]
logger.debug(f"Captured RL prediction for {symbol}: {prediction.predicted_direction} "
f"(confidence: {prediction.confidence:.3f})")
except Exception as e:
logger.error(f"Error capturing RL prediction: {e}")
async def _initialize_dashboard(self):
"""Initialize the COB dashboard with RL integration"""
logger.info("Initializing COB Dashboard with RL Integration...")
# Get dashboard configuration
dashboard_config = self.config.get('dashboard', {})
host = dashboard_config.get('host', 'localhost')
port = dashboard_config.get('port', 8053)
# Create enhanced dashboard server
self.dashboard = EnhancedCOBDashboardServer(
host=host,
port=port,
rl_system=self # Pass reference to get predictions
)
logger.info(f"COB Dashboard initialized at http://{host}:{port}")
async def _start_integrated_system(self):
"""Start the complete integrated system"""
logger.info("Starting Integrated RL COB System...")
# Start RL trader first (this initializes COB integration)
await self.trader.start()
logger.info("RL Trader started")
# Start dashboard (uses same COB integration)
await self.dashboard.start()
logger.info("COB Dashboard started")
self.running = True
logger.info("INTEGRATED SYSTEM FULLY OPERATIONAL!")
logger.info("1B parameter RL model: ACTIVE")
logger.info("Real-time COB data: STREAMING")
logger.info("Signal accumulation: ACTIVE")
logger.info("Live predictions: VISIBLE IN DASHBOARD")
logger.info("Continuous training: ACTIVE")
logger.info(f"Dashboard URL: http://{self.dashboard.host}:{self.dashboard.port}")
async def _run_main_loop(self):
"""Main monitoring and statistics loop"""
logger.info("Starting integrated system monitoring...")
last_stats_time = datetime.now()
stats_interval = 60 # Print stats every 60 seconds
while self.running:
try:
# Sleep for a bit
await asyncio.sleep(10)
# Print periodic statistics
current_time = datetime.now()
if (current_time - last_stats_time).total_seconds() >= stats_interval:
await self._print_integrated_stats()
last_stats_time = current_time
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in main loop: {e}")
await asyncio.sleep(5)
logger.info("Integrated system monitoring stopped")
async def _print_integrated_stats(self):
"""Print comprehensive integrated system statistics"""
try:
logger.info("=" * 80)
logger.info("INTEGRATED RL COB SYSTEM STATISTICS")
logger.info("=" * 80)
# RL Trader Statistics
if self.trader:
rl_stats = self.trader.get_performance_stats()
logger.info("\n🤖 RL TRADER PERFORMANCE:")
for symbol in self.trader.symbols:
training_stats = rl_stats.get('training_stats', {}).get(symbol, {})
inference_stats = rl_stats.get('inference_stats', {}).get(symbol, {})
signal_stats = rl_stats.get('signal_stats', {}).get(symbol, {})
logger.info(f"\n 📈 {symbol}:")
logger.info(f" Predictions: {training_stats.get('total_predictions', 0)}")
logger.info(f" Success Rate: {signal_stats.get('success_rate', 0):.1%}")
logger.info(f" Avg Inference: {inference_stats.get('average_inference_time_ms', 0):.1f}ms")
logger.info(f" Current Signals: {signal_stats.get('current_signals', 0)}")
# RL prediction stats for dashboard
recent_predictions = len(self.rl_predictions.get(symbol, []))
total_predictions = len(self.prediction_history.get(symbol, []))
logger.info(f" Dashboard Predictions: {recent_predictions} recent, {total_predictions} total")
# Dashboard Statistics
if self.dashboard:
logger.info(f"\nDASHBOARD STATISTICS:")
logger.info(f" Active Connections: {len(self.dashboard.websocket_connections)}")
logger.info(f" Server Status: {'RUNNING' if self.dashboard.site else 'STOPPED'}")
logger.info(f" URL: http://{self.dashboard.host}:{self.dashboard.port}")
# Trading Executor Statistics
if self.trading_executor:
positions = self.trading_executor.get_positions()
trade_history = self.trading_executor.get_trade_history()
logger.info(f"\n💰 TRADING STATISTICS:")
logger.info(f" Active Positions: {len(positions)}")
logger.info(f" Total Trades: {len(trade_history)}")
if trade_history:
total_pnl = sum(trade.pnl for trade in trade_history)
profitable_trades = sum(1 for trade in trade_history if trade.pnl > 0)
win_rate = (profitable_trades / len(trade_history)) * 100
logger.info(f" Total P&L: ${total_pnl:.2f}")
logger.info(f" Win Rate: {win_rate:.1f}%")
logger.info("=" * 80)
except Exception as e:
logger.error(f"Error printing integrated stats: {e}")
async def stop(self):
"""Stop the integrated system gracefully"""
if not self.running:
return
logger.info("Stopping Integrated RL COB System...")
self.running = False
# Stop dashboard
if self.dashboard:
await self.dashboard.stop()
logger.info("Dashboard stopped")
# Stop RL trader
if self.trader:
await self.trader.stop()
logger.info("RL Trader stopped")
logger.info("🏁 Integrated system stopped successfully")
def get_rl_predictions(self, symbol: str) -> Dict[str, Any]:
"""Get RL predictions for dashboard display"""
return {
'recent_predictions': self.rl_predictions.get(symbol, []),
'prediction_history': self.prediction_history.get(symbol, []),
'total_predictions': len(self.prediction_history.get(symbol, [])),
'recent_count': len(self.rl_predictions.get(symbol, []))
}
class EnhancedCOBDashboardServer(COBDashboardServer):
"""Enhanced COB Dashboard with RL prediction integration"""
def __init__(self, host: str = 'localhost', port: int = 8053, rl_system: IntegratedRLCOBSystem = None):
super().__init__(host, port)
self.rl_system = rl_system
# Add RL prediction routes
self._setup_rl_routes()
logger.info("Enhanced COB Dashboard with RL predictions initialized")
async def serve_dashboard(self, request):
"""Serve the enhanced dashboard HTML with RL predictions"""
try:
# Read the enhanced dashboard HTML
dashboard_path = os.path.join(os.path.dirname(__file__), 'enhanced_cob_dashboard.html')
if os.path.exists(dashboard_path):
with open(dashboard_path, 'r', encoding='utf-8') as f:
html_content = f.read()
return web.Response(text=html_content, content_type='text/html')
else:
# Fallback to basic dashboard
logger.warning("Enhanced dashboard HTML not found, using basic dashboard")
return await super().serve_dashboard(request)
except Exception as e:
logger.error(f"Error serving enhanced dashboard: {e}")
return web.Response(text="Dashboard error", status=500)
def _setup_rl_routes(self):
"""Setup additional routes for RL predictions"""
self.app.router.add_get('/api/rl-predictions/{symbol}', self.get_rl_predictions)
self.app.router.add_get('/api/rl-status', self.get_rl_status)
async def get_rl_predictions(self, request):
"""Get RL predictions for a symbol"""
try:
symbol = request.match_info['symbol']
symbol = symbol.replace('%2F', '/')
if symbol not in self.symbols:
return web.json_response({
'error': f'Symbol {symbol} not supported'
}, status=400)
if not self.rl_system:
return web.json_response({
'error': 'RL system not available'
}, status=503)
predictions = self.rl_system.get_rl_predictions(symbol)
return web.json_response({
'symbol': symbol,
'timestamp': datetime.now().isoformat(),
'predictions': predictions
})
except Exception as e:
logger.error(f"Error getting RL predictions: {e}")
return web.json_response({
'error': str(e)
}, status=500)
async def get_rl_status(self, request):
"""Get RL system status"""
try:
if not self.rl_system or not self.rl_system.trader:
return web.json_response({
'status': 'inactive',
'error': 'RL system not available'
})
rl_stats = self.rl_system.trader.get_performance_stats()
status = {
'status': 'active',
'symbols': self.rl_system.trader.symbols,
'model_info': rl_stats.get('model_info', {}),
'inference_interval_ms': self.rl_system.trader.inference_interval_ms,
'confidence_threshold': self.rl_system.trader.min_confidence_threshold,
'required_predictions': self.rl_system.trader.required_confident_predictions,
'device': str(self.rl_system.trader.device),
'running': self.rl_system.trader.running
}
return web.json_response(status)
except Exception as e:
logger.error(f"Error getting RL status: {e}")
return web.json_response({
'status': 'error',
'error': str(e)
}, status=500)
async def _broadcast_cob_update(self, symbol: str, data: Dict):
"""Enhanced COB update broadcast with RL predictions"""
try:
# Get RL predictions if available
rl_data = {}
if self.rl_system:
rl_predictions = self.rl_system.get_rl_predictions(symbol)
rl_data = {
'rl_predictions': rl_predictions.get('recent_predictions', [])[-10:], # Last 10
'prediction_count': rl_predictions.get('total_predictions', 0)
}
# Enhanced data with RL predictions
enhanced_data = {
**data,
'rl_data': rl_data
}
# Broadcast to all WebSocket connections
message = json.dumps({
'type': 'cob_update',
'symbol': symbol,
'timestamp': datetime.now().isoformat(),
'data': enhanced_data
}, default=str)
# Send to all connected clients
disconnected = []
for ws in self.websocket_connections:
try:
await ws.send_str(message)
except Exception as e:
logger.debug(f"WebSocket send failed: {e}")
disconnected.append(ws)
# Remove disconnected clients
for ws in disconnected:
self.websocket_connections.discard(ws)
except Exception as e:
logger.error(f"Error broadcasting enhanced COB update: {e}")
async def main():
"""Main entry point for integrated RL COB system"""
# Create logs directory
os.makedirs('logs', exist_ok=True)
# Initialize and start integrated system
system = IntegratedRLCOBSystem()
try:
await system.start()
except KeyboardInterrupt:
logger.info("Received keyboard interrupt")
except Exception as e:
logger.error(f"System error: {e}")
raise
finally:
await system.stop()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,52 +0,0 @@
#!/usr/bin/env python3
"""
One-Click MEXC Browser Launcher
Simply run this script to start capturing MEXC futures trading requests.
"""
import sys
import os
# Add project root to path
project_root = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, project_root)
def main():
"""Launch MEXC browser automation"""
print("🚀 MEXC Futures Request Interceptor")
print("=" * 50)
print("This will automatically:")
print("✅ Install ChromeDriver")
print("✅ Open MEXC futures page")
print("✅ Capture all API requests")
print("✅ Extract session cookies")
print("✅ Save data to JSON files")
print("\nRequirements will be installed automatically if missing.")
try:
# First try to run the auto browser directly
from core.mexc_webclient.auto_browser import main as run_auto_browser
run_auto_browser()
except ImportError as e:
print(f"\n⚠️ Import error: {e}")
print("Installing requirements first...")
# Try to install requirements and run setup
try:
from setup_mexc_browser import main as setup_main
setup_main()
except ImportError:
print("❌ Could not find setup script")
print("Please run: pip install selenium webdriver-manager")
except Exception as e:
print(f"❌ Error: {e}")
print("\nTroubleshooting:")
print("1. Make sure you have Chrome browser installed")
print("2. Check your internet connection")
print("3. Try running: pip install selenium webdriver-manager")
if __name__ == "__main__":
main()

View File

@ -1,451 +0,0 @@
#!/usr/bin/env python3
"""
Optimized COB System - Eliminates Redundant Implementations
This optimized script runs both the COB dashboard and 1B RL trading system
in a single process with shared data sources to eliminate redundancies:
BEFORE (Redundant):
- Dashboard: Own COBIntegration instance
- RL Trader: Own COBIntegration instance
- Training: Own COBIntegration instance
= 3x WebSocket connections, 3x order book processing
AFTER (Optimized):
- Shared COBIntegration instance
- Single WebSocket connection per exchange
- Shared order book processing and caching
= 1x connections, 1x processing, shared memory
Resource savings: ~60% memory, ~70% network bandwidth
"""
import asyncio
import logging
import signal
import sys
import json
import os
from datetime import datetime
from typing import Dict, Any, List, Optional
from aiohttp import web
import threading
# Local imports
from core.cob_integration import COBIntegration
from core.data_provider import DataProvider
from core.trading_executor import TradingExecutor
from core.config import load_config
from web.cob_realtime_dashboard import COBDashboardServer
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('logs/optimized_cob_system.log'),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
class OptimizedCOBSystem:
"""
Optimized COB System - Single COB instance shared across all components
"""
def __init__(self, config_path: str = "config.yaml"):
"""Initialize optimized system with shared resources"""
self.config = load_config(config_path)
self.running = False
# Shared components (eliminate redundancy)
self.data_provider = DataProvider()
self.shared_cob_integration: Optional[COBIntegration] = None
self.trading_executor: Optional[TradingExecutor] = None
# Dashboard using shared COB
self.dashboard_server: Optional[COBDashboardServer] = None
# Performance tracking
self.performance_stats = {
'start_time': None,
'cob_updates_processed': 0,
'dashboard_connections': 0,
'memory_saved_mb': 0
}
# Setup signal handlers
signal.signal(signal.SIGINT, self._signal_handler)
signal.signal(signal.SIGTERM, self._signal_handler)
logger.info("OptimizedCOBSystem initialized - Eliminating redundant implementations")
def _signal_handler(self, signum, frame):
"""Handle shutdown signals"""
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
self.running = False
async def start(self):
"""Start the optimized COB system"""
try:
logger.info("=" * 70)
logger.info("🚀 OPTIMIZED COB SYSTEM STARTING")
logger.info("=" * 70)
logger.info("Eliminating redundant COB implementations...")
logger.info("Single shared COB integration for all components")
logger.info("=" * 70)
# Initialize shared components
await self._initialize_shared_components()
# Initialize dashboard with shared COB
await self._initialize_optimized_dashboard()
# Start the integrated system
await self._start_optimized_system()
# Run main monitoring loop
await self._run_optimized_loop()
except Exception as e:
logger.error(f"Critical error in optimized system: {e}")
import traceback
logger.error(traceback.format_exc())
raise
finally:
await self.stop()
async def _initialize_shared_components(self):
"""Initialize shared components (eliminates redundancy)"""
logger.info("1. Initializing shared COB integration...")
# Single COB integration instance for entire system
self.shared_cob_integration = COBIntegration(
data_provider=self.data_provider,
symbols=['BTC/USDT', 'ETH/USDT']
)
# Start the shared COB integration
await self.shared_cob_integration.start()
logger.info("2. Initializing trading executor...")
# Trading executor configuration
trading_config = self.config.get('trading', {})
mexc_config = self.config.get('mexc', {})
simulation_mode = mexc_config.get('simulation_mode', True)
self.trading_executor = TradingExecutor()
logger.info("✅ Shared components initialized")
logger.info(f" Single COB integration: {len(self.shared_cob_integration.symbols)} symbols")
logger.info(f" Trading mode: {'SIMULATION' if simulation_mode else 'LIVE'}")
async def _initialize_optimized_dashboard(self):
"""Initialize dashboard that uses shared COB (no redundant instance)"""
logger.info("3. Initializing optimized dashboard...")
# Create dashboard and replace its COB with our shared one
self.dashboard_server = COBDashboardServer(host='localhost', port=8053)
# Replace the dashboard's COB integration with our shared one
self.dashboard_server.cob_integration = self.shared_cob_integration
logger.info("✅ Optimized dashboard initialized with shared COB")
async def _start_optimized_system(self):
"""Start the optimized system with shared resources"""
logger.info("4. Starting optimized system...")
self.running = True
self.performance_stats['start_time'] = datetime.now()
# Start dashboard server with shared COB
await self.dashboard_server.start()
# Estimate memory savings
# Start RL trader
await self.rl_trader.start()
# Estimate memory savings
estimated_savings = self._calculate_memory_savings()
self.performance_stats['memory_saved_mb'] = estimated_savings
logger.info("🚀 Optimized COB System started successfully!")
logger.info(f"💾 Estimated memory savings: {estimated_savings:.0f} MB")
logger.info(f"🌐 Dashboard: http://localhost:8053")
logger.info(f"🤖 RL Training: Active with 1B parameters")
logger.info(f"📊 Shared COB: Single integration for all components")
logger.info("🔄 System Status: OPTIMIZED - No redundant implementations")
def _calculate_memory_savings(self) -> float:
"""Calculate estimated memory savings from eliminating redundancy"""
# Estimates based on typical COB memory usage
cob_integration_memory_mb = 512 # Order books, caches, connections
websocket_connection_memory_mb = 64 # Per exchange connection
# Before: 3 separate COB integrations (dashboard + RL trader + training)
before_memory = 3 * cob_integration_memory_mb + 3 * websocket_connection_memory_mb
# After: 1 shared COB integration
after_memory = 1 * cob_integration_memory_mb + 1 * websocket_connection_memory_mb
savings = before_memory - after_memory
return savings
async def _run_optimized_loop(self):
"""Main optimized monitoring loop"""
logger.info("Starting optimized monitoring loop...")
last_stats_time = datetime.now()
stats_interval = 60 # Print stats every 60 seconds
while self.running:
try:
# Sleep for a bit
await asyncio.sleep(10)
# Update performance stats
self._update_performance_stats()
# Print periodic statistics
current_time = datetime.now()
if (current_time - last_stats_time).total_seconds() >= stats_interval:
await self._print_optimized_stats()
last_stats_time = current_time
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in optimized loop: {e}")
await asyncio.sleep(5)
logger.info("Optimized monitoring loop stopped")
def _update_performance_stats(self):
"""Update performance statistics"""
try:
# Get stats from shared COB integration
if self.shared_cob_integration:
cob_stats = self.shared_cob_integration.get_statistics()
self.performance_stats['cob_updates_processed'] = cob_stats.get('total_signals', {}).get('BTC/USDT', 0)
# Get stats from dashboard
if self.dashboard_server:
dashboard_stats = self.dashboard_server.get_stats()
self.performance_stats['dashboard_connections'] = dashboard_stats.get('active_connections', 0)
# Get stats from RL trader
if self.rl_trader:
rl_stats = self.rl_trader.get_stats()
self.performance_stats['rl_predictions'] = rl_stats.get('total_predictions', 0)
# Get stats from trading executor
if self.trading_executor:
trade_history = self.trading_executor.get_trade_history()
self.performance_stats['trades_executed'] = len(trade_history)
except Exception as e:
logger.warning(f"Error updating performance stats: {e}")
async def _print_optimized_stats(self):
"""Print comprehensive optimized system statistics"""
try:
stats = self.performance_stats
uptime = (datetime.now() - stats['start_time']).total_seconds() if stats['start_time'] else 0
logger.info("=" * 80)
logger.info("🚀 OPTIMIZED COB SYSTEM PERFORMANCE STATISTICS")
logger.info("=" * 80)
logger.info("📊 Resource Optimization:")
logger.info(f" Memory Saved: {stats['memory_saved_mb']:.0f} MB")
logger.info(f" Uptime: {uptime:.0f} seconds")
logger.info(f" COB Updates: {stats['cob_updates_processed']}")
logger.info("\n🌐 Dashboard Statistics:")
logger.info(f" Active Connections: {stats['dashboard_connections']}")
logger.info(f" Server Status: {'RUNNING' if self.dashboard_server else 'STOPPED'}")
logger.info("\n🤖 RL Trading Statistics:")
logger.info(f" Total Predictions: {stats['rl_predictions']}")
logger.info(f" Trades Executed: {stats['trades_executed']}")
logger.info(f" Trainer Status: {'ACTIVE' if self.rl_trader else 'STOPPED'}")
# Shared COB statistics
if self.shared_cob_integration:
cob_stats = self.shared_cob_integration.get_statistics()
logger.info("\n📈 Shared COB Integration:")
logger.info(f" Active Exchanges: {', '.join(cob_stats.get('active_exchanges', []))}")
logger.info(f" Streaming: {cob_stats.get('is_streaming', False)}")
logger.info(f" CNN Callbacks: {cob_stats.get('cnn_callbacks', 0)}")
logger.info(f" DQN Callbacks: {cob_stats.get('dqn_callbacks', 0)}")
logger.info(f" Dashboard Callbacks: {cob_stats.get('dashboard_callbacks', 0)}")
logger.info("=" * 80)
logger.info("✅ OPTIMIZATION STATUS: Redundancy eliminated, shared resources active")
except Exception as e:
logger.error(f"Error printing optimized stats: {e}")
async def stop(self):
"""Stop the optimized system gracefully"""
if not self.running:
return
logger.info("Stopping Optimized COB System...")
self.running = False
# Stop RL trader
if self.rl_trader:
await self.rl_trader.stop()
logger.info("✅ RL Trader stopped")
# Stop dashboard
if self.dashboard_server:
await self.dashboard_server.stop()
logger.info("✅ Dashboard stopped")
# Stop shared COB integration (last, as others depend on it)
if self.shared_cob_integration:
await self.shared_cob_integration.stop()
logger.info("✅ Shared COB integration stopped")
# Print final optimization report
await self._print_final_optimization_report()
logger.info("Optimized COB System stopped successfully")
async def _print_final_optimization_report(self):
"""Print final optimization report"""
stats = self.performance_stats
uptime = (datetime.now() - stats['start_time']).total_seconds() if stats['start_time'] else 0
logger.info("\n📊 FINAL OPTIMIZATION REPORT:")
logger.info(f" Total Runtime: {uptime:.0f} seconds")
logger.info(f" Memory Saved: {stats['memory_saved_mb']:.0f} MB")
logger.info(f" COB Updates Processed: {stats['cob_updates_processed']}")
logger.info(f" RL Predictions Made: {stats['rl_predictions']}")
logger.info(f" Trades Executed: {stats['trades_executed']}")
logger.info(" ✅ Redundant implementations eliminated")
logger.info(" ✅ Shared COB integration successful")
# Simplified components that use shared COB (no redundant integrations)
class EnhancedCOBDashboard(COBDashboardServer):
"""Enhanced dashboard that uses shared COB integration"""
def __init__(self, host: str = 'localhost', port: int = 8053,
shared_cob: COBIntegration = None, performance_tracker: Dict = None):
# Initialize parent without creating new COB integration
self.shared_cob = shared_cob
self.performance_tracker = performance_tracker or {}
super().__init__(host, port)
# Use shared COB instead of creating new one
self.cob_integration = shared_cob
logger.info("Enhanced dashboard using shared COB integration (no redundancy)")
def get_stats(self) -> Dict[str, Any]:
"""Get dashboard statistics"""
return {
'active_connections': len(self.websocket_connections),
'using_shared_cob': self.shared_cob is not None,
'server_running': self.runner is not None
}
class OptimizedRLTrader:
"""Optimized RL trader that uses shared COB integration"""
def __init__(self, symbols: List[str], shared_cob: COBIntegration,
trading_executor: TradingExecutor, performance_tracker: Dict = None):
self.symbols = symbols
self.shared_cob = shared_cob
self.trading_executor = trading_executor
self.performance_tracker = performance_tracker or {}
self.running = False
# Subscribe to shared COB updates instead of creating new integration
self.subscription_id = None
self.prediction_count = 0
logger.info("Optimized RL trader using shared COB integration (no redundancy)")
async def start(self):
"""Start RL trader with shared COB"""
self.running = True
# Subscribe to shared COB updates
self.subscription_id = self.shared_cob.add_dqn_callback(self._on_cob_update)
# Start prediction loop
asyncio.create_task(self._prediction_loop())
logger.info("Optimized RL trader started with shared COB subscription")
async def stop(self):
"""Stop RL trader"""
self.running = False
logger.info("Optimized RL trader stopped")
async def _on_cob_update(self, symbol: str, data: Dict):
"""Handle COB updates from shared integration"""
try:
# Process RL prediction using shared data
self.prediction_count += 1
# Simple prediction logic (placeholder)
confidence = 0.75 # Example confidence
if self.prediction_count % 100 == 0:
logger.info(f"RL Prediction #{self.prediction_count} for {symbol} (confidence: {confidence:.2f})")
except Exception as e:
logger.error(f"Error in RL update: {e}")
async def _prediction_loop(self):
"""Main prediction loop"""
while self.running:
try:
# RL model inference would go here
await asyncio.sleep(0.2) # 200ms inference interval
except Exception as e:
logger.error(f"Error in prediction loop: {e}")
await asyncio.sleep(1)
def get_stats(self) -> Dict[str, Any]:
"""Get RL trader statistics"""
return {
'total_predictions': self.prediction_count,
'using_shared_cob': self.shared_cob is not None,
'subscription_active': self.subscription_id is not None
}
async def main():
"""Main entry point for optimized COB system"""
try:
# Create logs directory
os.makedirs('logs', exist_ok=True)
# Initialize and start optimized system
system = OptimizedCOBSystem()
await system.start()
except KeyboardInterrupt:
logger.info("Received keyboard interrupt, shutting down...")
except Exception as e:
logger.error(f"Critical error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
# Set event loop policy for Windows compatibility
if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'):
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
asyncio.run(main())

View File

@ -1,324 +0,0 @@
#!/usr/bin/env python3
"""
Real-time RL COB Trader Launcher
Launch script for the real-time reinforcement learning trader that:
1. Uses COB data for training a 1B parameter model
2. Performs inference every 200ms
3. Accumulates confident signals for trade execution
4. Trains continuously in real-time based on outcomes
This script provides a complete trading system integration.
"""
import asyncio
import logging
import signal
import sys
import json
import os
from datetime import datetime
from typing import Dict, Any, Optional
# Local imports
from core.realtime_rl_cob_trader import RealtimeRLCOBTrader
from core.trading_executor import TradingExecutor
from core.config import load_config
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('logs/realtime_rl_cob_trader.log'),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
class RealtimeRLCOBTraderLauncher:
"""
Launcher for Real-time RL COB Trader system
"""
def __init__(self, config_path: str = "config.yaml"):
"""Initialize launcher with configuration"""
self.config_path = config_path
self.config = load_config(config_path)
self.trader: Optional[RealtimeRLCOBTrader] = None
self.trading_executor: Optional[TradingExecutor] = None
self.running = False
# Setup signal handlers for graceful shutdown
signal.signal(signal.SIGINT, self._signal_handler)
signal.signal(signal.SIGTERM, self._signal_handler)
logger.info("RealtimeRLCOBTraderLauncher initialized")
def _signal_handler(self, signum, frame):
"""Handle shutdown signals"""
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
self.running = False
async def start(self):
"""Start the real-time RL COB trading system"""
try:
logger.info("=" * 60)
logger.info("REAL-TIME RL COB TRADER SYSTEM STARTING")
logger.info("=" * 60)
# Initialize trading executor
await self._initialize_trading_executor()
# Initialize RL trader
await self._initialize_rl_trader()
# Start the trading system
await self._start_trading_system()
# Run main loop
await self._run_main_loop()
except Exception as e:
logger.error(f"Critical error in trader launcher: {e}")
raise
finally:
await self.stop()
async def _initialize_trading_executor(self):
"""Initialize the trading executor"""
logger.info("Initializing Trading Executor...")
# Get trading configuration
trading_config = self.config.get('trading', {})
mexc_config = self.config.get('mexc', {})
# Determine if we should run in simulation mode
simulation_mode = mexc_config.get('simulation_mode', True)
if simulation_mode:
logger.info("Running in SIMULATION mode - no real trades will be executed")
else:
logger.warning("Running in LIVE TRADING mode - real money at risk!")
# Add safety confirmation for live trading
confirmation = input("Type 'CONFIRM_LIVE_TRADING' to proceed with live trading: ")
if confirmation != 'CONFIRM_LIVE_TRADING':
logger.info("Live trading not confirmed, switching to simulation mode")
simulation_mode = True
# Initialize trading executor
self.trading_executor = TradingExecutor(self.config_path)
logger.info(f"Trading Executor initialized in {'SIMULATION' if simulation_mode else 'LIVE'} mode")
async def _initialize_rl_trader(self):
"""Initialize the RL trader"""
logger.info("Initializing Real-time RL COB Trader...")
# Get RL configuration
rl_config = self.config.get('realtime_rl', {})
# Trading symbols
symbols = rl_config.get('symbols', ['BTC/USDT', 'ETH/USDT'])
# RL parameters
inference_interval_ms = rl_config.get('inference_interval_ms', 200)
min_confidence_threshold = rl_config.get('min_confidence_threshold', 0.7)
required_confident_predictions = rl_config.get('required_confident_predictions', 3)
model_checkpoint_dir = rl_config.get('model_checkpoint_dir', 'models/realtime_rl_cob')
# Initialize RL trader
if self.trading_executor is None:
raise RuntimeError("Trading executor not initialized")
self.trader = RealtimeRLCOBTrader(
symbols=symbols,
trading_executor=self.trading_executor,
model_checkpoint_dir=model_checkpoint_dir,
inference_interval_ms=inference_interval_ms,
min_confidence_threshold=min_confidence_threshold,
required_confident_predictions=required_confident_predictions
)
logger.info(f"RL Trader initialized for symbols: {symbols}")
logger.info(f"Inference interval: {inference_interval_ms}ms")
logger.info(f"Confidence threshold: {min_confidence_threshold}")
logger.info(f"Required predictions: {required_confident_predictions}")
async def _start_trading_system(self):
"""Start the complete trading system"""
logger.info("Starting Real-time RL COB Trading System...")
# Start RL trader (this will start COB integration internally)
if self.trader is None:
raise RuntimeError("RL trader not initialized")
await self.trader.start()
self.running = True
logger.info("✅ Real-time RL COB Trading System started successfully!")
logger.info("🔥 1B parameter model training and inference active")
logger.info("📊 COB data streaming and processing")
logger.info("🎯 Signal accumulation and trade execution ready")
logger.info("⚡ Real-time training on prediction outcomes")
async def _run_main_loop(self):
"""Main monitoring and statistics loop"""
logger.info("Starting main monitoring loop...")
last_stats_time = datetime.now()
stats_interval = 60 # Print stats every 60 seconds
while self.running:
try:
# Sleep for a bit
await asyncio.sleep(10)
# Print periodic statistics
current_time = datetime.now()
if (current_time - last_stats_time).total_seconds() >= stats_interval:
await self._print_performance_stats()
last_stats_time = current_time
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in main loop: {e}")
await asyncio.sleep(5)
logger.info("Main monitoring loop stopped")
async def _print_performance_stats(self):
"""Print comprehensive performance statistics"""
try:
if not self.trader:
return
stats = self.trader.get_performance_stats()
logger.info("=" * 80)
logger.info("🔥 REAL-TIME RL COB TRADER PERFORMANCE STATISTICS")
logger.info("=" * 80)
# Model information
logger.info("📊 Model Information:")
for symbol, model_info in stats.get('model_info', {}).items():
total_params = model_info.get('total_parameters', 0)
logger.info(f" {symbol}: {total_params:,} parameters ({total_params/1e9:.2f}B)")
# Training statistics
logger.info("\n🧠 Training Statistics:")
for symbol, training_stats in stats.get('training_stats', {}).items():
total_preds = training_stats.get('total_predictions', 0)
successful_preds = training_stats.get('successful_predictions', 0)
success_rate = (successful_preds / max(1, total_preds)) * 100
avg_loss = training_stats.get('average_loss', 0.0)
training_steps = training_stats.get('total_training_steps', 0)
last_training = training_stats.get('last_training_time')
logger.info(f" {symbol}:")
logger.info(f" Predictions: {total_preds} (Success: {success_rate:.1f}%)")
logger.info(f" Training Steps: {training_steps}")
logger.info(f" Average Loss: {avg_loss:.6f}")
if last_training:
logger.info(f" Last Training: {last_training}")
# Inference statistics
logger.info("\n⚡ Inference Statistics:")
for symbol, inference_stats in stats.get('inference_stats', {}).items():
total_inferences = inference_stats.get('total_inferences', 0)
avg_time = inference_stats.get('average_inference_time_ms', 0.0)
last_inference = inference_stats.get('last_inference_time')
logger.info(f" {symbol}:")
logger.info(f" Total Inferences: {total_inferences}")
logger.info(f" Average Time: {avg_time:.1f}ms")
if last_inference:
logger.info(f" Last Inference: {last_inference}")
# Signal statistics
logger.info("\n🎯 Signal Accumulation:")
for symbol, signal_stats in stats.get('signal_stats', {}).items():
current_signals = signal_stats.get('current_signals', 0)
confidence_sum = signal_stats.get('confidence_sum', 0.0)
success_rate = signal_stats.get('success_rate', 0.0) * 100
logger.info(f" {symbol}:")
logger.info(f" Current Signals: {current_signals}")
logger.info(f" Confidence Sum: {confidence_sum:.2f}")
logger.info(f" Historical Success Rate: {success_rate:.1f}%")
# Trading executor statistics
if self.trading_executor:
positions = self.trading_executor.get_positions()
trade_history = self.trading_executor.get_trade_history()
logger.info("\n💰 Trading Statistics:")
logger.info(f" Active Positions: {len(positions)}")
logger.info(f" Total Trades: {len(trade_history)}")
if trade_history:
# Calculate P&L statistics
total_pnl = sum(trade.pnl for trade in trade_history)
profitable_trades = sum(1 for trade in trade_history if trade.pnl > 0)
win_rate = (profitable_trades / len(trade_history)) * 100
logger.info(f" Total P&L: ${total_pnl:.2f}")
logger.info(f" Win Rate: {win_rate:.1f}%")
# Show active positions
if positions:
logger.info("\n📍 Active Positions:")
for symbol, position in positions.items():
logger.info(f" {symbol}: {position.side} {position.quantity:.6f} @ ${position.entry_price:.2f}")
logger.info("=" * 80)
except Exception as e:
logger.error(f"Error printing performance stats: {e}")
async def stop(self):
"""Stop the trading system gracefully"""
if not self.running:
return
logger.info("Stopping Real-time RL COB Trading System...")
self.running = False
# Stop RL trader
if self.trader:
await self.trader.stop()
logger.info("✅ RL Trader stopped")
# Print final statistics
if self.trader:
logger.info("\n📊 Final Performance Summary:")
await self._print_performance_stats()
logger.info("Real-time RL COB Trading System stopped successfully")
async def main():
"""Main entry point"""
try:
# Create logs directory if it doesn't exist
os.makedirs('logs', exist_ok=True)
# Initialize and start launcher
launcher = RealtimeRLCOBTraderLauncher()
await launcher.start()
except KeyboardInterrupt:
logger.info("Received keyboard interrupt, shutting down...")
except Exception as e:
logger.error(f"Critical error: {e}")
raise
if __name__ == "__main__":
# Set event loop policy for Windows compatibility
if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'):
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
asyncio.run(main())

View File

@ -1,218 +0,0 @@
#!/usr/bin/env python3
"""
Simple Dashboard Runner - Fixed version for testing
"""
import os
import sys
import logging
import time
import threading
from pathlib import Path
# Fix OpenMP library conflicts
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
os.environ['OMP_NUM_THREADS'] = '4'
# Fix matplotlib backend
import matplotlib
matplotlib.use('Agg')
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def create_simple_dashboard():
"""Create a simple working dashboard"""
try:
import dash
from dash import html, dcc, Input, Output
import plotly.graph_objs as go
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
# Create Dash app
app = dash.Dash(__name__)
# Simple layout
app.layout = html.Div([
html.H1("Trading System Dashboard", style={'textAlign': 'center', 'color': '#2c3e50'}),
html.Div([
html.Div([
html.H3("System Status", style={'color': '#27ae60'}),
html.P(id='system-status', children="System: RUNNING", style={'fontSize': '18px'}),
html.P(id='current-time', children=f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"),
], style={'width': '48%', 'display': 'inline-block', 'padding': '20px'}),
html.Div([
html.H3("Trading Stats", style={'color': '#3498db'}),
html.P("Total Trades: 0"),
html.P("Success Rate: 0%"),
html.P("Current PnL: $0.00"),
], style={'width': '48%', 'display': 'inline-block', 'padding': '20px'}),
]),
html.Div([
dcc.Graph(id='price-chart'),
], style={'padding': '20px'}),
html.Div([
dcc.Graph(id='performance-chart'),
], style={'padding': '20px'}),
# Auto-refresh component
dcc.Interval(
id='interval-component',
interval=5000, # Update every 5 seconds
n_intervals=0
)
])
# Callback for updating time
@app.callback(
Output('current-time', 'children'),
Input('interval-component', 'n_intervals')
)
def update_time(n):
return f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# Callback for price chart
@app.callback(
Output('price-chart', 'figure'),
Input('interval-component', 'n_intervals')
)
def update_price_chart(n):
# Generate sample data
dates = pd.date_range(start=datetime.now() - timedelta(hours=24),
end=datetime.now(), freq='1H')
prices = 3000 + np.cumsum(np.random.randn(len(dates)) * 10)
fig = go.Figure()
fig.add_trace(go.Scatter(
x=dates,
y=prices,
mode='lines',
name='ETH/USDT',
line=dict(color='#3498db', width=2)
))
fig.update_layout(
title='ETH/USDT Price Chart (24H)',
xaxis_title='Time',
yaxis_title='Price (USD)',
template='plotly_white',
height=400
)
return fig
# Callback for performance chart
@app.callback(
Output('performance-chart', 'figure'),
Input('interval-component', 'n_intervals')
)
def update_performance_chart(n):
# Generate sample performance data
dates = pd.date_range(start=datetime.now() - timedelta(days=7),
end=datetime.now(), freq='1D')
performance = np.cumsum(np.random.randn(len(dates)) * 0.02) * 100
fig = go.Figure()
fig.add_trace(go.Scatter(
x=dates,
y=performance,
mode='lines+markers',
name='Portfolio Performance',
line=dict(color='#27ae60', width=3),
marker=dict(size=6)
))
fig.update_layout(
title='Portfolio Performance (7 Days)',
xaxis_title='Date',
yaxis_title='Performance (%)',
template='plotly_white',
height=400
)
return fig
return app
except Exception as e:
logger.error(f"Error creating dashboard: {e}")
import traceback
logger.error(traceback.format_exc())
return None
def test_data_provider():
"""Test data provider in background"""
try:
from core.data_provider import DataProvider
from core.api_rate_limiter import get_rate_limiter
logger.info("Testing data provider...")
# Create data provider
data_provider = DataProvider(
symbols=['ETH/USDT'],
timeframes=['1m', '5m']
)
# Test getting data
df = data_provider.get_historical_data('ETH/USDT', '1m', limit=10)
if df is not None and len(df) > 0:
logger.info(f"✓ Data provider working: {len(df)} candles retrieved")
else:
logger.warning("⚠ Data provider returned no data (rate limiting)")
# Test rate limiter status
rate_limiter = get_rate_limiter()
status = rate_limiter.get_all_endpoint_status()
logger.info(f"Rate limiter status: {status}")
except Exception as e:
logger.error(f"Data provider test error: {e}")
def main():
"""Main function"""
logger.info("=" * 60)
logger.info("SIMPLE DASHBOARD RUNNER - TESTING SYSTEM")
logger.info("=" * 60)
# Test data provider in background
data_thread = threading.Thread(target=test_data_provider, daemon=True)
data_thread.start()
# Create and run dashboard
app = create_simple_dashboard()
if app is None:
logger.error("Failed to create dashboard")
return
try:
logger.info("Starting dashboard server...")
logger.info("Dashboard URL: http://127.0.0.1:8050")
logger.info("Press Ctrl+C to stop")
# Run the dashboard
app.run(debug=False, host='127.0.0.1', port=8050, use_reloader=False)
except KeyboardInterrupt:
logger.info("Dashboard stopped by user")
except Exception as e:
logger.error(f"Dashboard error: {e}")
import traceback
logger.error(traceback.format_exc())
if __name__ == "__main__":
main()

View File

@ -1,275 +0,0 @@
#!/usr/bin/env python3
"""
Stable Dashboard Runner - Prioritizes System Stability
This runner focuses on:
1. System stability and reliability
2. Core trading functionality
3. Minimal resource usage
4. Robust error handling
5. Graceful degradation
Deferred features (until stability is achieved):
- TensorBoard integration
- Complex training loops
- Advanced visualizations
"""
import os
import sys
import time
import logging
import threading
import signal
from pathlib import Path
# Fix environment issues before imports
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
os.environ['OMP_NUM_THREADS'] = '2' # Reduced from 4 for stability
# Fix matplotlib backend
import matplotlib
matplotlib.use('Agg')
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging, get_config
from system_stability_audit import SystemStabilityAuditor
# Setup logging with reduced verbosity for stability
setup_logging()
logger = logging.getLogger(__name__)
# Reduce logging noise from other modules
logging.getLogger('werkzeug').setLevel(logging.ERROR)
logging.getLogger('dash').setLevel(logging.ERROR)
logging.getLogger('matplotlib').setLevel(logging.ERROR)
class StableDashboardRunner:
"""
Stable dashboard runner with focus on reliability
"""
def __init__(self):
"""Initialize stable dashboard runner"""
self.config = get_config()
self.running = False
self.dashboard = None
self.stability_auditor = None
# Core components
self.data_provider = None
self.orchestrator = None
self.trading_executor = None
# Stability monitoring
self.last_health_check = time.time()
self.health_check_interval = 30 # Check every 30 seconds
logger.info("Stable Dashboard Runner initialized")
def initialize_components(self):
"""Initialize core components with error handling"""
try:
logger.info("Initializing core components...")
# Initialize data provider
from core.data_provider import DataProvider
self.data_provider = DataProvider()
logger.info("✓ Data provider initialized")
# Initialize trading executor
from core.trading_executor import TradingExecutor
self.trading_executor = TradingExecutor()
logger.info("✓ Trading executor initialized")
# Initialize orchestrator with minimal features for stability
from core.orchestrator import TradingOrchestrator
self.orchestrator = TradingOrchestrator(
data_provider=self.data_provider,
enhanced_rl_training=False # Disabled for stability
)
logger.info("✓ Orchestrator initialized (training disabled for stability)")
# Initialize dashboard
from web.clean_dashboard import CleanTradingDashboard
self.dashboard = CleanTradingDashboard(
data_provider=self.data_provider,
orchestrator=self.orchestrator,
trading_executor=self.trading_executor
)
logger.info("✓ Dashboard initialized")
return True
except Exception as e:
logger.error(f"Error initializing components: {e}")
return False
def start_stability_monitoring(self):
"""Start system stability monitoring"""
try:
self.stability_auditor = SystemStabilityAuditor()
self.stability_auditor.start_monitoring()
logger.info("✓ Stability monitoring started")
except Exception as e:
logger.error(f"Error starting stability monitoring: {e}")
def health_check(self):
"""Perform system health check"""
try:
current_time = time.time()
if current_time - self.last_health_check < self.health_check_interval:
return
self.last_health_check = current_time
# Check stability score
if self.stability_auditor:
report = self.stability_auditor.get_stability_report()
stability_score = report.get('stability_score', 0)
if stability_score < 50:
logger.warning(f"Low stability score: {stability_score:.1f}/100")
# Attempt to fix issues
self.stability_auditor.fix_common_issues()
elif stability_score < 80:
logger.info(f"Moderate stability: {stability_score:.1f}/100")
else:
logger.debug(f"Good stability: {stability_score:.1f}/100")
# Check component health
if self.dashboard and hasattr(self.dashboard, 'app'):
logger.debug("✓ Dashboard responsive")
if self.data_provider:
logger.debug("✓ Data provider active")
if self.orchestrator:
logger.debug("✓ Orchestrator active")
except Exception as e:
logger.error(f"Error in health check: {e}")
def run(self):
"""Run the stable dashboard"""
try:
logger.info("=" * 60)
logger.info("STABLE TRADING DASHBOARD")
logger.info("=" * 60)
logger.info("Priority: System Stability & Core Functionality")
logger.info("Training: Disabled (will be enabled after stability)")
logger.info("TensorBoard: Deferred (documented in design)")
logger.info("Focus: Dashboard, Data, Basic Trading")
logger.info("=" * 60)
# Initialize components
if not self.initialize_components():
logger.error("Failed to initialize components")
return False
# Start stability monitoring
self.start_stability_monitoring()
# Start health check thread
health_thread = threading.Thread(target=self._health_check_loop, daemon=True)
health_thread.start()
# Get dashboard port
port = int(os.environ.get('DASHBOARD_PORT', '8051'))
logger.info(f"Starting dashboard on http://127.0.0.1:{port}")
logger.info("Press Ctrl+C to stop")
self.running = True
# Start dashboard (this blocks)
if self.dashboard and hasattr(self.dashboard, 'app'):
self.dashboard.app.run_server(
host='127.0.0.1',
port=port,
debug=False,
use_reloader=False, # Disable reloader for stability
threaded=True
)
else:
logger.error("Dashboard not properly initialized")
return False
except KeyboardInterrupt:
logger.info("Dashboard stopped by user")
self.shutdown()
return True
except Exception as e:
logger.error(f"Error running dashboard: {e}")
import traceback
logger.error(traceback.format_exc())
return False
def _health_check_loop(self):
"""Health check loop running in background"""
while self.running:
try:
self.health_check()
time.sleep(self.health_check_interval)
except Exception as e:
logger.error(f"Error in health check loop: {e}")
time.sleep(60) # Wait longer on error
def shutdown(self):
"""Graceful shutdown"""
try:
logger.info("Shutting down stable dashboard...")
self.running = False
# Stop stability monitoring
if self.stability_auditor:
self.stability_auditor.stop_monitoring()
logger.info("✓ Stability monitoring stopped")
# Stop components
if self.orchestrator and hasattr(self.orchestrator, 'stop'):
self.orchestrator.stop()
logger.info("✓ Orchestrator stopped")
if self.data_provider and hasattr(self.data_provider, 'stop'):
self.data_provider.stop()
logger.info("✓ Data provider stopped")
logger.info("Stable dashboard shutdown complete")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
def signal_handler(signum, frame):
"""Handle shutdown signals"""
logger.info("Received shutdown signal")
sys.exit(0)
def main():
"""Main function"""
# Setup signal handlers
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
runner = StableDashboardRunner()
success = runner.run()
if success:
logger.info("Dashboard completed successfully")
sys.exit(0)
else:
logger.error("Dashboard failed")
sys.exit(1)
except Exception as e:
logger.error(f"Fatal error: {e}")
import traceback
logger.error(traceback.format_exc())
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -1,64 +0,0 @@
#!/usr/bin/env python3
"""
Run Templated Trading Dashboard
Demonstrates the new MVC template-based architecture
"""
import logging
import sys
import os
# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from web.templated_dashboard import create_templated_dashboard
from web.dashboard_model import create_sample_dashboard_data
from web.template_renderer import DashboardTemplateRenderer
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def main():
"""Main function to run the templated dashboard"""
try:
logger.info("=== TEMPLATED DASHBOARD DEMO ===")
# Test the template system first
logger.info("Testing template system...")
# Create sample data
sample_data = create_sample_dashboard_data()
logger.info(f"Created sample data with {len(sample_data.metrics)} metrics")
# Test template renderer
renderer = DashboardTemplateRenderer()
logger.info("Template renderer initialized")
# Create templated dashboard
logger.info("Creating templated dashboard...")
dashboard = create_templated_dashboard()
logger.info("Dashboard created successfully!")
logger.info("Template-based MVC architecture features:")
logger.info(" ✓ HTML templates separated from Python code")
logger.info(" ✓ Data models for structured data")
logger.info(" ✓ Template renderer for clean separation")
logger.info(" ✓ Easy to modify HTML without touching Python")
logger.info(" ✓ Reusable components and templates")
# Run the dashboard
logger.info("Starting templated dashboard server...")
dashboard.run_server(host='127.0.0.1', port=8052, debug=False)
except Exception as e:
logger.error(f"Error running templated dashboard: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@ -1,155 +0,0 @@
#!/usr/bin/env python3
"""
TensorBoard Launch Script
Starts TensorBoard server for monitoring training progress.
Visualizes training metrics, rewards, state information, and model performance.
This script can be run standalone or integrated with the dashboard.
"""
import subprocess
import sys
import os
import time
import webbrowser
import argparse
from pathlib import Path
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def start_tensorboard(logdir="runs", port=6006, open_browser=True):
"""
Start TensorBoard server programmatically
Args:
logdir: Directory containing TensorBoard logs
port: Port to run TensorBoard on
open_browser: Whether to open browser automatically
Returns:
subprocess.Popen: TensorBoard process
"""
# Set log directory
runs_dir = Path(logdir)
if not runs_dir.exists():
logger.warning(f"No '{logdir}' directory found. Creating it.")
runs_dir.mkdir(parents=True, exist_ok=True)
# Check if there are any log directories
log_dirs = list(runs_dir.glob("*"))
if not log_dirs:
logger.warning(f"No training logs found in '{logdir}' directory.")
else:
logger.info(f"Found {len(log_dirs)} training sessions")
# List available sessions
logger.info("Available training sessions:")
for i, log_dir in enumerate(sorted(log_dirs), 1):
logger.info(f" {i}. {log_dir.name}")
try:
logger.info(f"Starting TensorBoard on port {port}...")
# Try to open browser automatically if requested
if open_browser:
try:
webbrowser.open(f"http://localhost:{port}")
logger.info("Browser opened automatically")
except Exception as e:
logger.warning(f"Could not open browser automatically: {e}")
# Start TensorBoard process with enhanced options
cmd = [
sys.executable,
"-m",
"tensorboard.main",
"--logdir", str(runs_dir),
"--port", str(port),
"--samples_per_plugin", "images=100,audio=100,text=100",
"--reload_interval", "5", # Reload data every 5 seconds
"--reload_multifile", "true" # Better handling of multiple log files
]
logger.info("TensorBoard is running with enhanced training visualization!")
logger.info(f"View training metrics at: http://localhost:{port}")
logger.info("Available dashboards:")
logger.info(" - SCALARS: Training metrics, rewards, and losses")
logger.info(" - HISTOGRAMS: Feature distributions and model weights")
logger.info(" - TIME SERIES: Training progress over time")
# Start TensorBoard process
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
# Return process for management
return process
except FileNotFoundError:
logger.error("TensorBoard not found. Install with: pip install tensorboard")
return None
except Exception as e:
logger.error(f"Error starting TensorBoard: {e}")
return None
def main():
"""Launch TensorBoard with enhanced visualization options"""
# Parse command line arguments
parser = argparse.ArgumentParser(description="Launch TensorBoard for training visualization")
parser.add_argument("--port", type=int, default=6006, help="Port to run TensorBoard on")
parser.add_argument("--logdir", type=str, default="runs", help="Directory containing TensorBoard logs")
parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically")
parser.add_argument("--dashboard-integration", action="store_true", help="Run in dashboard integration mode")
args = parser.parse_args()
# Start TensorBoard
process = start_tensorboard(
logdir=args.logdir,
port=args.port,
open_browser=not args.no_browser
)
if process is None:
return 1
# If running in dashboard integration mode, return immediately
if args.dashboard_integration:
return 0
# Otherwise, wait for process to complete
try:
print("\n" + "="*70)
print("🔥 TensorBoard is running with enhanced training visualization!")
print(f"📈 View training metrics at: http://localhost:{args.port}")
print("⏹️ Press Ctrl+C to stop TensorBoard")
print("="*70 + "\n")
# Wait for process to complete or user interrupt
process.wait()
return 0
except KeyboardInterrupt:
print("\n🛑 TensorBoard stopped")
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
return 0
except Exception as e:
print(f"❌ Error: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,182 +0,0 @@
#!/usr/bin/env python3
"""
Unified Test Runner for Trading System
This script provides a unified interface to run all tests in the system:
- Essential functionality tests
- Model persistence tests
- Training integration tests
- Indicators and signals tests
- Remaining individual test files
Usage:
python run_tests.py # Run all tests
python run_tests.py essential # Run essential tests only
python run_tests.py persistence # Run model persistence tests only
python run_tests.py training # Run training integration tests only
python run_tests.py indicators # Run indicators and signals tests only
python run_tests.py individual # Run individual test files only
"""
import sys
import os
import subprocess
import logging
from pathlib import Path
from safe_logging import setup_safe_logging
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging
logger = logging.getLogger(__name__)
def run_test_module(module_path, test_type="all"):
"""Run a specific test module"""
try:
cmd = [sys.executable, str(module_path)]
if test_type != "all":
cmd.append(test_type)
logger.info(f"Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, cwd=project_root)
if result.returncode == 0:
logger.info(f"{module_path.name} passed")
if result.stdout:
logger.info(result.stdout)
return True
else:
logger.error(f"{module_path.name} failed")
if result.stderr:
logger.error(result.stderr)
if result.stdout:
logger.error(result.stdout)
return False
except Exception as e:
logger.error(f"❌ Error running {module_path}: {e}")
return False
def run_essential_tests():
"""Run essential functionality tests"""
logger.info("=== Running Essential Tests ===")
return run_test_module(project_root / "tests" / "test_essential.py")
def run_persistence_tests():
"""Run model persistence tests"""
logger.info("=== Running Model Persistence Tests ===")
return run_test_module(project_root / "tests" / "test_model_persistence.py")
def run_training_tests():
"""Run training integration tests"""
logger.info("=== Running Training Integration Tests ===")
return run_test_module(project_root / "tests" / "test_training_integration.py")
def run_indicators_tests():
"""Run indicators and signals tests"""
logger.info("=== Running Indicators and Signals Tests ===")
return run_test_module(project_root / "tests" / "test_indicators_and_signals.py")
def run_individual_tests():
"""Run remaining individual test files"""
logger.info("=== Running Individual Test Files ===")
individual_tests = [
"test_positions.py",
"test_tick_cache.py",
"test_timestamps.py"
]
results = []
for test_file in individual_tests:
test_path = project_root / test_file
if test_path.exists():
logger.info(f"Running {test_file}...")
result = run_test_module(test_path)
results.append(result)
else:
logger.warning(f"Test file not found: {test_file}")
results.append(False)
return all(results)
def run_all_tests():
"""Run all test suites"""
logger.info("🧪 Running All Trading System Tests")
logger.info("=" * 60)
test_suites = [
("Essential Tests", run_essential_tests),
("Model Persistence Tests", run_persistence_tests),
("Training Integration Tests", run_training_tests),
("Indicators and Signals Tests", run_indicators_tests),
("Individual Tests", run_individual_tests),
]
results = []
for suite_name, suite_func in test_suites:
logger.info(f"\n📋 {suite_name}")
logger.info("-" * 40)
try:
result = suite_func()
results.append((suite_name, result))
except Exception as e:
logger.error(f"{suite_name} crashed: {e}")
results.append((suite_name, False))
# Print summary
logger.info("\n" + "=" * 60)
logger.info("📊 TEST RESULTS SUMMARY")
logger.info("=" * 60)
passed = 0
for suite_name, result in results:
status = "✅ PASS" if result else "❌ FAIL"
logger.info(f"{status}: {suite_name}")
if result:
passed += 1
logger.info(f"\nPassed: {passed}/{len(results)} test suites")
if passed == len(results):
logger.info("🎉 All tests passed! Trading system is working correctly.")
return True
else:
logger.warning(f"⚠️ {len(results) - passed} test suite(s) failed. Please check the issues above.")
return False
def main():
"""Main test runner"""
setup_safe_logging()
# Parse command line arguments
if len(sys.argv) > 1:
test_type = sys.argv[1].lower()
if test_type == "essential":
success = run_essential_tests()
elif test_type == "persistence":
success = run_persistence_tests()
elif test_type == "training":
success = run_training_tests()
elif test_type == "indicators":
success = run_indicators_tests()
elif test_type == "individual":
success = run_individual_tests()
elif test_type in ["help", "-h", "--help"]:
print(__doc__)
return 0
else:
logger.error(f"Unknown test type: {test_type}")
print(__doc__)
return 1
else:
success = run_all_tests()
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,197 +0,0 @@
#!/usr/bin/env python3
"""
Kill Stale Processes Script
Safely terminates stale Python processes related to the trading dashboard
with proper error handling and graceful termination.
"""
import os
import sys
import time
import signal
from pathlib import Path
import threading
# Global timeout flag
timeout_reached = False
def timeout_handler():
"""Handler for overall script timeout"""
global timeout_reached
timeout_reached = True
print("\n⚠️ WARNING: Script timeout reached (10s) - forcing exit")
os._exit(0) # Force exit
def kill_stale_processes():
"""Kill stale trading dashboard processes safely"""
global timeout_reached
# Set up overall timeout (10 seconds)
timer = threading.Timer(10.0, timeout_handler)
timer.daemon = True
timer.start()
try:
import psutil
except ImportError:
print("psutil not available - using fallback method")
return kill_stale_fallback()
current_pid = os.getpid()
killed_processes = []
failed_processes = []
# Keywords to identify trading dashboard processes
target_keywords = [
'dashboard', 'scalping', 'trading', 'tensorboard',
'run_clean', 'run_main', 'gogo2', 'mexc'
]
try:
print("Scanning for stale processes...")
# Get all Python processes with timeout
python_processes = []
scan_start = time.time()
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
if timeout_reached or (time.time() - scan_start) > 3.0: # 3s max for scanning
print("Process scanning timeout - proceeding with found processes")
break
try:
if proc.info['pid'] == current_pid:
continue
name = proc.info['name'].lower()
if 'python' in name or 'tensorboard' in name:
cmdline_str = ' '.join(proc.info['cmdline']) if proc.info['cmdline'] else ''
# Check if this is a target process
if any(keyword in cmdline_str.lower() for keyword in target_keywords):
python_processes.append({
'proc': proc,
'pid': proc.info['pid'],
'name': proc.info['name'],
'cmdline': cmdline_str
})
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
continue
if not python_processes:
print("No stale processes found")
timer.cancel() # Cancel the timeout
return True
print(f"Found {len(python_processes)} target processes to terminate:")
for p in python_processes[:5]: # Show max 5 to save time
print(f" - PID {p['pid']}: {p['name']} - {p['cmdline'][:80]}...")
if len(python_processes) > 5:
print(f" ... and {len(python_processes) - 5} more")
# Graceful termination first (with reduced wait time)
print("\nAttempting graceful termination...")
termination_start = time.time()
for p in python_processes:
if timeout_reached or (time.time() - termination_start) > 2.0:
print("Termination timeout - moving to force kill")
break
try:
proc = p['proc']
if proc.is_running():
proc.terminate()
print(f" Sent SIGTERM to PID {p['pid']}")
except Exception as e:
failed_processes.append(f"Failed to terminate PID {p['pid']}: {e}")
# Wait for graceful shutdown (reduced from 2.0 to 1.0)
time.sleep(1.0)
# Force kill remaining processes
print("\nChecking for remaining processes...")
kill_start = time.time()
for p in python_processes:
if timeout_reached or (time.time() - kill_start) > 2.0:
print("Force kill timeout - exiting")
break
try:
proc = p['proc']
if proc.is_running():
print(f" Force killing PID {p['pid']} ({p['name']})")
proc.kill()
killed_processes.append(f"Force killed PID {p['pid']} ({p['name']})")
else:
killed_processes.append(f"Gracefully terminated PID {p['pid']} ({p['name']})")
except (psutil.NoSuchProcess, psutil.AccessDenied):
killed_processes.append(f"Process PID {p['pid']} already terminated")
except Exception as e:
failed_processes.append(f"Failed to kill PID {p['pid']}: {e}")
# Results (quick summary)
print(f"\n=== Quick Results ===")
print(f"✓ Cleaned up {len(killed_processes)} processes")
if failed_processes:
print(f"✗ Failed: {len(failed_processes)} processes")
timer.cancel() # Cancel the timeout if we finished early
return len(failed_processes) == 0
except Exception as e:
print(f"Error during process cleanup: {e}")
timer.cancel()
return False
def kill_stale_fallback():
"""Fallback method using basic OS commands"""
print("Using fallback process killing method...")
try:
if os.name == 'nt': # Windows
import subprocess
# Kill Python processes with dashboard keywords (with timeout)
result = subprocess.run([
'taskkill', '/f', '/im', 'python.exe'
], capture_output=True, text=True, timeout=5.0)
if result.returncode == 0:
print("Windows: Killed all Python processes")
else:
print("Windows: No Python processes to kill or access denied")
else: # Unix/Linux
import subprocess
# More targeted approach for Unix (with timeouts)
subprocess.run(['pkill', '-f', 'dashboard'], capture_output=True, timeout=2.0)
subprocess.run(['pkill', '-f', 'scalping'], capture_output=True, timeout=2.0)
subprocess.run(['pkill', '-f', 'tensorboard'], capture_output=True, timeout=2.0)
print("Unix: Killed dashboard-related processes")
return True
except subprocess.TimeoutExpired:
print("Fallback method timed out")
return False
except Exception as e:
print(f"Fallback method failed: {e}")
return False
if __name__ == "__main__":
print("=" * 50)
print("STALE PROCESS CLEANUP (10s timeout)")
print("=" * 50)
start_time = time.time()
success = kill_stale_processes()
elapsed = time.time() - start_time
exit_code = 0 if success else 1
print(f"Completed in {elapsed:.1f}s")
print("=" * 50)
sys.exit(exit_code)

View File

@ -1,40 +0,0 @@
#!/usr/bin/env python3
"""
Restart Dashboard with Forced Learning Enabled
Simple script to start dashboard with all learning features enabled
"""
import sys
import logging
from datetime import datetime
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def main():
"""Start dashboard with forced learning"""
logger.info("🚀 Starting Dashboard with FORCED LEARNING ENABLED")
logger.info("=" * 60)
logger.info(f"Timestamp: {datetime.now()}")
logger.info("Fixes Applied:")
logger.info("✅ Enhanced RL: FORCED ENABLED")
logger.info("✅ CNN Training: FORCED ENABLED")
logger.info("✅ Williams Pivots: CNN INTEGRATED")
logger.info("✅ Learning Pipeline: ACTIVE")
logger.info("=" * 60)
try:
# Import and run main
from main_clean import run_web_dashboard
logger.info("Starting web dashboard...")
run_web_dashboard()
except KeyboardInterrupt:
logger.info("Dashboard stopped by user")
except Exception as e:
logger.error(f"Dashboard failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@ -1,188 +0,0 @@
#!/usr/bin/env python3
"""
Overnight Training Restart Script
Keeps main.py running continuously, restarting it if it crashes.
Designed for overnight training sessions with unstable code.
Usage:
python restart_main_overnight.py
Press Ctrl+C to stop the restart loop.
"""
import subprocess
import sys
import time
import logging
from datetime import datetime
from pathlib import Path
import signal
import os
# Setup logging for the restart script
def setup_restart_logging():
"""Setup logging for restart events"""
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
# Create restart log file with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = log_dir / f"restart_main_{timestamp}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file, encoding='utf-8'),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
logger.info(f"Restart script logging to: {log_file}")
return logger
def kill_existing_processes(logger):
"""Kill any existing main.py processes to avoid conflicts"""
try:
if os.name == 'nt': # Windows
# Kill any existing Python processes running main.py
subprocess.run(['taskkill', '/f', '/im', 'python.exe'],
capture_output=True, check=False)
subprocess.run(['taskkill', '/f', '/im', 'pythonw.exe'],
capture_output=True, check=False)
time.sleep(2)
except Exception as e:
logger.warning(f"Could not kill existing processes: {e}")
def run_main_with_restart(logger):
"""Main restart loop"""
restart_count = 0
consecutive_fast_exits = 0
start_time = datetime.now()
logger.info("=" * 60)
logger.info("OVERNIGHT TRAINING RESTART SCRIPT STARTED")
logger.info("=" * 60)
logger.info("Press Ctrl+C to stop the restart loop")
logger.info("Main script: main.py")
logger.info("Restart delay on crash: 10 seconds")
logger.info("Fast exit protection: Enabled")
logger.info("=" * 60)
# Kill any existing processes
kill_existing_processes(logger)
while True:
try:
restart_count += 1
run_start_time = datetime.now()
logger.info(f"[RESTART #{restart_count}] Starting main.py at {run_start_time.strftime('%H:%M:%S')}")
# Start main.py as subprocess
process = subprocess.Popen([
sys.executable, "main.py"
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
universal_newlines=True, bufsize=1)
logger.info(f"[PROCESS] main.py started with PID: {process.pid}")
# Stream output from main.py
try:
if process.stdout:
while True:
output = process.stdout.readline()
if output == '' and process.poll() is not None:
break
if output:
# Forward output from main.py (remove extra newlines)
print(f"[MAIN] {output.rstrip()}")
else:
# If no stdout, just wait for process to complete
process.wait()
except KeyboardInterrupt:
logger.info("[INTERRUPT] Ctrl+C received, stopping main.py...")
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
logger.warning("[FORCE KILL] Process didn't terminate, force killing...")
process.kill()
raise
# Process has exited
exit_code = process.poll()
run_end_time = datetime.now()
run_duration = (run_end_time - run_start_time).total_seconds()
logger.info(f"[EXIT] main.py exited with code {exit_code}")
logger.info(f"[DURATION] Process ran for {run_duration:.1f} seconds")
# Check for fast exits (potential configuration issues)
if run_duration < 30: # Less than 30 seconds
consecutive_fast_exits += 1
logger.warning(f"[FAST EXIT] Process exited quickly ({consecutive_fast_exits} consecutive)")
if consecutive_fast_exits >= 5:
logger.error("[ABORT] Too many consecutive fast exits (5+)")
logger.error("This indicates a configuration or startup problem")
logger.error("Please check the main.py script manually")
break
# Longer delay for fast exits
delay = min(60, 10 * consecutive_fast_exits)
logger.info(f"[DELAY] Waiting {delay} seconds before restart due to fast exit...")
time.sleep(delay)
else:
consecutive_fast_exits = 0 # Reset counter
logger.info("[DELAY] Waiting 10 seconds before restart...")
time.sleep(10)
# Log session statistics every 10 restarts
if restart_count % 10 == 0:
total_duration = (datetime.now() - start_time).total_seconds()
logger.info(f"[STATS] Session: {restart_count} restarts in {total_duration/3600:.1f} hours")
except KeyboardInterrupt:
logger.info("[SHUTDOWN] Restart loop interrupted by user")
break
except Exception as e:
logger.error(f"[ERROR] Unexpected error in restart loop: {e}")
logger.error("Continuing restart loop after 30 second delay...")
time.sleep(30)
total_duration = (datetime.now() - start_time).total_seconds()
logger.info("=" * 60)
logger.info("OVERNIGHT TRAINING SESSION COMPLETE")
logger.info(f"Total restarts: {restart_count}")
logger.info(f"Total session time: {total_duration/3600:.1f} hours")
logger.info("=" * 60)
def main():
"""Main entry point"""
# Setup signal handlers for clean shutdown
def signal_handler(signum, frame):
logger.info(f"[SIGNAL] Received signal {signum}, shutting down...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
if hasattr(signal, 'SIGTERM'):
signal.signal(signal.SIGTERM, signal_handler)
# Setup logging
global logger
logger = setup_restart_logging()
try:
run_main_with_restart(logger)
except Exception as e:
logger.error(f"[FATAL] Fatal error in restart script: {e}")
import traceback
logger.error(traceback.format_exc())
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,88 +0,0 @@
#!/usr/bin/env python3
"""
MEXC Browser Setup & Runner
This script automatically installs dependencies and runs the MEXC browser automation.
"""
import subprocess
import sys
import os
import importlib
def check_and_install_requirements():
"""Check and install required packages"""
required_packages = [
'selenium',
'webdriver-manager',
'requests'
]
print("🔍 Checking required packages...")
missing_packages = []
for package in required_packages:
try:
importlib.import_module(package.replace('-', '_'))
print(f"{package} - already installed")
except ImportError:
missing_packages.append(package)
print(f"{package} - missing")
if missing_packages:
print(f"\n📦 Installing missing packages: {', '.join(missing_packages)}")
for package in missing_packages:
try:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
print(f"✅ Successfully installed {package}")
except subprocess.CalledProcessError as e:
print(f"❌ Failed to install {package}: {e}")
return False
print("✅ All requirements satisfied!")
return True
def run_browser_automation():
"""Run the MEXC browser automation"""
try:
# Import and run the auto browser
from core.mexc_webclient.auto_browser import main as auto_browser_main
auto_browser_main()
except ImportError:
print("❌ Could not import auto browser module")
print("Make sure core/mexc_webclient/auto_browser.py exists")
except Exception as e:
print(f"❌ Error running browser automation: {e}")
def main():
"""Main setup and run function"""
print("🚀 MEXC Browser Automation Setup")
print("=" * 40)
# Check Python version
if sys.version_info < (3, 7):
print("❌ Python 3.7+ required")
return
print(f"✅ Python {sys.version.split()[0]} detected")
# Install requirements
if not check_and_install_requirements():
print("❌ Failed to install requirements")
return
print("\n🌐 Starting browser automation...")
print("This will:")
print("• Download ChromeDriver automatically")
print("• Open MEXC futures page")
print("• Capture all trading requests")
print("• Extract session cookies")
input("\nPress Enter to continue...")
# Run the automation
run_browser_automation()
if __name__ == "__main__":
main()

View File

@ -1,160 +0,0 @@
#!/usr/bin/env python3
"""
Helper script to start monitoring services for RL training
"""
import subprocess
import sys
import time
import requests
import os
import json
from pathlib import Path
# Available ports to try for TensorBoard
TENSORBOARD_PORTS = [6006, 6007, 6008, 6009, 6010, 6011, 6012]
def check_port(port, service_name):
"""Check if a service is running on the specified port"""
try:
response = requests.get(f"http://localhost:{port}", timeout=3)
print(f"{service_name} is running on port {port}")
return True
except requests.exceptions.RequestException:
return False
def is_port_in_use(port):
"""Check if a port is already in use"""
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(('localhost', port))
return False
except OSError:
return True
def find_available_port(ports_list, service_name):
"""Find an available port from the list"""
for port in ports_list:
if not is_port_in_use(port):
print(f"🔍 Found available port {port} for {service_name}")
return port
else:
print(f"⚠️ Port {port} is already in use")
return None
def save_port_config(tensorboard_port):
"""Save the port configuration to a file"""
config = {
"tensorboard_port": tensorboard_port,
"web_dashboard_port": 8051
}
with open("monitoring_ports.json", "w") as f:
json.dump(config, f, indent=2)
print(f"💾 Port configuration saved to monitoring_ports.json")
def start_tensorboard():
"""Start TensorBoard in background on an available port"""
try:
# First check if TensorBoard is already running on any of our ports
for port in TENSORBOARD_PORTS:
if check_port(port, "TensorBoard"):
print(f"✅ TensorBoard already running on port {port}")
save_port_config(port)
return port
# Find an available port
port = find_available_port(TENSORBOARD_PORTS, "TensorBoard")
if port is None:
print(f"❌ No available ports found in range {TENSORBOARD_PORTS}")
return None
print(f"🚀 Starting TensorBoard on port {port}...")
# Create runs directory if it doesn't exist
Path("runs").mkdir(exist_ok=True)
# Start TensorBoard
if os.name == 'nt': # Windows
subprocess.Popen([
sys.executable, "-m", "tensorboard",
"--logdir=runs", f"--port={port}", "--reload_interval=1"
], creationflags=subprocess.CREATE_NEW_CONSOLE)
else: # Linux/Mac
subprocess.Popen([
sys.executable, "-m", "tensorboard",
"--logdir=runs", f"--port={port}", "--reload_interval=1"
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
# Wait for TensorBoard to start
print(f"⏳ Waiting for TensorBoard to start on port {port}...")
for i in range(15):
time.sleep(2)
if check_port(port, "TensorBoard"):
save_port_config(port)
return port
print(f"⚠️ TensorBoard failed to start on port {port} within 30 seconds")
return None
except Exception as e:
print(f"❌ Error starting TensorBoard: {e}")
return None
def check_web_dashboard_port():
"""Check if web dashboard port is available"""
port = 8051
if is_port_in_use(port):
print(f"⚠️ Web dashboard port {port} is in use")
# Try alternative ports
for alt_port in [8052, 8053, 8054, 8055]:
if not is_port_in_use(alt_port):
print(f"🔍 Alternative port {alt_port} available for web dashboard")
return alt_port
print("❌ No alternative ports found for web dashboard")
return port
else:
print(f"✅ Web dashboard port {port} is available")
return port
def main():
"""Main function"""
print("=" * 60)
print("🎯 RL TRAINING MONITORING SETUP")
print("=" * 60)
# Check web dashboard port
web_port = check_web_dashboard_port()
# Start TensorBoard
tensorboard_port = start_tensorboard()
print("\n" + "=" * 60)
print("📊 MONITORING STATUS")
print("=" * 60)
if tensorboard_port:
print(f"✅ TensorBoard: http://localhost:{tensorboard_port}")
# Update port config
save_port_config(tensorboard_port)
else:
print("❌ TensorBoard: Failed to start")
print(" Manual start: python -m tensorboard --logdir=runs --port=6007")
if web_port:
print(f"✅ Web Dashboard: Ready on port {web_port}")
print(f"\n🎯 Ready to start RL training!")
if tensorboard_port and web_port != 8051:
print(f"Run: python train_realtime_with_tensorboard.py --episodes 10 --web-port {web_port}")
else:
print("Run: python train_realtime_with_tensorboard.py --episodes 10")
print(f"\n📋 Available URLs:")
if tensorboard_port:
print(f" 📊 TensorBoard: http://localhost:{tensorboard_port}")
if web_port:
print(f" 🌐 Web Dashboard: http://localhost:{web_port} (starts with training)")
if __name__ == "__main__":
main()

View File

@ -1,179 +0,0 @@
#!/usr/bin/env python3
"""
Start Overnight Training Session
This script starts a comprehensive overnight training session that:
1. Ensures CNN and COB RL training processes are implemented and running
2. Executes training passes on each signal when predictions change
3. Calculates PnL and records trades in SIM mode
4. Tracks model performance statistics
5. Converts signals to actual trades for performance tracking
"""
import os
import sys
import time
import logging
from datetime import datetime
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f'overnight_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def main():
"""Start the overnight training session"""
try:
logger.info("🌙 STARTING OVERNIGHT TRAINING SESSION")
logger.info("=" * 80)
# Import required components
from core.config import get_config, setup_logging
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
from core.trading_executor import TradingExecutor
from web.clean_dashboard import CleanTradingDashboard
# Setup logging
setup_logging()
# Initialize components
logger.info("Initializing components...")
# Create data provider
data_provider = DataProvider()
logger.info("✅ Data Provider initialized")
# Create trading executor in simulation mode
trading_executor = TradingExecutor()
trading_executor.simulation_mode = True # Ensure we're in simulation mode
logger.info("✅ Trading Executor initialized (SIMULATION MODE)")
# Create orchestrator with enhanced training
orchestrator = TradingOrchestrator(
data_provider=data_provider,
enhanced_rl_training=True
)
logger.info("✅ Trading Orchestrator initialized")
# Connect trading executor to orchestrator
if hasattr(orchestrator, 'set_trading_executor'):
orchestrator.set_trading_executor(trading_executor)
logger.info("✅ Trading Executor connected to Orchestrator")
# Create dashboard (this initializes the overnight training coordinator)
dashboard = CleanTradingDashboard(
data_provider=data_provider,
orchestrator=orchestrator,
trading_executor=trading_executor
)
logger.info("✅ Dashboard initialized with Overnight Training Coordinator")
# Start the overnight training session
logger.info("Starting overnight training session...")
success = dashboard.start_overnight_training()
if success:
logger.info("🌙 OVERNIGHT TRAINING SESSION STARTED SUCCESSFULLY")
logger.info("=" * 80)
logger.info("Training Features Active:")
logger.info("✅ CNN training on signal changes")
logger.info("✅ COB RL training on market microstructure")
logger.info("✅ DQN training on trading decisions")
logger.info("✅ Trade execution and recording (SIMULATION)")
logger.info("✅ Performance tracking and statistics")
logger.info("✅ Model checkpointing every 50 trades")
logger.info("✅ Signal-to-trade conversion with PnL calculation")
logger.info("=" * 80)
# Monitor training progress
logger.info("Monitoring training progress...")
logger.info("Press Ctrl+C to stop the training session")
# Keep the session running and periodically report progress
start_time = datetime.now()
last_report_time = start_time
while True:
try:
time.sleep(60) # Check every minute
current_time = datetime.now()
elapsed_time = current_time - start_time
# Get performance summary every 10 minutes
if (current_time - last_report_time).total_seconds() >= 600: # 10 minutes
performance = dashboard.get_training_performance_summary()
logger.info("=" * 60)
logger.info(f"🌙 TRAINING PROGRESS REPORT - {elapsed_time}")
logger.info("=" * 60)
logger.info(f"Total Signals: {performance.get('total_signals', 0)}")
logger.info(f"Total Trades: {performance.get('total_trades', 0)}")
logger.info(f"Successful Trades: {performance.get('successful_trades', 0)}")
logger.info(f"Success Rate: {performance.get('success_rate', 0):.1%}")
logger.info(f"Total P&L: ${performance.get('total_pnl', 0):.2f}")
logger.info(f"Models Trained: {', '.join(performance.get('models_trained', []))}")
logger.info(f"Training Status: {'ACTIVE' if performance.get('is_running', False) else 'INACTIVE'}")
logger.info("=" * 60)
last_report_time = current_time
except KeyboardInterrupt:
logger.info("\n🛑 Training session interrupted by user")
break
except Exception as e:
logger.error(f"Error during training monitoring: {e}")
time.sleep(30) # Wait 30 seconds before retrying
# Stop the training session
logger.info("Stopping overnight training session...")
dashboard.stop_overnight_training()
# Final report
final_performance = dashboard.get_training_performance_summary()
total_time = datetime.now() - start_time
logger.info("=" * 80)
logger.info("🌅 OVERNIGHT TRAINING SESSION COMPLETED")
logger.info("=" * 80)
logger.info(f"Total Duration: {total_time}")
logger.info(f"Final Statistics:")
logger.info(f" Total Signals: {final_performance.get('total_signals', 0)}")
logger.info(f" Total Trades: {final_performance.get('total_trades', 0)}")
logger.info(f" Successful Trades: {final_performance.get('successful_trades', 0)}")
logger.info(f" Success Rate: {final_performance.get('success_rate', 0):.1%}")
logger.info(f" Total P&L: ${final_performance.get('total_pnl', 0):.2f}")
logger.info(f" Models Trained: {', '.join(final_performance.get('models_trained', []))}")
logger.info("=" * 80)
else:
logger.error("❌ Failed to start overnight training session")
return 1
except KeyboardInterrupt:
logger.info("\n🛑 Training session interrupted by user")
return 0
except Exception as e:
logger.error(f"❌ Error in overnight training session: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code)

View File

@ -1,426 +0,0 @@
#!/usr/bin/env python3
"""
System Stability Audit and Monitoring
This script performs a comprehensive audit of the trading system to identify
and fix stability issues, memory leaks, and performance bottlenecks.
"""
import os
import sys
import psutil
import logging
import time
import threading
import gc
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Any
import traceback
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging, get_config
# Setup logging
setup_logging()
logger = logging.getLogger(__name__)
class SystemStabilityAuditor:
"""
Comprehensive system stability auditor and monitor
Monitors:
- Memory usage and leaks
- CPU usage and performance
- Thread health and deadlocks
- Model performance and stability
- Dashboard responsiveness
- Data provider health
"""
def __init__(self):
"""Initialize the stability auditor"""
self.config = get_config()
self.monitoring_active = False
self.monitoring_thread = None
# Performance baselines
self.baseline_memory = psutil.virtual_memory().used
self.baseline_cpu = psutil.cpu_percent()
# Monitoring data
self.memory_history = []
self.cpu_history = []
self.thread_history = []
self.error_history = []
# Stability metrics
self.stability_score = 100.0
self.critical_issues = []
self.warnings = []
logger.info("System Stability Auditor initialized")
def start_monitoring(self):
"""Start continuous system monitoring"""
if self.monitoring_active:
logger.warning("Monitoring already active")
return
self.monitoring_active = True
self.monitoring_thread = threading.Thread(target=self._monitoring_loop, daemon=True)
self.monitoring_thread.start()
logger.info("System stability monitoring started")
def stop_monitoring(self):
"""Stop system monitoring"""
self.monitoring_active = False
if self.monitoring_thread:
self.monitoring_thread.join(timeout=5)
logger.info("System stability monitoring stopped")
def _monitoring_loop(self):
"""Main monitoring loop"""
while self.monitoring_active:
try:
# Collect system metrics
self._collect_system_metrics()
# Check for memory leaks
self._check_memory_leaks()
# Check CPU usage
self._check_cpu_usage()
# Check thread health
self._check_thread_health()
# Check for deadlocks
self._check_for_deadlocks()
# Update stability score
self._update_stability_score()
# Log status every 60 seconds
if len(self.memory_history) % 12 == 0: # Every 12 * 5s = 60s
self._log_stability_status()
time.sleep(5) # Check every 5 seconds
except Exception as e:
logger.error(f"Error in monitoring loop: {e}")
self.error_history.append({
'timestamp': datetime.now(),
'error': str(e),
'traceback': traceback.format_exc()
})
time.sleep(10) # Wait longer on error
def _collect_system_metrics(self):
"""Collect system performance metrics"""
try:
# Memory metrics
memory = psutil.virtual_memory()
memory_data = {
'timestamp': datetime.now(),
'used_gb': memory.used / (1024**3),
'available_gb': memory.available / (1024**3),
'percent': memory.percent
}
self.memory_history.append(memory_data)
# Keep only last 720 entries (1 hour at 5s intervals)
if len(self.memory_history) > 720:
self.memory_history = self.memory_history[-720:]
# CPU metrics
cpu_percent = psutil.cpu_percent(interval=1)
cpu_data = {
'timestamp': datetime.now(),
'percent': cpu_percent,
'cores': psutil.cpu_count()
}
self.cpu_history.append(cpu_data)
# Keep only last 720 entries
if len(self.cpu_history) > 720:
self.cpu_history = self.cpu_history[-720:]
# Thread metrics
thread_count = threading.active_count()
thread_data = {
'timestamp': datetime.now(),
'count': thread_count,
'threads': [t.name for t in threading.enumerate()]
}
self.thread_history.append(thread_data)
# Keep only last 720 entries
if len(self.thread_history) > 720:
self.thread_history = self.thread_history[-720:]
except Exception as e:
logger.error(f"Error collecting system metrics: {e}")
def _check_memory_leaks(self):
"""Check for memory leaks"""
try:
if len(self.memory_history) < 10:
return
# Check if memory usage is consistently increasing
recent_memory = [m['used_gb'] for m in self.memory_history[-10:]]
memory_trend = sum(recent_memory[-5:]) / 5 - sum(recent_memory[:5]) / 5
# If memory increased by more than 100MB in last 10 checks
if memory_trend > 0.1:
warning = f"Potential memory leak detected: +{memory_trend:.2f}GB in last 50s"
if warning not in self.warnings:
self.warnings.append(warning)
logger.warning(warning)
# Force garbage collection
gc.collect()
logger.info("Forced garbage collection to free memory")
# Check for excessive memory usage
current_memory = self.memory_history[-1]['percent']
if current_memory > 85:
critical = f"High memory usage: {current_memory:.1f}%"
if critical not in self.critical_issues:
self.critical_issues.append(critical)
logger.error(critical)
except Exception as e:
logger.error(f"Error checking memory leaks: {e}")
def _check_cpu_usage(self):
"""Check CPU usage patterns"""
try:
if len(self.cpu_history) < 10:
return
# Check for sustained high CPU usage
recent_cpu = [c['percent'] for c in self.cpu_history[-10:]]
avg_cpu = sum(recent_cpu) / len(recent_cpu)
if avg_cpu > 90:
critical = f"Sustained high CPU usage: {avg_cpu:.1f}%"
if critical not in self.critical_issues:
self.critical_issues.append(critical)
logger.error(critical)
elif avg_cpu > 75:
warning = f"High CPU usage: {avg_cpu:.1f}%"
if warning not in self.warnings:
self.warnings.append(warning)
logger.warning(warning)
except Exception as e:
logger.error(f"Error checking CPU usage: {e}")
def _check_thread_health(self):
"""Check thread health and detect issues"""
try:
if len(self.thread_history) < 5:
return
current_threads = self.thread_history[-1]['count']
# Check for thread explosion
if current_threads > 50:
critical = f"Thread explosion detected: {current_threads} active threads"
if critical not in self.critical_issues:
self.critical_issues.append(critical)
logger.error(critical)
# Log thread names for debugging
thread_names = self.thread_history[-1]['threads']
logger.error(f"Active threads: {thread_names}")
# Check for thread leaks (gradually increasing thread count)
if len(self.thread_history) >= 10:
thread_counts = [t['count'] for t in self.thread_history[-10:]]
thread_trend = sum(thread_counts[-5:]) / 5 - sum(thread_counts[:5]) / 5
if thread_trend > 2: # More than 2 threads increase on average
warning = f"Potential thread leak: +{thread_trend:.1f} threads in last 50s"
if warning not in self.warnings:
self.warnings.append(warning)
logger.warning(warning)
except Exception as e:
logger.error(f"Error checking thread health: {e}")
def _check_for_deadlocks(self):
"""Check for potential deadlocks"""
try:
# Simple deadlock detection based on thread states
all_threads = threading.enumerate()
blocked_threads = []
for thread in all_threads:
if hasattr(thread, '_is_stopped') and not thread._is_stopped:
# Thread is running but might be blocked
# This is a simplified check - real deadlock detection is complex
pass
# For now, just check if we have threads that haven't been active
# More sophisticated deadlock detection would require thread state analysis
except Exception as e:
logger.error(f"Error checking for deadlocks: {e}")
def _update_stability_score(self):
"""Update overall system stability score"""
try:
score = 100.0
# Deduct points for critical issues
score -= len(self.critical_issues) * 20
# Deduct points for warnings
score -= len(self.warnings) * 5
# Deduct points for recent errors
recent_errors = [e for e in self.error_history
if e['timestamp'] > datetime.now() - timedelta(minutes=10)]
score -= len(recent_errors) * 10
# Deduct points for high resource usage
if self.memory_history:
current_memory = self.memory_history[-1]['percent']
if current_memory > 80:
score -= (current_memory - 80) * 2
if self.cpu_history:
current_cpu = self.cpu_history[-1]['percent']
if current_cpu > 80:
score -= (current_cpu - 80) * 1
self.stability_score = max(0, score)
except Exception as e:
logger.error(f"Error updating stability score: {e}")
def _log_stability_status(self):
"""Log current stability status"""
try:
logger.info("=" * 50)
logger.info("SYSTEM STABILITY STATUS")
logger.info("=" * 50)
logger.info(f"Stability Score: {self.stability_score:.1f}/100")
if self.memory_history:
mem = self.memory_history[-1]
logger.info(f"Memory: {mem['used_gb']:.1f}GB used ({mem['percent']:.1f}%)")
if self.cpu_history:
cpu = self.cpu_history[-1]
logger.info(f"CPU: {cpu['percent']:.1f}%")
if self.thread_history:
threads = self.thread_history[-1]
logger.info(f"Threads: {threads['count']} active")
if self.critical_issues:
logger.error(f"Critical Issues ({len(self.critical_issues)}):")
for issue in self.critical_issues[-5:]: # Show last 5
logger.error(f" - {issue}")
if self.warnings:
logger.warning(f"Warnings ({len(self.warnings)}):")
for warning in self.warnings[-5:]: # Show last 5
logger.warning(f" - {warning}")
logger.info("=" * 50)
except Exception as e:
logger.error(f"Error logging stability status: {e}")
def get_stability_report(self) -> Dict[str, Any]:
"""Get comprehensive stability report"""
try:
return {
'stability_score': self.stability_score,
'critical_issues': self.critical_issues,
'warnings': self.warnings,
'memory_usage': self.memory_history[-1] if self.memory_history else None,
'cpu_usage': self.cpu_history[-1] if self.cpu_history else None,
'thread_count': self.thread_history[-1]['count'] if self.thread_history else 0,
'recent_errors': len([e for e in self.error_history
if e['timestamp'] > datetime.now() - timedelta(minutes=10)]),
'monitoring_active': self.monitoring_active
}
except Exception as e:
logger.error(f"Error generating stability report: {e}")
return {'error': str(e)}
def fix_common_issues(self):
"""Attempt to fix common stability issues"""
try:
logger.info("Attempting to fix common stability issues...")
# Force garbage collection
gc.collect()
logger.info("✓ Forced garbage collection")
# Clear old history to free memory
if len(self.memory_history) > 360: # Keep only 30 minutes
self.memory_history = self.memory_history[-360:]
if len(self.cpu_history) > 360:
self.cpu_history = self.cpu_history[-360:]
if len(self.thread_history) > 360:
self.thread_history = self.thread_history[-360:]
logger.info("✓ Cleared old monitoring history")
# Clear old errors
cutoff_time = datetime.now() - timedelta(hours=1)
self.error_history = [e for e in self.error_history if e['timestamp'] > cutoff_time]
logger.info("✓ Cleared old error history")
# Reset warnings and critical issues that might be stale
self.warnings = []
self.critical_issues = []
logger.info("✓ Reset stale warnings and critical issues")
logger.info("Common stability fixes applied")
except Exception as e:
logger.error(f"Error fixing common issues: {e}")
def main():
"""Main function for standalone execution"""
try:
logger.info("Starting System Stability Audit")
auditor = SystemStabilityAuditor()
auditor.start_monitoring()
# Run for 5 minutes then generate report
time.sleep(300)
report = auditor.get_stability_report()
logger.info("FINAL STABILITY REPORT:")
logger.info(f"Stability Score: {report['stability_score']:.1f}/100")
logger.info(f"Critical Issues: {len(report['critical_issues'])}")
logger.info(f"Warnings: {len(report['warnings'])}")
# Attempt fixes if needed
if report['stability_score'] < 80:
auditor.fix_common_issues()
auditor.stop_monitoring()
except KeyboardInterrupt:
logger.info("Audit interrupted by user")
except Exception as e:
logger.error(f"Error in stability audit: {e}")
if __name__ == "__main__":
main()

View File

@ -1,191 +0,0 @@
#!/usr/bin/env python3
"""
Test Build Base Data Performance
This script tests the performance of build_base_data_input to ensure it's instantaneous.
"""
import sys
import os
import time
import logging
from datetime import datetime
# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.orchestrator import TradingOrchestrator
from core.config import get_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_build_base_data_performance():
"""Test the performance of build_base_data_input"""
logger.info("=== Testing Build Base Data Performance ===")
try:
# Initialize orchestrator
config = get_config()
orchestrator = TradingOrchestrator(
symbol="ETH/USDT",
config=config
)
# Start the orchestrator to initialize data
orchestrator.start()
logger.info("✅ Orchestrator started")
# Wait a bit for data to be populated
time.sleep(2)
# Test performance of build_base_data_input
symbol = "ETH/USDT"
num_tests = 10
total_time = 0
logger.info(f"Running {num_tests} performance tests...")
for i in range(num_tests):
start_time = time.time()
base_data = orchestrator.build_base_data_input(symbol)
end_time = time.time()
duration = (end_time - start_time) * 1000 # Convert to milliseconds
total_time += duration
if base_data:
logger.info(f"Test {i+1}: {duration:.2f}ms - ✅ Success")
else:
logger.warning(f"Test {i+1}: {duration:.2f}ms - ❌ Failed (no data)")
avg_time = total_time / num_tests
logger.info(f"=== Performance Results ===")
logger.info(f"Average time: {avg_time:.2f}ms")
logger.info(f"Total time: {total_time:.2f}ms")
# Performance thresholds
if avg_time < 10: # Less than 10ms is excellent
logger.info("🎉 EXCELLENT: Build time is under 10ms")
elif avg_time < 50: # Less than 50ms is good
logger.info("✅ GOOD: Build time is under 50ms")
elif avg_time < 100: # Less than 100ms is acceptable
logger.info("⚠️ ACCEPTABLE: Build time is under 100ms")
else:
logger.error("❌ SLOW: Build time is over 100ms - needs optimization")
# Test with multiple symbols
logger.info("Testing with multiple symbols...")
symbols = ["ETH/USDT", "BTC/USDT"]
for symbol in symbols:
start_time = time.time()
base_data = orchestrator.build_base_data_input(symbol)
end_time = time.time()
duration = (end_time - start_time) * 1000
logger.info(f"{symbol}: {duration:.2f}ms")
# Stop orchestrator
orchestrator.stop()
logger.info("✅ Orchestrator stopped")
return avg_time < 100 # Return True if performance is acceptable
except Exception as e:
logger.error(f"❌ Performance test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_cache_effectiveness():
"""Test that caching is working effectively"""
logger.info("=== Testing Cache Effectiveness ===")
try:
# Initialize orchestrator
config = get_config()
orchestrator = TradingOrchestrator(
symbol="ETH/USDT",
config=config
)
orchestrator.start()
time.sleep(2) # Let data populate
symbol = "ETH/USDT"
# First call (should build cache)
start_time = time.time()
base_data1 = orchestrator.build_base_data_input(symbol)
first_call_time = (time.time() - start_time) * 1000
# Second call (should use cache)
start_time = time.time()
base_data2 = orchestrator.build_base_data_input(symbol)
second_call_time = (time.time() - start_time) * 1000
# Third call (should still use cache)
start_time = time.time()
base_data3 = orchestrator.build_base_data_input(symbol)
third_call_time = (time.time() - start_time) * 1000
logger.info(f"First call (build cache): {first_call_time:.2f}ms")
logger.info(f"Second call (use cache): {second_call_time:.2f}ms")
logger.info(f"Third call (use cache): {third_call_time:.2f}ms")
# Cache should make subsequent calls faster
if second_call_time < first_call_time * 0.5:
logger.info("✅ Cache is working effectively")
cache_effective = True
else:
logger.warning("⚠️ Cache may not be working as expected")
cache_effective = False
# Verify data consistency
if base_data1 and base_data2 and base_data3:
# Check that we get consistent data structure
if (len(base_data1.ohlcv_1s) == len(base_data2.ohlcv_1s) == len(base_data3.ohlcv_1s)):
logger.info("✅ Data consistency maintained")
else:
logger.warning("⚠️ Data consistency issues detected")
orchestrator.stop()
return cache_effective
except Exception as e:
logger.error(f"❌ Cache effectiveness test failed: {e}")
return False
def main():
"""Run all performance tests"""
logger.info("Starting Build Base Data Performance Tests")
# Test 1: Basic performance
test1_passed = test_build_base_data_performance()
# Test 2: Cache effectiveness
test2_passed = test_cache_effectiveness()
# Summary
logger.info("=== Test Summary ===")
logger.info(f"Performance Test: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
logger.info(f"Cache Effectiveness: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
if test1_passed and test2_passed:
logger.info("🎉 All tests passed! build_base_data_input is optimized.")
logger.info("The system now:")
logger.info(" - Builds BaseDataInput in under 100ms")
logger.info(" - Uses effective caching for repeated calls")
logger.info(" - Maintains data consistency")
else:
logger.error("❌ Some tests failed. Performance optimization needed.")
if __name__ == "__main__":
main()

View File

@ -1,348 +0,0 @@
#!/usr/bin/env python3
"""
Test script for Bybit ETH futures position opening/closing
"""
import os
import sys
import time
import logging
from datetime import datetime
# Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# Load environment variables from .env file
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
# If dotenv is not available, try to load .env manually
if os.path.exists('.env'):
with open('.env', 'r') as f:
for line in f:
if line.strip() and not line.startswith('#'):
key, value = line.strip().split('=', 1)
os.environ[key] = value
from core.exchanges.bybit_interface import BybitInterface
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class BybitEthFuturesTest:
"""Test class for Bybit ETH futures trading"""
def __init__(self, test_mode=True):
self.test_mode = test_mode
self.bybit = BybitInterface(test_mode=test_mode)
self.test_symbol = 'ETHUSDT'
self.test_quantity = 0.01 # Small test amount
def run_tests(self):
"""Run all tests"""
print("=" * 60)
print("BYBIT ETH FUTURES POSITION TESTING")
print("=" * 60)
print(f"Test mode: {'TESTNET' if self.test_mode else 'LIVE'}")
print(f"Symbol: {self.test_symbol}")
print(f"Test quantity: {self.test_quantity} ETH")
print("=" * 60)
# Test 1: Connection
if not self.test_connection():
print("❌ Connection failed - stopping tests")
return False
# Test 2: Check balance
if not self.test_balance():
print("❌ Balance check failed - stopping tests")
return False
# Test 3: Check current positions
self.test_current_positions()
# Test 4: Get ticker
if not self.test_ticker():
print("❌ Ticker test failed - stopping tests")
return False
# Test 5: Open a long position
long_order = self.test_open_long_position()
if not long_order:
print("❌ Open long position failed")
return False
# Test 6: Check position after opening
time.sleep(2) # Wait for position to be reflected
if not self.test_position_after_open():
print("❌ Position check after opening failed")
return False
# Test 7: Close the position
if not self.test_close_position():
print("❌ Close position failed")
return False
# Test 8: Check position after closing
time.sleep(2) # Wait for position to be reflected
self.test_position_after_close()
print("\n" + "=" * 60)
print("✅ ALL TESTS COMPLETED SUCCESSFULLY")
print("=" * 60)
return True
def test_connection(self):
"""Test connection to Bybit"""
print("\n📡 Testing connection to Bybit...")
# First test simple connectivity without auth
print("Testing basic API connectivity...")
try:
from core.exchanges.bybit_rest_client import BybitRestClient
client = BybitRestClient(
api_key="dummy",
api_secret="dummy",
testnet=True
)
# Test public endpoint (server time)
server_time = client.get_server_time()
print(f"✅ Public API working - Server time: {server_time.get('result', {}).get('timeSecond')}")
except Exception as e:
print(f"❌ Public API failed: {e}")
return False
# Now test with actual credentials
print("Testing with API credentials...")
try:
connected = self.bybit.connect()
if connected:
print("✅ Successfully connected to Bybit with credentials")
return True
else:
print("❌ Failed to connect to Bybit with credentials")
print("This might be due to:")
print("- Invalid API credentials")
print("- Credentials not enabled for testnet")
print("- Missing required permissions")
return False
except Exception as e:
print(f"❌ Connection error: {e}")
return False
def test_balance(self):
"""Test getting account balance"""
print("\n💰 Testing account balance...")
try:
# Get USDT balance (for margin)
usdt_balance = self.bybit.get_balance('USDT')
print(f"USDT Balance: {usdt_balance}")
# Get all balances
all_balances = self.bybit.get_all_balances()
print("All balances:")
for asset, balance in all_balances.items():
if balance['total'] > 0:
print(f" {asset}: Free={balance['free']}, Locked={balance['locked']}, Total={balance['total']}")
if usdt_balance > 10: # Need at least $10 for testing
print("✅ Sufficient balance for testing")
return True
else:
print("❌ Insufficient USDT balance for testing (need at least $10)")
return False
except Exception as e:
print(f"❌ Balance check error: {e}")
return False
def test_current_positions(self):
"""Test getting current positions"""
print("\n📊 Checking current positions...")
try:
positions = self.bybit.get_positions()
if positions:
print(f"Found {len(positions)} open positions:")
for pos in positions:
print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}")
print(f" PnL: ${pos['unrealized_pnl']:.2f} ({pos['percentage']:.2f}%)")
else:
print("No open positions found")
except Exception as e:
print(f"❌ Position check error: {e}")
def test_ticker(self):
"""Test getting ticker information"""
print(f"\n📈 Testing ticker for {self.test_symbol}...")
try:
ticker = self.bybit.get_ticker(self.test_symbol)
if ticker:
print(f"✅ Ticker data received:")
print(f" Last Price: ${ticker['last_price']:.2f}")
print(f" Bid: ${ticker['bid_price']:.2f}")
print(f" Ask: ${ticker['ask_price']:.2f}")
print(f" 24h Volume: {ticker['volume_24h']:.2f}")
print(f" 24h Change: {ticker['change_24h']:.4f}%")
return True
else:
print("❌ Failed to get ticker data")
return False
except Exception as e:
print(f"❌ Ticker error: {e}")
return False
def test_open_long_position(self):
"""Test opening a long position"""
print(f"\n🚀 Opening long position for {self.test_quantity} {self.test_symbol}...")
try:
# Place market buy order
order = self.bybit.place_order(
symbol=self.test_symbol,
side='buy',
order_type='market',
quantity=self.test_quantity
)
if 'error' in order:
print(f"❌ Order failed: {order['error']}")
return None
print("✅ Long position opened successfully:")
print(f" Order ID: {order['order_id']}")
print(f" Symbol: {order['symbol']}")
print(f" Side: {order['side']}")
print(f" Quantity: {order['quantity']}")
print(f" Status: {order['status']}")
return order
except Exception as e:
print(f"❌ Open position error: {e}")
return None
def test_position_after_open(self):
"""Test checking position after opening"""
print(f"\n📊 Checking position after opening...")
try:
positions = self.bybit.get_positions(self.test_symbol)
if positions:
position = positions[0]
print("✅ Position found:")
print(f" Symbol: {position['symbol']}")
print(f" Side: {position['side']}")
print(f" Size: {position['size']}")
print(f" Entry Price: ${position['entry_price']:.2f}")
print(f" Mark Price: ${position['mark_price']:.2f}")
print(f" Unrealized PnL: ${position['unrealized_pnl']:.2f}")
print(f" Percentage: {position['percentage']:.2f}%")
print(f" Leverage: {position['leverage']}x")
return True
else:
print("❌ No position found after opening")
return False
except Exception as e:
print(f"❌ Position check error: {e}")
return False
def test_close_position(self):
"""Test closing the position"""
print(f"\n🔄 Closing position for {self.test_symbol}...")
try:
# Close the position
close_order = self.bybit.close_position(self.test_symbol)
if 'error' in close_order:
print(f"❌ Close order failed: {close_order['error']}")
return False
print("✅ Position closed successfully:")
print(f" Order ID: {close_order['order_id']}")
print(f" Symbol: {close_order['symbol']}")
print(f" Side: {close_order['side']}")
print(f" Quantity: {close_order['quantity']}")
print(f" Status: {close_order['status']}")
return True
except Exception as e:
print(f"❌ Close position error: {e}")
return False
def test_position_after_close(self):
"""Test checking position after closing"""
print(f"\n📊 Checking position after closing...")
try:
positions = self.bybit.get_positions(self.test_symbol)
if positions:
position = positions[0]
print("⚠️ Position still exists (may be partially closed):")
print(f" Symbol: {position['symbol']}")
print(f" Side: {position['side']}")
print(f" Size: {position['size']}")
print(f" Entry Price: ${position['entry_price']:.2f}")
print(f" Unrealized PnL: ${position['unrealized_pnl']:.2f}")
else:
print("✅ Position successfully closed - no open positions")
except Exception as e:
print(f"❌ Position check error: {e}")
def test_order_history(self):
"""Test getting order history"""
print(f"\n📋 Checking recent orders...")
try:
# Get open orders
open_orders = self.bybit.get_open_orders(self.test_symbol)
print(f"Open orders: {len(open_orders)}")
for order in open_orders:
print(f" {order['order_id']}: {order['side']} {order['quantity']} @ ${order['price']:.2f} - {order['status']}")
except Exception as e:
print(f"❌ Order history error: {e}")
def main():
"""Main function"""
print("Starting Bybit ETH Futures Test...")
# Check if API credentials are set
api_key = os.getenv('BYBIT_API_KEY')
api_secret = os.getenv('BYBIT_API_SECRET')
if not api_key or not api_secret:
print("❌ Please set BYBIT_API_KEY and BYBIT_API_SECRET environment variables")
return False
# Create test instance
test = BybitEthFuturesTest(test_mode=True) # Always use testnet for safety
# Run tests
success = test.run_tests()
if success:
print("\n🎉 All tests passed!")
else:
print("\n💥 Some tests failed!")
return success
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -1,304 +0,0 @@
#!/usr/bin/env python3
"""
Fixed Bybit ETH futures trading test with proper minimum order size handling
"""
import os
import sys
import time
import logging
import json
# Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# Load environment variables
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
if os.path.exists('.env'):
with open('.env', 'r') as f:
for line in f:
if line.strip() and not line.startswith('#'):
key, value = line.strip().split('=', 1)
os.environ[key] = value
from core.exchanges.bybit_interface import BybitInterface
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def get_instrument_info(bybit: BybitInterface, symbol: str) -> dict:
"""Get instrument information including minimum order size"""
try:
instruments = bybit.get_instruments("linear")
for instrument in instruments:
if instrument.get('symbol') == symbol:
return instrument
return {}
except Exception as e:
logger.error(f"Error getting instrument info: {e}")
return {}
def test_eth_futures_trading():
"""Test ETH futures trading with proper minimum order size"""
print("🚀 Starting Fixed Bybit ETH Futures Live Trading Test...")
print("=" * 60)
print("BYBIT ETH FUTURES LIVE TRADING TEST (FIXED)")
print("=" * 60)
print("⚠️ This uses LIVE environment with real money!")
print("⚠️ Will check minimum order size first")
print("=" * 60)
# Check if API credentials are set
api_key = os.getenv('BYBIT_API_KEY')
api_secret = os.getenv('BYBIT_API_SECRET')
if not api_key or not api_secret:
print("❌ API credentials not found in environment")
return False
# Create Bybit interface with live environment
bybit = BybitInterface(
api_key=api_key,
api_secret=api_secret,
test_mode=False # Use live environment
)
symbol = 'ETHUSDT'
# Test 1: Connection
print(f"\n📡 Testing connection to Bybit live environment...")
try:
if not bybit.connect():
print("❌ Failed to connect to Bybit")
return False
print("✅ Successfully connected to Bybit live environment")
except Exception as e:
print(f"❌ Connection error: {e}")
return False
# Test 2: Get instrument information to check minimum order size
print(f"\n📋 Getting instrument information for {symbol}...")
try:
instrument_info = get_instrument_info(bybit, symbol)
if not instrument_info:
print(f"❌ Failed to get instrument info for {symbol}")
return False
print("✅ Instrument information retrieved:")
print(f" Symbol: {instrument_info.get('symbol')}")
print(f" Status: {instrument_info.get('status')}")
print(f" Base Coin: {instrument_info.get('baseCoin')}")
print(f" Quote Coin: {instrument_info.get('quoteCoin')}")
# Extract minimum order size
lot_size_filter = instrument_info.get('lotSizeFilter', {})
min_order_qty = float(lot_size_filter.get('minOrderQty', 0.01))
max_order_qty = float(lot_size_filter.get('maxOrderQty', 10000))
qty_step = float(lot_size_filter.get('qtyStep', 0.01))
print(f" Minimum Order Qty: {min_order_qty}")
print(f" Maximum Order Qty: {max_order_qty}")
print(f" Quantity Step: {qty_step}")
# Use minimum order size for testing
test_quantity = min_order_qty
print(f" Using test quantity: {test_quantity} ETH")
except Exception as e:
print(f"❌ Instrument info error: {e}")
return False
# Test 3: Get account balance
print(f"\n💰 Checking account balance...")
try:
usdt_balance = bybit.get_balance('USDT')
print(f"USDT Balance: ${usdt_balance:.2f}")
# Calculate required balance (with some buffer)
current_price_data = bybit.get_ticker(symbol)
if not current_price_data:
print("❌ Failed to get current ETH price")
return False
current_price = current_price_data['last_price']
required_balance = current_price * test_quantity * 1.1 # 10% buffer
print(f"Current ETH price: ${current_price:.2f}")
print(f"Required balance: ${required_balance:.2f}")
if usdt_balance < required_balance:
print(f"❌ Insufficient USDT balance for testing (need at least ${required_balance:.2f})")
return False
print("✅ Sufficient balance for testing")
except Exception as e:
print(f"❌ Balance check error: {e}")
return False
# Test 4: Check existing positions
print(f"\n📊 Checking existing positions...")
try:
positions = bybit.get_positions(symbol)
if positions:
print(f"Found {len(positions)} existing positions:")
for pos in positions:
print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}")
print(f" PnL: ${pos['unrealized_pnl']:.2f}")
else:
print("No existing positions found")
except Exception as e:
print(f"❌ Position check error: {e}")
return False
# Test 5: Ask user confirmation before trading
print(f"\n⚠️ TRADING CONFIRMATION")
print(f" Symbol: {symbol}")
print(f" Quantity: {test_quantity} ETH")
print(f" Estimated cost: ${current_price * test_quantity:.2f}")
print(f" Environment: LIVE (real money)")
print(f" Minimum order size confirmed: {min_order_qty}")
response = input("\nDo you want to proceed with the live trading test? (y/N): ").lower()
if response != 'y' and response != 'yes':
print("❌ Trading test cancelled by user")
return False
# Test 6: Open a small long position
print(f"\n🚀 Opening small long position...")
try:
order = bybit.place_order(
symbol=symbol,
side='buy',
order_type='market',
quantity=test_quantity
)
if 'error' in order:
print(f"❌ Order failed: {order['error']}")
return False
print("✅ Long position opened successfully:")
print(f" Order ID: {order['order_id']}")
print(f" Symbol: {order['symbol']}")
print(f" Side: {order['side']}")
print(f" Quantity: {order['quantity']}")
print(f" Status: {order['status']}")
order_id = order['order_id']
except Exception as e:
print(f"❌ Order placement error: {e}")
return False
# Test 7: Wait a moment and check position
print(f"\n⏳ Waiting 5 seconds for position to be reflected...")
time.sleep(5)
try:
positions = bybit.get_positions(symbol)
if positions:
position = positions[0]
print("✅ Position confirmed:")
print(f" Symbol: {position['symbol']}")
print(f" Side: {position['side']}")
print(f" Size: {position['size']}")
print(f" Entry Price: ${position['entry_price']:.2f}")
print(f" Current PnL: ${position['unrealized_pnl']:.2f}")
print(f" Leverage: {position['leverage']}x")
else:
print("⚠️ No position found (may already be closed)")
except Exception as e:
print(f"❌ Position check error: {e}")
# Test 8: Close the position
print(f"\n🔄 Closing the position...")
try:
close_order = bybit.close_position(symbol)
if 'error' in close_order:
print(f"❌ Close order failed: {close_order['error']}")
# Don't return False here, as the position might still exist
print("⚠️ You may need to manually close the position")
else:
print("✅ Position closed successfully:")
print(f" Order ID: {close_order['order_id']}")
print(f" Symbol: {close_order['symbol']}")
print(f" Side: {close_order['side']}")
print(f" Quantity: {close_order['quantity']}")
print(f" Status: {close_order['status']}")
except Exception as e:
print(f"❌ Close position error: {e}")
print("⚠️ You may need to manually close the position")
# Test 9: Final position check
print(f"\n📊 Final position check...")
time.sleep(3)
try:
positions = bybit.get_positions(symbol)
if positions:
position = positions[0]
print("⚠️ Position still exists:")
print(f" Size: {position['size']}")
print(f" PnL: ${position['unrealized_pnl']:.2f}")
print("💡 You may want to manually close this position")
else:
print("✅ No open positions - trading test completed successfully")
except Exception as e:
print(f"❌ Final position check error: {e}")
# Test 10: Final balance check
print(f"\n💰 Final balance check...")
try:
final_balance = bybit.get_balance('USDT')
print(f"Final USDT Balance: ${final_balance:.2f}")
balance_change = final_balance - usdt_balance
if balance_change > 0:
print(f"💰 Profit: +${balance_change:.2f}")
elif balance_change < 0:
print(f"📉 Loss: ${balance_change:.2f}")
else:
print(f"🔄 No change: ${balance_change:.2f}")
except Exception as e:
print(f"❌ Final balance check error: {e}")
return True
def main():
"""Main function"""
print("🚀 Starting Fixed Bybit ETH Futures Live Trading Test...")
success = test_eth_futures_trading()
if success:
print("\n" + "=" * 60)
print("✅ BYBIT ETH FUTURES TRADING TEST COMPLETED")
print("=" * 60)
print("🎯 Your Bybit integration is fully functional!")
print("🔄 Position opening and closing works correctly")
print("💰 Account balance integration works")
print("📊 All trading functions are operational")
print("📏 Minimum order size handling works")
print("=" * 60)
else:
print("\n💥 Trading test failed!")
print("🔍 Check the error messages above for details")
return success
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -1,249 +0,0 @@
#!/usr/bin/env python3
"""
Test Bybit ETH futures trading with live environment
"""
import os
import sys
import time
import logging
# Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# Load environment variables
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
if os.path.exists('.env'):
with open('.env', 'r') as f:
for line in f:
if line.strip() and not line.startswith('#'):
key, value = line.strip().split('=', 1)
os.environ[key] = value
from core.exchanges.bybit_interface import BybitInterface
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def test_eth_futures_trading():
"""Test ETH futures trading with live environment"""
print("=" * 60)
print("BYBIT ETH FUTURES LIVE TRADING TEST")
print("=" * 60)
print("⚠️ This uses LIVE environment with real money!")
print("⚠️ Test amount: 0.001 ETH (very small)")
print("=" * 60)
# Check if API credentials are set
api_key = os.getenv('BYBIT_API_KEY')
api_secret = os.getenv('BYBIT_API_SECRET')
if not api_key or not api_secret:
print("❌ API credentials not found in environment")
return False
# Create Bybit interface with live environment
bybit = BybitInterface(
api_key=api_key,
api_secret=api_secret,
test_mode=False # Use live environment
)
symbol = 'ETHUSDT'
test_quantity = 0.01 # Minimum order size for ETH futures
# Test 1: Connection
print(f"\n📡 Testing connection to Bybit live environment...")
try:
if not bybit.connect():
print("❌ Failed to connect to Bybit")
return False
print("✅ Successfully connected to Bybit live environment")
except Exception as e:
print(f"❌ Connection error: {e}")
return False
# Test 2: Get account balance
print(f"\n💰 Checking account balance...")
try:
usdt_balance = bybit.get_balance('USDT')
print(f"USDT Balance: ${usdt_balance:.2f}")
if usdt_balance < 5:
print("❌ Insufficient USDT balance for testing (need at least $5)")
return False
print("✅ Sufficient balance for testing")
except Exception as e:
print(f"❌ Balance check error: {e}")
return False
# Test 3: Get current ETH price
print(f"\n📈 Getting current ETH price...")
try:
ticker = bybit.get_ticker(symbol)
if not ticker:
print("❌ Failed to get ticker")
return False
current_price = ticker['last_price']
print(f"Current ETH price: ${current_price:.2f}")
print(f"Test order value: ${current_price * test_quantity:.2f}")
except Exception as e:
print(f"❌ Ticker error: {e}")
return False
# Test 4: Check existing positions
print(f"\n📊 Checking existing positions...")
try:
positions = bybit.get_positions(symbol)
if positions:
print(f"Found {len(positions)} existing positions:")
for pos in positions:
print(f" {pos['symbol']}: {pos['side']} {pos['size']} @ ${pos['entry_price']:.2f}")
print(f" PnL: ${pos['unrealized_pnl']:.2f}")
else:
print("No existing positions found")
except Exception as e:
print(f"❌ Position check error: {e}")
return False
# Test 5: Ask user confirmation before trading
print(f"\n⚠️ TRADING CONFIRMATION")
print(f" Symbol: {symbol}")
print(f" Quantity: {test_quantity} ETH")
print(f" Estimated cost: ${current_price * test_quantity:.2f}")
print(f" Environment: LIVE (real money)")
response = input("\nDo you want to proceed with the live trading test? (y/N): ").lower()
if response != 'y' and response != 'yes':
print("❌ Trading test cancelled by user")
return False
# Test 6: Open a small long position
print(f"\n🚀 Opening small long position...")
try:
order = bybit.place_order(
symbol=symbol,
side='buy',
order_type='market',
quantity=test_quantity
)
if 'error' in order:
print(f"❌ Order failed: {order['error']}")
return False
print("✅ Long position opened successfully:")
print(f" Order ID: {order['order_id']}")
print(f" Symbol: {order['symbol']}")
print(f" Side: {order['side']}")
print(f" Quantity: {order['quantity']}")
print(f" Status: {order['status']}")
order_id = order['order_id']
except Exception as e:
print(f"❌ Order placement error: {e}")
return False
# Test 7: Wait a moment and check position
print(f"\n⏳ Waiting 3 seconds for position to be reflected...")
time.sleep(3)
try:
positions = bybit.get_positions(symbol)
if positions:
position = positions[0]
print("✅ Position confirmed:")
print(f" Symbol: {position['symbol']}")
print(f" Side: {position['side']}")
print(f" Size: {position['size']}")
print(f" Entry Price: ${position['entry_price']:.2f}")
print(f" Current PnL: ${position['unrealized_pnl']:.2f}")
print(f" Leverage: {position['leverage']}x")
else:
print("⚠️ No position found (may already be closed)")
except Exception as e:
print(f"❌ Position check error: {e}")
# Test 8: Close the position
print(f"\n🔄 Closing the position...")
try:
close_order = bybit.close_position(symbol)
if 'error' in close_order:
print(f"❌ Close order failed: {close_order['error']}")
return False
print("✅ Position closed successfully:")
print(f" Order ID: {close_order['order_id']}")
print(f" Symbol: {close_order['symbol']}")
print(f" Side: {close_order['side']}")
print(f" Quantity: {close_order['quantity']}")
print(f" Status: {close_order['status']}")
except Exception as e:
print(f"❌ Close position error: {e}")
return False
# Test 9: Final position check
print(f"\n📊 Final position check...")
time.sleep(2)
try:
positions = bybit.get_positions(symbol)
if positions:
position = positions[0]
print("⚠️ Position still exists:")
print(f" Size: {position['size']}")
print(f" PnL: ${position['unrealized_pnl']:.2f}")
else:
print("✅ No open positions - trading test completed successfully")
except Exception as e:
print(f"❌ Final position check error: {e}")
# Test 10: Final balance check
print(f"\n💰 Final balance check...")
try:
final_balance = bybit.get_balance('USDT')
print(f"Final USDT Balance: ${final_balance:.2f}")
except Exception as e:
print(f"❌ Final balance check error: {e}")
return True
def main():
"""Main function"""
print("🚀 Starting Bybit ETH Futures Live Trading Test...")
success = test_eth_futures_trading()
if success:
print("\n" + "=" * 60)
print("✅ BYBIT ETH FUTURES TRADING TEST COMPLETED")
print("=" * 60)
print("🎯 Your Bybit integration is fully functional!")
print("🔄 Position opening and closing works correctly")
print("💰 Account balance integration works")
print("📊 All trading functions are operational")
print("=" * 60)
else:
print("\n💥 Trading test failed!")
return success
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -1,220 +0,0 @@
#!/usr/bin/env python3
"""
Test Bybit public API functionality (no authentication required)
"""
import os
import sys
import time
import logging
# Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.exchanges.bybit_rest_client import BybitRestClient
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def test_public_api():
"""Test public API endpoints"""
print("=" * 60)
print("BYBIT PUBLIC API TEST")
print("=" * 60)
# Test both testnet and live for public endpoints
for testnet in [True, False]:
env_name = "TESTNET" if testnet else "LIVE"
print(f"\n🔄 Testing {env_name} environment...")
client = BybitRestClient(
api_key="dummy",
api_secret="dummy",
testnet=testnet
)
# Test 1: Server time
try:
server_time = client.get_server_time()
time_second = server_time.get('result', {}).get('timeSecond')
print(f"✅ Server time: {time_second}")
except Exception as e:
print(f"❌ Server time failed: {e}")
continue
# Test 2: Get ticker for ETHUSDT
try:
ticker = client.get_ticker('ETHUSDT', 'linear')
ticker_data = ticker.get('result', {}).get('list', [])
if ticker_data:
data = ticker_data[0]
print(f"✅ ETH/USDT ticker:")
print(f" Last Price: ${float(data.get('lastPrice', 0)):.2f}")
print(f" 24h Volume: {float(data.get('volume24h', 0)):.2f}")
print(f" 24h Change: {float(data.get('price24hPcnt', 0)) * 100:.2f}%")
else:
print("❌ No ticker data received")
except Exception as e:
print(f"❌ Ticker failed: {e}")
# Test 3: Get instruments info
try:
instruments = client.get_instruments_info('linear')
instruments_list = instruments.get('result', {}).get('list', [])
eth_instruments = [i for i in instruments_list if 'ETH' in i.get('symbol', '')]
print(f"✅ Found {len(eth_instruments)} ETH instruments")
for instr in eth_instruments[:3]: # Show first 3
print(f" {instr.get('symbol')} - Status: {instr.get('status')}")
except Exception as e:
print(f"❌ Instruments failed: {e}")
# Test 4: Get orderbook
try:
orderbook = client.get_orderbook('ETHUSDT', 'linear', 5)
ob_data = orderbook.get('result', {})
bids = ob_data.get('b', [])
asks = ob_data.get('a', [])
if bids and asks:
print(f"✅ Orderbook (top 3):")
print(f" Best bid: ${float(bids[0][0]):.2f} (qty: {float(bids[0][1]):.4f})")
print(f" Best ask: ${float(asks[0][0]):.2f} (qty: {float(asks[0][1]):.4f})")
spread = float(asks[0][0]) - float(bids[0][0])
print(f" Spread: ${spread:.2f}")
else:
print("❌ No orderbook data received")
except Exception as e:
print(f"❌ Orderbook failed: {e}")
print(f"📊 {env_name} environment test completed")
def test_live_authentication():
"""Test live authentication (if user wants to test with live credentials)"""
print("\n" + "=" * 60)
print("BYBIT LIVE AUTHENTICATION TEST")
print("=" * 60)
print("⚠️ This will test with LIVE credentials (not testnet)")
# Load environment variables
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
# If dotenv is not available, try to load .env manually
if os.path.exists('.env'):
with open('.env', 'r') as f:
for line in f:
if line.strip() and not line.startswith('#'):
key, value = line.strip().split('=', 1)
os.environ[key] = value
api_key = os.getenv('BYBIT_API_KEY')
api_secret = os.getenv('BYBIT_API_SECRET')
if not api_key or not api_secret:
print("❌ No API credentials found in environment")
return
print(f"🔑 Using API key: {api_key[:8]}...")
# Test with live environment (testnet=False)
client = BybitRestClient(
api_key=api_key,
api_secret=api_secret,
testnet=False # Use live environment
)
# Test connectivity
try:
if client.test_connectivity():
print("✅ Basic connectivity OK")
else:
print("❌ Basic connectivity failed")
return
except Exception as e:
print(f"❌ Connectivity error: {e}")
return
# Test authentication
try:
if client.test_authentication():
print("✅ Authentication successful!")
# Get account info
account_info = client.get_account_info()
accounts = account_info.get('result', {}).get('list', [])
if accounts:
print("📊 Account information:")
for account in accounts:
account_type = account.get('accountType', 'Unknown')
print(f" Account Type: {account_type}")
coins = account.get('coin', [])
usdt_balance = None
for coin in coins:
if coin.get('coin') == 'USDT':
usdt_balance = float(coin.get('walletBalance', 0))
break
if usdt_balance:
print(f" USDT Balance: ${usdt_balance:.2f}")
# Show positions if any
try:
positions = client.get_positions('linear')
pos_list = positions.get('result', {}).get('list', [])
active_positions = [p for p in pos_list if float(p.get('size', 0)) != 0]
if active_positions:
print(f" Active Positions: {len(active_positions)}")
for pos in active_positions:
symbol = pos.get('symbol')
side = pos.get('side')
size = float(pos.get('size', 0))
pnl = float(pos.get('unrealisedPnl', 0))
print(f" {symbol}: {side} {size} (PnL: ${pnl:.2f})")
else:
print(" No active positions")
except Exception as e:
print(f" ⚠️ Could not get positions: {e}")
return True
else:
print("❌ Authentication failed")
return False
except Exception as e:
print(f"❌ Authentication error: {e}")
return False
def main():
"""Main function"""
print("🚀 Starting Bybit API Tests...")
# Test public API
test_public_api()
# Ask user if they want to test live authentication
print("\n" + "=" * 60)
response = input("Do you want to test live authentication? (y/N): ").lower()
if response == 'y' or response == 'yes':
success = test_live_authentication()
if success:
print("\n✅ Live authentication test passed!")
print("🎯 Your Bybit integration is working!")
else:
print("\n❌ Live authentication test failed")
else:
print("\n📋 Skipping live authentication test")
print("\n🎉 Public API tests completed successfully!")
print("📈 Bybit integration is functional for market data")
if __name__ == "__main__":
main()

View File

@ -1,137 +0,0 @@
#!/usr/bin/env python3
"""
Test Cache Fix
Creates a corrupted Parquet file to test the fix mechanism
"""
import os
import pandas as pd
from pathlib import Path
from utils.cache_manager import get_cache_manager
def create_test_data():
"""Create test cache files including a corrupted one"""
print("Creating test cache files...")
# Ensure cache directory exists
cache_dir = Path("data/cache")
cache_dir.mkdir(parents=True, exist_ok=True)
# Create a valid Parquet file
valid_data = pd.DataFrame({
'timestamp': pd.date_range('2025-01-01', periods=100, freq='1min'),
'open': [100.0 + i for i in range(100)],
'high': [101.0 + i for i in range(100)],
'low': [99.0 + i for i in range(100)],
'close': [100.5 + i for i in range(100)],
'volume': [1000 + i*10 for i in range(100)]
})
valid_file = cache_dir / "ETHUSDT_1m.parquet"
valid_data.to_parquet(valid_file, index=False)
print(f"Created valid file: {valid_file}")
# Create a corrupted Parquet file by writing invalid data
corrupted_file = cache_dir / "BTCUSDT_1m.parquet"
with open(corrupted_file, 'wb') as f:
f.write(b"This is not a valid Parquet file - corrupted data")
print(f"Created corrupted file: {corrupted_file}")
# Create an empty file
empty_file = cache_dir / "SOLUSDT_1m.parquet"
empty_file.touch()
print(f"Created empty file: {empty_file}")
def test_cache_manager():
"""Test the cache manager's ability to detect and fix issues"""
print("\n=== Testing Cache Manager ===")
cache_manager = get_cache_manager()
# Scan health
print("1. Scanning cache health...")
health_summary = cache_manager.get_cache_summary()
print(f"Total files: {health_summary['total_files']}")
print(f"Valid files: {health_summary['valid_files']}")
print(f"Corrupted files: {health_summary['corrupted_files']}")
print(f"Health percentage: {health_summary['health_percentage']:.1f}%")
# Show corrupted files
for cache_dir, report in health_summary['directories'].items():
if report['corrupted_files'] > 0:
print(f"\nCorrupted files in {cache_dir}:")
for corrupted in report['corrupted_files_list']:
print(f" - {corrupted['file']}: {corrupted['error']}")
# Test cleanup
print("\n2. Testing cleanup...")
deleted_files = cache_manager.cleanup_corrupted_files(dry_run=False)
deleted_count = 0
for cache_dir, files in deleted_files.items():
for file_info in files:
if "DELETED:" in file_info:
deleted_count += 1
print(f" {file_info}")
print(f"Deleted {deleted_count} corrupted files")
# Verify cleanup
print("\n3. Verifying cleanup...")
health_summary_after = cache_manager.get_cache_summary()
print(f"Corrupted files after cleanup: {health_summary_after['corrupted_files']}")
def test_data_provider_integration():
"""Test that the data provider can handle corrupted cache gracefully"""
print("\n=== Testing Data Provider Integration ===")
# Create another corrupted file
cache_dir = Path("data/cache")
corrupted_file = cache_dir / "ETHUSDT_5m.parquet"
with open(corrupted_file, 'wb') as f:
f.write(b"PAR1\x00\x00corrupted thrift data that will cause deserialization error")
print(f"Created corrupted file with thrift error: {corrupted_file}")
# Try to import and use data provider
try:
from core.data_provider import DataProvider
# Create data provider (should auto-fix corrupted cache)
data_provider = DataProvider()
print("Data provider created successfully - auto-fix worked!")
# Try to load data (should handle corruption gracefully)
try:
data = data_provider._load_from_cache("ETH/USDT", "5m")
if data is None:
print("Cache loading returned None (expected for corrupted file)")
else:
print(f"Loaded {len(data)} rows from cache")
except Exception as e:
print(f"Cache loading failed: {e}")
except Exception as e:
print(f"Data provider test failed: {e}")
def main():
"""Run all tests"""
print("=== Cache Fix Test Suite ===")
# Clean up any existing test files
cache_dir = Path("data/cache")
if cache_dir.exists():
for file in cache_dir.glob("*.parquet"):
file.unlink()
# Run tests
create_test_data()
test_cache_manager()
test_data_provider_integration()
print("\n=== Test Complete ===")
print("The cache fix system is working correctly!")
if __name__ == "__main__":
main()

View File

@ -1,175 +0,0 @@
#!/usr/bin/env python3
"""
Test CNN Integration
This script tests if the CNN adapter is working properly and identifies issues.
"""
import logging
import sys
import os
from datetime import datetime
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_cnn_adapter():
"""Test CNN adapter initialization and basic functionality"""
try:
logger.info("Testing CNN adapter initialization...")
# Test 1: Import CNN adapter
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
logger.info("✅ CNN adapter import successful")
# Test 2: Initialize CNN adapter
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
logger.info("✅ CNN adapter initialization successful")
# Test 3: Check adapter attributes
logger.info(f"CNN adapter model: {cnn_adapter.model}")
logger.info(f"CNN adapter device: {cnn_adapter.device}")
logger.info(f"CNN adapter model_name: {cnn_adapter.model_name}")
# Test 4: Check metrics tracking
logger.info(f"Inference count: {cnn_adapter.inference_count}")
logger.info(f"Training count: {cnn_adapter.training_count}")
logger.info(f"Training data length: {len(cnn_adapter.training_data)}")
# Test 5: Test simple training sample addition
cnn_adapter.add_training_sample("ETH/USDT", "BUY", 0.1)
logger.info(f"✅ Training sample added, new length: {len(cnn_adapter.training_data)}")
# Test 6: Test training if we have enough samples
if len(cnn_adapter.training_data) >= 2:
# Add another sample to have minimum for training
cnn_adapter.add_training_sample("ETH/USDT", "SELL", -0.05)
# Try training
training_result = cnn_adapter.train(epochs=1)
logger.info(f"✅ Training successful: {training_result}")
# Check if metrics were updated
logger.info(f"Last training time: {cnn_adapter.last_training_time}")
logger.info(f"Last training loss: {cnn_adapter.last_training_loss}")
logger.info(f"Training count: {cnn_adapter.training_count}")
else:
logger.info("⚠️ Not enough training samples for training test")
return True
except Exception as e:
logger.error(f"❌ CNN adapter test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_base_data_input():
"""Test BaseDataInput creation"""
try:
logger.info("Testing BaseDataInput creation...")
# Test 1: Import BaseDataInput
from core.data_models import BaseDataInput, OHLCVBar, COBData
logger.info("✅ BaseDataInput import successful")
# Test 2: Create sample OHLCV bars
sample_bars = []
for i in range(10): # Create 10 sample bars
bar = OHLCVBar(
symbol="ETH/USDT",
timestamp=datetime.now(),
open=3500.0 + i,
high=3510.0 + i,
low=3490.0 + i,
close=3505.0 + i,
volume=1000.0,
timeframe="1s"
)
sample_bars.append(bar)
logger.info(f"✅ Created {len(sample_bars)} sample OHLCV bars")
# Test 3: Create BaseDataInput
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=datetime.now(),
ohlcv_1s=sample_bars,
ohlcv_1m=sample_bars,
ohlcv_1h=sample_bars,
ohlcv_1d=sample_bars,
btc_ohlcv_1s=sample_bars
)
logger.info("✅ BaseDataInput created successfully")
# Test 4: Validate BaseDataInput
is_valid = base_data.validate()
logger.info(f"BaseDataInput validation: {is_valid}")
# Test 5: Get feature vector
feature_vector = base_data.get_feature_vector()
logger.info(f"✅ Feature vector created, shape: {feature_vector.shape}")
return base_data
except Exception as e:
logger.error(f"❌ BaseDataInput test failed: {e}")
import traceback
traceback.print_exc()
return None
def test_cnn_prediction():
"""Test CNN prediction with BaseDataInput"""
try:
logger.info("Testing CNN prediction...")
# Get CNN adapter and base data
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
base_data = test_base_data_input()
if not base_data:
logger.error("❌ Cannot test prediction without valid BaseDataInput")
return False
# Test prediction
model_output = cnn_adapter.predict(base_data)
logger.info(f"✅ Prediction successful: {model_output.predictions['action']} ({model_output.confidence:.3f})")
# Check if metrics were updated
logger.info(f"Inference count after prediction: {cnn_adapter.inference_count}")
logger.info(f"Last inference time: {cnn_adapter.last_inference_time}")
logger.info(f"Last prediction output: {cnn_adapter.last_prediction_output}")
return True
except Exception as e:
logger.error(f"❌ CNN prediction test failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Run all tests"""
logger.info("🧪 Starting CNN Integration Tests")
# Test 1: CNN Adapter
if not test_cnn_adapter():
logger.error("❌ CNN adapter test failed - stopping")
return False
# Test 2: CNN Prediction
if not test_cnn_prediction():
logger.error("❌ CNN prediction test failed - stopping")
return False
logger.info("✅ All CNN integration tests passed!")
logger.info("🎯 The CNN adapter should now work properly in the dashboard")
return True
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -1,22 +0,0 @@
#!/usr/bin/env python3
"""
Test COB Dashboard with Enhanced WebSocket
"""
import asyncio
import logging
from web.cob_realtime_dashboard import COBDashboardServer
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
async def main():
"""Test the COB dashboard"""
dashboard = COBDashboardServer(host='localhost', port=8053)
await dashboard.start()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,118 +0,0 @@
#!/usr/bin/env python3
"""
Test script for COB data quality and imbalance indicators
"""
import time
import logging
from core.data_provider import DataProvider
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_cob_data_quality():
"""Test COB data quality and imbalance indicators"""
logger.info("Testing COB data quality and imbalance indicators...")
# Initialize data provider
dp = DataProvider()
# Wait for initial data load and COB connection
logger.info("Waiting for initial data load and COB connection...")
time.sleep(20)
# Test 1: Check cached data summary
logger.info("\n=== Test 1: Cached Data Summary ===")
cache_summary = dp.get_cached_data_summary()
for symbol in cache_summary['cached_data']:
logger.info(f"\n{symbol}:")
for timeframe, info in cache_summary['cached_data'][symbol].items():
if 'candle_count' in info and info['candle_count'] > 0:
logger.info(f" {timeframe}: {info['candle_count']} candles, latest: ${info['latest_price']}")
else:
logger.info(f" {timeframe}: {info.get('status', 'no data')}")
# Test 2: Check COB data quality
logger.info("\n=== Test 2: COB Data Quality ===")
cob_quality = dp.get_cob_data_quality()
for symbol in cob_quality['symbols']:
logger.info(f"\n{symbol} COB Data:")
# Raw ticks
raw_info = cob_quality['raw_ticks'].get(symbol, {})
logger.info(f" Raw ticks: {raw_info.get('count', 0)} ticks")
if raw_info.get('age_seconds') is not None:
logger.info(f" Raw data age: {raw_info['age_seconds']:.1f} seconds")
# Aggregated 1s data
agg_info = cob_quality['aggregated_1s'].get(symbol, {})
logger.info(f" Aggregated 1s: {agg_info.get('count', 0)} records")
if agg_info.get('age_seconds') is not None:
logger.info(f" Aggregated data age: {agg_info['age_seconds']:.1f} seconds")
# Imbalance indicators
imbalance_info = cob_quality['imbalance_indicators'].get(symbol, {})
if imbalance_info:
logger.info(f" Imbalance 1s: {imbalance_info.get('imbalance_1s', 0):.4f}")
logger.info(f" Imbalance 5s: {imbalance_info.get('imbalance_5s', 0):.4f}")
logger.info(f" Imbalance 15s: {imbalance_info.get('imbalance_15s', 0):.4f}")
logger.info(f" Imbalance 60s: {imbalance_info.get('imbalance_60s', 0):.4f}")
logger.info(f" Total volume: {imbalance_info.get('total_volume', 0):.2f}")
logger.info(f" Price buckets: {imbalance_info.get('bucket_count', 0)}")
# Data freshness
freshness = cob_quality['data_freshness'].get(symbol, 'unknown')
logger.info(f" Data freshness: {freshness}")
# Test 3: Get recent COB aggregated data
logger.info("\n=== Test 3: Recent COB Aggregated Data ===")
for symbol in ['ETH/USDT', 'BTC/USDT']:
recent_cob = dp.get_cob_1s_aggregated(symbol, count=5)
logger.info(f"\n{symbol} - Last 5 aggregated records:")
for i, record in enumerate(recent_cob[-5:]):
timestamp = record.get('timestamp', 0)
imbalance_1s = record.get('imbalance_1s', 0)
imbalance_5s = record.get('imbalance_5s', 0)
total_volume = record.get('total_volume', 0)
bucket_count = len(record.get('bid_buckets', {})) + len(record.get('ask_buckets', {}))
logger.info(f" [{i+1}] Time: {timestamp}, Imb1s: {imbalance_1s:.4f}, "
f"Imb5s: {imbalance_5s:.4f}, Vol: {total_volume:.2f}, Buckets: {bucket_count}")
# Test 4: Monitor real-time updates
logger.info("\n=== Test 4: Real-time Updates (30 seconds) ===")
logger.info("Monitoring COB data updates...")
initial_quality = dp.get_cob_data_quality()
time.sleep(30)
updated_quality = dp.get_cob_data_quality()
for symbol in ['ETH/USDT', 'BTC/USDT']:
initial_count = initial_quality['raw_ticks'].get(symbol, {}).get('count', 0)
updated_count = updated_quality['raw_ticks'].get(symbol, {}).get('count', 0)
new_ticks = updated_count - initial_count
initial_agg = initial_quality['aggregated_1s'].get(symbol, {}).get('count', 0)
updated_agg = updated_quality['aggregated_1s'].get(symbol, {}).get('count', 0)
new_agg = updated_agg - initial_agg
logger.info(f"{symbol}: +{new_ticks} raw ticks, +{new_agg} aggregated records")
# Show latest imbalances
latest_imbalances = updated_quality['imbalance_indicators'].get(symbol, {})
if latest_imbalances:
logger.info(f" Latest imbalances: 1s={latest_imbalances.get('imbalance_1s', 0):.4f}, "
f"5s={latest_imbalances.get('imbalance_5s', 0):.4f}, "
f"15s={latest_imbalances.get('imbalance_15s', 0):.4f}, "
f"60s={latest_imbalances.get('imbalance_60s', 0):.4f}")
# Clean shutdown
logger.info("\n=== Shutting Down ===")
dp.stop_automatic_data_maintenance()
logger.info("COB data quality test completed")
if __name__ == "__main__":
test_cob_data_quality()

View File

@ -1,131 +0,0 @@
#!/usr/bin/env python3
"""
Test COB WebSocket Only Integration
This script tests that COB integration works with Enhanced WebSocket only,
without falling back to REST API calls.
"""
import asyncio
import time
from datetime import datetime
from typing import Dict
from core.cob_integration import COBIntegration
async def test_cob_websocket_only():
"""Test COB integration with WebSocket only"""
print("=== Testing COB WebSocket Only Integration ===")
# Initialize COB integration
print("1. Initializing COB integration...")
symbols = ['ETH/USDT', 'BTC/USDT']
cob_integration = COBIntegration(symbols=symbols)
# Track updates
update_count = 0
last_update_time = None
def dashboard_callback(symbol: str, data: Dict):
nonlocal update_count, last_update_time
update_count += 1
last_update_time = datetime.now()
if update_count <= 5: # Show first 5 updates
data_type = data.get('type', 'unknown')
if data_type == 'cob_update':
stats = data.get('data', {}).get('stats', {})
mid_price = stats.get('mid_price', 0)
spread_bps = stats.get('spread_bps', 0)
source = stats.get('source', 'unknown')
print(f" Update #{update_count}: {symbol} - Price: ${mid_price:.2f}, Spread: {spread_bps:.1f}bps, Source: {source}")
elif data_type == 'websocket_status':
status_data = data.get('data', {})
status = status_data.get('status', 'unknown')
print(f" Status #{update_count}: {symbol} - WebSocket: {status}")
# Add dashboard callback
cob_integration.add_dashboard_callback(dashboard_callback)
# Start COB integration
print("2. Starting COB integration...")
try:
# Start in background
start_task = asyncio.create_task(cob_integration.start())
# Wait for initialization
await asyncio.sleep(3)
# Check if COB provider is disabled
print("3. Checking COB provider status:")
if cob_integration.cob_provider is None:
print(" ✅ COB provider is disabled (using Enhanced WebSocket only)")
else:
print(" ❌ COB provider is still active (may cause REST API fallback)")
# Check Enhanced WebSocket status
print("4. Checking Enhanced WebSocket status:")
if cob_integration.enhanced_websocket:
print(" ✅ Enhanced WebSocket is initialized")
# Check WebSocket status for each symbol
websocket_status = cob_integration.get_websocket_status()
for symbol, status in websocket_status.items():
print(f" {symbol}: {status}")
else:
print(" ❌ Enhanced WebSocket is not initialized")
# Monitor updates for a few seconds
print("5. Monitoring COB updates...")
initial_count = update_count
monitor_start = time.time()
# Wait for updates
await asyncio.sleep(5)
monitor_duration = time.time() - monitor_start
updates_received = update_count - initial_count
update_rate = updates_received / monitor_duration
print(f" Received {updates_received} updates in {monitor_duration:.1f}s")
print(f" Update rate: {update_rate:.1f} updates/second")
if update_rate >= 8: # Should be around 10 updates/second
print(" ✅ Update rate is excellent (8+ updates/second)")
elif update_rate >= 5:
print(" ✅ Update rate is good (5+ updates/second)")
elif update_rate >= 1:
print(" ⚠️ Update rate is low (1+ updates/second)")
else:
print(" ❌ Update rate is too low (<1 update/second)")
# Check data quality
print("6. Data quality check:")
if last_update_time:
time_since_last = (datetime.now() - last_update_time).total_seconds()
if time_since_last < 1:
print(f" ✅ Recent data (last update {time_since_last:.1f}s ago)")
else:
print(f" ⚠️ Stale data (last update {time_since_last:.1f}s ago)")
else:
print(" ❌ No updates received")
# Stop the integration
print("7. Stopping COB integration...")
await cob_integration.stop()
# Cancel the start task
start_task.cancel()
try:
await start_task
except asyncio.CancelledError:
pass
except Exception as e:
print(f" ❌ Error during COB integration test: {e}")
print(f"\n✅ COB WebSocket only test completed!")
print(f"Total updates received: {update_count}")
print("Enhanced WebSocket is now the sole data source (no REST API fallback)")
if __name__ == "__main__":
asyncio.run(test_cob_websocket_only())

Some files were not shown because too many files have changed in this diff Show More