massive clenup

This commit is contained in:
Dobromir Popov 2025-05-24 10:32:00 +03:00
parent 310f3c5bf9
commit b5ad023b16
87 changed files with 1930 additions and 784568 deletions

View File

@ -1,150 +0,0 @@
# Project Cleanup & Reorganization Plan
## Current Issues
1. **Code Duplication**: Multiple CNN models, RL agents, training scripts doing similar things
2. **Missing Methods**: Core functionality like `run()`, `start_websocket()` missing from classes
3. **Unclear Architecture**: No clean separation between components
4. **Hard to Maintain**: Scattered implementations make changes difficult
## New Clean Architecture
```
gogo2/
├── core/ # Core system components
│ ├── __init__.py
│ ├── data_provider.py # Multi-timeframe, multi-symbol data
│ ├── orchestrator.py # Main decision making module
│ └── config.py # Central configuration
├── models/ # AI/ML Models
│ ├── __init__.py
│ ├── cnn/ # CNN module
│ │ ├── __init__.py
│ │ ├── model.py # Single CNN implementation
│ │ ├── trainer.py # CNN training pipeline
│ │ └── predictor.py # CNN inference with confidence
│ └── rl/ # RL module
│ ├── __init__.py
│ ├── agent.py # Single RL agent implementation
│ ├── environment.py # Trading environment
│ └── trainer.py # RL training loop
├── trading/ # Trading execution
│ ├── __init__.py
│ ├── executor.py # Trade execution
│ ├── portfolio.py # Position/portfolio management
│ └── metrics.py # Performance tracking
├── web/ # Web interface
│ ├── __init__.py
│ ├── dashboard.py # Main dashboard
│ └── charts.py # Chart components
├── utils/ # Utilities
│ ├── __init__.py
│ ├── logger.py # Centralized logging
│ └── helpers.py # Common helpers
├── main.py # Single entry point
├── config.yaml # Configuration file
└── requirements.txt # Dependencies
```
## Core Goals
### 1. Data Provider (`core/data_provider.py`)
- **Multi-symbol support**: ETH/USDT, BTC/USDT (configurable)
- **Multi-timeframe**: 1m, 5m, 15m, 1h, 4h, 1d
- **Real-time streaming**: WebSocket integration
- **Historical data**: API integration for backtesting
- **Clean interface**: Simple methods for getting data
### 2. CNN Module (`models/cnn/`)
- **Single model implementation**: Remove duplicates
- **Timeframe-specific predictions**: Separate predictions per timeframe
- **Confidence scoring**: Each prediction includes confidence
- **Training pipeline**: Supervised learning with marked data (perfect moves)
### 3. RL Module (`models/rl/`)
- **Single agent**: Remove duplicate DQN implementations
- **Environment**: Clean trading simulation
- **Learning loop**: Evaluates trading actions and adapts
### 4. Orchestrator (`core/orchestrator.py`)
- **Decision making**: Combines CNN and RL outputs
- **Final actions**: BUY/SELL/HOLD decisions
- **Confidence weighting**: Uses CNN confidence in decisions
### 5. Web Interface (`web/`)
- **Real-time charts**: Live trading visualization
- **Performance dashboard**: Metrics and analytics
- **Simple & clean**: Remove complex chart implementations
## Cleanup Steps
### Phase 1: Core Infrastructure
1. Create new clean directory structure
2. Implement `core/data_provider.py` (consolidate all data functionality)
3. Implement `core/orchestrator.py` (main decision maker)
4. Create `config.yaml` for all settings
### Phase 2: Model Consolidation
1. Create single `models/cnn/model.py` (consolidate all CNN implementations)
2. Create single `models/rl/agent.py` (consolidate DQN implementations)
3. Remove duplicate model files
### Phase 3: Training Simplification
1. Create `models/cnn/trainer.py` (single CNN training script)
2. Create `models/rl/trainer.py` (single RL training script)
3. Remove all duplicate training scripts
### Phase 4: Web Interface
1. Create clean `web/dashboard.py` (consolidate chart functionality)
2. Remove complex/unused chart implementations
### Phase 5: Integration & Testing
1. Create single `main.py` entry point
2. Test all components work together
3. Remove unused files
## Files to Remove (After consolidation)
### Duplicate Training Scripts
- `train_hybrid.py`
- `train_dqn.py`
- `train_cnn_with_realtime.py`
- `train_with_realtime_ticks.py`
- `train_improved_rl.py`
- `NN/train_enhanced.py`
- `NN/train_rl.py`
### Duplicate Model Files
- `NN/models/cnn_model.py`
- `NN/models/enhanced_cnn.py`
- `NN/models/simple_cnn.py`
- `NN/models/transformer_model.py`
- `NN/models/transformer_model_pytorch.py`
- `NN/models/dqn_agent_enhanced.py`
### Duplicate Main Files
- `trading_main.py`
- `NN/main.py`
- `NN/realtime_main.py`
- `NN/realtime-main.py`
### Unused Utilities
- `launch_training.py`
- `NN/example.py`
- Most logs and backup directories
## Benefits of New Architecture
1. **Single Source of Truth**: One implementation per component
2. **Clear Separation**: CNN, RL, and Orchestrator are distinct
3. **Easy to Extend**: Adding new symbols/timeframes is simple
4. **Maintainable**: Changes are localized to specific modules
5. **Testable**: Each component can be tested independently
## Implementation Priority
1. **HIGH**: Core data provider and orchestrator
2. **HIGH**: Single CNN and RL implementations
3. **MEDIUM**: Web dashboard consolidation
4. **LOW**: Cleanup of unused files
This plan will result in a much cleaner, more maintainable codebase focused on the core goal: multi-modal trading system with CNN predictions and RL decision making.

View File

View File

@ -1,340 +0,0 @@
# Disk Space Optimization for Model Training
## Issue
The training process was encountering "No space left on device" errors during model saving operations, preventing successful completion of training cycles. Additionally, we identified matrix multiplication errors and TorchScript compatibility issues that were causing training crashes.
## Solution Implemented
A comprehensive set of improvements were implemented in the `main.py` file to address these issues:
1. Creating smaller checkpoint files with minimal model data
2. Providing multiple fallback mechanisms when primary save methods fail
3. Saving essential model parameters as JSON when full model saving fails
4. Automatic cleanup of old model files to free up disk space
5. **NEW**: Model quantization for even smaller file sizes
6. **NEW**: Fixed TorchScript compatibility issues with `CandlePatternCNN`
7. **NEW**: Fixed matrix multiplication errors in the `LSTMAttentionDQN` class
8. **NEW**: Added aggressive cleanup option for very low disk space situations
## Implementation Details
### Compact Save Function with Quantization
The updated `compact_save` function now includes an option to use model quantization for even smaller file sizes:
```python
def compact_save(model, optimizer, reward, epsilon, state_size, action_size, hidden_size, path, use_quantization=False):
"""
Save a model in a compact format suitable for low disk space environments.
Includes fallbacks if the primary save method fails.
"""
try:
# Create minimal checkpoint with essential data only
checkpoint = {
'model_state_dict': model.state_dict(),
'epsilon': epsilon,
'state_size': state_size,
'action_size': action_size,
'hidden_size': hidden_size
}
# Apply quantization if requested
if use_quantization:
try:
logging.info(f"Attempting quantized save to {path}")
# Quantize model to int8
quantized_model = torch.quantization.quantize_dynamic(
model, # the original model
{torch.nn.Linear}, # a set of layers to dynamically quantize
dtype=torch.qint8 # the target dtype for quantized weights
)
# Create quantized checkpoint
quantized_checkpoint = {
'model_state_dict': quantized_model.state_dict(),
'epsilon': epsilon,
'state_size': state_size,
'action_size': action_size,
'hidden_size': hidden_size,
'is_quantized': True
}
# Save with older pickle protocol and disable new zipfile serialization
torch.save(quantized_checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2)
logging.info(f"Quantized compact save successful to {path}")
return True
except Exception as e:
logging.warning(f"Quantized save failed, falling back to regular save: {str(e)}")
# Fall back to regular save if quantization fails
# Regular save with older pickle protocol and no zipfile serialization
torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2)
logging.info(f"Compact save successful to {path}")
return True
except Exception as e:
logging.error(f"Compact save failed: {str(e)}")
logging.error(traceback.format_exc())
# Fallback: Save just the parameters as JSON if we can't save the full model
try:
params = {
'epsilon': epsilon,
'state_size': state_size,
'action_size': action_size,
'hidden_size': hidden_size
}
json_path = f"{path}.params.json"
with open(json_path, 'w') as f:
json.dump(params, f)
logging.info(f"Saved minimal parameters to {json_path}")
return False
except Exception as json_e:
logging.error(f"JSON parameter save failed: {str(json_e)}")
return False
```
### TorchScript Compatibility Fix
The `CandlePatternCNN` class was refactored to make it compatible with TorchScript by replacing the dictionary-based feature storage with tensor attributes:
```python
class CandlePatternCNN(nn.Module):
"""Convolutional neural network for detecting candlestick patterns"""
def __init__(self, input_channels=5, feature_dimension=512):
super(CandlePatternCNN, self).__init__()
# ... existing CNN layers ...
# Initialize intermediate features as empty tensors, not as a dict
# This makes the model TorchScript compatible
self.feature_1m = torch.zeros(1, feature_dimension)
self.feature_1h = torch.zeros(1, feature_dimension)
self.feature_1d = torch.zeros(1, feature_dimension)
def forward(self, x_1m, x_1h, x_1d):
# Process timeframe data
feat_1m = self.process_timeframe(x_1m)
feat_1h = self.process_timeframe(x_1h)
feat_1d = self.process_timeframe(x_1d)
# Store features as attributes instead of in a dictionary
self.feature_1m = feat_1m
self.feature_1h = feat_1h
self.feature_1d = feat_1d
# Concatenate features from different timeframes
combined_features = torch.cat([feat_1m, feat_1h, feat_1d], dim=1)
return combined_features
```
### Matrix Multiplication Error Fix
The `LSTMAttentionDQN` forward method was enhanced to handle different tensor shapes safely, preventing matrix multiplication errors:
```python
def forward(self, state, x_1m=None, x_1h=None, x_1d=None):
"""
Forward pass handling different input shapes and optional CNN features
"""
batch_size = state.size(0)
# Handle CNN features if provided
if x_1m is not None and x_1h is not None and x_1d is not None:
# Ensure all CNN features have batch dimension
if len(x_1m.shape) == 2:
x_1m = x_1m.unsqueeze(0)
if len(x_1h.shape) == 2:
x_1h = x_1h.unsqueeze(0)
if len(x_1d.shape) == 2:
x_1d = x_1d.unsqueeze(0)
# Ensure batch dimensions match
if x_1m.size(0) != batch_size:
x_1m = x_1m.expand(batch_size, -1, -1) if x_1m.size(0) == 1 else x_1m[:batch_size]
# ... additional shape handling ...
# Handle variable dimensions more gracefully
needed_features = 512
if x_1m_flat.size(1) < needed_features:
x_1m_flat = F.pad(x_1m_flat, (0, needed_features - x_1m_flat.size(1)))
else:
x_1m_flat = x_1m_flat[:, :needed_features]
```
### Enhanced File Cleanup
The file cleanup function now includes an aggressive mode and disk space reporting:
```python
def cleanup_model_files(keep_best=True, keep_latest_n=5, aggressive=False):
"""
Delete old model files to free up disk space.
Args:
keep_best (bool): Whether to keep the best model files (reward, pnl, net_pnl)
keep_latest_n (int): Number of latest checkpoint files to keep
aggressive (bool): If True, apply more aggressive cleanup in very low disk scenarios
"""
try:
logging.info(f"Running model file cleanup: keep_best={keep_best}, keep_latest_n={keep_latest_n}")
models_dir = "models"
# Get all files in the models directory
all_files = os.listdir(models_dir)
# Files to potentially delete
checkpoint_files = []
# Best files to keep if keep_best is True
best_patterns = [
"trading_agent_best_reward.pt",
"trading_agent_best_pnl.pt",
"trading_agent_best_net_pnl.pt",
"trading_agent_final.pt"
]
# Collect checkpoint files that can be deleted
for filename in all_files:
file_path = os.path.join(models_dir, filename)
# Skip directories
if os.path.isdir(file_path):
continue
# Skip current best files if keep_best is True
if keep_best and any(filename == pattern for pattern in best_patterns):
continue
# Collect checkpoint files
if "checkpoint" in filename and filename.endswith(".pt"):
checkpoint_files.append((filename, os.path.getmtime(file_path), file_path))
# If we have more checkpoint files than we want to keep
if len(checkpoint_files) > keep_latest_n:
# Sort by modification time (newest first)
checkpoint_files.sort(key=lambda x: x[1], reverse=True)
# Keep the newest N files
files_to_delete = checkpoint_files[keep_latest_n:]
# Delete old checkpoint files
bytes_freed = 0
for _, _, file_path in files_to_delete:
try:
file_size = os.path.getsize(file_path)
os.remove(file_path)
bytes_freed += file_size
logging.info(f"Deleted old checkpoint file: {file_path}")
except Exception as e:
logging.error(f"Failed to delete file {file_path}: {str(e)}")
logging.info(f"Cleanup complete. Deleted {len(files_to_delete)} files, freed {bytes_freed / (1024*1024):.2f} MB")
else:
logging.info(f"No cleanup needed. Found {len(checkpoint_files)} checkpoint files, keeping {keep_latest_n}")
except Exception as e:
logging.error(f"Error during file cleanup: {str(e)}")
logging.error(traceback.format_exc())
# Check available disk space after cleanup
try:
if platform.system() == 'Windows':
free_bytes = ctypes.c_ulonglong(0)
ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(os.path.abspath(models_dir)), None, None, ctypes.pointer(free_bytes))
free_mb = free_bytes.value / (1024 * 1024)
else:
st = os.statvfs(os.path.abspath(models_dir))
free_mb = (st.f_bavail * st.f_frsize) / (1024 * 1024)
logging.info(f"Available disk space after cleanup: {free_mb:.2f} MB")
# If space is still low, recommend aggressive cleanup
if free_mb < 200 and not aggressive: # Less than 200MB available
logging.warning("Disk space still critically low. Consider using aggressive cleanup.")
except Exception as e:
logging.error(f"Error checking disk space: {str(e)}")
```
### Train Agent Function Modification
The `train_agent` function was modified to include the `use_compact_save` option:
```python
def train_agent(episodes, max_steps, update_interval=10, training_iterations=10,
use_compact_save=False):
# ...existing code...
if use_compact_save:
compact_save(agent.policy_net, agent.optimizer, total_reward, agent.epsilon,
agent.state_size, agent.action_size, agent.hidden_size,
f"models/trading_agent_best_reward.pt")
else:
agent.save(f"models/trading_agent_best_reward.pt")
# ...similar modifications for other save points...
```
### Command Line Arguments
New command line arguments have been added to support these features:
```python
parser.add_argument('--compact_save', action='store_true', help='Use compact save to reduce disk usage')
parser.add_argument('--use_quantization', action='store_true', help='Use model quantization for even smaller file sizes')
parser.add_argument('--cleanup', action='store_true', help='Clean up old model files before training')
parser.add_argument('--aggressive_cleanup', action='store_true', help='Perform aggressive cleanup to free more space')
parser.add_argument('--keep_latest', type=int, default=5, help='Number of latest checkpoint files to keep when cleaning up')
```
## Results
### Effectiveness
The comprehensive approach to disk space optimization addresses multiple issues:
1. **Successful Saves**: Multiple successful save methods that adapt to different disk space conditions
2. **Fallback Mechanism**: Smaller fallback files when full model saving fails
3. **Training Stability**: Fixed TorchScript compatibility and matrix multiplication errors prevent crashes
4. **Automatic Cleanup**: Reduced disk usage through automatic cleanup of old files
### File Size Comparison
The optimization techniques create smaller files through multiple approaches:
- **Quantized Models**: Using INT8 quantization can reduce model size by up to 75%
- **Non-Optimizer Saves**: Excluding optimizer state reduces file size by ~50%
- **JSON Parameters**: Extremely small (under 100 bytes) for essential restart capability
- **Cleanup**: Automatic removal of old checkpoint files frees up disk space
## Usage Instructions
To use these disk space optimization features, run the training with the following command line options:
```bash
# Basic usage with compact save
python main.py --mode train --episodes 10 --max_steps 200 --compact_save
# With model quantization for even smaller files
python main.py --mode train --episodes 10 --max_steps 200 --compact_save --use_quantization
# With file cleanup before training
python main.py --mode train --episodes 10 --max_steps 200 --compact_save --cleanup
# With aggressive cleanup for very low disk space
python main.py --mode train --episodes 10 --max_steps 200 --compact_save --cleanup --aggressive_cleanup
# Specify how many checkpoint files to keep
python main.py --mode train --episodes 10 --max_steps 200 --compact_save --cleanup --keep_latest 3
```
## Additional Recommendations
1. **Disk Space Monitoring**: The code now reports available disk space after cleanup. Monitor this to ensure sufficient space is maintained.
2. **Regular Cleanup**: Schedule regular cleanup operations, especially for long training sessions.
3. **Model Pruning**: Consider implementing neural network pruning to remove unnecessary connections in the model, further reducing size.
4. **Remote Storage**: For very long training sessions, consider implementing automatic upload of checkpoint files to remote storage.
## Conclusion
The implemented disk space optimization features have successfully addressed multiple issues:
1. Fixed TorchScript compatibility and matrix multiplication errors that were causing crashes
2. Implemented model quantization for significantly smaller file sizes
3. Added aggressive cleanup options to manage disk space automatically
4. Provided multiple fallback mechanisms to ensure training progress isn't lost
These improvements allow training to continue even under severe disk space constraints, with minimal intervention required.

View File

@ -1,87 +0,0 @@
# Implementation Summary: Training Stability and Disk Space Optimization
## Issues Addressed
1. **Disk Space Errors**: "No space left on device" errors during model saving operations
2. **Matrix Multiplication Errors**: Shape mismatches in neural network operations
3. **TorchScript Compatibility Issues**: Errors when attempting to use `torch.jit.save()`
4. **Training Crashes**: Unhandled exceptions in saving process
## Solutions Implemented
### Disk Space Optimization
1. **Compact Model Saving**
- Created minimal checkpoint files with essential data only
- Implemented multiple fallback mechanisms for different disk space scenarios
- Added JSON parameter saving as a last resort
- Integrated model quantization (INT8) for reduced file sizes
2. **Automatic File Cleanup**
- Added automatic cleanup of older checkpoint files
- Implemented "aggressive cleanup" mode for critically low disk space
- Added disk space monitoring to report available space
- Created retention policies to keep best models while removing unnecessary files
### Neural Network Improvements
1. **TorchScript Compatibility**
- Refactored `CandlePatternCNN` class to use tensor attributes instead of dictionaries
- Simplified layer architecture to ensure compatibility with TorchScript
- Fixed forward method to handle tensor shapes consistently
2. **Matrix Multiplication Fix**
- Enhanced tensor shape handling in `LSTMAttentionDQN` forward method
- Added robust dimension checking and correction
- Implemented padding/truncating for variable-sized inputs
- Fixed batch dimension handling for CNN features
## Results
The implemented changes resulted in:
1. **Improved Stability**: Training no longer crashes due to matrix multiplication errors or torch.jit issues
2. **Efficient Disk Usage**: Freed up 3.8 GB of disk space through aggressive cleanup
3. **Fallback Mechanisms**: Successfully created fallback files when primary saves failed
4. **Enhanced Monitoring**: Added disk space tracking to report remaining space after cleanup operations
## Command Line Usage
The improvements can be activated with the following command line arguments:
```bash
# Basic usage with compact save
python main.py --mode train --episodes 10 --compact_save
# With model quantization for smaller files
python main.py --mode train --episodes 10 --compact_save --use_quantization
# With file cleanup before training
python main.py --mode train --episodes 10 --compact_save --cleanup
# With aggressive cleanup for very low disk space
python main.py --mode train --episodes 10 --compact_save --cleanup --aggressive_cleanup
# Specify how many checkpoint files to keep
python main.py --mode train --episodes 10 --compact_save --cleanup --keep_latest 3
```
## Key Files Modified
1. `main.py`: Added new functions and modified existing ones:
- Added `compact_save()` function with quantization support
- Enhanced `cleanup_model_files()` function with aggressive mode
- Refactored `CandlePatternCNN` class for TorchScript compatibility
- Fixed shape handling in `LSTMAttentionDQN` forward method
2. `DISK_SPACE_OPTIMIZATION.md`: Comprehensive documentation of the disk space optimization features
- Detailed explanation of all implemented features
- Usage instructions and recommendations
- Performance analysis of the enhancements
## Future Recommendations
1. **Long-term Storage Solution**: Implement automatic upload to cloud storage for long training sessions
2. **Advanced Model Compression**: Explore neural network pruning and mixed-precision training
3. **Automatic Cleanup Scheduler**: Set up periodic cleanup based on disk usage thresholds
4. **Checkpoint Rotation Strategy**: Implement more sophisticated model retention policies

View File

@ -1,74 +0,0 @@
# Model Saving Fix
## Issue
During training sessions, PyTorch model saving operations sometimes fail with errors like:
```
RuntimeError: [enforce fail at inline_container.cc:626] . unexpected pos 18278784 vs 18278680
```
or
```
RuntimeError: [enforce fail at inline_container.cc:820] . PytorchStreamWriter failed writing file data/75: file write failed
```
These errors occur in the PyTorch serialization mechanism when saving models using `torch.save()`.
## Solution
We've implemented a robust model saving approach that uses multiple fallback methods if the primary save operation fails:
1. **Attempt 1**: Save to a backup file first, then copy to the target path.
2. **Attempt 2**: Use an older pickle protocol (pickle protocol 2) which can be more compatible.
3. **Attempt 3**: Save without the optimizer state, which can reduce file size and avoid serialization issues.
4. **Attempt 4**: Use TorchScript's `torch.jit.save()` instead of `torch.save()`, which uses a different serialization mechanism.
## Implementation
The solution is implemented in two parts:
1. A `robust_save` function that tries multiple saving approaches with fallbacks.
2. A monkey patch that replaces the Agent's `save` method with our robust version.
### Example Usage
```python
# Import the robust_save function
from live_training import robust_save
# Save a model with fallbacks
success = robust_save(agent, "models/my_model.pt")
if success:
print("Model saved successfully!")
else:
print("All save attempts failed")
```
## Testing
We've created a test script `test_save.py` that demonstrates the robust saving approach and verifies that it works correctly.
To run the test:
```bash
python test_save.py
```
This script creates a simple model, attempts to save it using both the standard and robust methods, and reports on the results.
## Future Improvements
Possible future improvements to the model saving mechanism:
1. Additional fallback methods like serializing individual neural network layers.
2. Automatic retry mechanism with exponential backoff.
3. Asynchronous saving to avoid blocking the training loop.
4. Checksumming saved models to verify integrity.
## Related Issues
For more information on similar issues with PyTorch model saving, see:
- https://github.com/pytorch/pytorch/issues/27736
- https://github.com/pytorch/pytorch/issues/24045

View File

@ -1,72 +0,0 @@
# Model Saving Recommendations
During training, several PyTorch model serialization errors were identified and fixed. Here's a summary of our findings and recommendations to ensure robust model saving:
## Issues Found
1. **PyTorch Serialization Errors**: Errors like `PytorchStreamWriter failed writing file data...` and `unexpected pos...` indicate issues with PyTorch's serialization mechanism.
2. **Disk Space Issues**: Our tests showed `No space left on device` errors, which can cause model corruption.
3. **Compatibility Issues**: Some serialization methods might not be compatible with specific PyTorch versions or environments.
## Implemented Solutions
1. **Robust Save Function**: We added a `robust_save` function that tries multiple saving approaches in sequence:
- First attempt: Standard save to a backup file, then copy to the target path
- Second attempt: Save with pickle protocol 2 (more compatible)
- Third attempt: Save without optimizer state (reduces file size)
- Fourth attempt: Use TorchScript's `jit.save()` (different serialization mechanism)
2. **Memory Management**: Implemented memory cleanup before saving:
- Clearing GPU cache with `torch.cuda.empty_cache()`
- Running garbage collection with `gc.collect()`
3. **Error Handling**: Added comprehensive error handling around all saving operations.
4. **Circuit Breaker Pattern**: Added circuit breakers to prevent consecutive failures during training.
## Recommendations
1. **Disk Space**: Ensure sufficient disk space is available (at least 1-2GB free). Large models can use several GB of disk space.
2. **Checkpoint Cleanup**: Periodically remove old checkpoints to free up space:
```bash
# Example script to keep only the most recent 5 checkpoints
Get-ChildItem -Path .\models\trading_agent_checkpoint_*.pt |
Sort-Object LastWriteTime -Descending |
Select-Object -Skip 5 |
Remove-Item
```
3. **File System Check**: If persistent errors occur, check the file system for errors or corruption.
4. **Use Smaller Models**: Consider reducing model size if saving large models is problematic.
5. **Alternative Serialization**: For very large models, consider saving key parameters separately rather than the entire model.
6. **Training Stability**: Use our improved training functions with memory management and error handling.
## How to Test Model Saving
We've provided a test script `test_model_save_load.py` that can verify if model saving is working correctly. Run it with:
```bash
python test_model_save_load.py
```
Or test all robust save methods with:
```bash
python test_model_save_load.py --test_robust
```
## Future Development
1. **Checksumming**: Add checksums to saved models to verify integrity.
2. **Compression**: Implement model compression to reduce file size.
3. **Distributed Saving**: For very large models, explore distributed saving mechanisms.
4. **Format Conversion**: Add ability to save models in ONNX or other portable formats.

View File

@ -1,261 +0,0 @@
#!/usr/bin/env python
"""
Example script for the Neural Network Trading System
This shows basic usage patterns for the system components
"""
import os
import sys
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime
import logging
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import components
from NN.utils.data_interface import DataInterface
from NN.models.cnn_model import CNNModel
from NN.models.transformer_model import TransformerModel, MixtureOfExpertsModel
from NN.main import NeuralNetworkOrchestrator
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('example')
def example_data_interface():
"""Show how to use the data interface"""
logger.info("=== Data Interface Example ===")
# Initialize data interface
di = DataInterface(symbol="BTC/USDT", timeframes=['1h', '4h', '1d'])
# Get historical data
df_1h = di.get_historical_data(timeframe='1h', n_candles=100)
if df_1h is not None and not df_1h.empty:
logger.info(f"Retrieved {len(df_1h)} 1-hour candles")
logger.info(f"Most recent candle: {df_1h.iloc[-1]}")
# Prepare data for neural network
X, y, timestamps = di.prepare_nn_input(timeframes=['1h'], n_candles=500, window_size=20)
if X is not None and y is not None:
logger.info(f"Prepared input shape: {X.shape}, target shape: {y.shape}")
# Generate a dataset
dataset = di.generate_training_dataset(
timeframes=['1h', '4h'],
n_candles=1000,
window_size=20
)
if dataset:
logger.info(f"Dataset generated and saved to: {list(dataset.values())}")
return X, y, timestamps if X is not None else (None, None, None)
def example_cnn_model(X=None, y=None):
"""Show how to use the CNN model"""
logger.info("=== CNN Model Example ===")
# If no data provided, create dummy data
if X is None or y is None:
logger.info("Creating dummy data for CNN example")
X = np.random.random((1000, 20, 5)) # 1000 samples, 20 time steps, 5 features
y = np.random.randint(0, 2, size=(1000,)) # Binary labels
# Split data into training and testing sets
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Initialize and build the CNN model
cnn = CNNModel(input_shape=(20, 5), output_size=1, model_dir='NN/models/saved')
cnn.build_model(filters=(32, 64, 128), kernel_sizes=(3, 5, 7), dropout_rate=0.3)
# Train the model (very small number of epochs for this example)
history = cnn.train(
X_train, y_train,
batch_size=32,
epochs=5, # Just a few epochs for the example
validation_split=0.2
)
# Evaluate the model
metrics = cnn.evaluate(X_test, y_test, plot_results=True)
if metrics:
logger.info(f"CNN Evaluation metrics: {metrics}")
# Make a prediction
y_pred, y_proba = cnn.predict(X_test[:1])
logger.info(f"CNN Prediction: {y_pred[0]}, Probability: {y_proba[0]:.4f}")
return cnn
def example_transformer_model(X=None, y=None, cnn_model=None):
"""Show how to use the Transformer model"""
logger.info("=== Transformer Model Example ===")
# If no data provided, create dummy data
if X is None or y is None:
logger.info("Creating dummy data for Transformer example")
X = np.random.random((1000, 20, 5)) # 1000 samples, 20 time steps, 5 features
y = np.random.randint(0, 2, size=(1000,)) # Binary labels
# Generate high-level features (from CNN model or random if no CNN provided)
if cnn_model is not None and hasattr(cnn_model, 'extract_hidden_features'):
# Extract features from CNN model
X_features = cnn_model.extract_hidden_features(X)
logger.info(f"Extracted {X_features.shape[1]} features from CNN model")
else:
# Generate random features
X_features = np.random.random((len(X), 128))
logger.info("Generated random features for Transformer model")
# Split data into training and testing sets
from sklearn.model_selection import train_test_split
X_train, X_test, X_feat_train, X_feat_test, y_train, y_test = train_test_split(
X, X_features, y, test_size=0.2, random_state=42
)
# Initialize and build the Transformer model
transformer = TransformerModel(
ts_input_shape=(20, 5),
feature_input_shape=X_features.shape[1],
output_size=1,
model_dir='NN/models/saved'
)
transformer.build_model(
embed_dim=32,
num_heads=2,
ff_dim=64,
num_transformer_blocks=2,
dropout_rate=0.2
)
# Train the model (very small number of epochs for this example)
history = transformer.train(
X_train, X_feat_train, y_train,
batch_size=32,
epochs=5, # Just a few epochs for the example
validation_split=0.2
)
# Make a prediction
y_pred, y_proba = transformer.predict(X_test[:1], X_feat_test[:1])
logger.info(f"Transformer Prediction: {y_pred[0]}, Probability: {y_proba[0]:.4f}")
return transformer
def example_moe_model(X=None, y=None, cnn_model=None, transformer_model=None):
"""Show how to use the Mixture of Experts model"""
logger.info("=== Mixture of Experts Example ===")
# If no data provided, create dummy data
if X is None or y is None:
logger.info("Creating dummy data for MoE example")
X = np.random.random((1000, 20, 5)) # 1000 samples, 20 time steps, 5 features
y = np.random.randint(0, 2, size=(1000,)) # Binary labels
# If models not provided, create them
if cnn_model is None:
logger.info("Creating a new CNN model for MoE")
cnn_model = CNNModel(input_shape=(20, 5), output_size=1)
cnn_model.build_model()
if transformer_model is None:
logger.info("Creating a new Transformer model for MoE")
transformer_model = TransformerModel(ts_input_shape=(20, 5), feature_input_shape=128, output_size=1)
transformer_model.build_model()
# Initialize MoE model
moe = MixtureOfExpertsModel(output_size=1, model_dir='NN/models/saved')
# Add expert models
moe.add_expert('cnn', cnn_model)
moe.add_expert('transformer', transformer_model)
# Build the MoE model (this is a simplified implementation - in a real scenario
# you would need to handle the interfaces between models more carefully)
moe.build_model(
ts_input_shape=(20, 5),
expert_weights={'cnn': 0.7, 'transformer': 0.3}
)
# In a real implementation, you would train the MoE model here
logger.info("MoE model built - in a real implementation, you would train it here")
return moe
def example_orchestrator():
"""Show how to use the Orchestrator"""
logger.info("=== Orchestrator Example ===")
# Configure the orchestrator
config = {
'symbol': 'BTC/USDT',
'timeframes': ['1h', '4h'],
'window_size': 20,
'n_features': 5,
'output_size': 3, # BUY/HOLD/SELL
'batch_size': 32,
'epochs': 5, # Small number for example
'model_dir': 'NN/models/saved',
'data_dir': 'NN/data'
}
# Initialize the orchestrator
orchestrator = NeuralNetworkOrchestrator(config)
# Prepare training data
X, y, timestamps = orchestrator.prepare_training_data(
timeframes=['1h'],
n_candles=200
)
if X is not None and y is not None:
logger.info(f"Prepared training data: X shape {X.shape}, y shape {y.shape}")
# Train CNN model
logger.info("Training CNN model with orchestrator...")
history = orchestrator.train_cnn_model(X, y, epochs=2) # Very small for example
# Make a prediction
result = orchestrator.run_inference_pipeline(
model_type='cnn',
timeframe='1h'
)
if result:
logger.info(f"Inference result: {result}")
else:
logger.warning("Could not prepare training data - this is expected if no real data is available")
logger.info("The orchestrator would normally handle training and inference")
def main():
"""Run all examples"""
logger.info("Starting Neural Network Trading System Examples")
# Example 1: Data Interface
X, y, timestamps = example_data_interface()
# Example 2: CNN Model
cnn_model = example_cnn_model(X, y)
# Example 3: Transformer Model
transformer_model = example_transformer_model(X, y, cnn_model)
# Example 4: Mixture of Experts
moe_model = example_moe_model(X, y, cnn_model, transformer_model)
# Example 5: Orchestrator
example_orchestrator()
logger.info("Examples completed")
if __name__ == "__main__":
main()

View File

@ -1,244 +0,0 @@
"""
Neural Network Trading System Main Module (Compatibility Layer)
This module serves as a compatibility layer for the realtime.py module.
It re-exports the functionality from realtime_main.py that is needed by realtime.py.
"""
import os
import sys
import logging
from datetime import datetime
import numpy as np
# Configure logging
logger = logging.getLogger('NN')
logger.setLevel(logging.INFO)
# Re-export everything from realtime_main.py
from .realtime_main import (
parse_arguments,
realtime,
train,
predict
)
# Create a class that realtime.py expects
class NeuralNetworkOrchestrator:
"""
Orchestrates the neural network operations.
"""
def __init__(self, config):
"""
Initialize the orchestrator with configuration.
Args:
config (dict): Configuration parameters
"""
self.config = config
self.symbol = config.get('symbol', 'BTC/USDT')
self.timeframes = config.get('timeframes', ['1m', '5m', '1h', '4h'])
self.window_size = config.get('window_size', 20)
self.n_features = config.get('n_features', 5)
self.output_size = config.get('output_size', 3)
self.model_dir = config.get('model_dir', 'NN/models/saved')
self.data_dir = config.get('data_dir', 'NN/data')
self.model = None
self.data_interface = None
# Initialize with default values in case imports fail
self.model_initialized = False
self.data_initialized = False
# Import necessary modules dynamically
try:
from .utils.data_interface import DataInterface
# Initialize data interface
self.data_interface = DataInterface(
symbol=self.symbol,
timeframes=self.timeframes
)
self.data_initialized = True
logger.info(f"Data interface initialized for {self.symbol}")
try:
from .models.cnn_model_pytorch import CNNModelPyTorch as Model
# Initialize model
feature_count = self.data_interface.get_feature_count() if hasattr(self.data_interface, 'get_feature_count') else 5
try:
# First try with expected parameters
self.model = Model(
window_size=self.window_size,
num_features=feature_count,
output_size=self.output_size,
timeframes=self.timeframes
)
except TypeError as e:
logger.warning(f"TypeError in model initialization with num_features: {str(e)}")
# Try alternate parameter naming
try:
self.model = Model(
input_shape=(self.window_size, feature_count),
output_size=self.output_size
)
logger.info("Model initialized with alternate parameters")
except Exception as ex:
logger.error(f"Failed to initialize model with alternate parameters: {str(ex)}")
self.model = DummyModel()
# Try to load the best model
self._load_model()
self.model_initialized = True
logger.info("Model initialized successfully")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
import traceback
logger.error(traceback.format_exc())
self.model = DummyModel()
logger.info(f"NeuralNetworkOrchestrator initialized with config: {config}")
except Exception as e:
logger.error(f"Error initializing NeuralNetworkOrchestrator: {str(e)}")
import traceback
logger.error(traceback.format_exc())
self.model = DummyModel()
def _load_model(self):
"""Load the best trained model from available files"""
try:
model_paths = [
os.path.join(self.model_dir, "dqn_agent_best_policy.pt"),
os.path.join(self.model_dir, "cnn_model_best.pt"),
os.path.join("models/saved", "dqn_agent_best_policy.pt"),
os.path.join("models/saved", "cnn_model_best.pt")
]
for model_path in model_paths:
if os.path.exists(model_path):
try:
self.model.load(model_path)
logger.info(f"Loaded model from {model_path}")
return True
except Exception as e:
logger.warning(f"Failed to load model from {model_path}: {str(e)}")
continue
logger.warning("No trained model found, using dummy model")
self.model = DummyModel()
return False
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
self.model = DummyModel()
return False
def run_inference_pipeline(self, model_type='cnn', timeframe='1h'):
"""
Run the inference pipeline using the trained model.
Args:
model_type (str): Type of model to use (cnn, transformer, etc.)
timeframe (str): Timeframe to use for inference
Returns:
dict: Inference result
"""
try:
# Check if we have a model
if not hasattr(self, 'model') or self.model is None:
logger.warning("No model available, initializing dummy model")
self.model = DummyModel()
# Check if we have a data interface
if not hasattr(self, 'data_interface') or self.data_interface is None:
logger.warning("No data interface available")
# Return a dummy prediction
return self._get_dummy_prediction()
# Prepare input data for the selected timeframe
X, timestamp = self.data_interface.prepare_realtime_input(
timeframe=timeframe,
n_candles=self.window_size + 10, # Extra candles for safety
window_size=self.window_size
)
if X is None:
logger.warning(f"No data available for {self.symbol}")
return self._get_dummy_prediction()
# Get model predictions
action_probs, price_pred = self.model.predict(X)
# Convert predictions to action
action_idx = np.argmax(action_probs) if hasattr(action_probs, 'argmax') else 1 # Default to HOLD
action_names = ['SELL', 'HOLD', 'BUY']
action = action_names[action_idx]
# Format timestamp
if not isinstance(timestamp, str):
try:
if hasattr(timestamp, 'isoformat'): # If it's already a datetime-like object
timestamp = timestamp.isoformat()
else: # If it's a numeric timestamp
timestamp = datetime.fromtimestamp(float(timestamp)/1000).isoformat()
except (TypeError, ValueError):
timestamp = datetime.now().isoformat()
# Return result
result = {
'timestamp': timestamp,
'action': action,
'action_index': int(action_idx),
'probability': float(action_probs[action_idx]) if hasattr(action_probs, '__getitem__') else 0.33,
'probabilities': {name: float(prob) for name, prob in zip(action_names, action_probs)} if hasattr(action_probs, '__iter__') else {'SELL': 0.33, 'HOLD': 0.34, 'BUY': 0.33},
'price_prediction': float(price_pred) if price_pred is not None else None
}
logger.info(f"Inference result: {result}")
return result
except Exception as e:
logger.error(f"Error in inference pipeline: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return self._get_dummy_prediction()
def _get_dummy_prediction(self):
"""Return a dummy prediction when model or data is unavailable"""
action_names = ['SELL', 'HOLD', 'BUY']
action_idx = 1 # Default to HOLD
timestamp = datetime.now().isoformat()
return {
'timestamp': timestamp,
'action': 'HOLD',
'action_index': action_idx,
'probability': 0.8,
'probabilities': {'SELL': 0.1, 'HOLD': 0.8, 'BUY': 0.1},
'price_prediction': None,
'is_dummy': True
}
class DummyModel:
"""Dummy model that returns random predictions"""
def __init__(self):
logger.info("Initializing dummy model")
def predict(self, X):
"""Return random predictions"""
# Generate random probabilities for SELL, HOLD, BUY
action_probs = np.array([0.1, 0.8, 0.1]) # Bias towards HOLD
# Generate a random price prediction (None for now)
price_pred = None
return action_probs, price_pred
def load(self, model_path):
"""Dummy load method"""
logger.info(f"Dummy model pretending to load from {model_path}")
return True

View File

@ -1,287 +0,0 @@
import logging
import threading
import time
from typing import Dict, Any, List, Optional, Callable, Tuple
import os
import numpy as np
import pandas as pd
from .trading_agent import TradingAgent
logger = logging.getLogger(__name__)
class NeuralNetworkOrchestrator:
"""Orchestrator for neural network models and trading operations.
This class coordinates between neural network models and trading agents,
ensuring that signals from the models are properly processed and trades
are executed according to the strategy.
"""
def __init__(self, model, data_interface, chart=None,
symbols: List[str] = None,
timeframes: List[str] = None,
window_size: int = 20,
num_features: int = 5,
output_size: int = 3,
models_dir: str = "NN/models/saved",
data_dir: str = "NN/data",
exchange_config: Dict[str, Any] = None):
"""Initialize the neural network orchestrator.
Args:
model: Neural network model instance
data_interface: Data interface for retrieving market data
chart: Real-time chart for visualization (optional)
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
timeframes: List of timeframes to monitor (e.g., ['1m', '5m', '1h'])
window_size: Window size for model input
num_features: Number of features per datapoint
output_size: Number of output classes (e.g., 3 for BUY/HOLD/SELL)
models_dir: Directory for saved models
data_dir: Directory for data storage
exchange_config: Configuration for trading agent (exchange, API keys, etc.)
"""
self.model = model
self.data_interface = data_interface
self.chart = chart
self.symbols = symbols or ["BTC/USDT"]
self.timeframes = timeframes or ["1m", "5m", "1h", "4h", "1d"]
self.window_size = window_size
self.num_features = num_features
self.output_size = output_size
self.models_dir = models_dir
self.data_dir = data_dir
# Initialize trading agent if configuration provided
self.trading_agent = None
if exchange_config:
self.init_trading_agent(exchange_config)
# Initialize inference state
self.is_running = False
self.inference_thread = None
self.stop_event = threading.Event()
self.last_inference_time = 0
self.inference_interval = int(os.environ.get("NN_INFERENCE_INTERVAL", "60"))
logger.info(f"Initializing NeuralNetworkOrchestrator with:")
logger.info(f"- Symbol: {self.symbols[0]}")
logger.info(f"- Timeframes: {', '.join(self.timeframes)}")
logger.info(f"- Window size: {window_size}")
logger.info(f"- Num features: {num_features}")
logger.info(f"- Output size: {output_size}")
logger.info(f"- Models dir: {models_dir}")
logger.info(f"- Data dir: {data_dir}")
logger.info(f"- Inference interval: {self.inference_interval} seconds")
def init_trading_agent(self, config: Dict[str, Any]):
"""Initialize the trading agent with the given configuration.
Args:
config: Configuration for the trading agent
"""
exchange_name = config.get("exchange", "binance")
api_key = config.get("api_key")
api_secret = config.get("api_secret")
test_mode = config.get("test_mode", True)
trade_symbols = config.get("trade_symbols", self.symbols)
position_size = config.get("position_size", 0.1)
max_trades_per_day = config.get("max_trades_per_day", 5)
trade_cooldown_minutes = config.get("trade_cooldown_minutes", 60)
self.trading_agent = TradingAgent(
exchange_name=exchange_name,
api_key=api_key,
api_secret=api_secret,
test_mode=test_mode,
trade_symbols=trade_symbols,
position_size=position_size,
max_trades_per_day=max_trades_per_day,
trade_cooldown_minutes=trade_cooldown_minutes
)
logger.info(f"Trading agent initialized for {exchange_name} exchange.")
def start_inference(self):
"""Start the inference thread."""
if self.is_running:
logger.warning("Neural network inference is already running.")
return
self.is_running = True
self.stop_event.clear()
# Start inference thread
self.inference_thread = threading.Thread(target=self._inference_loop)
self.inference_thread.daemon = True
self.inference_thread.start()
logger.info(f"Neural network inference thread started with {self.inference_interval}s interval.")
# Start trading agent if available
if self.trading_agent:
self.trading_agent.start(signal_callback=self._on_trade_executed)
def stop_inference(self):
"""Stop the inference thread."""
if not self.is_running:
logger.warning("Neural network inference is not running.")
return
logger.info("Stopping neural network inference...")
self.is_running = False
self.stop_event.set()
if self.inference_thread and self.inference_thread.is_alive():
self.inference_thread.join(timeout=10)
logger.info("Neural network inference stopped.")
# Stop trading agent if available
if self.trading_agent:
self.trading_agent.stop()
def _inference_loop(self):
"""Main inference loop that processes data and generates signals."""
logger.info("Inference loop started.")
try:
while self.is_running and not self.stop_event.is_set():
current_time = time.time()
# Check if we should run inference
if current_time - self.last_inference_time >= self.inference_interval:
try:
# Run inference for all symbols
for symbol in self.symbols:
prediction = self._run_inference(symbol)
if prediction:
self._process_prediction(symbol, prediction)
self.last_inference_time = current_time
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
# Sleep for a short time to prevent CPU hogging
time.sleep(1)
except Exception as e:
logger.error(f"Error in inference loop: {str(e)}")
finally:
logger.info("Inference loop stopped.")
def _run_inference(self, symbol: str) -> Optional[Tuple[np.ndarray, float]]:
"""Run inference for a specific symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
Returns:
tuple: (action probabilities, current price) or None if inference failed
"""
try:
# Get the model timeframe from environment
model_timeframe = os.environ.get("NN_TIMEFRAME", "1h")
if model_timeframe not in self.timeframes:
logger.warning(f"Model timeframe {model_timeframe} not in available timeframes. Using {self.timeframes[0]}.")
model_timeframe = self.timeframes[0]
# Load candles for the model timeframe
logger.info(f"Loading {1000} candles from cache for {symbol} at {model_timeframe} timeframe")
candles = self.data_interface.get_historical_data(
symbol=symbol,
timeframe=model_timeframe,
n_candles=1000
)
if candles is None or len(candles) < self.window_size:
logger.warning(f"Not enough data for {symbol} at {model_timeframe} timeframe. Need at least {self.window_size} candles.")
return None
# Prepare input data
X, timestamp = self.data_interface.prepare_model_input(
data=candles,
window_size=self.window_size,
symbol=symbol
)
if X is None:
logger.warning(f"Failed to prepare model input for {symbol}.")
return None
# Get current price
current_price = candles['close'].iloc[-1]
# Run model inference
action_probs, price_pred = self.model.predict(X)
return action_probs, current_price
except Exception as e:
logger.error(f"Error running inference for {symbol}: {str(e)}")
return None
def _process_prediction(self, symbol: str, prediction: Tuple[np.ndarray, float]):
"""Process a prediction and generate signals.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
prediction: Tuple of (action probabilities, current price)
"""
action_probs, current_price = prediction
# Get the best action (0=SELL, 1=HOLD, 2=BUY)
best_action = np.argmax(action_probs)
best_prob = float(action_probs[best_action])
# Convert to action name
action_names = ["SELL", "HOLD", "BUY"]
action_name = action_names[best_action]
# Log the prediction
logger.info(f"Inference result for {symbol}: Action={action_name}, Probability={best_prob:.2f}, Price={current_price:.2f}")
# Add signal to chart if available
if self.chart:
self.chart.add_nn_signal(symbol=symbol, signal=action_name, confidence=best_prob, timestamp=int(time.time()))
# Process signal with trading agent if available
if self.trading_agent:
self.trading_agent.process_signal(
symbol=symbol,
action=action_name,
confidence=best_prob,
timestamp=int(time.time())
)
def _on_trade_executed(self, trade_record: Dict[str, Any]):
"""Callback for when a trade is executed.
Args:
trade_record: Trade information
"""
if self.chart and trade_record:
# Add trade to chart
self.chart.add_trade(
action=trade_record['action'],
price=trade_record.get('price', 0),
timestamp=trade_record['timestamp'],
pnl=trade_record.get('pnl', 0)
)
logger.info(f"Trade added to chart: {trade_record['action']} at {trade_record.get('price', 0):.2f}")
def get_trading_agent_info(self) -> Dict[str, Any]:
"""Get information about the trading agent.
Returns:
dict: Trading agent information or None if no agent is available
"""
if self.trading_agent:
return {
'exchange_info': self.trading_agent.get_exchange_info(),
'positions': self.trading_agent.get_current_positions(),
'trades': len(self.trading_agent.get_trade_history())
}
return None

View File

@ -1,287 +0,0 @@
#!/usr/bin/env python3
"""
Neural Network Trading System Main Module
This module serves as the main entry point for the NN trading system,
coordinating data flow between different components and implementing
training and inference pipelines.
"""
import os
import sys
import logging
import argparse
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler(os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'))
]
)
logger = logging.getLogger('NN')
# Create logs directory if it doesn't exist
os.makedirs('logs', exist_ok=True)
def parse_arguments():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Neural Network Trading System')
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
help='Mode to run (train, predict, realtime)')
parser.add_argument('--symbol', type=str, default='BTC/USDT',
help='Trading pair symbol')
parser.add_argument('--timeframes', type=str, nargs='+', default=['1h', '4h'],
help='Timeframes to use')
parser.add_argument('--window-size', type=int, default=20,
help='Window size for input data')
parser.add_argument('--output-size', type=int, default=3,
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
parser.add_argument('--batch-size', type=int, default=32,
help='Batch size for training')
parser.add_argument('--epochs', type=int, default=100,
help='Number of epochs for training')
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
help='Model type to use')
parser.add_argument('--framework', type=str, choices=['tensorflow', 'pytorch'], default='pytorch',
help='Deep learning framework to use')
return parser.parse_args()
def main():
"""Main entry point for the NN trading system"""
# Parse arguments
args = parse_arguments()
logger.info(f"Starting NN Trading System in {args.mode} mode")
logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}, "
f"Window Size={args.window_size}, Output Size={args.output_size}, "
f"Model Type={args.model_type}, Framework={args.framework}")
# Import the appropriate modules based on the framework
if args.framework == 'pytorch':
try:
import torch
logger.info(f"Using PyTorch {torch.__version__}")
# Import PyTorch-based modules
from NN.utils.data_interface import DataInterface
if args.model_type == 'cnn':
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
elif args.model_type == 'transformer':
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
elif args.model_type == 'moe':
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
else:
logger.error(f"Unknown model type: {args.model_type}")
return
except ImportError as e:
logger.error(f"Failed to import PyTorch modules: {str(e)}")
logger.error("Please make sure PyTorch is installed or use the TensorFlow framework.")
return
elif args.framework == 'tensorflow':
try:
import tensorflow as tf
logger.info(f"Using TensorFlow {tf.__version__}")
# Import TensorFlow-based modules
from NN.utils.data_interface import DataInterface
if args.model_type == 'cnn':
from NN.models.cnn_model import CNNModel as Model
elif args.model_type == 'transformer':
from NN.models.transformer_model import TransformerModel as Model
elif args.model_type == 'moe':
from NN.models.transformer_model import MixtureOfExpertsModel as Model
else:
logger.error(f"Unknown model type: {args.model_type}")
return
except ImportError as e:
logger.error(f"Failed to import TensorFlow modules: {str(e)}")
logger.error("Please make sure TensorFlow is installed or use the PyTorch framework.")
return
else:
logger.error(f"Unknown framework: {args.framework}")
return
# Initialize data interface
try:
logger.info("Initializing data interface...")
data_interface = DataInterface(
symbol=args.symbol,
timeframes=args.timeframes
)
except Exception as e:
logger.error(f"Failed to initialize data interface: {str(e)}")
return
# Initialize model
try:
logger.info(f"Initializing {args.model_type.upper()} model...")
model = Model(
window_size=args.window_size,
num_features=data_interface.get_feature_count(),
output_size=args.output_size,
timeframes=args.timeframes
)
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
return
# Execute the requested mode
if args.mode == 'train':
train(data_interface, model, args)
elif args.mode == 'predict':
predict(data_interface, model, args)
elif args.mode == 'realtime':
realtime(data_interface, model, args)
else:
logger.error(f"Unknown mode: {args.mode}")
return
logger.info("Neural Network Trading System finished successfully")
def train(data_interface, model, args):
"""Enhanced training with performance tracking"""
from torch.utils.tensorboard import SummaryWriter
logger.info("Starting training mode...")
writer = SummaryWriter(log_dir=f"runs/{args.model_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
try:
best_val_acc = 0
for epoch in range(args.epochs):
# Refresh data every few epochs
if epoch % 3 == 0:
X_train, y_train, X_val, y_val = data_interface.prepare_training_data(refresh=True)
else:
X_train, y_train, X_val, y_val = data_interface.prepare_training_data()
# Train for one epoch
train_loss, train_acc = model.train_epoch(
X_train, y_train,
batch_size=args.batch_size
)
# Validate
val_loss, val_acc = model.evaluate(X_val, y_val)
# Log metrics
writer.add_scalar('Loss/Train', train_loss, epoch)
writer.add_scalar('Accuracy/Train', train_acc, epoch)
writer.add_scalar('Loss/Validation', val_loss, epoch)
writer.add_scalar('Accuracy/Validation', val_acc, epoch)
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
model_path = os.path.join(
'models',
f"{args.model_type}_best_{args.symbol.replace('/', '_')}.pt"
)
model.save(model_path)
logger.info(f"New best model saved with val_acc: {val_acc:.2f}")
logger.info(f"Epoch {epoch+1}/{args.epochs} - "
f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f} - "
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}")
# Save final model
model_path = os.path.join(
'models',
f"{args.model_type}_final_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
)
model.save(model_path)
logger.info(f"Training Complete - Best Val Accuracy: {best_val_acc:.2f}")
except Exception as e:
logger.error(f"Error in training mode: {str(e)}")
return
def predict(data_interface, model, args):
"""Make predictions using the trained model"""
logger.info("Starting prediction mode...")
try:
# Load the latest model
model_dir = os.path.join('models')
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
if not model_files:
logger.error(f"No saved model found for type {args.model_type}")
return
latest_model = sorted(model_files)[-1]
model_path = os.path.join(model_dir, latest_model)
logger.info(f"Loading model from {model_path}...")
model.load(model_path)
# Prepare prediction data
logger.info("Preparing prediction data...")
X_pred = data_interface.prepare_prediction_data()
# Make predictions
logger.info("Making predictions...")
predictions = model.predict(X_pred)
# Process and display predictions
logger.info("Processing predictions...")
data_interface.process_predictions(predictions)
except Exception as e:
logger.error(f"Error in prediction mode: {str(e)}")
return
def realtime(data_interface, model, args):
"""Run the model in real-time mode"""
logger.info("Starting real-time mode...")
try:
# Import realtime analyzer
from NN.utils.realtime_analyzer import RealtimeAnalyzer
# Load the latest model
model_dir = os.path.join('models')
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
if not model_files:
logger.error(f"No saved model found for type {args.model_type}")
return
latest_model = sorted(model_files)[-1]
model_path = os.path.join(model_dir, latest_model)
logger.info(f"Loading model from {model_path}...")
model.load(model_path)
# Initialize realtime analyzer
logger.info("Initializing real-time analyzer...")
realtime_analyzer = RealtimeAnalyzer(
data_interface=data_interface,
model=model,
symbol=args.symbol,
timeframes=args.timeframes
)
# Start real-time analysis
logger.info("Starting real-time analysis...")
realtime_analyzer.start()
except Exception as e:
logger.error(f"Error in real-time mode: {str(e)}")
return
if __name__ == "__main__":
main()

View File

@ -1,241 +0,0 @@
import logging
import numpy as np
import pandas as pd
import time
from typing import Dict, Any, List, Optional, Tuple
logger = logging.getLogger(__name__)
class RealtimeDataInterface:
"""Interface for retrieving real-time market data for neural network models.
This class serves as a bridge between the RealTimeChart data sources and
the neural network models, providing properly formatted data for model
inference.
"""
def __init__(self, symbols: List[str], chart=None, max_cache_size: int = 5000):
"""Initialize the data interface.
Args:
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
chart: RealTimeChart instance (optional)
max_cache_size: Maximum number of cached candles
"""
self.symbols = symbols
self.chart = chart
self.max_cache_size = max_cache_size
# Initialize data cache
self.ohlcv_cache = {} # timeframe -> symbol -> DataFrame
logger.info(f"Initialized RealtimeDataInterface with symbols: {', '.join(symbols)}")
def get_historical_data(self, symbol: str = None, timeframe: str = '1h',
n_candles: int = 500) -> Optional[pd.DataFrame]:
"""Get historical OHLCV data for a symbol and timeframe.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
Returns:
DataFrame with OHLCV data or None if not available
"""
if not symbol:
if len(self.symbols) > 0:
symbol = self.symbols[0]
else:
logger.error("No symbol specified and no default symbols available")
return None
if symbol not in self.symbols:
logger.warning(f"Symbol {symbol} not in tracked symbols")
return None
try:
# Get data from chart if available
if self.chart:
candles = self._get_chart_data(symbol, timeframe, n_candles)
if candles is not None and len(candles) > 0:
return candles
# Fallback to default empty DataFrame
logger.warning(f"No historical data available for {symbol} at timeframe {timeframe}")
return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
except Exception as e:
logger.error(f"Error getting historical data for {symbol}: {str(e)}")
return None
def _get_chart_data(self, symbol: str, timeframe: str, n_candles: int) -> Optional[pd.DataFrame]:
"""Get data from the RealTimeChart for the specified symbol and timeframe.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
Returns:
DataFrame with OHLCV data or None if not available
"""
if not self.chart:
return None
# Get chart data using the _get_chart_data method
try:
# Map to interval seconds
interval_map = {
'1s': 1,
'5s': 5,
'10s': 10,
'15s': 15,
'30s': 30,
'1m': 60,
'3m': 180,
'5m': 300,
'15m': 900,
'30m': 1800,
'1h': 3600,
'2h': 7200,
'4h': 14400,
'6h': 21600,
'8h': 28800,
'12h': 43200,
'1d': 86400,
'3d': 259200,
'1w': 604800
}
# Convert timeframe to seconds
if timeframe in interval_map:
interval_seconds = interval_map[timeframe]
else:
# Try to parse the interval (e.g., '1m' -> 60)
try:
if timeframe.endswith('s'):
interval_seconds = int(timeframe[:-1])
elif timeframe.endswith('m'):
interval_seconds = int(timeframe[:-1]) * 60
elif timeframe.endswith('h'):
interval_seconds = int(timeframe[:-1]) * 3600
elif timeframe.endswith('d'):
interval_seconds = int(timeframe[:-1]) * 86400
elif timeframe.endswith('w'):
interval_seconds = int(timeframe[:-1]) * 604800
else:
interval_seconds = int(timeframe)
except ValueError:
logger.error(f"Could not parse timeframe: {timeframe}")
return None
# Get data from chart
df = self.chart._get_chart_data(interval_seconds)
if df is not None and not df.empty:
# Limit to requested number of candles
if len(df) > n_candles:
df = df.iloc[-n_candles:]
return df
else:
logger.warning(f"No data retrieved from chart for {symbol} at timeframe {timeframe}")
return None
except Exception as e:
logger.error(f"Error getting chart data for {symbol} at {timeframe}: {str(e)}")
return None
def prepare_model_input(self, data: pd.DataFrame, window_size: int = 20,
symbol: str = None) -> Tuple[np.ndarray, Optional[int]]:
"""Prepare model input from OHLCV data.
Args:
data: DataFrame with OHLCV data
window_size: Window size for model input
symbol: Symbol for the data (for logging)
Returns:
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
"""
if data is None or len(data) < window_size:
logger.warning(f"Not enough data to prepare model input for {symbol or 'unknown symbol'}")
return None, None
try:
# Get last window_size candles
recent_data = data.iloc[-window_size:].copy()
# Get timestamp of the most recent candle
timestamp = int(recent_data.iloc[-1]['timestamp']) if 'timestamp' in recent_data.columns else int(time.time())
# Extract OHLCV features and normalize
if 'open' in recent_data.columns and 'high' in recent_data.columns and 'low' in recent_data.columns and 'close' in recent_data.columns and 'volume' in recent_data.columns:
# Normalize price data by the last close price
last_close = recent_data['close'].iloc[-1]
# Avoid division by zero
if last_close == 0:
last_close = 1.0
opens = (recent_data['open'] / last_close).values
highs = (recent_data['high'] / last_close).values
lows = (recent_data['low'] / last_close).values
closes = (recent_data['close'] / last_close).values
# Normalize volume by the max volume in the window
max_volume = recent_data['volume'].max()
if max_volume == 0:
max_volume = 1.0
volumes = (recent_data['volume'] / max_volume).values
# Stack features into a 3D array [batch_size=1, window_size, n_features=5]
X = np.column_stack((opens, highs, lows, closes, volumes))
X = X.reshape(1, window_size, 5)
# Replace any NaN or infinite values
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=0.0)
return X, timestamp
else:
logger.error(f"Data missing required OHLCV columns for {symbol or 'unknown symbol'}")
return None, None
except Exception as e:
logger.error(f"Error preparing model input for {symbol or 'unknown symbol'}: {str(e)}")
return None, None
def prepare_realtime_input(self, timeframe: str = '1h', n_candles: int = 30,
window_size: int = 20) -> Tuple[np.ndarray, Optional[int]]:
"""Prepare real-time input for the model.
Args:
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
window_size: Window size for model input
Returns:
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
"""
# Get data for the main symbol
if len(self.symbols) == 0:
logger.error("No symbols available for real-time input")
return None, None
symbol = self.symbols[0]
try:
# Get historical data
data = self.get_historical_data(symbol, timeframe, n_candles)
if data is None or len(data) < window_size:
logger.warning(f"Not enough data for real-time input. Need at least {window_size} candles.")
return None, None
# Prepare model input
return self.prepare_model_input(data, window_size, symbol)
except Exception as e:
logger.error(f"Error preparing real-time input: {str(e)}")
return None, None

View File

@ -1,507 +0,0 @@
#!/usr/bin/env python3
"""
Neural Network Trading System Main Module - PyTorch Version
This module serves as the main entry point for the NN trading system,
using PyTorch exclusively for all model operations.
"""
import os
import sys
import logging
import argparse
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import time
# Configure logging
logger = logging.getLogger('NN')
logger.setLevel(logging.INFO)
try:
# Create logs directory if it doesn't exist
os.makedirs('logs', exist_ok=True)
# Try setting up file logging
log_file = os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
fh = logging.FileHandler(log_file)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)
logger.info(f"Logging to file: {log_file}")
except Exception as e:
logger.warning(f"Failed to setup file logging: {str(e)}. Falling back to console logging only.")
# Always setup console logging
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
def parse_arguments():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Neural Network Trading System')
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
help='Mode to run (train, predict, realtime)')
parser.add_argument('--symbol', type=str, default='BTC/USDT',
help='Trading pair symbol')
parser.add_argument('--timeframes', type=str, nargs='+', default=['1s', '1m', '5m', '1h', '4h'],
help='Timeframes to use (include 1s for ticks)')
parser.add_argument('--window-size', type=int, default=20,
help='Window size for input data')
parser.add_argument('--output-size', type=int, default=3,
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
parser.add_argument('--batch-size', type=int, default=32,
help='Batch size for training')
parser.add_argument('--epochs', type=int, default=10,
help='Number of epochs for training')
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
help='Model type to use')
return parser.parse_args()
def main():
"""Main entry point for the NN trading system"""
args = parse_arguments()
logger.info(f"Starting NN Trading System in {args.mode} mode")
logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}")
try:
import torch
from NN.utils.data_interface import DataInterface
# Import appropriate PyTorch model
if args.model_type == 'cnn':
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
elif args.model_type == 'transformer':
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
elif args.model_type == 'moe':
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
else:
logger.error(f"Unknown model type: {args.model_type}")
return
except ImportError as e:
logger.error(f"Failed to import PyTorch modules: {str(e)}")
logger.error("Please make sure PyTorch is installed")
return
# Initialize data interface
try:
data_interface = DataInterface(
symbol=args.symbol,
timeframes=args.timeframes
)
# Verify data interface by fetching initial data
logger.info("Verifying data interface...")
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
if X_sample is None or y_sample is not None:
logger.error("Failed to prepare initial training data")
return
logger.info(f"Data interface verified - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
except Exception as e:
logger.error(f"Failed to initialize data interface: {str(e)}")
return
# Initialize model
try:
# Calculate total number of features across all timeframes
num_features = data_interface.get_feature_count()
logger.info(f"Initializing model with {num_features} features")
model = Model(
window_size=args.window_size,
num_features=num_features,
output_size=args.output_size,
timeframes=args.timeframes
)
# Ensure model is on the correct device
if torch.cuda.is_available():
model.model = model.model.cuda()
logger.info("Model moved to CUDA device")
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
return
# Execute requested mode
if args.mode == 'train':
train(data_interface, model, args)
elif args.mode == 'predict':
predict(data_interface, model, args)
elif args.mode == 'realtime':
realtime(data_interface, model, args)
def train(data_interface, model, args):
"""Enhanced training with performance tracking and retrospective fine-tuning"""
logger.info("Starting training mode...")
writer = SummaryWriter()
try:
best_val_acc = 0
best_val_pnl = float('-inf')
best_win_rate = 0
best_price_mae = float('inf')
logger.info("Verifying data interface...")
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
logger.info(f"Data validation - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
# Calculate refresh intervals based on timeframes
min_timeframe = min(args.timeframes)
refresh_interval = {
'1s': 1,
'1m': 60,
'5m': 300,
'15m': 900,
'1h': 3600,
'4h': 14400,
'1d': 86400
}.get(min_timeframe, 60)
logger.info(f"Using refresh interval of {refresh_interval} seconds based on {min_timeframe} timeframe")
for epoch in range(args.epochs):
# Always refresh for tick data or when using multiple timeframes
refresh = '1s' in args.timeframes or len(args.timeframes) > 1
logger.info(f"\nStarting epoch {epoch+1}/{args.epochs}")
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
refresh=refresh,
refresh_interval=refresh_interval
)
logger.info(f"Training data - X shape: {X_train.shape}, y shape: {y_train.shape}")
logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}")
# Get future prices for retrospective training
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=3)
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=3)
# Train and validate
try:
train_action_loss, train_price_loss, train_acc = model.train_epoch(
X_train, y_train, train_future_prices, args.batch_size
)
val_action_loss, val_price_loss, val_acc = model.evaluate(
X_val, y_val, val_future_prices
)
# Get predictions for PnL calculation
train_action_probs, train_price_preds = model.predict(X_train)
val_action_probs, val_price_preds = model.predict(X_val)
# Convert probabilities to actions for PnL calculation
train_preds = np.argmax(train_action_probs, axis=1)
val_preds = np.argmax(val_action_probs, axis=1)
# Calculate PnL and win rates
try:
if train_preds is not None and train_prices is not None:
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
train_preds, train_prices, position_size=1.0
)
else:
train_pnl, train_win_rate, train_trades = 0, 0, []
if val_preds is not None and val_prices is not None:
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
val_preds, val_prices, position_size=1.0
)
else:
val_pnl, val_win_rate, val_trades = 0, 0, []
except Exception as e:
logger.error(f"Error calculating PnL: {str(e)}")
train_pnl, train_win_rate, val_pnl, val_win_rate = 0, 0, 0, 0
train_trades, val_trades = [], []
# Calculate price prediction error
if train_future_prices is not None and train_price_preds is not None:
# Ensure arrays have the same shape and are numpy arrays
train_future_prices_np = np.array(train_future_prices) if not isinstance(train_future_prices, np.ndarray) else train_future_prices
train_price_preds_np = np.array(train_price_preds) if not isinstance(train_price_preds, np.ndarray) else train_price_preds
if len(train_price_preds_np) > 0 and len(train_future_prices_np) > 0:
min_len = min(len(train_price_preds_np), len(train_future_prices_np))
train_price_mae = np.mean(np.abs(train_price_preds_np[:min_len] - train_future_prices_np[:min_len]))
else:
train_price_mae = float('inf')
else:
train_price_mae = float('inf')
if val_future_prices is not None and val_price_preds is not None:
# Ensure arrays have the same shape and are numpy arrays
val_future_prices_np = np.array(val_future_prices) if not isinstance(val_future_prices, np.ndarray) else val_future_prices
val_price_preds_np = np.array(val_price_preds) if not isinstance(val_price_preds, np.ndarray) else val_price_preds
if len(val_price_preds_np) > 0 and len(val_future_prices_np) > 0:
min_len = min(len(val_price_preds_np), len(val_future_prices_np))
val_price_mae = np.mean(np.abs(val_price_preds_np[:min_len] - val_future_prices_np[:min_len]))
else:
val_price_mae = float('inf')
else:
val_price_mae = float('inf')
# Monitor action distribution
train_actions = np.bincount(np.argmax(train_action_probs, axis=1), minlength=3)
val_actions = np.bincount(np.argmax(val_action_probs, axis=1), minlength=3)
# Log metrics
writer.add_scalar('Loss/action_train', train_action_loss, epoch)
writer.add_scalar('Loss/price_train', train_price_loss, epoch)
writer.add_scalar('Loss/action_val', val_action_loss, epoch)
writer.add_scalar('Loss/price_val', val_price_loss, epoch)
writer.add_scalar('Accuracy/train', train_acc, epoch)
writer.add_scalar('Accuracy/val', val_acc, epoch)
writer.add_scalar('PnL/train', train_pnl, epoch)
writer.add_scalar('PnL/val', val_pnl, epoch)
writer.add_scalar('WinRate/train', train_win_rate, epoch)
writer.add_scalar('WinRate/val', val_win_rate, epoch)
writer.add_scalar('PriceMAE/train', train_price_mae, epoch)
writer.add_scalar('PriceMAE/val', val_price_mae, epoch)
# Log action distribution
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
writer.add_scalar(f'Actions/train_{action}', train_actions[i], epoch)
writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch)
# Save best model based on validation metrics
if np.isscalar(val_pnl) and np.isscalar(best_val_pnl) and (val_pnl > best_val_pnl or (np.isclose(val_pnl, best_val_pnl) and val_acc > best_val_acc)):
best_val_pnl = val_pnl
best_val_acc = val_acc
best_win_rate = val_win_rate
best_price_mae = val_price_mae
model.save(f"models/{args.model_type}_best.pt")
logger.info("Saved new best model based on validation metrics")
# Log detailed metrics
logger.info(f"Epoch {epoch+1}/{args.epochs}")
logger.info("Training Metrics:")
logger.info(f" Action Loss: {train_action_loss:.4f}")
logger.info(f" Price Loss: {train_price_loss:.4f}")
logger.info(f" Accuracy: {train_acc:.2f}")
logger.info(f" PnL: {train_pnl:.2%}")
logger.info(f" Win Rate: {train_win_rate:.2%}")
logger.info(f" Price MAE: {train_price_mae:.2f}")
logger.info("Validation Metrics:")
logger.info(f" Action Loss: {val_action_loss:.4f}")
logger.info(f" Price Loss: {val_price_loss:.4f}")
logger.info(f" Accuracy: {val_acc:.2f}")
logger.info(f" PnL: {val_pnl:.2%}")
logger.info(f" Win Rate: {val_win_rate:.2%}")
logger.info(f" Price MAE: {val_price_mae:.2f}")
# Log action distribution
logger.info("Action Distribution:")
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
logger.info(f" {action}: Train={train_actions[i]}, Val={val_actions[i]}")
# Log trade statistics
logger.info("Trade Statistics:")
logger.info(f" Training trades: {len(train_trades)}")
logger.info(f" Validation trades: {len(val_trades)}")
# Log next candle predictions
if epoch % 10 == 0: # Every 10 epochs
logger.info("\nNext Candle Predictions:")
next_candles = model.predict_next_candles(X_val[-1:], n_candles=3)
for tf in args.timeframes:
if tf in next_candles:
logger.info(f"\n{tf} timeframe predictions:")
for i, pred in enumerate(next_candles[tf]):
action = ['SELL', 'HOLD', 'BUY'][np.argmax(pred)]
confidence = np.max(pred)
logger.info(f" Candle {i+1}: {action} (confidence: {confidence:.2f})")
except Exception as e:
logger.error(f"Error during epoch {epoch+1}: {str(e)}")
continue
# Save final model
model.save(f"models/{args.model_type}_final_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
logger.info(f"\nTraining complete. Best validation metrics:")
logger.info(f"Accuracy: {best_val_acc:.2f}")
logger.info(f"PnL: {best_val_pnl:.2%}")
logger.info(f"Win Rate: {best_win_rate:.2%}")
logger.info(f"Price MAE: {best_price_mae:.2f}")
except Exception as e:
logger.error(f"Error in training: {str(e)}")
def predict(data_interface, model, args):
"""Make predictions using the trained model"""
logger.info("Starting prediction mode...")
try:
# Load the latest model
model_dir = os.path.join('models')
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
if not model_files:
logger.error(f"No saved model found for type {args.model_type}")
return
latest_model = sorted(model_files)[-1]
model_path = os.path.join(model_dir, latest_model)
logger.info(f"Loading model from {model_path}...")
model.load(model_path)
# Prepare prediction data
logger.info("Preparing prediction data...")
X_pred = data_interface.prepare_prediction_data()
# Make predictions
logger.info("Making predictions...")
predictions = model.predict(X_pred)
# Process and display predictions
logger.info("Processing predictions...")
data_interface.process_predictions(predictions)
except Exception as e:
logger.error(f"Error in prediction mode: {str(e)}")
def realtime(data_interface, model, args, chart=None, symbol=None):
"""Run real-time inference with the trained model"""
logger.info(f"Starting real-time inference mode for {symbol}...")
try:
from NN.utils.realtime_analyzer import RealtimeAnalyzer
# Load the latest model
model_dir = os.path.join('models')
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
if not model_files:
logger.error(f"No saved model found for type {args.model_type}")
return
latest_model = sorted(model_files)[-1]
model_path = os.path.join(model_dir, latest_model)
logger.info(f"Loading model from {model_path}...")
model.load(model_path)
# Initialize realtime analyzer
logger.info("Initializing real-time analyzer...")
realtime_analyzer = RealtimeAnalyzer(
data_interface=data_interface,
model=model,
symbol=args.symbol,
timeframes=args.timeframes
)
# Start real-time analysis
logger.info("Starting real-time analysis...")
realtime_analyzer.start()
# Initialize variables for tracking performance
total_pnl = 0.0
trades = []
current_position = 0.0
last_action = None
last_price = None
# Get the pair index for this symbol
pair_index = args.symbols.index(symbol)
# Only execute trades if this is the main pair (BTC/USDT)
is_main_pair = symbol == "BTC/USDT"
while True:
# Get current market data for all pairs
all_pairs_data = []
for s in args.symbols:
X, timestamp = data_interface.prepare_realtime_input(
timeframe=args.timeframes[0], # Use shortest timeframe
n_candles=args.window_size + 10, # Extra candles for safety
window_size=args.window_size
)
if X is not None:
all_pairs_data.append(X)
else:
logger.warning(f"No data available for {s}")
time.sleep(1)
continue
if not all_pairs_data:
logger.warning("No data available for any pair")
time.sleep(1)
continue
# Stack data from all pairs for model input
X_combined = np.concatenate(all_pairs_data, axis=2)
# Get model predictions
action_probs, price_pred = model.predict(X_combined)
# Get predictions for this specific pair
action = np.argmax(action_probs[pair_index]) # 0=SELL, 1=HOLD, 2=BUY
# Get current price for the main pair
current_price = data_interface.get_historical_data(
timeframe=args.timeframes[0],
n_candles=1
)['close'].iloc[-1]
# Calculate PnL if we have a position (only for main pair)
pnl = 0.0
if is_main_pair and last_action is not None and last_price is not None:
if last_action == 2: # BUY
pnl = (current_price - last_price) / last_price
elif last_action == 0: # SELL
pnl = (last_price - current_price) / last_price
# Update total PnL (only for main pair)
if is_main_pair and pnl != 0:
total_pnl += pnl
# Log the prediction
action_name = "SELL" if action == 0 else "HOLD" if action == 1 else "BUY"
log_msg = f"Time: {timestamp}, Symbol: {symbol}, Action: {action_name}, "
if is_main_pair:
log_msg += f"Price: {current_price:.2f}, PnL: {pnl:.2%}, Total PnL: {total_pnl:.2%}"
else:
log_msg += f"Price: {current_price:.2f} (Context Only)"
logger.info(log_msg)
# Update the chart if provided (only for main pair)
if chart is not None and is_main_pair and action != 1: # Skip HOLD actions
chart.add_trade(
action=action_name,
price=current_price,
timestamp=timestamp,
pnl=pnl
)
# Update tracking variables (only for main pair)
if is_main_pair and action != 1: # If not HOLD
last_action = action
last_price = current_price
# Sleep for a short time
time.sleep(1)
except KeyboardInterrupt:
if is_main_pair:
logger.info(f"Real-time inference stopped by user for {symbol}")
logger.info(f"Final performance for {symbol} - Total PnL: {total_pnl:.2%}")
else:
logger.info(f"Real-time inference stopped by user for {symbol} (Context Only)")
except Exception as e:
logger.error(f"Error in real-time inference for {symbol}: {str(e)}")
raise
if __name__ == "__main__":
main()

View File

@ -1,310 +0,0 @@
import logging
import time
import threading
from typing import Dict, Any, List, Optional, Callable, Tuple, Union
from .exchanges import ExchangeInterface, MEXCInterface, BinanceInterface
logger = logging.getLogger(__name__)
class TradingAgent:
"""Trading agent that executes trades based on neural network signals.
This agent interfaces with different exchanges and executes trades
based on the signals received from the neural network.
"""
def __init__(self,
exchange_name: str = 'binance',
api_key: str = None,
api_secret: str = None,
test_mode: bool = True,
trade_symbols: List[str] = None,
position_size: float = 0.1,
max_trades_per_day: int = 5,
trade_cooldown_minutes: int = 60):
"""Initialize the trading agent.
Args:
exchange_name: Name of the exchange to use ('binance', 'mexc')
api_key: API key for the exchange
api_secret: API secret for the exchange
test_mode: If True, use test/sandbox environment
trade_symbols: List of trading symbols to monitor (e.g., ['BTC/USDT'])
position_size: Size of each position as a fraction of total available balance (0.0-1.0)
max_trades_per_day: Maximum number of trades to execute per day
trade_cooldown_minutes: Minimum time between trades in minutes
"""
self.exchange_name = exchange_name.lower()
self.api_key = api_key
self.api_secret = api_secret
self.test_mode = test_mode
self.trade_symbols = trade_symbols or ['BTC/USDT']
self.position_size = min(max(position_size, 0.01), 1.0) # Ensure between 0.01 and 1.0
self.max_trades_per_day = max(1, max_trades_per_day)
self.trade_cooldown_seconds = max(60, trade_cooldown_minutes * 60)
# Initialize exchange interface
self.exchange = self._create_exchange()
# Trading state
self.active = False
self.current_positions = {} # Symbol -> quantity
self.trades_today = {} # Symbol -> count
self.last_trade_time = {} # Symbol -> timestamp
self.trade_history = [] # List of trade records
# Threading
self.trading_thread = None
self.stop_event = threading.Event()
# Signal callback
self.signal_callback = None
# Connect to exchange
if not self.exchange.connect():
logger.error(f"Failed to connect to {self.exchange_name} exchange. Trading agent disabled.")
else:
logger.info(f"Successfully connected to {self.exchange_name} exchange.")
self._load_current_positions()
def _create_exchange(self) -> ExchangeInterface:
"""Create an exchange interface based on the exchange name."""
if self.exchange_name == 'mexc':
return MEXCInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
elif self.exchange_name == 'binance':
return BinanceInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
else:
raise ValueError(f"Unsupported exchange: {self.exchange_name}")
def _load_current_positions(self):
"""Load current positions from the exchange."""
for symbol in self.trade_symbols:
try:
base_asset, quote_asset = symbol.split('/')
balance = self.exchange.get_balance(base_asset)
if balance > 0:
self.current_positions[symbol] = balance
logger.info(f"Loaded existing position for {symbol}: {balance} {base_asset}")
except Exception as e:
logger.error(f"Error loading position for {symbol}: {str(e)}")
def start(self, signal_callback: Callable = None):
"""Start the trading agent.
Args:
signal_callback: Optional callback function to receive trade signals
"""
if self.active:
logger.warning("Trading agent is already running.")
return
self.active = True
self.signal_callback = signal_callback
self.stop_event.clear()
logger.info(f"Starting trading agent for {self.exchange_name} exchange.")
logger.info(f"Trading symbols: {', '.join(self.trade_symbols)}")
logger.info(f"Position size: {self.position_size * 100:.1f}% of available balance")
logger.info(f"Max trades per day: {self.max_trades_per_day}")
logger.info(f"Trade cooldown: {self.trade_cooldown_seconds // 60} minutes")
# Reset trading state
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
self.last_trade_time = {symbol: 0 for symbol in self.trade_symbols}
# Start trading thread
self.trading_thread = threading.Thread(target=self._trading_loop)
self.trading_thread.daemon = True
self.trading_thread.start()
def stop(self):
"""Stop the trading agent."""
if not self.active:
logger.warning("Trading agent is not running.")
return
logger.info("Stopping trading agent...")
self.active = False
self.stop_event.set()
if self.trading_thread and self.trading_thread.is_alive():
self.trading_thread.join(timeout=10)
logger.info("Trading agent stopped.")
def _trading_loop(self):
"""Main trading loop that monitors positions and executes trades."""
logger.info("Trading loop started.")
try:
while self.active and not self.stop_event.is_set():
# Check positions and update state
for symbol in self.trade_symbols:
try:
base_asset, _ = symbol.split('/')
current_balance = self.exchange.get_balance(base_asset)
# Update position if it has changed
if symbol in self.current_positions:
prev_balance = self.current_positions[symbol]
if abs(current_balance - prev_balance) > 0.001 * prev_balance:
logger.info(f"Position updated for {symbol}: {prev_balance} -> {current_balance} {base_asset}")
self.current_positions[symbol] = current_balance
except Exception as e:
logger.error(f"Error checking position for {symbol}: {str(e)}")
# Sleep for a while
time.sleep(10)
except Exception as e:
logger.error(f"Error in trading loop: {str(e)}")
finally:
logger.info("Trading loop stopped.")
def reset_daily_limits(self):
"""Reset daily trading limits. Call this at the start of each trading day."""
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
logger.info("Daily trading limits reset.")
def process_signal(self, symbol: str, action: str,
confidence: float = None, timestamp: int = None) -> Optional[Dict[str, Any]]:
"""Process a trading signal and execute a trade if conditions are met.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
action: Trade action ('BUY', 'SELL', 'HOLD')
confidence: Confidence level of the signal (0.0-1.0)
timestamp: Timestamp of the signal (unix time)
Returns:
dict: Trade information if a trade was executed, None otherwise
"""
if not self.active:
logger.warning("Trading agent is not active. Signal ignored.")
return None
if symbol not in self.trade_symbols:
logger.warning(f"Symbol {symbol} is not in the trading symbols list. Signal ignored.")
return None
if action not in ['BUY', 'SELL', 'HOLD']:
logger.warning(f"Invalid action: {action}. Must be 'BUY', 'SELL', or 'HOLD'.")
return None
# Log the signal
confidence_str = f" (confidence: {confidence:.2f})" if confidence is not None else ""
logger.info(f"Received {action} signal for {symbol}{confidence_str}")
# Ignore HOLD signals for trading
if action == 'HOLD':
return None
# Check if we can trade based on limits
current_time = time.time()
# Check max trades per day
if self.trades_today.get(symbol, 0) >= self.max_trades_per_day:
logger.warning(f"Max trades per day reached for {symbol}. Signal ignored.")
return None
# Check trade cooldown
last_trade_time = self.last_trade_time.get(symbol, 0)
if current_time - last_trade_time < self.trade_cooldown_seconds:
cooldown_remaining = self.trade_cooldown_seconds - (current_time - last_trade_time)
logger.warning(f"Trade cooldown active for {symbol}. {cooldown_remaining:.1f} seconds remaining. Signal ignored.")
return None
# Check if the action makes sense based on current position
base_asset, _ = symbol.split('/')
current_position = self.current_positions.get(symbol, 0)
if action == 'BUY' and current_position > 0:
logger.warning(f"Already have a position in {symbol}. BUY signal ignored.")
return None
if action == 'SELL' and current_position <= 0:
logger.warning(f"No position in {symbol} to sell. SELL signal ignored.")
return None
# Execute the trade
try:
trade_result = self.exchange.execute_trade(
symbol=symbol,
action=action,
percent_of_balance=self.position_size
)
if trade_result:
# Update trading state
self.trades_today[symbol] = self.trades_today.get(symbol, 0) + 1
self.last_trade_time[symbol] = current_time
# Create trade record
trade_record = {
'symbol': symbol,
'action': action,
'timestamp': timestamp or int(current_time),
'confidence': confidence,
'order_id': trade_result.get('orderId') if isinstance(trade_result, dict) else None,
'status': 'executed'
}
# Add to trade history
self.trade_history.append(trade_record)
# Call signal callback if provided
if self.signal_callback:
self.signal_callback(trade_record)
logger.info(f"Successfully executed {action} trade for {symbol}")
return trade_record
else:
logger.error(f"Failed to execute {action} trade for {symbol}")
return None
except Exception as e:
logger.error(f"Error executing trade for {symbol}: {str(e)}")
return None
def get_current_positions(self) -> Dict[str, float]:
"""Get current positions.
Returns:
dict: Symbol -> position size
"""
return self.current_positions.copy()
def get_trade_history(self) -> List[Dict[str, Any]]:
"""Get trade history.
Returns:
list: List of trade records
"""
return self.trade_history.copy()
def get_exchange_info(self) -> Dict[str, Any]:
"""Get exchange information.
Returns:
dict: Exchange information
"""
return {
'name': self.exchange_name,
'test_mode': self.test_mode,
'active': self.active,
'trade_symbols': self.trade_symbols,
'position_size': self.position_size,
'max_trades_per_day': self.max_trades_per_day,
'trade_cooldown_seconds': self.trade_cooldown_seconds,
'trades_today': self.trades_today.copy()
}

View File

@ -1,585 +0,0 @@
import os
import sys
import time
import logging
import argparse
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import contextlib
from sklearn.model_selection import train_test_split
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import our enhanced agent
from NN.models.dqn_agent_enhanced import EnhancedDQNAgent
from NN.utils.data_interface import DataInterface
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('logs/enhanced_training.log')
]
)
logger = logging.getLogger(__name__)
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Train enhanced RL trading agent')
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train')
parser.add_argument('--max-steps', type=int, default=2000, help='Maximum steps per episode')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol')
parser.add_argument('--no-gpu', action='store_true', help='Disable GPU usage')
parser.add_argument('--confidence', type=float, default=0.4, help='Confidence threshold')
parser.add_argument('--load-model', type=str, default='', help='Load existing model')
parser.add_argument('--batch-size', type=int, default=128, help='Training batch size')
parser.add_argument('--learning-rate', type=float, default=0.0003, help='Learning rate')
parser.add_argument('--no-pretrain', action='store_true', help='Skip pre-training')
parser.add_argument('--pretrain-epochs', type=int, default=20, help='Number of pre-training epochs')
return parser.parse_args()
def generate_price_prediction_training_data(data_1m, data_1h, data_1d, window_size=20):
"""
Generate labeled training data for price prediction pre-training
Args:
data_1m: 1-minute candle data
data_1h: 1-hour candle data
data_1d: 1-day candle data
window_size: Size of the observation window
Returns:
X, y_immediate, y_midterm, y_longterm, y_values
"""
logger.info("Generating price prediction training data")
# Features to use
ohlcv_columns = ['open', 'high', 'low', 'close', 'volume']
# Create feature sets
X = []
y_immediate = [] # 1m prediction (next 5min)
y_midterm = [] # 1h prediction (next few hours)
y_longterm = [] # 1d prediction (next day)
y_values = [] # % change for each timeframe
# Need enough data for all timeframes
if len(data_1m) < window_size + 5 or len(data_1h) < 2 or len(data_1d) < 2:
logger.error("Not enough data for all timeframes")
return np.array([]), np.array([]), np.array([]), np.array([]), np.array([])
# Generate examples
for i in range(window_size, len(data_1m) - 5):
# Skip if we can't align with higher timeframes
if i % 60 != 0: # Only use minutes that align with hour boundaries
continue
try:
# Get window of 1m data as input
window_1m = data_1m[i-window_size:i][ohlcv_columns].values
# Find corresponding indices in higher timeframes
curr_timestamp = data_1m.index[i]
h_idx = data_1h.index.get_indexer([curr_timestamp], method='nearest')[0]
d_idx = data_1d.index.get_indexer([curr_timestamp], method='nearest')[0]
# Skip if indices are out of bounds
if h_idx < 0 or h_idx >= len(data_1h) - 1 or d_idx < 0 or d_idx >= len(data_1d) - 1:
continue
# Get future prices for label generation
future_5m = data_1m[i+5]['close']
future_1h = data_1h[h_idx+1]['close']
future_1d = data_1d[d_idx+1]['close']
current_price = data_1m[i]['close']
# Calculate % change for each timeframe
change_5m = (future_5m - current_price) / current_price * 100
change_1h = (future_1h - current_price) / current_price * 100
change_1d = (future_1d - current_price) / current_price * 100
# Determine price direction (0=down, 1=sideways, 2=up)
def get_direction(change):
if change < -0.5: # Down if less than -0.5%
return 0
elif change > 0.5: # Up if more than 0.5%
return 2
else: # Sideways if between -0.5% and 0.5%
return 1
direction_5m = get_direction(change_5m)
direction_1h = get_direction(change_1h)
direction_1d = get_direction(change_1d)
# Add to dataset
X.append(window_1m.flatten())
y_immediate.append(direction_5m)
y_midterm.append(direction_1h)
y_longterm.append(direction_1d)
y_values.append([change_5m, change_1h, change_1d, 0]) # Last value reserved
except Exception as e:
logger.warning(f"Error generating training example at index {i}: {str(e)}")
# Convert to numpy arrays
X = np.array(X)
y_immediate = np.array(y_immediate)
y_midterm = np.array(y_midterm)
y_longterm = np.array(y_longterm)
y_values = np.array(y_values)
logger.info(f"Generated {len(X)} training examples")
logger.info(f"Class distribution - Immediate: {np.bincount(y_immediate)}, "
f"Midterm: {np.bincount(y_midterm)}, Long-term: {np.bincount(y_longterm)}")
return X, y_immediate, y_midterm, y_longterm, y_values
def pretrain_price_prediction(agent, data_interface, n_epochs=20, batch_size=128, device=None):
"""
Pre-train the price prediction capabilities of the agent
Args:
agent: EnhancedDQNAgent instance
data_interface: DataInterface instance
n_epochs: Number of pre-training epochs
batch_size: Batch size for pre-training
device: Device to use for pre-training
Returns:
The pre-trained agent
"""
logger.info("Starting price prediction pre-training")
try:
# Ensure we have the necessary timeframes
timeframes_needed = ['1m', '1h', '1d']
for tf in timeframes_needed:
if tf not in data_interface.timeframes:
logger.info(f"Adding timeframe {tf} for pre-training")
# Add timeframe to the list if not present
if tf not in data_interface.timeframes:
data_interface.timeframes.append(tf)
data_interface.dataframes[tf] = None
# Get data for each timeframe
data_1m = data_interface.get_historical_data(timeframe='1m')
data_1h = data_interface.get_historical_data(timeframe='1h')
data_1d = data_interface.get_historical_data(timeframe='1d')
# Generate labeled training data
X, y_immediate, y_midterm, y_longterm, y_values = generate_price_prediction_training_data(
data_1m, data_1h, data_1d, window_size=20
)
if len(X) == 0:
logger.error("No training examples generated. Skipping pre-training.")
return agent
# Split data into training and validation sets
X_train, X_val, y_imm_train, y_imm_val, y_mid_train, y_mid_val, y_long_train, y_long_val, y_val_train, y_val_val = train_test_split(
X, y_immediate, y_midterm, y_longterm, y_values, test_size=0.2, random_state=42
)
# Convert to torch tensors
X_train_tensor = torch.FloatTensor(X_train).to(device)
y_imm_train_tensor = torch.LongTensor(y_imm_train).to(device)
y_mid_train_tensor = torch.LongTensor(y_mid_train).to(device)
y_long_train_tensor = torch.LongTensor(y_long_train).to(device)
y_val_train_tensor = torch.FloatTensor(y_val_train).to(device)
X_val_tensor = torch.FloatTensor(X_val).to(device)
y_imm_val_tensor = torch.LongTensor(y_imm_val).to(device)
y_mid_val_tensor = torch.LongTensor(y_mid_val).to(device)
y_long_val_tensor = torch.LongTensor(y_long_val).to(device)
y_val_val_tensor = torch.FloatTensor(y_val_val).to(device)
# Calculate class weights for imbalanced data
def get_class_weights(labels):
counts = np.bincount(labels)
if len(counts) < 3: # Ensure we have 3 classes
counts = np.append(counts, [0] * (3 - len(counts)))
weights = 1.0 / np.array(counts)
weights = weights / np.sum(weights) # Normalize
return weights
imm_weights = torch.FloatTensor(get_class_weights(y_imm_train)).to(device)
mid_weights = torch.FloatTensor(get_class_weights(y_mid_train)).to(device)
long_weights = torch.FloatTensor(get_class_weights(y_long_train)).to(device)
# Create DataLoader for batch training
train_dataset = TensorDataset(
X_train_tensor, y_imm_train_tensor, y_mid_train_tensor,
y_long_train_tensor, y_val_train_tensor
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Set up loss functions with class weights
imm_criterion = nn.CrossEntropyLoss(weight=imm_weights)
mid_criterion = nn.CrossEntropyLoss(weight=mid_weights)
long_criterion = nn.CrossEntropyLoss(weight=long_weights)
value_criterion = nn.MSELoss()
# Set up optimizer (separate from agent's optimizer)
pretrain_optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=0.0002)
pretrain_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
pretrain_optimizer, mode='min', factor=0.5, patience=3, verbose=True
)
# Set model to training mode
agent.policy_net.train()
# Training loop
best_val_loss = float('inf')
patience = 5
patience_counter = 0
# Create TensorBoard writer for pre-training
writer = SummaryWriter(log_dir=f'runs/pretrain_{int(time.time())}')
for epoch in range(n_epochs):
# Training phase
train_loss = 0.0
imm_correct, mid_correct, long_correct = 0, 0, 0
total = 0
for X_batch, y_imm_batch, y_mid_batch, y_long_batch, y_val_batch in train_loader:
# Zero gradients
pretrain_optimizer.zero_grad()
# Forward pass
with torch.cuda.amp.autocast() if agent.use_mixed_precision else contextlib.nullcontext():
q_values, _, price_preds, _ = agent.policy_net(X_batch)
# Calculate losses for each prediction head
imm_loss = imm_criterion(price_preds['immediate'], y_imm_batch)
mid_loss = mid_criterion(price_preds['midterm'], y_mid_batch)
long_loss = long_criterion(price_preds['longterm'], y_long_batch)
value_loss = value_criterion(price_preds['values'], y_val_batch)
# Combined loss (weighted by importance)
total_loss = imm_loss + 0.7 * mid_loss + 0.5 * long_loss + 0.3 * value_loss
# Backward pass and optimize
if agent.use_mixed_precision:
agent.scaler.scale(total_loss).backward()
agent.scaler.unscale_(pretrain_optimizer)
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
agent.scaler.step(pretrain_optimizer)
agent.scaler.update()
else:
total_loss.backward()
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
pretrain_optimizer.step()
# Accumulate metrics
train_loss += total_loss.item()
total += X_batch.size(0)
# Calculate accuracy
_, imm_pred = torch.max(price_preds['immediate'], 1)
_, mid_pred = torch.max(price_preds['midterm'], 1)
_, long_pred = torch.max(price_preds['longterm'], 1)
imm_correct += (imm_pred == y_imm_batch).sum().item()
mid_correct += (mid_pred == y_mid_batch).sum().item()
long_correct += (long_pred == y_long_batch).sum().item()
# Calculate epoch metrics
train_loss /= len(train_loader)
imm_acc = imm_correct / total
mid_acc = mid_correct / total
long_acc = long_correct / total
# Validation phase
agent.policy_net.eval()
val_loss = 0.0
imm_val_correct, mid_val_correct, long_val_correct = 0, 0, 0
with torch.no_grad():
# Forward pass on validation data
q_values, _, val_price_preds, _ = agent.policy_net(X_val_tensor)
# Calculate validation losses
val_imm_loss = imm_criterion(val_price_preds['immediate'], y_imm_val_tensor)
val_mid_loss = mid_criterion(val_price_preds['midterm'], y_mid_val_tensor)
val_long_loss = long_criterion(val_price_preds['longterm'], y_long_val_tensor)
val_value_loss = value_criterion(val_price_preds['values'], y_val_val_tensor)
val_total_loss = val_imm_loss + 0.7 * val_mid_loss + 0.5 * val_long_loss + 0.3 * val_value_loss
val_loss = val_total_loss.item()
# Calculate validation accuracy
_, imm_val_pred = torch.max(val_price_preds['immediate'], 1)
_, mid_val_pred = torch.max(val_price_preds['midterm'], 1)
_, long_val_pred = torch.max(val_price_preds['longterm'], 1)
imm_val_correct = (imm_val_pred == y_imm_val_tensor).sum().item()
mid_val_correct = (mid_val_pred == y_mid_val_tensor).sum().item()
long_val_correct = (long_val_pred == y_long_val_tensor).sum().item()
imm_val_acc = imm_val_correct / len(X_val_tensor)
mid_val_acc = mid_val_correct / len(X_val_tensor)
long_val_acc = long_val_correct / len(X_val_tensor)
# Log to TensorBoard
writer.add_scalar('pretrain/train_loss', train_loss, epoch)
writer.add_scalar('pretrain/val_loss', val_loss, epoch)
writer.add_scalar('pretrain/imm_acc', imm_acc, epoch)
writer.add_scalar('pretrain/mid_acc', mid_acc, epoch)
writer.add_scalar('pretrain/long_acc', long_acc, epoch)
writer.add_scalar('pretrain/imm_val_acc', imm_val_acc, epoch)
writer.add_scalar('pretrain/mid_val_acc', mid_val_acc, epoch)
writer.add_scalar('pretrain/long_val_acc', long_val_acc, epoch)
# Learning rate scheduling
pretrain_scheduler.step(val_loss)
# Early stopping check
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# Copy policy_net weights to target_net
agent.target_net.load_state_dict(agent.policy_net.state_dict())
logger.info(f"Saved best model with validation loss: {val_loss:.4f}")
# Save pre-trained model
agent.save("NN/models/saved/enhanced_dqn_pretrained")
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(f"Early stopping triggered after {epoch+1} epochs")
break
# Log progress
logger.info(f"Epoch {epoch+1}/{n_epochs}: "
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
f"Imm Acc: {imm_acc:.4f}/{imm_val_acc:.4f}, "
f"Mid Acc: {mid_acc:.4f}/{mid_val_acc:.4f}, "
f"Long Acc: {long_acc:.4f}/{long_val_acc:.4f}")
# Set model back to training mode for next epoch
agent.policy_net.train()
writer.close()
logger.info("Price prediction pre-training complete")
return agent
except Exception as e:
logger.error(f"Error during price prediction pre-training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return agent
def train_enhanced_rl(args):
"""
Train the enhanced RL agent for trading
Args:
args: Command line arguments
"""
# Setup device
if args.no_gpu:
device = torch.device('cpu')
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Set up data interface
data_interface = DataInterface(symbol=args.symbol, timeframes=['1m', '5m', '15m'])
# Fetch historical data for each timeframe
for timeframe in data_interface.timeframes:
df = data_interface.get_historical_data(timeframe=timeframe)
logger.info(f"Using data for {args.symbol} {timeframe} ({len(data_interface.dataframes[timeframe])} candles)")
# Create environment for training
from NN.environments.trading_env import TradingEnvironment
window_size = 20
train_env = TradingEnvironment(
data_interface=data_interface,
initial_balance=10000.0,
transaction_fee=0.0002,
window_size=window_size,
max_position=1.0,
reward_scaling=100.0
)
# Create agent with improved parameters
state_shape = train_env.observation_space.shape
n_actions = train_env.action_space.n
agent = EnhancedDQNAgent(
state_shape=state_shape,
n_actions=n_actions,
learning_rate=args.learning_rate,
gamma=0.95,
epsilon=1.0,
epsilon_min=0.05,
epsilon_decay=0.995,
buffer_size=50000,
batch_size=args.batch_size,
target_update=10,
confidence_threshold=args.confidence,
device=device
)
# Load existing model if specified
if args.load_model:
model_path = args.load_model
if agent.load(model_path):
logger.info(f"Loaded existing model from {model_path}")
else:
logger.error(f"Error loading model from {model_path}")
# Pre-training for price prediction
if not args.no_pretrain and not args.load_model:
logger.info("Starting pre-training phase")
agent = pretrain_price_prediction(
agent=agent,
data_interface=data_interface,
n_epochs=args.pretrain_epochs,
batch_size=args.batch_size,
device=device
)
logger.info("Pre-training completed")
# Setup TensorBoard
writer = SummaryWriter(log_dir=f'runs/enhanced_rl_{int(time.time())}')
# Log hardware info
writer.add_text("hardware/device", str(device), 0)
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
# Move agent to device
agent.move_models_to_device(device)
# Training loop
logger.info(f"Starting enhanced training for {args.episodes} episodes")
total_rewards = []
episode_losses = []
trade_win_rates = []
best_reward = -np.inf
try:
for episode in range(args.episodes):
# Reset environment for new episode
state = train_env.reset()
total_reward = 0.0
done = False
step = 0
episode_start_time = time.time()
# Track trade statistics
trades = []
wins = 0
losses = 0
# Run episode
while not done and step < args.max_steps:
# Choose action
action, confidence = agent.act(state)
# Take action in environment
next_state, reward, done, info = train_env.step(action)
# Remember experience
agent.remember(state, action, reward, next_state, done)
# Track trade results
if 'trade_result' in info and info['trade_result'] is not None:
trade_result = info['trade_result']
trade_pnl = trade_result['pnl']
trades.append(trade_pnl)
if trade_pnl > 0:
wins += 1
logger.info(f"Profitable trade! {trade_pnl:.2f}% profit, reward: {reward:.4f}")
else:
losses += 1
logger.info(f"Loss trade! {trade_pnl:.2f}% loss, penalty: {reward:.4f}")
# Update state and counters
state = next_state
total_reward += reward
step += 1
# Train agent
loss = agent.replay()
if loss > 0:
episode_losses.append(loss)
# Log training metrics for each episode
episode_time = time.time() - episode_start_time
total_rewards.append(total_reward)
# Calculate win rate
win_rate = wins / max(1, (wins + losses))
trade_win_rates.append(win_rate)
# Log to console and TensorBoard
logger.info(f"Episode {episode}/{args.episodes} - Reward: {total_reward:.4f}, Win Rate: {win_rate:.2f}, "
f"Trades: {len(trades)}, Balance: ${train_env.balance:.2f}, Epsilon: {agent.epsilon:.4f}, "
f"Time: {episode_time:.2f}s")
writer.add_scalar('metrics/reward', total_reward, episode)
writer.add_scalar('metrics/balance', train_env.balance, episode)
writer.add_scalar('metrics/win_rate', win_rate, episode)
writer.add_scalar('metrics/trades', len(trades), episode)
writer.add_scalar('metrics/epsilon', agent.epsilon, episode)
if episode_losses:
avg_loss = sum(episode_losses) / len(episode_losses)
writer.add_scalar('metrics/loss', avg_loss, episode)
# Check if this is the best model so far
if total_reward > best_reward:
best_reward = total_reward
# Save best model
agent.save(f"NN/models/saved/enhanced_dqn_best")
logger.info(f"New best model saved with reward: {best_reward:.4f}")
# Save checkpoint every 10 episodes
if episode % 10 == 0 and episode > 0:
agent.save(f"NN/models/saved/enhanced_dqn_checkpoint")
logger.info(f"Checkpoint saved at episode {episode}")
# Reset episode losses
episode_losses = []
# Final save
agent.save(f"NN/models/saved/enhanced_dqn_final")
logger.info("Enhanced training completed, final model saved")
except KeyboardInterrupt:
logger.info("Training interrupted by user")
except Exception as e:
logger.error(f"Training failed: {str(e)}")
import traceback
logger.error(traceback.format_exc())
finally:
# Close TensorBoard writer
writer.close()
return agent, train_env
if __name__ == "__main__":
# Create logs directory if it doesn't exist
os.makedirs("logs", exist_ok=True)
os.makedirs("NN/models/saved", exist_ok=True)
# Parse arguments
args = parse_args()
# Start training
train_enhanced_rl(args)

View File

@ -1,657 +0,0 @@
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import logging
import time
from datetime import datetime
import os
import sys
import pandas as pd
import gym
import json
import random
import torch.nn as nn
import contextlib
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from NN.utils.data_interface import DataInterface
from NN.utils.trading_env import TradingEnvironment
from NN.models.dqn_agent import DQNAgent
from NN.utils.signal_interpreter import SignalInterpreter
# Configure logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('rl_training.log'),
logging.StreamHandler()
]
)
# Set up device for PyTorch (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Log GPU status
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
gpu_names = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
logger.info(f"Using GPU: {gpu_names}")
# Enable TensorFloat32 for NVIDIA Ampere GPUs for faster training
if hasattr(torch.cuda, 'amp') and torch.cuda.is_bf16_supported():
logger.info("BFloat16 precision is supported - will use for faster training")
else:
logger.warning("GPU not available. Using CPU for training (slower).")
class RLTradingEnvironment(gym.Env):
"""
Reinforcement Learning environment for trading with technical indicators
from multiple timeframes
"""
def __init__(self, features_1m, features_1h, features_1d, window_size=20, trading_fee=0.0025, min_trade_interval=15):
super().__init__()
# Initialize attributes before parent class
self.window_size = window_size
self.num_features = features_1m.shape[1] - 1 # Exclude close price
# Count available timeframes
self.num_timeframes = 3 # We require all timeframes now
self.feature_dim = self.num_features * self.num_timeframes
# Store features from different timeframes
self.features_1m = features_1m
self.features_1h = features_1h
self.features_1d = features_1d
# Trading parameters
self.initial_balance = 1.0
self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
self.min_trade_interval = min_trade_interval # Minimum steps between trades
# Define action and observation spaces
self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold
self.observation_space = gym.spaces.Box(
low=-np.inf,
high=np.inf,
shape=(self.window_size, self.feature_dim),
dtype=np.float32
)
# State variables
self.reset()
# Callback for visualization or external monitoring
self.action_callback = None
def reset(self):
"""Reset the environment to initial state"""
self.balance = self.initial_balance
self.position = 0.0 # Amount of asset held
self.current_step = self.window_size
self.trades = 0
self.wins = 0
self.losses = 0
self.trade_history = []
self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade
# Get initial observation
observation = self._get_observation()
return observation
def _get_observation(self):
"""
Get the current state observation.
Combine features from multiple timeframes, reshaped for the CNN.
"""
# Calculate indices for each timeframe
idx_1m = min(self.current_step, self.features_1m.shape[0] - 1)
idx_1h = idx_1m // 60 # 60 minutes in an hour
idx_1d = idx_1h // 24 # 24 hours in a day
# Cap indices to prevent out of bounds
idx_1h = min(idx_1h, self.features_1h.shape[0] - 1)
idx_1d = min(idx_1d, self.features_1d.shape[0] - 1)
# Extract feature windows from each timeframe
window_1m = self.features_1m[max(0, idx_1m - self.window_size):idx_1m]
# Handle hourly timeframe
start_1h = max(0, idx_1h - self.window_size)
window_1h = self.features_1h[start_1h:idx_1h]
# Handle daily timeframe
start_1d = max(0, idx_1d - self.window_size)
window_1d = self.features_1d[start_1d:idx_1d]
# Pad if needed (for higher timeframes)
if len(window_1m) < self.window_size:
padding = np.zeros((self.window_size - len(window_1m), window_1m.shape[1]))
window_1m = np.vstack([padding, window_1m])
if len(window_1h) < self.window_size:
padding = np.zeros((self.window_size - len(window_1h), window_1h.shape[1]))
window_1h = np.vstack([padding, window_1h])
if len(window_1d) < self.window_size:
padding = np.zeros((self.window_size - len(window_1d), window_1d.shape[1]))
window_1d = np.vstack([padding, window_1d])
# Combine features from all timeframes
combined_features = np.hstack([
window_1m.reshape(self.window_size, -1),
window_1h.reshape(self.window_size, -1),
window_1d.reshape(self.window_size, -1)
])
# Convert to float32 and handle any NaN values
combined_features = np.nan_to_num(combined_features, nan=0.0).astype(np.float32)
return combined_features
def step(self, action):
"""Take an action and return the next state, reward, done flag, and info"""
# Initialize info dictionary for additional data
info = {
'trade_executed': False,
'price_change': 0.0,
'position_change': 0,
'current_price': 0.0,
'next_price': 0.0,
'balance_change': 0.0,
'reward_components': {},
'future_prices': {}
}
# Get the current and next price
current_price = self.features_1m[self.current_step, -1]
# Handle edge case at the end of the data
if self.current_step >= len(self.features_1m) - 1:
next_price = current_price # Use current price as next price
done = True
else:
next_price = self.features_1m[self.current_step + 1, -1]
done = False
# Handle zero or negative price (data error)
if current_price <= 0:
current_price = 0.01 # Set to a small positive number
logger.warning(f"Zero or negative price detected at step {self.current_step}. Setting to 0.01.")
if next_price <= 0:
next_price = current_price # Use current price instead
logger.warning(f"Zero or negative next price detected at step {self.current_step + 1}. Using current price.")
# Calculate price change as percentage
price_change_pct = ((next_price - current_price) / current_price) * 100
# Store prices in info
info['current_price'] = current_price
info['next_price'] = next_price
info['price_change'] = price_change_pct
# Initialize reward components dictionary
reward_components = {
'holding_reward': 0.0,
'action_reward': 0.0,
'profit_reward': 0.0,
'trade_freq_penalty': 0.0
}
# Default small negative reward to discourage inaction
reward = -0.01
reward_components['holding_reward'] = -0.01
# Track previous balance for changes
previous_balance = self.balance
# Execute action (0: Buy, 1: Sell, 2: Hold)
if action == 0: # Buy
if self.position == 0: # Only buy if we don't already have a position
# Calculate how much of the asset we can buy with 100% of balance
self.position = self.balance / current_price
self.balance = 0 # All balance used
# If price goes up after buying, that's good
expected_profit = price_change_pct
# Scale reward based on expected profit
if expected_profit > 0:
# Positive reward for profitable buy decision
action_reward = 0.1 + (expected_profit * 0.05) # Base reward + profit-based bonus
reward_components['action_reward'] = action_reward
reward += action_reward
else:
# Small negative reward for unprofitable buy
action_reward = -0.1 + (expected_profit * 0.03) # Smaller penalty for small losses
reward_components['action_reward'] = action_reward
reward += action_reward
# Check if we've traded too frequently
if len(self.trade_history) > 0:
last_trade_step = self.trade_history[-1]['step']
if self.current_step - last_trade_step < 5: # If less than 5 steps since last trade
freq_penalty = -0.2 # Penalty for trading too frequently
reward += freq_penalty
reward_components['trade_freq_penalty'] = freq_penalty
# Record the trade
self.trade_history.append({
'step': self.current_step,
'action': 'buy',
'price': current_price,
'position': self.position,
'balance': self.balance
})
info['trade_executed'] = True
logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
elif action == 1: # Sell
if self.position > 0: # Only sell if we have a position
# Calculate sale proceeds
sale_value = self.position * current_price
# Calculate profit or loss percentage from last buy
last_buy_price = None
for trade in reversed(self.trade_history):
if trade['action'] == 'buy':
last_buy_price = trade['price']
break
# If we found the last buy price, calculate profit
if last_buy_price is not None:
profit_pct = ((current_price - last_buy_price) / last_buy_price) * 100
# Highly reward profitable trades
if profit_pct > 0:
# Progressive reward based on profit percentage
profit_reward = min(5.0, profit_pct * 0.2) # Cap at 5.0 to prevent exploitation
reward_components['profit_reward'] = profit_reward
reward += profit_reward
logger.info(f"Profitable trade! {profit_pct:.2f}% profit, reward: {profit_reward:.4f}")
else:
# Penalize losses more heavily based on size of loss
loss_penalty = max(-3.0, profit_pct * 0.15) # Cap at -3.0 to prevent excessive punishment
reward_components['profit_reward'] = loss_penalty
reward += loss_penalty
logger.info(f"Loss trade! {profit_pct:.2f}% loss, penalty: {loss_penalty:.4f}")
# If price goes down after selling, that's good
if price_change_pct < 0:
# Reward for good timing on sell (avoiding future loss)
timing_reward = min(1.0, abs(price_change_pct) * 0.05)
reward_components['action_reward'] = timing_reward
reward += timing_reward
# Check for trading too frequently
if len(self.trade_history) > 0:
last_trade_step = self.trade_history[-1]['step']
if self.current_step - last_trade_step < 5: # If less than 5 steps since last trade
freq_penalty = -0.2 # Penalty for trading too frequently
reward += freq_penalty
reward_components['trade_freq_penalty'] = freq_penalty
# Update balance and position
self.balance = sale_value
position_change = self.position
self.position = 0
# Record the trade
self.trade_history.append({
'step': self.current_step,
'action': 'sell',
'price': current_price,
'position': self.position,
'balance': self.balance
})
info['trade_executed'] = True
info['position_change'] = position_change
logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, new balance: {self.balance:.4f}")
elif action == 2: # Hold
# Small reward if holding was a good decision
if self.position > 0 and price_change_pct > 0: # Holding long position during price increase
hold_reward = price_change_pct * 0.01 # Small reward proportional to price increase
reward += hold_reward
reward_components['holding_reward'] = hold_reward
elif self.position == 0 and price_change_pct < 0: # Holding cash during price decrease
hold_reward = abs(price_change_pct) * 0.01 # Small reward for avoiding loss
reward += hold_reward
reward_components['holding_reward'] = hold_reward
# Move to the next step
self.current_step += 1
# Update current portfolio value
if self.position > 0:
self.current_value = self.balance + (self.position * next_price)
else:
self.current_value = self.balance
# Calculate balance change
balance_change = self.current_value - previous_balance
info['balance_change'] = balance_change
# Check if we've reached the end of the data
if self.current_step >= len(self.features_1m) - 1:
done = True
# Final evaluation if we have a position
if self.position > 0:
# Sell remaining position at the final price
final_balance = self.balance + (self.position * next_price)
# Calculate final portfolio value and return
final_return_pct = ((final_balance - self.initial_balance) / self.initial_balance) * 100
# Add big reward/penalty based on overall performance
performance_reward = final_return_pct * 0.1
reward += performance_reward
reward_components['final_performance'] = performance_reward
logger.info(f"Episode ended. Final balance: {final_balance:.4f}, Return: {final_return_pct:.2f}%")
# Get future prices for evaluation (1-hour and 1-day ahead)
info['future_prices'] = {}
# 1-hour future price if hourly data is available
if hasattr(self, 'features_1h') and self.features_1h is not None:
# Find the closest hourly data point
if self.current_step < len(self.features_1m):
current_time = self.current_step # Use as index for simplicity
hourly_idx = min(current_time // 60, len(self.features_1h) - 1) # Assuming 60 minutes per hour
if hourly_idx < len(self.features_1h) - 1:
future_1h_price = self.features_1h[hourly_idx + 1, -1]
info['future_prices']['1h'] = future_1h_price
# 1-day future price if daily data is available
if hasattr(self, 'features_1d') and self.features_1d is not None:
# Find the closest daily data point
if self.current_step < len(self.features_1m):
current_time = self.current_step # Use as index for simplicity
daily_idx = min(current_time // 1440, len(self.features_1d) - 1) # Assuming 1440 minutes per day
if daily_idx < len(self.features_1d) - 1:
future_1d_price = self.features_1d[daily_idx + 1, -1]
info['future_prices']['1d'] = future_1d_price
# Get next observation
next_state = self._get_observation()
# Store reward components in info
info['reward_components'] = reward_components
# Clip reward to prevent extreme values
reward = np.clip(reward, -10.0, 10.0)
return next_state, reward, done, info
def set_action_callback(self, callback):
"""
Set a callback function to be called after each action
Args:
callback: Function with signature (action, price, reward, info)
"""
self.action_callback = callback
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
action_callback=None, episode_callback=None, symbol="BTC/USDT",
pretrain_price_prediction_enabled=False, pretrain_epochs=10):
"""
Train a reinforcement learning agent for trading using ONLY real market data
Args:
env_class: Optional environment class override
num_episodes: Number of episodes to train for
max_steps: Maximum steps per episode
save_path: Path to save the trained model
action_callback: Callback function for monitoring actions
episode_callback: Callback function for monitoring episodes
symbol: Trading symbol to use
pretrain_price_prediction_enabled: DEPRECATED - No longer supported (synthetic data not used)
pretrain_epochs: DEPRECATED - No longer supported (synthetic data not used)
Returns:
tuple: (trained agent, environment)
"""
# Load data for the selected symbol
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m', '1h', '1d'])
try:
# Try to load data for the requested symbol using get_historical_data method
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
if data_1m is None or data_5m is None or data_15m is None or data_1h is None or data_1d is None:
raise FileNotFoundError("Could not retrieve all required timeframes data for specified symbol")
except Exception as e:
logger.warning(f"Data for {symbol} not available: {str(e)}. Using default cached data.")
# Try to use cached data if available
symbol = "BTC/USDT"
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m', '1h', '1d'])
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
if data_1m is None or data_5m is None or data_15m is None or data_1h is None or data_1d is None:
logger.error("Failed to retrieve all required timeframes data. Cannot continue training.")
raise ValueError("No data available for training")
# Create features from the data by adding technical indicators and converting to numpy format
if data_1m is not None:
data_1m = data_interface.add_technical_indicators(data_1m)
# Convert to numpy array with close price as the last column
features_1m = np.hstack([
data_1m.drop(['timestamp', 'close'], axis=1).values,
data_1m['close'].values.reshape(-1, 1)
])
else:
features_1m = None
if data_5m is not None:
data_5m = data_interface.add_technical_indicators(data_5m)
# Convert to numpy array with close price as the last column
features_5m = np.hstack([
data_5m.drop(['timestamp', 'close'], axis=1).values,
data_5m['close'].values.reshape(-1, 1)
])
else:
features_5m = None
if data_15m is not None:
data_15m = data_interface.add_technical_indicators(data_15m)
# Convert to numpy array with close price as the last column
features_15m = np.hstack([
data_15m.drop(['timestamp', 'close'], axis=1).values,
data_15m['close'].values.reshape(-1, 1)
])
else:
features_15m = None
if data_1h is not None:
data_1h = data_interface.add_technical_indicators(data_1h)
# Convert to numpy array with close price as the last column
features_1h = np.hstack([
data_1h.drop(['timestamp', 'close'], axis=1).values,
data_1h['close'].values.reshape(-1, 1)
])
else:
features_1h = None
if data_1d is not None:
data_1d = data_interface.add_technical_indicators(data_1d)
# Convert to numpy array with close price as the last column
features_1d = np.hstack([
data_1d.drop(['timestamp', 'close'], axis=1).values,
data_1d['close'].values.reshape(-1, 1)
])
else:
features_1d = None
# Check if we have all the required features
if features_1m is None or features_5m is None or features_15m is None or features_1h is None or features_1d is None:
logger.error("Failed to create features for all timeframes.")
raise ValueError("Could not create features for training")
# Create the environment
if env_class:
# Use provided environment class
env = env_class(features_1m, features_1h, features_1d)
else:
# Use the default environment
env = RLTradingEnvironment(features_1m, features_1h, features_1d)
# Set action callback if provided
if action_callback:
env.set_action_callback(action_callback)
# Get environment properties for agent creation
input_shape = env.observation_space.shape
n_actions = env.action_space.n
# Create the agent
agent = DQNAgent(
state_shape=input_shape,
n_actions=n_actions,
epsilon=1.0,
epsilon_decay=0.995,
epsilon_min=0.01,
learning_rate=0.0001,
gamma=0.99,
buffer_size=10000,
batch_size=64,
device=device # Pass device to agent for GPU usage
)
# Check if model file exists and load it
model_file = f"{save_path}_model.pth"
if os.path.exists(model_file):
try:
agent.load(model_file)
logger.info(f"Loaded existing model from {model_file}")
except Exception as e:
logger.error(f"Error loading model: {e}")
else:
logger.info("No existing model found. Starting with a new model.")
# Remove pre-training code since it used synthetic data
# Pre-training with real data would require a separate implementation
if pretrain_price_prediction_enabled:
logger.warning("Pre-training with synthetic data is no longer supported. Continuing with RL training only.")
# Create TensorBoard writer
writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}')
# Log GPU status to TensorBoard
writer.add_text("hardware/device", str(device), 0)
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
# Training loop
total_rewards = []
trade_win_rates = []
best_reward = -np.inf
# Move models to the appropriate device if not already there
agent.move_models_to_device(device)
# Enable mixed precision if GPU and feature is available
use_mixed_precision = False
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp'):
logger.info("Enabling mixed precision training")
use_mixed_precision = True
scaler = torch.cuda.amp.GradScaler()
# Define step callback for tensorboard logging and model tracking
def step_callback(action, price, reward, info):
# Pass to external callback if provided
if action_callback:
action_callback(env.current_step, action, price, reward, info)
# Main training loop
logger.info(f"Starting training for {num_episodes} episodes...")
logger.info(f"Starting training on device: {agent.device}")
try:
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
for step in range(max_steps):
# Select action
action = agent.act(state)
# Take action and observe next state and reward
next_state, reward, done, info = env.step(action)
# Store the experience in memory
agent.remember(state, action, reward, next_state, done)
# Update state and reward
state = next_state
total_reward += reward
# Train the agent by sampling from memory
if len(agent.memory) >= agent.batch_size:
loss = agent.replay()
if done or step == max_steps - 1:
break
# Track rewards
total_rewards.append(total_reward)
# Calculate trading metrics
win_rate = env.wins / max(1, env.trades)
trades = env.trades
# Log to TensorBoard
writer.add_scalar('Reward/Episode', total_reward, episode)
writer.add_scalar('Trade/WinRate', win_rate, episode)
writer.add_scalar('Trade/Count', trades, episode)
# Save best model
if total_reward > best_reward and episode > 10:
logger.info(f"New best average reward: {total_reward:.4f}, saving model")
agent.save(save_path)
best_reward = total_reward
# Periodic save every 100 episodes
if episode % 100 == 0 and episode > 0:
agent.save(f"{save_path}_episode_{episode}")
# Call episode callback if provided
if episode_callback:
# Add environment to info dict to use for extrema training
info_with_env = info.copy()
info_with_env['env'] = env
episode_callback(episode, total_reward, info_with_env)
# Final save
logger.info("Training completed, saving final model")
agent.save(f"{save_path}_final")
except Exception as e:
logger.error(f"Training failed: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Close TensorBoard writer
writer.close()
return agent, env
if __name__ == "__main__":
train_rl()

185
ROOT_CLEANUP_SUMMARY.md Normal file
View File

@ -0,0 +1,185 @@
# Root Directory Cleanup Summary
## Overview
Comprehensive cleanup of the root directory to remove unnecessary files, duplicates, and outdated documentation. The goal was to create a cleaner, more organized project structure while preserving all essential functionality.
## Files Removed
### Large Log Files (10MB+ space saved)
- `trading_bot.log` (6.1MB) - Large trading log file
- `realtime_testing.log` (3.9MB) - Large realtime testing log
- `realtime_20250401_181308.log` (25KB) - Old realtime log
- `exchange_test.log` (15KB) - Exchange testing log
- `training_launch.log` (10KB) - Training launch log
- `multi_timeframe_data.log` (1.6KB) - Multi-timeframe data log
- `custom_realtime_log.log` (1.6KB) - Custom realtime log
- `binance_data.log` (1.4KB) - Binance data log
- `binance_training.log` (96B) - Binance training log
### Duplicate Training Files (150KB+ space saved)
- `train_rl_with_realtime.py` (63KB) - Duplicate RL training → consolidated in `training/`
- `train_hybrid_fixed.py` (62KB) - Duplicate hybrid training → consolidated in `training/`
- `train_realtime_with_tensorboard.py` (18KB) - Duplicate training → consolidated in `training/`
- `train_config.py` (7.4KB) - Duplicate config → functionality in `core/config.py`
### Outdated Documentation (30KB+ space saved)
- `CLEANUP_PLAN.md` (5.9KB) - Old cleanup plan (superseded by execution)
- `CLEANUP_EXECUTION_PLAN.md` (6.8KB) - Executed plan (work complete)
- `SYNTHETIC_DATA_REMOVAL_SUMMARY.md` (2.9KB) - Outdated summary
- `MODEL_SAVING_FIX.md` (2.4KB) - Old documentation (issues resolved)
- `MODEL_SAVING_RECOMMENDATIONS.md` (3.0KB) - Old recommendations (implemented)
- `DATA_SOLUTION.md` (0KB) - Empty file
- `DISK_SPACE_OPTIMIZATION.md` (15KB) - Old optimization doc (cleanup complete)
- `IMPLEMENTATION_SUMMARY.md` (3.8KB) - Outdated summary (architecture modernized)
- `TRAINING_STATUS.md` (1B) - Empty file
- `_notes.md` (5.0KB) - Development notes and temporary commands
### Test Utility Files (10KB+ space saved)
- `add_test_trades.py` (1.6KB) - Test utility
- `generate_trades.py` (401B) - Simple test utility
- `random.nb.txt` (767B) - Random notes
- `live_trading_20250318_093045.csv` (50B) - Old trading log
- `training_stats.csv` (25KB) - Old training statistics
- `tests.py` (14KB) - Old test file (reorganized into `tests/` directory)
### Old Batch Files and Scripts (5KB+ space saved)
- `run_pytorch_nn.bat` (1.4KB) - Old batch file
- `run_nn_in_conda.bat` (206B) - Old conda batch file
- `setup_env.bat` (2.1KB) - Old environment setup
- `start_app.bat` (796B) - Old app startup batch
- `run_demo.py` (1015B) - Old demo file
- `run_live_demo.py` (836B) - Old live demo
### Duplicate/Obsolete Python Files (25KB+ space saved)
- `access_app.py` (1.5KB) - Old app access (functionality in `main_clean.py`)
- `fix_live_trading.py` (2.8KB) - Old fix file (issues resolved)
- `mexc_tick_stream.py` (10KB) - Exchange-specific (functionality in `dataprovider_realtime.py`)
- `run_nn.py` (8.6KB) - Old NN runner (functionality in `main_clean.py` and `training/`)
### Cache and Temporary Files
- `__pycache__/` directory - Python cache files
## Files Preserved
Essential files that remain in the root directory:
### Core Application Files
- `main_clean.py` (16KB) - Main application entry point
- `dataprovider_realtime.py` (106KB) - Real-time data provider
- `trading_main.py` (6.4KB) - Trading system main
- `config.yaml` (2.3KB) - Configuration file
- `requirements.txt` (134B) - Python dependencies
### Monitoring and Utilities
- `check_live_trading.py` (5.5KB) - Live trading checker
- `launch_training.py` (4KB) - Training launcher
- `monitor_training.py` (3KB) - Training monitor
- `start_monitoring.py` (5.5KB) - Monitoring starter
- `run_tensorboard.py` (2.3KB) - TensorBoard runner
- `run_tests.py` (5.9KB) - Unified test runner
- `read_logs.py` (4.4KB) - Log reader utility
### Documentation (Well-Organized)
- `readme.md` (5.5KB) - Main project README
- `CLEAN_ARCHITECTURE_SUMMARY.md` (8.4KB) - Architecture overview
- `CNN_TESTING_GUIDE.md` (6.8KB) - CNN testing guide
- `HYBRID_TRAINING_GUIDE.md` (5.2KB) - Hybrid training guide
- `README_enhanced_trading_model.md` (5.9KB) - Enhanced model README
- `README_LAUNCH_MODES.md` (10KB) - Launch modes documentation
- `REAL_MARKET_DATA_POLICY.md` (4.1KB) - Data policy
- `TENSORBOARD_MONITORING.md` (9.7KB) - TensorBoard monitoring guide
- `LOGGING.md` (2.4KB) - Logging documentation
- `TODO.md` (4.7KB) - Project TODO list
- `TEST_CLEANUP_SUMMARY.md` (5.5KB) - Test cleanup summary
### Test Files (Remaining Individual Tests)
- `test_positions.py` (4KB) - Position testing
- `test_tick_cache.py` (4.6KB) - Tick cache testing
- `test_timestamps.py` (1.3KB) - Timestamp testing
### Configuration and Assets
- `.env` (0.3KB) - Environment variables
- `.gitignore` (1KB) - Git ignore rules
- `start_live_trading.ps1` (0.7KB) - PowerShell startup script
- `fee_impact_analysis.png` (230KB) - Fee analysis chart
- `training_results.png` (75KB) - Training results visualization
## Space Savings Summary
- **Log files**: ~10MB freed
- **Duplicate training files**: ~150KB freed
- **Outdated documentation**: ~30KB freed
- **Test utilities**: ~10KB freed
- **Old scripts**: ~5KB freed
- **Obsolete Python files**: ~25KB freed
- **Cache files**: Variable space freed
**Total estimated space saved**: ~10.2MB+ (not including cache files)
## Benefits Achieved
### Organization
- **Cleaner structure**: Root directory now contains only essential files
- **Logical grouping**: Related functionality properly organized in subdirectories
- **Reduced clutter**: Eliminated duplicate and obsolete files
### Maintainability
- **Easier navigation**: Fewer files to search through
- **Clear purpose**: Each remaining file has a clear, documented purpose
- **Reduced confusion**: No more duplicate implementations
### Performance
- **Faster file operations**: Fewer files to scan
- **Reduced disk usage**: Significant space savings
- **Cleaner git history**: Fewer unnecessary files to track
## Directory Structure After Cleanup
```
gogo2/
├── Core Application
│ ├── main_clean.py
│ ├── dataprovider_realtime.py
│ ├── trading_main.py
│ └── config.yaml
├── Monitoring & Utilities
│ ├── check_live_trading.py
│ ├── launch_training.py
│ ├── monitor_training.py
│ ├── start_monitoring.py
│ ├── run_tensorboard.py
│ ├── run_tests.py
│ └── read_logs.py
├── Documentation
│ ├── readme.md
│ ├── CLEAN_ARCHITECTURE_SUMMARY.md
│ ├── CNN_TESTING_GUIDE.md
│ ├── HYBRID_TRAINING_GUIDE.md
│ ├── README_enhanced_trading_model.md
│ ├── README_LAUNCH_MODES.md
│ ├── REAL_MARKET_DATA_POLICY.md
│ ├── TENSORBOARD_MONITORING.md
│ ├── LOGGING.md
│ ├── TODO.md
│ └── TEST_CLEANUP_SUMMARY.md
├── Individual Tests
│ ├── test_positions.py
│ ├── test_tick_cache.py
│ └── test_timestamps.py
├── Configuration
│ ├── .env
│ ├── .gitignore
│ ├── requirements.txt
│ └── start_live_trading.ps1
└── Assets
├── fee_impact_analysis.png
└── training_results.png
```
## Conclusion
The root directory cleanup successfully:
- ✅ Removed 10MB+ of unnecessary files
- ✅ Eliminated duplicate implementations
- ✅ Organized remaining files logically
- ✅ Preserved all essential functionality
- ✅ Improved project maintainability
- ✅ Created cleaner development environment
The project now has a much cleaner and more professional structure that's easier to navigate and maintain.

View File

@ -1,65 +0,0 @@
# Synthetic Data Removal Summary
This document summarizes all changes made to eliminate the use of synthetic data throughout the trading system.
## Files Modified
1. **NN/train_rl.py**
- Removed `_create_synthetic_1s_data` method
- Removed `_create_synthetic_hourly_data` method
- Removed `_create_synthetic_daily_data` method
- Modified `RLTradingEnvironment` class to require all timeframes as real data
- Removed fallback to synthetic data when real data is unavailable
- Eliminated `generate_price_prediction_training_data` function
- Removed `pretrain_price_prediction` function that used synthetic data
- Updated `train_rl` function to load all required timeframes
2. **train_rl_with_realtime.py**
- Updated `EnhancedRLTradingEnvironment` class to require all timeframes
- Modified `create_enhanced_env` function to load all required timeframes
- Added prominent warning logs about requiring real market data
- Fixed imports to accommodate the changes
3. **README_enhanced_trading_model.md**
- Updated to emphasize that only real market data is supported
- Listed all required timeframes and their importance
- Added clear warnings against using synthetic data
- Updated usage instructions
4. **New files created**
- **REAL_MARKET_DATA_POLICY.md**: Comprehensive policy document explaining why we only use real market data
## Key Changes in Implementation
1. **Data Requirements**
- Now explicitly require all timeframes (1m, 5m, 15m, 1h, 1d) as real data
- Removed all synthetic data generation functionalities
- Added validation to ensure all required timeframes are available
2. **Error Handling**
- Improved error messages when required data is missing
- Eliminated synthetic data fallbacks when real data is unavailable
- Added clear logging to indicate when real data is required
3. **Training Process**
- Removed pre-training functions that used synthetic data
- Updated the main training loop to work exclusively with real data
- Disabled options related to synthetic data generation
## Benefits of These Changes
1. **More Realistic Training**
- Models now train exclusively on real market patterns and behaviors
- No risk of learning artificial patterns that don't exist in real markets
2. **Better Performance**
- Trading strategies more likely to work in live markets
- Models develop more realistic expectations about market behavior
3. **Simplified Codebase**
- Removal of synthetic data generation code reduces complexity
- Clearer data requirements make the system easier to understand and use
## Conclusion
These changes ensure our trading system works exclusively with real market data, providing more realistic training and better performance in live trading environments. The system now requires all timeframes to be available as real data and will not fall back to synthetic data under any circumstances.

148
TEST_CLEANUP_SUMMARY.md Normal file
View File

@ -0,0 +1,148 @@
# Test Cleanup Summary
## Overview
Comprehensive cleanup and consolidation of test files in the trading system project. The goal was to eliminate duplicate test implementations while preserving all valuable functionality and improving test organization.
## Test Files Removed
The following test files were removed after extracting their valuable functionality:
### Consolidated into New Test Suites
- `test_model.py` (11KB) - Extended training functionality → `tests/test_training_integration.py`
- `test_cnn_only.py` (2KB) - CNN training tests → `tests/test_training_integration.py`
- `test_training.py` (2KB) - Training pipeline tests → `tests/test_training_integration.py`
- `test_chart_data.py` (5KB) - Data provider tests → `tests/test_training_integration.py`
- `test_indicators.py` (4KB) - Technical indicators → `tests/test_indicators_and_signals.py`
- `test_signal_interpreter.py` (14KB) - Signal processing → `tests/test_indicators_and_signals.py`
### Removed as Non-Essential
- `test_dash.py` (3KB) - UI testing (not core functionality)
- `test_websocket.py` (1KB) - Minimal websocket test (covered by integration)
## New Consolidated Test Structure
### `tests/test_essential.py`
**Purpose**: Core functionality validation
- Critical module imports
- Configuration loading
- DataProvider initialization
- Model utilities
- Basic signal generation logic
### `tests/test_model_persistence.py`
**Purpose**: Comprehensive model save/load testing
- Robust save/load with multiple fallback methods
- MockAgent class for testing
- Comprehensive test coverage for model persistence
- Error handling and recovery testing
### `tests/test_training_integration.py`
**Purpose**: Training pipeline integration testing
- Data provider functionality (Binance API, TickStorage, RealTimeChart)
- CNN training with small datasets
- RL training with minimal episodes
- Extended training metrics tracking
- Integration between CNN and RL components
### `tests/test_indicators_and_signals.py`
**Purpose**: Technical analysis and signal processing
- Technical indicator calculation and categorization
- Signal distribution calculations
- Signal interpretation logic
- Signal filtering and threshold testing
- Oscillation prevention
- Market data analysis (price movements, volatility)
## Preserved Individual Test Files
These files were kept as they test specific functionality:
- `test_positions.py` (4KB) - Trading environment position testing
- `test_tick_cache.py` (5KB) - Tick caching with timestamp serialization
- `test_timestamps.py` (1KB) - Timestamp handling validation
## Updated Test Runner
**`run_tests.py`** - Unified test runner with multiple execution modes:
- `python run_tests.py` - Run all tests
- `python run_tests.py essential` - Quick validation
- `python run_tests.py persistence` - Model save/load tests
- `python run_tests.py training` - Training integration tests
- `python run_tests.py indicators` - Technical analysis tests
- `python run_tests.py individual` - Remaining individual tests
## Functionality Preservation
**Zero functionality was lost** during cleanup:
### From test_model.py
- Extended training session logic
- Comprehensive metrics tracking (train/val loss, accuracy, PnL, win rates)
- Signal distribution calculation
- Multiple position size testing
- Performance tracking over epochs
### From test_signal_interpreter.py
- Signal interpretation with confidence levels
- Threshold-based filtering
- Trend and volume filters
- Oscillation prevention logic
- Performance tracking for trades
### From test_indicators.py
- Technical indicator categorization (trend, momentum, volatility, volume)
- Multi-timeframe feature matrix creation
- Indicator calculation verification
### From test_chart_data.py
- Binance API data fetching
- TickStorage functionality
- RealTimeChart initialization
## Benefits Achieved
### Code Organization
- **Reduced file count**: 14 test files → 7 files (50% reduction)
- **Better structure**: Logical grouping by functionality
- **Unified interface**: Single test runner for all scenarios
### Maintainability
- **Consolidated logic**: Related tests grouped together
- **Comprehensive coverage**: All scenarios covered in organized suites
- **Better documentation**: Clear purpose for each test suite
### Space Savings
- **Eliminated duplicates**: Removed redundant test implementations
- **Cleaner codebase**: Easier to navigate and understand
- **Reduced complexity**: Fewer files to maintain
## Test Coverage
The new test structure provides comprehensive coverage:
1. **Essential functionality** - Core system validation
2. **Model persistence** - Robust save/load with fallbacks
3. **Training integration** - End-to-end training pipeline
4. **Technical analysis** - Indicators and signal processing
5. **Specific components** - Individual functionality tests
## Usage Examples
```bash
# Quick validation (fastest)
python run_tests.py essential
# Full test suite
python run_tests.py
# Specific test categories
python run_tests.py training
python run_tests.py indicators
python run_tests.py persistence
```
## Conclusion
The test cleanup successfully:
- ✅ Consolidated duplicate functionality
- ✅ Preserved all valuable test logic
- ✅ Improved code organization
- ✅ Created unified test interface
- ✅ Reduced maintenance overhead
- ✅ Enhanced test coverage documentation
The trading system now has a clean, well-organized test suite that covers all functionality while being easier to maintain and extend.

53
TODO.md
View File

@ -1,55 +1,6 @@
# Trading System Enhancement TODO List # Trading System Enhancement TODO List## Implemented Enhancements1. **Enhanced CNN Architecture** - [x] Implemented deeper CNN with residual connections for better feature extraction - [x] Added self-attention mechanisms to capture temporal patterns - [x] Implemented dueling architecture for more stable Q-value estimation - [x] Added more capacity to prediction heads for better confidence estimation2. **Improved Training Pipeline** - [x] Created example sifting dataset to prioritize high-quality training examples - [x] Implemented price prediction pre-training to bootstrap learning - [x] Lowered confidence threshold to allow more trades (0.4 instead of 0.5) - [x] Added better normalization of state inputs3. **Visualization and Monitoring** - [x] Added detailed confidence metrics tracking - [x] Implemented TensorBoard logging for pre-training and RL phases - [x] Added more comprehensive trading statistics4. **GPU Optimization & Performance** - [x] Fixed GPU detection and utilization during training - [x] Added GPU memory monitoring during training - [x] Implemented mixed precision training for faster GPU-based training - [x] Optimized batch sizes for GPU training5. **Trading Metrics & Monitoring** - [x] Added trade signal rate display and tracking - [x] Implemented counter for actions per second/minute/hour - [x] Added visualization of trading frequency over time - [x] Created moving average of trade signals to show trends6. **Reward Function Optimization** - [x] Revised reward function to better balance profit and risk - [x] Implemented progressive rewards based on holding time - [x] Added penalty for frequent trading (to reduce noise) - [x] Implemented risk-adjusted returns (Sharpe ratio) in reward calculation
## Implemented Enhancements ## Future Enhancements1. **Multi-timeframe Price Direction Prediction** - [ ] Extend CNN model to predict price direction for multiple timeframes - [ ] Modify CNN output to predict short, mid, and long-term price directions - [ ] Create data generation method for back-propagation using historical data - [ ] Implement real-time example generation for training - [ ] Feed direction predictions to RL agent as additional state information2. **Model Architecture Improvements** - [ ] Experiment with different residual block configurations - [ ] Implement Transformer-based models for better sequence handling - [ ] Try LSTM/GRU layers to combine with CNN for temporal data - [ ] Implement ensemble methods to combine multiple models3. **Training Process Improvements** - [ ] Implement curriculum learning (start with simple patterns, move to complex) - [ ] Add adversarial training to make model more robust - [ ] Implement Meta-Learning approaches for faster adaptation - [ ] Expand pre-training to include extrema detection4. **Trading Strategy Enhancements** - [ ] Add position sizing based on confidence levels (dynamic sizing based on prediction confidence) - [ ] Implement risk management constraints - [ ] Add support for stop-loss and take-profit mechanisms - [ ] Develop adaptive confidence thresholds based on market volatility - [ ] Implement Kelly criterion for optimal position sizing5. **Training Data & Model Improvements** - [ ] Implement data augmentation for more robust training - [ ] Simulate different market conditions - [ ] Add noise to training data - [ ] Generate synthetic data for rare market events6. **Model Interpretability** - [ ] Add visualization for model decision making - [ ] Implement feature importance analysis - [ ] Add attention visualization for key price patterns - [ ] Create explainable AI components7. **Performance Optimizations** - [ ] Optimize data loading pipeline for faster training - [ ] Implement distributed training for larger models - [ ] Profile and optimize inference speed for real-time trading - [ ] Optimize memory usage for longer training sessions8. **Research Directions** - [ ] Explore reinforcement learning algorithms beyond DQN (PPO, SAC, A3C) - [ ] Research ways to incorporate fundamental data - [ ] Investigate transfer learning from pre-trained models - [ ] Study methods to interpret model decisions for better trust
1. **Enhanced CNN Architecture**
- [x] Implemented deeper CNN with residual connections for better feature extraction
- [x] Added self-attention mechanisms to capture temporal patterns
- [x] Implemented dueling architecture for more stable Q-value estimation
- [x] Added more capacity to prediction heads for better confidence estimation
2. **Improved Training Pipeline**
- [x] Created example sifting dataset to prioritize high-quality training examples
- [x] Implemented price prediction pre-training to bootstrap learning
- [x] Lowered confidence threshold to allow more trades (0.4 instead of 0.5)
- [x] Added better normalization of state inputs
3. **Visualization and Monitoring**
- [x] Added detailed confidence metrics tracking
- [x] Implemented TensorBoard logging for pre-training and RL phases
- [x] Added more comprehensive trading statistics
## Future Enhancements
1. **Model Architecture Improvements**
- [ ] Experiment with different residual block configurations
- [ ] Implement Transformer-based models for better sequence handling
- [ ] Try LSTM/GRU layers to combine with CNN for temporal data
- [ ] Implement ensemble methods to combine multiple models
2. **Training Process Improvements**
- [ ] Implement curriculum learning (start with simple patterns, move to complex)
- [ ] Add adversarial training to make model more robust
- [ ] Implement Meta-Learning approaches for faster adaptation
- [ ] Expand pre-training to include extrema detection
3. **Trading Strategy Enhancements**
- [ ] Add position sizing based on confidence levels
- [ ] Implement risk management constraints
- [ ] Add support for stop-loss and take-profit mechanisms
- [ ] Develop adaptive confidence thresholds based on market volatility
4. **Performance Optimizations**
- [ ] Optimize data loading pipeline for faster training
- [ ] Implement distributed training for larger models
- [ ] Profile and optimize inference speed for real-time trading
- [ ] Optimize memory usage for longer training sessions
5. **Research Directions**
- [ ] Explore reinforcement learning algorithms beyond DQN (PPO, SAC, A3C)
- [ ] Research ways to incorporate fundamental data
- [ ] Investigate transfer learning from pre-trained models
- [ ] Study methods to interpret model decisions for better trust
## Implementation Timeline ## Implementation Timeline

View File

@ -1,224 +0,0 @@
# Cryptocurrency Trading System Improvements
## Overview
This document outlines necessary improvements to our cryptocurrency trading system to enhance performance, profitability, and monitoring capabilities.
## High Priority Tasks
### 1. GPU Utilization for Training
- [x] Fix GPU detection and utilization during training
- [x] Debug why CUDA is detected but not utilized (check logs showing "Starting training on device: cpu")
- [x] Ensure PyTorch correctly detects and uses available CUDA devices
- [x] Add GPU memory monitoring during training
- [x] Optimize batch sizes for GPU training
Implementation status:
- Added `setup_gpu()` function in `train_rl_with_realtime.py` to properly detect and configure GPU usage
- Added device parameter to DQNAgent to ensure models are created on the correct device
- Implemented mixed precision training for faster GPU-based training
- Added GPU memory monitoring and logging to TensorBoard
### 2. Trade Signal Rate Display
- [x] Add metrics to track and display trading frequency
- [x] Implement counter for actions per second/minute/hour
- [x] Add visualization to the chart showing trading frequency over time
- [x] Create a moving average of trade signals to show trends
- [x] Add dashboard section showing current and average trading rates
Implementation status:
- Added trade time tracking in `_add_trade_compat` function
- Added `calculate_trade_rate` method to `RealTimeChart` class
- Updated dashboard layout to display trade rates
- Added visualization of trade frequency in chart's bottom panel
### 3. Reward Function Optimization
- [x] Revise reward function to better balance profit and risk
- [x] Increase transaction fee penalty for more realistic simulation
- [x] Implement progressive rewards based on holding time
- [x] Add penalty for frequent trading (to reduce noise)
- [x] Scale rewards based on market volatility
- [x] Implement risk-adjusted returns (Sharpe ratio) in reward calculation
Implementation status:
- Created `improved_reward_function.py` with `ImprovedRewardCalculator` class
- Implemented Sharpe ratio for risk-adjusted rewards
- Added frequency penalty for excessive trading
- Added holding time rewards for profitable positions
- Integrated with `EnhancedRLTradingEnvironment` class
### 4. Multi-timeframe Price Direction Prediction
- [ ] Extend CNN model to predict price direction for multiple timeframes
- [ ] Modify CNN output to predict short, mid, and long-term price directions
- [ ] Create data generation method for back-propagation using historical data
- [ ] Implement real-time example generation for training
- [ ] Feed direction predictions to RL agent as additional state information
## Medium Priority Tasks
### 5. Position Sizing Optimization
- [ ] Implement dynamic position sizing based on confidence and volatility
- [ ] Add confidence score to model outputs
- [ ] Scale position size based on prediction confidence
- [ ] Implement Kelly criterion for optimal position sizing
### 6. Training Data Augmentation
- [ ] Implement data augmentation for more robust training
- [ ] Simulate different market conditions
- [ ] Add noise to training data
- [ ] Generate synthetic data for rare market events
### 7. Model Interpretability
- [ ] Add visualization for model decision making
- [ ] Implement feature importance analysis
- [ ] Add attention visualization for key price patterns
- [ ] Create explainable AI components
## Implementation Details
### Completed: Displaying Trade Rate
The trade rate display implementation has been completed in the `RealTimeChart` class:
```python
def calculate_trade_rate(self):
"""Calculate and return trading rate statistics based on recent trades"""
if not hasattr(self, 'trade_times') or not self.trade_times:
return {"per_second": 0, "per_minute": 0, "per_hour": 0}
# Get current time
now = datetime.now()
# Calculate different time windows
one_second_ago = now - timedelta(seconds=1)
one_minute_ago = now - timedelta(minutes=1)
one_hour_ago = now - timedelta(hours=1)
# Count trades in different time windows
trades_last_second = sum(1 for t in self.trade_times if t > one_second_ago)
trades_last_minute = sum(1 for t in self.trade_times if t > one_minute_ago)
trades_last_hour = sum(1 for t in self.trade_times if t > one_hour_ago)
# Calculate rates
return {
"per_second": trades_last_second,
"per_minute": trades_last_minute,
"per_hour": trades_last_hour
}
```
### Completed: Improved Reward Function
The improved reward function has been implemented in `improved_reward_function.py`:
```python
def calculate_reward(self, action, price_change, position_held_time=0,
volatility=None, is_profitable=False):
"""
Calculate the improved reward with risk adjustment
"""
# Calculate trading fee
fee = self.base_fee_rate
# Calculate frequency penalty
frequency_penalty = self._calculate_frequency_penalty()
# Base reward calculation
if action == 0: # BUY
# Small penalty for transaction plus frequency penalty
reward = -fee - frequency_penalty
elif action == 1: # SELL
# Calculate profit percentage minus fees (both entry and exit)
profit_pct = price_change
net_profit = profit_pct - (fee * 2)
# Scale reward and apply frequency penalty
reward = net_profit * 10 # Scale reward
reward -= frequency_penalty
# Record PnL for risk adjustment
self.record_pnl(net_profit)
else: # HOLD
# Small reward for holding a profitable position, small cost otherwise
if is_profitable:
reward = self._calculate_holding_reward(position_held_time, price_change)
else:
reward = -0.0001 # Very small negative reward
# Apply risk adjustment if enabled
if self.risk_adjusted:
reward = self._calculate_risk_adjustment(reward)
# Record this action for future frequency calculations
self.record_trade(action=action)
return reward
```
### Completed: GPU Optimization
Added GPU optimization in `train_rl_with_realtime.py`:
```python
def setup_gpu():
"""
Configure GPU usage for PyTorch training
Returns:
tuple: (success, device, message)
"""
try:
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
device_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
logger.info(f"Found {gpu_count} GPU(s): {', '.join(device_info)}")
device = torch.device("cuda:0")
# Test CUDA by creating a small tensor
test_tensor = torch.tensor([1.0, 2.0, 3.0], device=device)
# Enable mixed precision if supported
if hasattr(torch.cuda, 'amp') and torch.cuda.is_bf16_supported():
logger.info("BFloat16 is supported - enabling for faster training")
return True, device, f"GPU enabled: {device_info}"
else:
return False, torch.device("cpu"), "GPU not available, using CPU"
except Exception as e:
return False, torch.device("cpu"), f"GPU setup failed: {str(e)}"
```
### CNN Price Direction Prediction (To be implemented)
```python
def generate_direction_examples(self, historical_data, timeframes=['1m', '1h', '1d']):
"""Generate price direction examples from historical data"""
examples = []
labels = []
for tf in timeframes:
df = historical_data[tf]
for i in range(20, len(df) - 10):
# Use window of 20 candles for input
window = df.iloc[i-20:i]
# Create labels for future price direction (next 5, 10, 20 candles)
future_5 = df.iloc[i].close < df.iloc[i+5].close # True if price goes up
future_10 = df.iloc[i].close < df.iloc[i+10].close
future_20 = df.iloc[i].close < df.iloc[min(i+20, len(df)-1)].close
examples.append(window.values)
labels.append([future_5, future_10, future_20])
return np.array(examples), np.array(labels)
```
## Validation Plan
After implementing these improvements, we should validate the system with:
1. Backtesting on historical data
2. Forward testing with small position sizes
3. A/B testing of different reward functions
4. Measuring the improvement in profitability and Sharpe ratio
## Progress Tracking
- Implementation started: June 2023
- GPU utilization fixed: July 2023
- Trade signal rate display implemented: July 2023
- Reward function optimized: July 2023
- CNN direction prediction added: To be completed
- Full system tested: To be completed

View File

@ -1 +0,0 @@

Binary file not shown.

Binary file not shown.

View File

@ -1,67 +0,0 @@
https://github.com/mexcdevelop/mexc-api-sdk/blob/main/README.md#test-new-order
https://mexcdevelop.github.io/apidocs/spot_v3_en/#test-new-order
python mexc_tick_visualizer.py --symbol BTC/USDT --interval 1.0 --candle 60
python main.py --mode live --symbol ETH/USDT --timeframe 1m --use-websocket
python main.py --mode live --symbol BTC/USDT --timeframe 1m --use-websocket --dashboard
# http://localhost:8060
& 'C:\Users\popov\miniforge3\python.exe' 'c:\Users\popov\.cursor\extensions\ms-python.debugpy-2024.6.0-win32-x64\bundled\libs\debugpy\adapter/../..\debugpy\launcher' '51766' '--' 'main.py' '--mode' 'live' '--demo' 'false' '--symbol' 'ETH/USDT' '--timeframe' '1m' '--leverage' '50'
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch
ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function.
add 1h and 1d OHLCV data to let the model have the price action context
2025-03-10 12:11:28,651 - INFO - Initialized environment with 500 candles
C:\Users\popov\miniforge3\Lib\site-packages\torch\nn\modules\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
warnings.warn(
main.py:1105: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
self.scaler = amp.GradScaler()
C:\Users\popov\miniforge3\Lib\site-packages\torch\amp\grad_scaler.py:132: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.
warnings.warn(
2025-03-10 12:11:30,927 - INFO - Starting training for 1000 episodes...
2025-03-10 12:11:30,927 - INFO - Starting training on device: cpu
2025-03-10 12:11:30,928 - ERROR - Training failed: 'TradingEnvironment' object has no attribute 'initialize_price_predictor'
2025-03-10 12:11:30,928 - INFO - Exchange connection closed
Backend tkagg is interactive backend. Turning interactive mode on.
remodel our NN architecture. we should support up to 3 pairs simultaniously. so input can be 3 pairs: each pair will have up to 5 timeframes 1s(ticks, unspecified length), 1m, 1h, 1d + one additionall. we should normalize them in a way that preserves the relations between them (one price should be normalized to the same value across all tieframes). additionally to the 5 features OHLCV we will add up to 20 additional features for various technical indcators. 1s timeframe will be streamed in realtime. the MOE model should handle all that. we still need to access latest of the CNN hidden layers in the MOe model so we can extract learned features recognition
.
now let's run our "NN Training Pipeline" debug config. for now we start with single pair - BTC/USD. later we'll add up to 3 pairs for context. the NN will always have only 1 "main" pair - where the buy/sell actions are applied and which price prediction is calculater for each frame. we'll also try to predict the next local extrema that will help us be profitable
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch --epochs 1000
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch --epochs 1000 --symbol BTC/USDT --timeframes 1m 5m 1h 4h --epochs 10 --batch-size 32 --window-size 20 --output-size 3
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch --epochs 10 --symbol BTC/USDT --timeframes 1s 1m 1h 1d --batch-size 32 --window-size 20 --output-size 3
python NN/realtime_main.py --mode train --model-type cnn --epochs 1 --symbol BTC/USDT --timeframes 1s 1m --batch-size 32 --window-size 20 --output-size 3
python NN/realtime-main.py --mode train --model-type cnn --framework pytorch --symbol BTC/USDT --timeframes 1m 5m 1h 4h --epochs 10 --batch-size 32 --window-size 20 --output-size 3
----------
$ python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --epochs 10
python test_model.py
python train_with_realtime_ticks.py
python NN/train_rl.py
python train_rl_with_realtime.py
python train_rl_with_realtime.py --episodes 2 --no-train --visualize-only
python train_hybrid_fixed.py --iterations 1000 --sv-epochs 5 --rl-episodes 3 --symbol ETH/USDT --window 24 --batch-size 64 --new-model

View File

@ -1,46 +0,0 @@
import sys
import os
import time
from datetime import datetime
# Add the project root to the path
sys.path.append(os.getcwd())
# Try to access the running application's charts
from train_rl_with_realtime import charts
if not charts:
print("No charts are running. Please start the application first.")
sys.exit(1)
# Get the first chart
chart = charts[0]
print(f"Accessing chart for {chart.symbol}")
# Add some test trades
print("Adding trades...")
chart.add_trade(price=64500, timestamp=datetime.now(), pnl=0.5, action='BUY')
print("Added BUY trade 1")
time.sleep(1)
chart.add_trade(price=64800, timestamp=datetime.now(), pnl=0.7, action='SELL')
print("Added SELL trade 1")
time.sleep(1)
chart.add_trade(price=64600, timestamp=datetime.now(), pnl=0.3, action='BUY')
print("Added BUY trade 2")
time.sleep(1)
chart.add_trade(price=64900, timestamp=datetime.now(), pnl=0.6, action='SELL')
print("Added SELL trade 2")
# Try to get the current trades
print("\nCurrent trades:")
if hasattr(chart, 'trades') and chart.trades:
for i, trade in enumerate(chart.trades[-5:]):
action = trade.get('action', 'UNKNOWN')
price = trade.get('price', 'N/A')
timestamp = trade.get('timestamp', 'N/A')
pnl = trade.get('pnl', None)
close_price = trade.get('close_price', 'N/A')
print(f"Trade {i+1}: {action} @ {price}, Close: {close_price}, PnL: {pnl}, Time: {timestamp}")
else:
print("No trades found.")

View File

@ -1,53 +0,0 @@
from datetime import datetime, timedelta
import random
import time
from dataprovider_realtime import RealTimeChart
# Create a standalone chart instance
chart = RealTimeChart('BTC/USDT')
# Base price
base_price = 65000.0
# Add 5 pairs of trades (BUY followed by SELL)
for i in range(5):
# Create a buy trade
buy_price = base_price + random.uniform(-200, 200)
buy_time = datetime.now() - timedelta(minutes=5-i) # Older to newer
buy_amount = round(random.uniform(0.05, 0.5), 2)
# Add the BUY trade
chart.add_trade(
price=buy_price,
timestamp=buy_time,
amount=buy_amount,
pnl=None, # Set to None for buys
action='BUY'
)
print(f"Added BUY trade {i+1}: Price={buy_price:.2f}, Amount={buy_amount}, Time={buy_time}")
# Wait a moment
time.sleep(0.5)
# Create a sell trade (typically at a different price)
price_change = random.uniform(-100, 300) # More likely to be positive for profit
sell_price = buy_price + price_change
sell_time = buy_time + timedelta(minutes=random.uniform(0.5, 1.5))
# Calculate PnL
pnl = (sell_price - buy_price) * buy_amount
# Add the SELL trade
chart.add_trade(
price=sell_price,
timestamp=sell_time,
amount=buy_amount, # Same amount as buy
pnl=pnl,
action='SELL'
)
print(f"Added SELL trade {i+1}: Price={sell_price:.2f}, PnL={pnl:.2f}, Time={sell_time}")
# Wait a moment before the next pair
time.sleep(0.5)
print("\nAll trades added successfully!")

View File

@ -1,71 +0,0 @@
def fix_live_trading():
try:
# Read the file content as a single string
with open('main.py', 'r') as f:
content = f.read()
print(f"Read {len(content)} characters from main.py")
# Fix the live_trading function signature
live_trading_pos = content.find('async def live_trading(')
if live_trading_pos != -1:
print(f"Found live_trading function at position {live_trading_pos}")
content = content.replace('async def live_trading(', 'async def live_trading(agent=None, env=None, exchange=None, ')
print("Updated live_trading function signature")
else:
print("WARNING: Could not find live_trading function!")
# Fix the TradingEnvironment initialization
env_init_pos = content.find('env = TradingEnvironment(')
if env_init_pos != -1:
print(f"Found env initialization at position {env_init_pos}")
# Find the closing parenthesis
paren_depth = 0
close_pos = env_init_pos
for i in range(env_init_pos, len(content)):
if content[i] == '(':
paren_depth += 1
elif content[i] == ')':
paren_depth -= 1
if paren_depth == 0:
close_pos = i + 1
break
# Calculate indentation
line_start = content.rfind('\n', 0, env_init_pos) + 1
indent = ' ' * (env_init_pos - line_start)
# Create the new environment initialization code
new_env_init = f'''if env is None:
{indent} env = TradingEnvironment(
{indent} initial_balance=initial_balance,
{indent} leverage=leverage,
{indent} window_size=window_size,
{indent} commission=commission,
{indent} symbol=symbol,
{indent} timeframe=timeframe
{indent} )'''
# Replace the old code with the new code
content = content[:env_init_pos] + new_env_init + content[close_pos:]
print("Updated TradingEnvironment initialization")
else:
print("WARNING: Could not find TradingEnvironment initialization!")
# Write the updated content back to the file
with open('main.py', 'w') as f:
f.write(content)
print(f"Wrote {len(content)} characters back to main.py")
print('Fixed live_trading function and TradingEnvironment initialization')
return True
except Exception as e:
print(f'Error fixing file: {e}')
import traceback
print(traceback.format_exc())
return False
if __name__ == "__main__":
fix_live_trading()

View File

@ -1,11 +0,0 @@
from datetime import datetime
from dataprovider_realtime import RealTimeChart
chart = RealTimeChart('BTC/USDT')
for i in range(3):
chart.add_trade(price=65000+i*100, timestamp=datetime.now(), pnl=0.1*i, action='BUY')
chart.add_trade(price=65100+i*100, timestamp=datetime.now(), pnl=0.2*i, action='SELL')
print(f'Added trade pair {i+1}')
print('All trades added successfully')

View File

@ -1 +0,0 @@
timestamp,action,price,position_size,balance,pnl
1 timestamp action price position_size balance pnl

View File

View File

@ -1,593 +0,0 @@
#!/usr/bin/env python
import asyncio
import logging
import sys
import platform
import argparse
import os
import datetime
import traceback
import numpy as np
import torch
import gc
from functools import partial
from main import initialize_exchange, TradingEnvironment, Agent
from torch.utils.tensorboard import SummaryWriter
# Fix for Windows asyncio issues with aiodns
if platform.system() == 'Windows':
try:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
except Exception as e:
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
# Setup logging function
def setup_logging():
"""Setup logging configuration for the application"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("live_training.log"),
logging.StreamHandler(sys.stdout) # Added stdout handler for immediate feedback
]
)
# Set up logging
setup_logging()
logger = logging.getLogger(__name__)
# Implement a robust save function to handle PyTorch serialization errors
def robust_save(model, path):
"""
Robust model saving with multiple fallback approaches
Args:
model: The Agent model to save
path: Path to save the model
Returns:
bool: True if successful, False otherwise
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Clean up GPU memory before saving
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If backup worked, copy to the actual path
if os.path.exists(backup_path):
import shutil
shutil.copy(backup_path, path)
logger.info(f"Copied backup to {path}")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {e}")
# Attempt 2: Try with pickle protocol 2 (more compatible)
try:
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {e}")
# Attempt 3: Try without optimizer state (which can be large and cause issues)
try:
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Successfully saved to {path} without optimizer state")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {e}")
# Attempt 4: Try with torch.jit.save instead
try:
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
# Save policy network using jit
scripted_policy = torch.jit.script(model.policy_net)
torch.jit.save(scripted_policy, f"{path}.policy.jit")
# Save target network using jit
scripted_target = torch.jit.script(model.target_net)
torch.jit.save(scripted_target, f"{path}.target.jit")
# Save epsilon value separately
with open(f"{path}.epsilon.txt", "w") as f:
f.write(str(model.epsilon))
logger.info(f"Successfully saved model components with jit.save")
return True
except Exception as e:
logger.error(f"All save attempts failed: {e}")
return False
# Implement timeout wrapper for exchange operations
async def with_timeout(coroutine, timeout=30, default=None):
"""
Execute a coroutine with a timeout
Args:
coroutine: The coroutine to execute
timeout: Timeout in seconds
default: Default value to return on timeout
Returns:
The result of the coroutine or default value on timeout
"""
try:
return await asyncio.wait_for(coroutine, timeout=timeout)
except asyncio.TimeoutError:
logger.warning(f"Operation timed out after {timeout} seconds")
return default
except Exception as e:
logger.error(f"Operation failed: {e}")
return default
# Implement fetch_and_update_data function
async def fetch_and_update_data(exchange, env, symbol, timeframe):
"""
Fetch new candle data and update the environment
Args:
exchange: CCXT exchange instance
env: Trading environment instance
symbol: Trading pair symbol
timeframe: Timeframe for the candles
"""
logger.info(f"Fetching new data for {symbol} on {timeframe} timeframe")
try:
# Default to 100 candles if not specified
limit = 1000
# Fetch OHLCV data with timeout
candles = await with_timeout(
exchange.fetch_ohlcv(symbol, timeframe, limit=limit),
timeout=30,
default=[]
)
if not candles or len(candles) == 0:
logger.warning(f"No candles returned for {symbol} on {timeframe}")
return False
logger.info(f"Successfully fetched {len(candles)} candles")
# Convert to format expected by environment
formatted_candles = []
for candle in candles:
timestamp, open_price, high, low, close, volume = candle
formatted_candles.append({
'timestamp': timestamp,
'open': open_price,
'high': high,
'low': low,
'close': close,
'volume': volume
})
# Update environment data
env.data = formatted_candles
if hasattr(env, '_initialize_features'):
env._initialize_features()
logger.info(f"Updated environment with {len(formatted_candles)} candles")
# Print latest candle info
if formatted_candles:
latest = formatted_candles[-1]
dt = datetime.datetime.fromtimestamp(latest['timestamp']/1000).strftime('%Y-%m-%d %H:%M:%S')
logger.info(f"Latest candle: Time={dt}, Open={latest['open']}, High={latest['high']}, Low={latest['low']}, Close={latest['close']}, Volume={latest['volume']}")
return True
except Exception as e:
logger.error(f"Error fetching candle data: {e}")
logger.error(traceback.format_exc())
return False
# Implement memory management function
def manage_memory():
"""
Clean up memory to avoid memory leaks during long running sessions
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.debug("Memory cleaned")
async def live_training(
symbol="ETH/USDT",
timeframe="1m",
model_path="models/trading_agent_best_pnl.pt",
save_path="models/trading_agent_live_trained.pt",
initial_balance=1000,
update_interval=60,
training_iterations=100,
learning_rate=0.0001,
batch_size=64,
gamma=0.99,
window_size=30,
max_episodes=0, # 0 means unlimited
retry_delay=5, # Seconds to wait before retrying after an error
max_retries=3, # Maximum number of retries for operations
):
"""
Live training function that uses real market data to improve the model without executing real trades.
Args:
symbol: Trading pair symbol
timeframe: Timeframe for training
model_path: Path to the initial model to load
save_path: Path to save the improved model
initial_balance: Initial balance for simulation
update_interval: Interval to update data in seconds
training_iterations: Number of training iterations per data update
learning_rate: Learning rate for training
batch_size: Batch size for training
gamma: Discount factor for training
window_size: Window size for the environment
max_episodes: Maximum number of episodes (0 for unlimited)
retry_delay: Seconds to wait before retrying after an error
max_retries: Maximum number of retries for operations
"""
logger.info(f"Starting live training for {symbol} on {timeframe} timeframe")
# Initialize exchange (without sandbox mode)
exchange = None
# Retry loop for exchange initialization
for retry in range(max_retries):
try:
exchange = await initialize_exchange()
logger.info(f"Exchange initialized: {exchange.id}")
break
except Exception as e:
logger.error(f"Error initializing exchange (attempt {retry+1}/{max_retries}): {e}")
if retry < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached. Could not initialize exchange.")
return
try:
# Initialize environment
env = TradingEnvironment(
initial_balance=initial_balance,
window_size=window_size,
symbol=symbol,
timeframe=timeframe,
)
# Fetch initial data (with retries)
logger.info(f"Fetching initial data for {symbol}")
success = False
for retry in range(max_retries):
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
if success:
break
logger.warning(f"Failed to fetch initial data (attempt {retry+1}/{max_retries})")
if retry < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
if not success:
logger.error("Failed to fetch initial data after multiple attempts, exiting")
return
# Initialize agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_size=384)
# Load model if provided
if os.path.exists(model_path):
try:
agent.load(model_path)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.warning(f"Error loading model: {e}")
logger.info("Starting with a new model")
else:
logger.warning(f"Model file {model_path} not found. Starting with a new model.")
# Initialize TensorBoard writer
run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(log_dir=f"runs/live_training_{run_id}")
agent.writer = writer
# Initialize training statistics
total_rewards = 0
episode_count = 0
best_reward = float('-inf')
best_pnl = float('-inf')
# Start live training loop
logger.info(f"Starting live training loop")
step_counter = 0
last_update_time = datetime.datetime.now()
# Track consecutive errors to enable circuit breaker
consecutive_errors = 0
max_consecutive_errors = 5
while True:
# Check if we've reached the maximum number of episodes
if max_episodes > 0 and episode_count >= max_episodes:
logger.info(f"Reached maximum episodes ({max_episodes}), stopping")
break
# Check if it's time to update data
current_time = datetime.datetime.now()
time_diff = (current_time - last_update_time).total_seconds()
if time_diff >= update_interval:
logger.info(f"Updating market data after {time_diff:.1f} seconds")
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
if not success:
logger.warning("Failed to update data, will try again later")
# Wait a bit before trying again
await asyncio.sleep(retry_delay)
continue
last_update_time = current_time
# Clean up memory before running an episode
manage_memory()
# Run training iterations on the updated data
episode_reward = 0
env.reset()
done = False
# Run one simulated episode with the current data
steps_in_episode = 0
max_steps = len(env.data) - env.window_size - 1
logger.info(f"Starting episode {episode_count + 1} with {max_steps} steps")
while not done and steps_in_episode < max_steps:
try:
state = env.get_state()
action = agent.select_action(state, training=True)
try:
next_state, reward, done, info = env.step(action)
except ValueError as e:
logger.error(f"Error during env.step: {e}")
# If we get a ValueError, it might be because step is returning 3 values instead of 4
# Let's try to handle this case
if "too many values to unpack" in str(e):
logger.info("Trying alternative step format")
result = env.step(action)
if len(result) == 3:
next_state, reward, done = result
info = {}
else:
raise
else:
raise
# Save experience in replay memory
agent.memory.push(state, action, reward, next_state, done)
# Move to the next state
state = next_state
episode_reward += reward
step_counter += 1
steps_in_episode += 1
# Log action and results every 50 steps
if steps_in_episode % 50 == 0:
logger.info(f"Step {steps_in_episode}/{max_steps} | Action: {action} | Reward: {reward:.2f} | Balance: ${env.balance:.2f}")
# Train the agent on a batch of experiences
if len(agent.memory) > batch_size:
try:
agent.learn()
# Additional training iterations
if steps_in_episode % 10 == 0 and training_iterations > 1:
for _ in range(training_iterations - 1):
agent.learn()
# Reset consecutive errors counter on successful learning
consecutive_errors = 0
except Exception as e:
logger.error(f"Error during learning: {e}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
if done:
logger.info(f"Episode done after {steps_in_episode} steps")
break
except Exception as e:
logger.error(f"Error during episode step: {e}")
logger.error(traceback.format_exc())
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
# Update training statistics
episode_count += 1
total_rewards += episode_reward
avg_reward = total_rewards / episode_count
# Track metrics
writer.add_scalar('LiveTraining/Reward', episode_reward, episode_count)
writer.add_scalar('LiveTraining/AvgReward', avg_reward, episode_count)
writer.add_scalar('LiveTraining/Balance', env.balance, episode_count)
writer.add_scalar('LiveTraining/PnL', env.total_pnl, episode_count)
# Report progress
logger.info(f"""
Episode: {episode_count}
Reward: {episode_reward:.2f}
Avg Reward: {avg_reward:.2f}
Balance: ${env.balance:.2f}
PnL: ${env.total_pnl:.2f}
Memory Size: {len(agent.memory)}
Total Steps: {step_counter}
""")
# Save the model if it's the best so far (by reward or PnL)
if episode_reward > best_reward:
best_reward = episode_reward
reward_model_path = f"models/trading_agent_best_reward_{run_id}.pt"
if robust_save(agent, reward_model_path):
logger.info(f"New best reward model saved: {episode_reward:.2f} to {reward_model_path}")
else:
logger.error(f"Failed to save best reward model")
if env.total_pnl > best_pnl:
best_pnl = env.total_pnl
pnl_model_path = f"models/trading_agent_best_pnl_{run_id}.pt"
if robust_save(agent, pnl_model_path):
logger.info(f"New best PnL model saved: ${env.total_pnl:.2f} to {pnl_model_path}")
else:
logger.error(f"Failed to save best PnL model")
# Regularly save the model
if episode_count % 5 == 0:
if robust_save(agent, save_path):
logger.info(f"Model checkpoint saved to {save_path}")
else:
logger.error(f"Failed to save checkpoint")
# Update target network periodically
if episode_count % 5 == 0:
try:
agent.update_target_network()
logger.info("Target network updated")
except Exception as e:
logger.error(f"Error updating target network: {e}")
# Sleep to avoid excessive API calls
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.info("Live training cancelled")
except KeyboardInterrupt:
logger.info("Live training stopped by user")
except Exception as e:
logger.error(f"Error in live training: {e}")
logger.error(traceback.format_exc())
finally:
# Save final model
if 'agent' in locals():
if robust_save(agent, save_path):
logger.info(f"Final model saved to {save_path}")
else:
logger.error(f"Failed to save final model")
# Close TensorBoard writer
try:
writer.close()
logger.info("TensorBoard writer closed")
except Exception as e:
logger.error(f"Error closing TensorBoard writer: {e}")
# Close exchange connection
if exchange:
try:
await with_timeout(exchange.close(), timeout=10)
logger.info("Exchange connection closed")
except Exception as e:
logger.error(f"Error closing exchange connection: {e}")
# Final memory cleanup
manage_memory()
logger.info("Live training completed")
async def main():
"""Main function to parse arguments and start live training"""
parser = argparse.ArgumentParser(description='Live Training with Real Market Data')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for training')
parser.add_argument('--model_path', type=str, default='models/trading_agent_best_pnl.pt', help='Path to initial model')
parser.add_argument('--save_path', type=str, default='models/trading_agent_live_trained.pt', help='Path to save improved model')
parser.add_argument('--initial_balance', type=float, default=1000, help='Initial balance for simulation')
parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds')
parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update')
parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)')
parser.add_argument('--retry_delay', type=int, default=5, help='Seconds to wait before retrying after an error')
parser.add_argument('--max_retries', type=int, default=3, help='Maximum number of retries for operations')
args = parser.parse_args()
logger.info(f"Starting live training with {args.symbol} on {args.timeframe} timeframe")
await live_training(
symbol=args.symbol,
timeframe=args.timeframe,
model_path=args.model_path,
save_path=args.save_path,
initial_balance=args.initial_balance,
update_interval=args.update_interval,
training_iterations=args.training_iterations,
max_episodes=args.max_episodes,
retry_delay=args.retry_delay,
max_retries=args.max_retries,
)
# Override Agent's save method with our robust save function
def monkey_patch_agent_save():
"""Replace Agent's save method with our robust save approach"""
original_save = Agent.save
def patched_save(self, path):
return robust_save(self, path)
# Apply the patch
Agent.save = patched_save
logger.info("Monkey patched Agent.save with robust_save")
# Return the original method in case we need to restore it
return original_save
# Call the monkey patch function at the appropriate place
if __name__ == "__main__":
try:
print("Starting live training script")
# Apply the monkey patch before running the main function
original_save = monkey_patch_agent_save()
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Live training stopped by user")
except Exception as e:
logger.error(f"Error in main function: {e}")
logger.error(traceback.format_exc())

4876
main.py

File diff suppressed because it is too large Load Diff

View File

@ -1,240 +0,0 @@
import os
import json
import asyncio
import logging
import datetime
import numpy as np
import pandas as pd
import websockets
from dotenv import load_dotenv
from torch.utils.tensorboard import SummaryWriter
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler("mexc_tick_stream.log"), logging.StreamHandler()]
)
logger = logging.getLogger("mexc_tick_stream")
# Load environment variables
load_dotenv()
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
class MexcTickStreamer:
def __init__(self, symbol="ETH/USDT", update_interval=1.0):
"""
Initialize the MEXC tick data streamer
Args:
symbol: Trading pair symbol (e.g., "ETH/USDT")
update_interval: How often to update the TensorBoard visualization (in seconds)
"""
self.symbol = symbol.replace("/", "").upper() # Convert to MEXC format (e.g., ETHUSDT)
self.update_interval = update_interval
self.uri = "wss://wbs-api.mexc.com/ws"
self.writer = SummaryWriter(f'runs/mexc_ticks_{self.symbol}')
self.trades = []
self.last_update_time = 0
self.running = False
# For visualization
self.price_history = []
self.volume_history = []
self.buy_volume = 0
self.sell_volume = 0
self.step = 0
async def connect(self):
"""Connect to MEXC WebSocket and subscribe to tick data"""
try:
self.websocket = await websockets.connect(self.uri)
logger.info(f"Connected to MEXC WebSocket for {self.symbol}")
# Subscribe to trade stream (using non-protobuf endpoint for simplicity)
subscribe_msg = {
"method": "SUBSCRIPTION",
"params": [f"spot@public.deals.v3.api@{self.symbol}"]
}
await self.websocket.send(json.dumps(subscribe_msg))
logger.info(f"Subscribed to {self.symbol} tick data")
# Start ping task to keep connection alive
asyncio.create_task(self.ping_loop())
return True
except Exception as e:
logger.error(f"Error connecting to MEXC WebSocket: {e}")
return False
async def ping_loop(self):
"""Send ping messages to keep the connection alive"""
while self.running:
try:
await self.websocket.send(json.dumps({"method": "PING"}))
await asyncio.sleep(30) # Send ping every 30 seconds
except Exception as e:
logger.error(f"Error in ping loop: {e}")
break
async def process_message(self, message):
"""Process incoming WebSocket messages"""
try:
# Try to parse as JSON
try:
data = json.loads(message)
# Handle PONG response
if data.get("msg") == "PONG":
return
# Handle subscription confirmation
if data.get("code") == 0:
logger.info(f"Subscription confirmed: {data.get('msg')}")
return
# Handle trade data in the non-protobuf format
if "c" in data and "d" in data and "deals" in data["d"]:
for trade in data["d"]["deals"]:
# Extract trade data
price = float(trade["p"])
quantity = float(trade["v"])
trade_type = 1 if trade["S"] == 1 else 2 # 1 for buy, 2 for sell
timestamp = trade["t"]
# Store trade data
self.trades.append({
"price": price,
"quantity": quantity,
"type": "buy" if trade_type == 1 else "sell",
"timestamp": timestamp
})
# Update volume counters
if trade_type == 1: # Buy
self.buy_volume += quantity
else: # Sell
self.sell_volume += quantity
# Store for visualization
self.price_history.append(price)
self.volume_history.append(quantity)
# Limit history size to prevent memory issues
if len(self.price_history) > 10000:
self.price_history = self.price_history[-5000:]
self.volume_history = self.volume_history[-5000:]
# Update TensorBoard if enough time has passed
current_time = datetime.datetime.now().timestamp()
if current_time - self.last_update_time >= self.update_interval:
await self.update_tensorboard()
self.last_update_time = current_time
except json.JSONDecodeError:
# If it's not valid JSON, it might be binary protobuf data
logger.debug("Received binary data, skipping (protobuf not implemented)")
except Exception as e:
logger.error(f"Error processing message: {e}")
async def update_tensorboard(self):
"""Update TensorBoard visualizations"""
try:
if not self.price_history:
return
# Calculate metrics
current_price = self.price_history[-1]
avg_price = np.mean(self.price_history[-100:]) if len(self.price_history) >= 100 else np.mean(self.price_history)
price_std = np.std(self.price_history[-100:]) if len(self.price_history) >= 100 else np.std(self.price_history)
# Calculate VWAP (Volume Weighted Average Price)
if len(self.price_history) >= 100 and len(self.volume_history) >= 100:
vwap = np.sum(np.array(self.price_history[-100:]) * np.array(self.volume_history[-100:])) / np.sum(self.volume_history[-100:])
else:
vwap = np.sum(np.array(self.price_history) * np.array(self.volume_history)) / np.sum(self.volume_history) if np.sum(self.volume_history) > 0 else current_price
# Calculate buy/sell ratio
total_volume = self.buy_volume + self.sell_volume
buy_ratio = self.buy_volume / total_volume if total_volume > 0 else 0.5
# Log to TensorBoard
self.writer.add_scalar('Price/Current', current_price, self.step)
self.writer.add_scalar('Price/VWAP', vwap, self.step)
self.writer.add_scalar('Price/StdDev', price_std, self.step)
self.writer.add_scalar('Volume/BuyRatio', buy_ratio, self.step)
self.writer.add_scalar('Volume/Total', total_volume, self.step)
# Create a candlestick-like chart for the last 100 ticks
if len(self.price_history) >= 100:
prices = np.array(self.price_history[-100:])
self.writer.add_histogram('Price/Distribution', prices, self.step)
# Create a custom scalars panel
layout = {
"Price": {
"Current vs VWAP": ["Multiline", ["Price/Current", "Price/VWAP"]],
},
"Volume": {
"Buy Ratio": ["Multiline", ["Volume/BuyRatio"]],
}
}
self.writer.add_custom_scalars(layout)
self.step += 1
logger.info(f"Updated TensorBoard: Price={current_price:.2f}, VWAP={vwap:.2f}, Buy Ratio={buy_ratio:.2f}")
except Exception as e:
logger.error(f"Error updating TensorBoard: {e}")
async def run(self):
"""Main loop to receive and process WebSocket messages"""
self.running = True
self.last_update_time = datetime.datetime.now().timestamp()
if not await self.connect():
logger.error("Failed to connect. Exiting.")
return
try:
while self.running:
message = await self.websocket.recv()
await self.process_message(message)
except websockets.exceptions.ConnectionClosed:
logger.warning("WebSocket connection closed")
except Exception as e:
logger.error(f"Error in run loop: {e}")
finally:
self.running = False
await self.cleanup()
async def cleanup(self):
"""Clean up resources"""
try:
if hasattr(self, 'websocket'):
await self.websocket.close()
self.writer.close()
logger.info("Cleaned up resources")
except Exception as e:
logger.error(f"Error during cleanup: {e}")
async def main():
"""Main entry point"""
# Parse command line arguments
import argparse
parser = argparse.ArgumentParser(description='MEXC Tick Data Streamer')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol (e.g., ETH/USDT)')
parser.add_argument('--interval', type=float, default=1.0, help='TensorBoard update interval in seconds')
args = parser.parse_args()
# Create and run the streamer
streamer = MexcTickStreamer(symbol=args.symbol, update_interval=args.interval)
await streamer.run()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Program interrupted by user")
except Exception as e:
logger.error(f"Unhandled exception: {e}")

View File

@ -1,8 +0,0 @@
SBIE2102 File is too large to copy into sandbox - state.vscdb [DefaultBox / 549171200]
SBIE2223 To increase the file size limit for copying files, please double-click on this message line
SBIE2102 File is too large to copy into sandbox - state.vscdb [DefaultBox / 549171200]
SBIE2223 To increase the file size limit for copying files, please double-click on this message line
SBIE2102 File is too large to copy into sandbox - state.vscdb.backup [DefaultBox / 549167104]
SBIE2223 To increase the file size limit for copying files, please double-click on this message line
SBIE2102 File is too large to copy into sandbox - state.vscdb [DefaultBox / 549171200]
SBIE2223 To increase the file size limit for copying files, please double-click on this message line

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,34 +0,0 @@
#!/usr/bin/env python
import asyncio
import logging
from main import live_trading, setup_logging
# Set up logging
setup_logging()
logger = logging.getLogger(__name__)
async def main():
"""Run a simplified demo trading session with mock data"""
logger.info("Starting simplified demo trading session")
# Run live trading in demo mode with simplified parameters
await live_trading(
symbol="ETH/USDT",
timeframe="1m",
model_path="models/trading_agent_best_pnl.pt",
demo=True,
initial_balance=1000,
update_interval=10, # Update every 10 seconds for faster feedback
max_position_size=0.1,
risk_per_trade=0.02,
stop_loss_pct=0.02,
take_profit_pct=0.04,
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Demo trading stopped by user")
except Exception as e:
logger.error(f"Error in demo trading: {e}")

View File

@ -1,29 +0,0 @@
#!/usr/bin/env python
import asyncio
import logging
from main import live_trading, setup_logging
# Set up logging
setup_logging()
logger = logging.getLogger(__name__)
async def main():
"""Run a simplified demo trading session with mock data"""
logger.info("Starting simplified demo trading session")
# Run live trading in demo mode with simplified parameters
await live_trading(
symbol="ETH/USDT",
timeframe="1m",
model_path="models/trading_agent_best_pnl.pt",
demo=True,
initial_balance=1000,
update_interval=10, # Update every 10 seconds for faster feedback
max_position_size=0.1,
risk_per_trade=0.02,
stop_loss_pct=0.02,
take_profit_pct=0.04,
)
if __name__ == "__main__":
asyncio.run(main())

232
run_nn.py
View File

@ -1,232 +0,0 @@
#!/usr/bin/env python3
"""
Neural Network Training Runner Script
This script runs the Neural Network Trading System with the existing conda environment.
It detects which deep learning framework is available (TensorFlow or PyTorch) and
adjusts the implementation accordingly.
"""
import os
import sys
import subprocess
import argparse
import logging
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('nn_runner')
def detect_framework():
"""Detect which deep learning framework is available in the environment"""
try:
import torch
torch_version = torch.__version__
logger.info(f"PyTorch {torch_version} detected")
return "pytorch", torch_version
except ImportError:
logger.warning("PyTorch not found in environment")
try:
import tensorflow as tf
tf_version = tf.__version__
logger.info(f"TensorFlow {tf_version} detected")
return "tensorflow", tf_version
except ImportError:
logger.error("Neither PyTorch nor TensorFlow is available in the environment")
return None, None
def check_dependencies():
"""Check for required dependencies and return if they are met"""
required_packages = ["numpy", "pandas", "matplotlib", "scikit-learn"]
missing_packages = []
for package in required_packages:
try:
__import__(package)
except ImportError:
missing_packages.append(package)
if missing_packages:
logger.warning(f"Missing required packages: {', '.join(missing_packages)}")
return False
return True
def create_run_command(args, framework):
"""Create the command to run the neural network based on the available framework"""
cmd = ["python", "-m", "NN.main"]
# Add mode
cmd.extend(["--mode", args.mode])
# Add symbol
if args.symbol:
cmd.extend(["--symbol", args.symbol])
# Add timeframes
if args.timeframes:
cmd.extend(["--timeframes"] + args.timeframes)
# Add window size
if args.window_size:
cmd.extend(["--window-size", str(args.window_size)])
# Add output size
if args.output_size:
cmd.extend(["--output-size", str(args.output_size)])
# Add batch size
if args.batch_size:
cmd.extend(["--batch-size", str(args.batch_size)])
# Add epochs
if args.epochs:
cmd.extend(["--epochs", str(args.epochs)])
# Add model type
if args.model_type:
cmd.extend(["--model-type", args.model_type])
# Add framework-specific flag
cmd.extend(["--framework", framework])
return cmd
def parse_arguments():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Neural Network Trading System Runner')
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
help='Mode to run (train, predict, realtime)')
parser.add_argument('--symbol', type=str, default='BTC/USDT',
help='Trading pair symbol')
parser.add_argument('--timeframes', type=str, nargs='+', default=['1h', '4h'],
help='Timeframes to use')
parser.add_argument('--window-size', type=int, default=20,
help='Window size for input data')
parser.add_argument('--output-size', type=int, default=3,
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
parser.add_argument('--batch-size', type=int, default=32,
help='Batch size for training')
parser.add_argument('--epochs', type=int, default=100,
help='Number of epochs for training')
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
help='Model type to use')
parser.add_argument('--conda-env', type=str, default='gpt-gpu',
help='Name of conda environment to use')
parser.add_argument('--no-conda', action='store_true',
help='Do not use conda environment activation')
parser.add_argument('--framework', type=str, choices=['tensorflow', 'pytorch'], default='pytorch',
help='Deep learning framework to use (default: pytorch)')
return parser.parse_args()
def main():
# Parse arguments
args = parse_arguments()
# Check if we should run with conda
if not args.no_conda and args.conda_env:
# Create conda activation command
if sys.platform == 'win32':
conda_cmd = f"conda activate {args.conda_env} && "
else:
conda_cmd = f"source activate {args.conda_env} && "
logger.info(f"Running with conda environment: {args.conda_env}")
# Create the run script
script_path = Path("run_nn_in_conda.bat" if sys.platform == 'win32' else "run_nn_in_conda.sh")
with open(script_path, 'w') as f:
if sys.platform == 'win32':
f.write("@echo off\n")
f.write(f"call conda activate {args.conda_env}\n")
f.write(f"python -m NN.main --mode {args.mode} --symbol {args.symbol}")
if args.timeframes:
f.write(f" --timeframes {' '.join(args.timeframes)}")
if args.window_size:
f.write(f" --window-size {args.window_size}")
if args.output_size:
f.write(f" --output-size {args.output_size}")
if args.batch_size:
f.write(f" --batch-size {args.batch_size}")
if args.epochs:
f.write(f" --epochs {args.epochs}")
if args.model_type:
f.write(f" --model-type {args.model_type}")
else:
f.write("#!/bin/bash\n")
f.write(f"source activate {args.conda_env}\n")
f.write(f"python -m NN.main --mode {args.mode} --symbol {args.symbol}")
if args.timeframes:
f.write(f" --timeframes {' '.join(args.timeframes)}")
if args.window_size:
f.write(f" --window-size {args.window_size}")
if args.output_size:
f.write(f" --output-size {args.output_size}")
if args.batch_size:
f.write(f" --batch-size {args.batch_size}")
if args.epochs:
f.write(f" --epochs {args.epochs}")
if args.model_type:
f.write(f" --model-type {args.model_type}")
# Make script executable on Unix
if sys.platform != 'win32':
os.chmod(script_path, 0o755)
# Run the script
logger.info(f"Created script: {script_path}")
logger.info("Run this script to execute the neural network with the conda environment")
if sys.platform == 'win32':
print("\nTo run the neural network, execute the following command:")
print(f" {script_path}")
else:
print("\nTo run the neural network, execute the following command:")
print(f" ./{script_path}")
else:
# Run directly without conda
# First detect available framework
framework, version = detect_framework()
if framework is None:
logger.error("Cannot run Neural Network - no deep learning framework available")
return
# Check dependencies
if not check_dependencies():
logger.error("Missing required dependencies - please install them first")
return
# Create command
cmd = create_run_command(args, framework)
# Run command
logger.info(f"Running command: {' '.join(cmd)}")
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
logger.error(f"Error running neural network: {str(e)}")
except Exception as e:
logger.error(f"Error: {str(e)}")
if __name__ == "__main__":
main()

View File

@ -1,3 +0,0 @@
@echo off
call conda activate gpt-gpu
python -m NN.main --mode train --symbol BTC/USDT --timeframes 1h 4h --window-size 20 --output-size 3 --batch-size 32 --epochs 100 --model-type cnn --framework pytorch

View File

@ -1,58 +0,0 @@
@echo off
echo ============================================================
echo Neural Network Trading System - PyTorch Implementation
echo ============================================================
call conda activate gpt-gpu
REM Install missing dependencies if needed
echo Checking for required packages...
python -c "import sklearn" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing scikit-learn...
call conda install -y scikit-learn
)
REM Parse command-line arguments
set MODE=train
set MODEL_TYPE=cnn
set SYMBOL=BTC/USDT
set EPOCHS=100
:parse
if "%~1"=="" goto endparse
if /i "%~1"=="--mode" (
set MODE=%~2
shift
shift
goto parse
)
if /i "%~1"=="--model" (
set MODEL_TYPE=%~2
shift
shift
goto parse
)
if /i "%~1"=="--symbol" (
set SYMBOL=%~2
shift
shift
goto parse
)
if /i "%~1"=="--epochs" (
set EPOCHS=%~2
shift
shift
goto parse
)
shift
goto parse
:endparse
echo Running Neural Network in %MODE% mode with %MODEL_TYPE% model for %SYMBOL% for %EPOCHS% epochs
python -m NN.main --mode %MODE% --symbol %SYMBOL% --timeframes 1h 4h --window-size 20 --output-size 3 --batch-size 32 --epochs %EPOCHS% --model-type %MODEL_TYPE% --framework pytorch
echo ============================================================
echo Run completed.
echo ============================================================

View File

@ -1,77 +1,181 @@
#!/usr/bin/env python #!/usr/bin/env python3
""" """
Run unit tests for the trading bot. Unified Test Runner for Trading System
This script runs the unit tests defined in tests.py and displays the results. This script provides a unified interface to run all tests in the system:
It can run a single test or all tests. - Essential functionality tests
- Model persistence tests
- Training integration tests
- Indicators and signals tests
- Remaining individual test files
Usage: Usage:
python run_tests.py [test_name] python run_tests.py # Run all tests
python run_tests.py essential # Run essential tests only
If test_name is provided, only that test will be run. python run_tests.py persistence # Run model persistence tests only
Otherwise, all tests will be run. python run_tests.py training # Run training integration tests only
python run_tests.py indicators # Run indicators and signals tests only
Example: python run_tests.py individual # Run individual test files only
python run_tests.py TestPeriodicUpdates
python run_tests.py TestBacktesting
python run_tests.py TestBacktestingLastSevenDays
python run_tests.py TestSingleDayBacktesting
python run_tests.py
""" """
import sys import sys
import unittest import os
import subprocess
import logging import logging
from tests import ( from pathlib import Path
TestPeriodicUpdates,
TestBacktesting,
TestBacktestingLastSevenDays,
TestSingleDayBacktesting
)
if __name__ == "__main__": # Add project root to path
# Configure logging project_root = Path(__file__).parent
logging.basicConfig(level=logging.INFO, sys.path.insert(0, str(project_root))
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()])
# Get the test name from the command line from core.config import setup_logging
test_name = sys.argv[1] if len(sys.argv) > 1 else None
# Run the specified test or all tests logger = logging.getLogger(__name__)
if test_name:
logging.info(f"Running test: {test_name}") def run_test_module(module_path, test_type="all"):
if test_name == "TestPeriodicUpdates": """Run a specific test module"""
suite = unittest.TestLoader().loadTestsFromTestCase(TestPeriodicUpdates) try:
elif test_name == "TestBacktesting": cmd = [sys.executable, str(module_path)]
suite = unittest.TestLoader().loadTestsFromTestCase(TestBacktesting) if test_type != "all":
elif test_name == "TestBacktestingLastSevenDays": cmd.append(test_type)
suite = unittest.TestLoader().loadTestsFromTestCase(TestBacktestingLastSevenDays)
elif test_name == "TestSingleDayBacktesting": logger.info(f"Running: {' '.join(cmd)}")
suite = unittest.TestLoader().loadTestsFromTestCase(TestSingleDayBacktesting) result = subprocess.run(cmd, capture_output=True, text=True, cwd=project_root)
if result.returncode == 0:
logger.info(f"{module_path.name} passed")
if result.stdout:
logger.info(result.stdout)
return True
else: else:
logging.error(f"Unknown test: {test_name}") logger.error(f"{module_path.name} failed")
logging.info("Available tests: TestPeriodicUpdates, TestBacktesting, TestBacktestingLastSevenDays, TestSingleDayBacktesting") if result.stderr:
sys.exit(1) logger.error(result.stderr)
else: if result.stdout:
# Run all tests logger.error(result.stdout)
logging.info("Running all tests") return False
suite = unittest.TestSuite()
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestPeriodicUpdates))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBacktesting))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBacktestingLastSevenDays))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSingleDayBacktesting))
# Run the tests except Exception as e:
runner = unittest.TextTestRunner(verbosity=2) logger.error(f"❌ Error running {module_path}: {e}")
result = runner.run(suite) return False
def run_essential_tests():
"""Run essential functionality tests"""
logger.info("=== Running Essential Tests ===")
return run_test_module(project_root / "tests" / "test_essential.py")
def run_persistence_tests():
"""Run model persistence tests"""
logger.info("=== Running Model Persistence Tests ===")
return run_test_module(project_root / "tests" / "test_model_persistence.py")
def run_training_tests():
"""Run training integration tests"""
logger.info("=== Running Training Integration Tests ===")
return run_test_module(project_root / "tests" / "test_training_integration.py")
def run_indicators_tests():
"""Run indicators and signals tests"""
logger.info("=== Running Indicators and Signals Tests ===")
return run_test_module(project_root / "tests" / "test_indicators_and_signals.py")
def run_individual_tests():
"""Run remaining individual test files"""
logger.info("=== Running Individual Test Files ===")
individual_tests = [
"test_positions.py",
"test_tick_cache.py",
"test_timestamps.py"
]
results = []
for test_file in individual_tests:
test_path = project_root / test_file
if test_path.exists():
logger.info(f"Running {test_file}...")
result = run_test_module(test_path)
results.append(result)
else:
logger.warning(f"Test file not found: {test_file}")
results.append(False)
return all(results)
def run_all_tests():
"""Run all test suites"""
logger.info("🧪 Running All Trading System Tests")
logger.info("=" * 60)
test_suites = [
("Essential Tests", run_essential_tests),
("Model Persistence Tests", run_persistence_tests),
("Training Integration Tests", run_training_tests),
("Indicators and Signals Tests", run_indicators_tests),
("Individual Tests", run_individual_tests),
]
results = []
for suite_name, suite_func in test_suites:
logger.info(f"\n📋 {suite_name}")
logger.info("-" * 40)
try:
result = suite_func()
results.append((suite_name, result))
except Exception as e:
logger.error(f"{suite_name} crashed: {e}")
results.append((suite_name, False))
# Print summary # Print summary
print("\nTest Summary:") logger.info("\n" + "=" * 60)
print(f" Ran {result.testsRun} tests") logger.info("📊 TEST RESULTS SUMMARY")
print(f" Errors: {len(result.errors)}") logger.info("=" * 60)
print(f" Failures: {len(result.failures)}")
print(f" Skipped: {len(result.skipped)}")
# Exit with non-zero status if any tests failed passed = 0
sys.exit(len(result.errors) + len(result.failures)) for suite_name, result in results:
status = "✅ PASS" if result else "❌ FAIL"
logger.info(f"{status}: {suite_name}")
if result:
passed += 1
logger.info(f"\nPassed: {passed}/{len(results)} test suites")
if passed == len(results):
logger.info("🎉 All tests passed! Trading system is working correctly.")
return True
else:
logger.warning(f"⚠️ {len(results) - passed} test suite(s) failed. Please check the issues above.")
return False
def main():
"""Main test runner"""
setup_logging()
# Parse command line arguments
if len(sys.argv) > 1:
test_type = sys.argv[1].lower()
if test_type == "essential":
success = run_essential_tests()
elif test_type == "persistence":
success = run_persistence_tests()
elif test_type == "training":
success = run_training_tests()
elif test_type == "indicators":
success = run_indicators_tests()
elif test_type == "individual":
success = run_individual_tests()
elif test_type in ["help", "-h", "--help"]:
print(__doc__)
return 0
else:
logger.error(f"Unknown test type: {test_type}")
print(__doc__)
return 1
else:
success = run_all_tests()
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,85 +0,0 @@
@echo off
echo ============================================================
echo Neural Network Trading System - Environment Setup
echo ============================================================
call conda activate gpt-gpu
echo Checking and installing required packages...
REM Check for PyTorch
python -c "import torch" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing PyTorch...
call conda install -y pytorch torchvision cpuonly -c pytorch
)
REM Check for NumPy
python -c "import numpy" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing NumPy...
call conda install -y numpy
)
REM Check for Pandas
python -c "import pandas" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing Pandas...
call conda install -y pandas
)
REM Check for Matplotlib
python -c "import matplotlib" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing Matplotlib...
call conda install -y matplotlib
)
REM Check for scikit-learn
python -c "import sklearn" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing scikit-learn...
call conda install -y scikit-learn
)
REM Check for additional dependencies
python -c "import h5py" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing h5py...
call conda install -y h5py
)
python -c "import tqdm" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing tqdm...
call conda install -y tqdm
)
python -c "import yaml" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing PyYAML...
call conda install -y pyyaml
)
python -c "import plotly" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing Plotly...
call conda install -y plotly
)
python -c "import tensorboard" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing TensorBoard...
call conda install -y tensorboard
)
python -c "import ccxt" 2>NUL
if %ERRORLEVEL% NEQ 0 (
echo Installing ccxt...
call pip install ccxt
)
echo ============================================================
echo Environment setup completed.
echo You can now run the Neural Network with: run_pytorch_nn.bat
echo ============================================================

View File

@ -1,118 +0,0 @@
#!/usr/bin/env python
import asyncio
import logging
import sys
import platform
import ccxt.async_support as ccxt
import os
import datetime
# Fix for Windows asyncio issues with aiodns
if platform.system() == 'Windows':
try:
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
except Exception as e:
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
# Setup direct console logging for immediate feedback
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
async def initialize_exchange():
"""Initialize the exchange with API credentials from environment variables"""
exchange_id = 'mexc'
try:
# Get API credentials from environment variables
api_key = os.getenv('MEXC_API_KEY', '')
secret_key = os.getenv('MEXC_SECRET_KEY', '')
# Initialize the exchange
exchange_class = getattr(ccxt, exchange_id)
exchange = exchange_class({
'apiKey': api_key,
'secret': secret_key,
'enableRateLimit': True,
})
logger.info(f"Exchange initialized with standard CCXT: {exchange_id}")
return exchange
except Exception as e:
logger.error(f"Error initializing exchange: {e}")
raise
async def fetch_ohlcv_data(exchange, symbol, timeframe, limit=1000):
"""Fetch OHLCV data from the exchange"""
logger.info(f"Fetching {limit} {timeframe} candles for {symbol} (attempt 1/3)")
try:
candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
if not candles or len(candles) == 0:
logger.warning(f"No candles returned for {symbol} on {timeframe}")
return None
logger.info(f"Successfully fetched {len(candles)} candles")
return candles
except Exception as e:
logger.error(f"Error fetching candle data: {e}")
return None
async def main():
"""Main function to test live data fetching"""
symbol = "ETH/USDT"
timeframe = "1m"
logger.info(f"Starting simplified live training test for {symbol} on {timeframe}")
try:
# Initialize exchange
exchange = await initialize_exchange()
# Fetch data every 10 seconds
for i in range(5):
logger.info(f"Fetch attempt {i+1}/5")
candles = await fetch_ohlcv_data(exchange, symbol, timeframe)
if candles:
# Print the latest candle
latest = candles[-1]
timestamp, open_price, high, low, close, volume = latest
dt = datetime.datetime.fromtimestamp(timestamp/1000).strftime('%Y-%m-%d %H:%M:%S')
logger.info(f"Latest candle: Time={dt}, Open={open_price}, High={high}, Low={low}, Close={close}, Volume={volume}")
# Wait 10 seconds before next fetch
if i < 4: # Don't wait after the last fetch
logger.info("Waiting 10 seconds before next fetch...")
await asyncio.sleep(10)
# Close exchange connection
await exchange.close()
logger.info("Exchange connection closed")
except Exception as e:
logger.error(f"Error in simplified live training test: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
try:
await exchange.close()
except:
pass
logger.info("Test completed")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Test stopped by user")
except Exception as e:
logger.error(f"Error in main function: {e}")
import traceback
logger.error(traceback.format_exc())

View File

@ -1,22 +0,0 @@
@echo off
rem Start the real-time chart application with log monitoring
set LOG_FILE=realtime_%date:~10,4%%date:~4,2%%date:~7,2%_%time:~0,2%%time:~3,2%%time:~6,2%.log
set LOG_FILE=%LOG_FILE: =0%
echo Starting application with log file: %LOG_FILE%
rem Start the application in one window
start "RealTime Trading Chart" cmd /k python train_rl_with_realtime.py --episodes 1 --no-train --visualize-only --log-file %LOG_FILE%
rem Wait for the log file to be created
timeout /t 3 > nul
rem Start log monitoring in another window (tail -f equivalent)
start "Log Monitor" cmd /k python read_logs.py --file %LOG_FILE% --follow
rem Open the dashboard in the browser
timeout /t 5 > nul
start http://localhost:8050/
echo Application started. Check the opened windows for details.

View File

@ -1,153 +0,0 @@
#!/usr/bin/env python3
"""
Test script to verify chart data loading functionality
"""
import logging
import sys
import os
# Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from dataprovider_realtime import RealTimeChart, TickStorage, BinanceHistoricalData
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_binance_data_fetch():
"""Test fetching data from Binance API"""
logger.info("Testing Binance historical data fetch...")
try:
binance_data = BinanceHistoricalData()
# Test fetching 1m data for ETH/USDT
df = binance_data.get_historical_candles("ETH/USDT", 60, 100)
if df is not None and not df.empty:
logger.info(f"✅ Successfully fetched {len(df)} 1m candles")
logger.info(f" Latest price: ${df.iloc[-1]['close']:.2f}")
logger.info(f" Date range: {df.iloc[0]['timestamp']} to {df.iloc[-1]['timestamp']}")
return True
else:
logger.error("❌ Failed to fetch Binance data")
return False
except Exception as e:
logger.error(f"❌ Error fetching Binance data: {str(e)}")
return False
def test_tick_storage():
"""Test TickStorage data loading"""
logger.info("Testing TickStorage data loading...")
try:
# Create tick storage
tick_storage = TickStorage("ETH/USDT", ["1s", "1m", "5m", "1h"])
# Load historical data
success = tick_storage.load_historical_data("ETH/USDT", limit=100)
if success:
logger.info("✅ TickStorage data loading successful")
# Check what we have
for tf in ["1s", "1m", "5m", "1h"]:
candles = tick_storage.get_candles(tf)
logger.info(f" {tf}: {len(candles)} candles")
if candles:
latest = candles[-1]
logger.info(f" Latest {tf}: {latest['timestamp']} - ${latest['close']:.2f}")
return True
else:
logger.error("❌ TickStorage data loading failed")
return False
except Exception as e:
logger.error(f"❌ Error in TickStorage: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return False
def test_chart_initialization():
"""Test RealTimeChart initialization and data loading"""
logger.info("Testing RealTimeChart initialization...")
try:
# Create chart (without app to avoid GUI issues)
chart = RealTimeChart(
app=None,
symbol="ETH/USDT",
standalone=False
)
# Test getting candles
candles_1s = chart.get_candles(1) # 1 second
candles_1m = chart.get_candles(60) # 1 minute
logger.info(f"✅ Chart initialized successfully")
logger.info(f" 1s candles: {len(candles_1s)}")
logger.info(f" 1m candles: {len(candles_1m)}")
if candles_1m:
latest = candles_1m[-1]
logger.info(f" Latest 1m candle: {latest['timestamp']} - ${latest['close']:.2f}")
return len(candles_1s) > 0 or len(candles_1m) > 0
except Exception as e:
logger.error(f"❌ Error in chart initialization: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return False
def main():
"""Run all tests"""
logger.info("🧪 Starting chart data loading tests...")
logger.info("=" * 60)
tests = [
("Binance API fetch", test_binance_data_fetch),
("TickStorage loading", test_tick_storage),
("Chart initialization", test_chart_initialization)
]
results = []
for test_name, test_func in tests:
logger.info(f"\n📋 Running test: {test_name}")
logger.info("-" * 40)
try:
result = test_func()
results.append((test_name, result))
except Exception as e:
logger.error(f"❌ Test {test_name} crashed: {str(e)}")
results.append((test_name, False))
# Print summary
logger.info("\n" + "=" * 60)
logger.info("📊 TEST RESULTS SUMMARY")
logger.info("=" * 60)
passed = 0
for test_name, result in results:
status = "✅ PASS" if result else "❌ FAIL"
logger.info(f"{status}: {test_name}")
if result:
passed += 1
logger.info(f"\nPassed: {passed}/{len(results)} tests")
if passed == len(results):
logger.info("🎉 All tests passed! Chart data loading is working correctly.")
return True
else:
logger.warning(f"⚠️ {len(results) - passed} test(s) failed. Please check the issues above.")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -1,65 +0,0 @@
#!/usr/bin/env python3
"""
Quick CNN Training Test - Real Market Data Only
This script tests CNN training with a small dataset for quick validation.
All training metrics are logged to TensorBoard for real-time monitoring.
"""
import logging
from core.config import setup_logging, get_config
from core.data_provider import DataProvider
from training.cnn_trainer import CNNTrainer
def main():
"""Test CNN training with real market data"""
setup_logging()
print("Setting up CNN training test...")
print("📊 Monitor training: tensorboard --logdir=runs")
# Configure test parameters
config = get_config()
# Test configuration
symbols = ['ETH/USDT']
timeframes = ['1m', '5m', '1h']
num_samples = 500
epochs = 2
batch_size = 16
# Override config for quick test
config._config['timeframes'] = timeframes # Direct config access
trainer = CNNTrainer(config)
trainer.batch_size = batch_size
trainer.epochs = epochs
print("Configuration:")
print(f" Symbols: {symbols}")
print(f" Timeframes: {timeframes}")
print(f" Samples: {num_samples}")
print(f" Epochs: {epochs}")
print(f" Batch size: {batch_size}")
print(" Data source: REAL market data from exchange APIs")
try:
# Train model with TensorBoard logging
results = trainer.train(symbols, save_path='test_models/quick_cnn.pt', num_samples=num_samples)
print(f"\n✅ CNN Training completed!")
print(f" Best accuracy: {results['best_val_accuracy']:.4f}")
print(f" Total epochs: {results['total_epochs']}")
print(f" Training time: {results['training_time']:.2f}s")
print(f" TensorBoard logs: {results['tensorboard_dir']}")
print(f"\n📊 View training progress: tensorboard --logdir=runs")
except Exception as e:
print(f"❌ Training failed: {e}")
import traceback
traceback.print_exc()
finally:
trainer.close_tensorboard()
if __name__ == "__main__":
main()

View File

@ -1,103 +0,0 @@
#!/usr/bin/env python
"""
Simple test for Dash to ensure chart rendering is working correctly
"""
import dash
from dash import html, dcc, Input, Output
import plotly.graph_objects as go
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
# Create some sample data
def generate_sample_data(days=10):
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
dates = pd.date_range(start=start_date, end=end_date, freq='1h')
np.random.seed(42)
prices = np.random.normal(loc=100, scale=5, size=len(dates))
prices = np.cumsum(np.random.normal(loc=0, scale=1, size=len(dates))) + 100
df = pd.DataFrame({
'timestamp': dates,
'open': prices,
'high': prices + np.random.normal(loc=1, scale=0.5, size=len(dates)),
'low': prices - np.random.normal(loc=1, scale=0.5, size=len(dates)),
'close': prices + np.random.normal(loc=0, scale=0.5, size=len(dates)),
'volume': np.abs(np.random.normal(loc=100, scale=50, size=len(dates)))
})
return df
# Create a Dash app
app = dash.Dash(__name__)
# Layout
app.layout = html.Div([
html.H1("Test Chart"),
dcc.Graph(id='test-chart', style={'height': '800px'}),
dcc.Interval(
id='interval-component',
interval=1*1000, # in milliseconds
n_intervals=0
),
html.Div(id='signal-display', children="No Signal")
])
# Callback
@app.callback(
[Output('test-chart', 'figure'),
Output('signal-display', 'children')],
[Input('interval-component', 'n_intervals')]
)
def update_chart(n):
# Generate new data on each update
df = generate_sample_data()
# Create a candlestick chart
fig = go.Figure(data=[go.Candlestick(
x=df['timestamp'],
open=df['open'],
high=df['high'],
low=df['low'],
close=df['close']
)])
# Add some random buy/sell signals
if n % 5 == 0:
signal_point = df.iloc[np.random.randint(0, len(df))]
action = "BUY" if n % 10 == 0 else "SELL"
color = "green" if action == "BUY" else "red"
fig.add_trace(go.Scatter(
x=[signal_point['timestamp']],
y=[signal_point['close']],
mode='markers',
marker=dict(
size=10,
color=color,
symbol='triangle-up' if action == 'BUY' else 'triangle-down'
),
name=action
))
signal_text = f"Current Signal: {action}"
else:
signal_text = "No Signal"
# Update layout
fig.update_layout(
title='Test Price Chart',
yaxis_title='Price',
xaxis_title='Time',
template='plotly_dark',
height=800
)
return fig, signal_text
if __name__ == '__main__':
print("Starting Dash server on http://localhost:8090/")
app.run_server(debug=False, host='localhost', port=8090)

View File

@ -1,82 +0,0 @@
#!/usr/bin/env python3
"""
Test script for enhanced technical indicators
"""
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging
from core.data_provider import DataProvider
def main():
setup_logging()
print("Testing Enhanced Technical Indicators")
print("=" * 50)
# Initialize data provider
dp = DataProvider(['ETH/USDT', 'BTC/USDT'], ['1m', '1h', '4h', '1d'])
# Test with fresh data
print("Fetching fresh data with all indicators...")
df = dp.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100)
if df is not None:
print(f"Data shape: {df.shape}")
print(f"Total columns: {len(df.columns)}")
print("\nAvailable indicators:")
# Categorize indicators
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
indicator_cols = [col for col in df.columns if col not in basic_cols]
print(f" Basic OHLCV: {len(basic_cols)} columns")
print(f" Technical indicators: {len(indicator_cols)} columns")
# Group indicators by type
trend_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['sma', 'ema', 'macd', 'adx', 'psar'])]
momentum_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['rsi', 'stoch', 'williams', 'cci'])]
volatility_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['bb_', 'atr', 'keltner'])]
volume_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['volume', 'obv', 'vpt', 'mfi', 'ad_line', 'vwap'])]
custom_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['trend_strength', 'momentum_composite', 'volatility_regime', 'price_position'])]
print(f"\nIndicator breakdown:")
print(f" Trend: {len(trend_indicators)} - {trend_indicators}")
print(f" Momentum: {len(momentum_indicators)} - {momentum_indicators}")
print(f" Volatility: {len(volatility_indicators)} - {volatility_indicators}")
print(f" Volume: {len(volume_indicators)} - {volume_indicators}")
print(f" Custom: {len(custom_indicators)} - {custom_indicators}")
# Test feature matrix creation
print("\nTesting multi-timeframe feature matrix...")
feature_matrix = dp.get_feature_matrix('ETH/USDT', ['1h', '4h'], window_size=20)
if feature_matrix is not None:
print(f"Feature matrix shape: {feature_matrix.shape}")
print(f" Timeframes: {feature_matrix.shape[0]}")
print(f" Window size: {feature_matrix.shape[1]}")
print(f" Features: {feature_matrix.shape[2]}")
else:
print("Failed to create feature matrix")
# Test multi-symbol feature matrix
print("\nTesting multi-symbol feature matrix...")
multi_symbol_matrix = dp.get_multi_symbol_feature_matrix(['ETH/USDT'], ['1h'], window_size=20)
if multi_symbol_matrix is not None:
print(f"Multi-symbol matrix shape: {multi_symbol_matrix.shape}")
else:
print("Failed to create multi-symbol feature matrix")
print("\n✅ All tests completed successfully!")
else:
print("❌ Failed to fetch data")
if __name__ == "__main__":
main()

View File

@ -1,254 +0,0 @@
#!/usr/bin/env python
"""
Extended training session for CNN model optimized for short-term high-leverage trading
"""
import os
import sys
import logging
import numpy as np
import torch
import time
# Add the project root to path
sys.path.append(os.path.abspath('.'))
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('extended_training')
# Import the optimized model
from NN.models.cnn_model_pytorch import CNNModelPyTorch
from NN.utils.data_interface import DataInterface
def run_extended_training():
"""
Run an extended training session for CNN model with comprehensive performance tracking
"""
# Extended configuration parameters
symbol = "BTC/USDT"
timeframes = ["1m", "5m", "15m"] # Multiple timeframes for better signal quality
window_size = 24 # Larger window size to capture more context
output_size = 3 # BUY/HOLD/SELL
batch_size = 64 # Increased batch size for more stable gradients
epochs = 30 # Extended training session
logger.info(f"Starting extended training session for CNN model with {symbol} data")
logger.info(f"Configuration: timeframes={timeframes}, window_size={window_size}, epochs={epochs}, batch_size={batch_size}")
start_time = time.time()
try:
# Initialize data interface with more data
logger.info("Initializing data interface...")
data_interface = DataInterface(
symbol=symbol,
timeframes=timeframes
)
# Prepare training data with more history
logger.info("Loading extended training data...")
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
refresh=True,
# Increase data size for better training
test_size=0.15, # Smaller test size to have more training data
max_samples=1000 # More samples for training
)
if X_train is None or y_train is None:
logger.error("Failed to load training data")
return
logger.info(f"Training data loaded - X shape: {X_train.shape}, y shape: {y_train.shape}")
logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}")
# Get future prices for longer-term prediction
logger.info("Calculating future price changes...")
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8) # Look further ahead
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
# Initialize model
num_features = data_interface.get_feature_count()
logger.info(f"Initializing model with {num_features} features")
# Use the same window size as the data interface
actual_window_size = X_train.shape[1]
logger.info(f"Actual window size from data: {actual_window_size}")
model = CNNModelPyTorch(
window_size=actual_window_size,
num_features=num_features,
output_size=output_size,
timeframes=timeframes
)
# Track metrics over time
best_val_pnl = -float('inf')
best_win_rate = 0
best_epoch = 0
# Create checkpoint directory
checkpoint_dir = "NN/models/saved/training_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
# Performance tracking
metrics_history = {
"epoch": [],
"train_loss": [],
"val_loss": [],
"train_acc": [],
"val_acc": [],
"train_pnl": [],
"val_pnl": [],
"train_win_rate": [],
"val_win_rate": [],
"signal_distribution": []
}
logger.info("Starting extended training...")
for epoch in range(epochs):
logger.info(f"Epoch {epoch+1}/{epochs}")
epoch_start = time.time()
# Train one epoch
train_action_loss, train_price_loss, train_acc = model.train_epoch(
X_train, y_train, train_future_prices, batch_size
)
# Evaluate
val_action_loss, val_price_loss, val_acc = model.evaluate(
X_val, y_val, val_future_prices
)
logger.info(f"Epoch {epoch+1} results:")
logger.info(f" Train - Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}")
logger.info(f" Valid - Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}")
# Get predictions for PnL calculation
train_action_probs, train_price_preds = model.predict(X_train)
val_action_probs, val_price_preds = model.predict(X_val)
# Convert probabilities to actions
train_preds = np.argmax(train_action_probs, axis=1)
val_preds = np.argmax(val_action_probs, axis=1)
# Track signal distribution
train_buy_count = np.sum(train_preds == 2)
train_sell_count = np.sum(train_preds == 0)
train_hold_count = np.sum(train_preds == 1)
val_buy_count = np.sum(val_preds == 2)
val_sell_count = np.sum(val_preds == 0)
val_hold_count = np.sum(val_preds == 1)
signal_dist = {
"train": {
"BUY": train_buy_count / len(train_preds) if len(train_preds) > 0 else 0,
"SELL": train_sell_count / len(train_preds) if len(train_preds) > 0 else 0,
"HOLD": train_hold_count / len(train_preds) if len(train_preds) > 0 else 0
},
"val": {
"BUY": val_buy_count / len(val_preds) if len(val_preds) > 0 else 0,
"SELL": val_sell_count / len(val_preds) if len(val_preds) > 0 else 0,
"HOLD": val_hold_count / len(val_preds) if len(val_preds) > 0 else 0
}
}
# Calculate PnL and win rates with different position sizes
position_sizes = [0.1, 0.25, 0.5, 1.0, 2.0] # Adding higher leverage
best_position_train_pnl = -float('inf')
best_position_val_pnl = -float('inf')
best_position_train_wr = 0
best_position_val_wr = 0
for position_size in position_sizes:
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
train_preds, train_prices, position_size=position_size
)
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
val_preds, val_prices, position_size=position_size
)
logger.info(f" Position Size: {position_size}")
logger.info(f" Train - PnL: {train_pnl:.4f}, Win Rate: {train_win_rate:.4f}, Trades: {len(train_trades)}")
logger.info(f" Valid - PnL: {val_pnl:.4f}, Win Rate: {val_win_rate:.4f}, Trades: {len(val_trades)}")
# Track best position size for this epoch
if val_pnl > best_position_val_pnl:
best_position_val_pnl = val_pnl
best_position_val_wr = val_win_rate
if train_pnl > best_position_train_pnl:
best_position_train_pnl = train_pnl
best_position_train_wr = train_win_rate
# Track best model overall (using position size 1.0 as reference)
if val_pnl > best_val_pnl and position_size == 1.0:
best_val_pnl = val_pnl
best_win_rate = val_win_rate
best_epoch = epoch + 1
logger.info(f" New best validation PnL: {best_val_pnl:.4f} at epoch {best_epoch}")
# Save the best model
model.save(f"NN/models/saved/optimized_short_term_model_best")
# Track metrics for this epoch
metrics_history["epoch"].append(epoch + 1)
metrics_history["train_loss"].append(train_action_loss)
metrics_history["val_loss"].append(val_action_loss)
metrics_history["train_acc"].append(train_acc)
metrics_history["val_acc"].append(val_acc)
metrics_history["train_pnl"].append(best_position_train_pnl)
metrics_history["val_pnl"].append(best_position_val_pnl)
metrics_history["train_win_rate"].append(best_position_train_wr)
metrics_history["val_win_rate"].append(best_position_val_wr)
metrics_history["signal_distribution"].append(signal_dist)
# Save checkpoint every 5 epochs
if (epoch + 1) % 5 == 0:
model.save(f"{checkpoint_dir}/checkpoint_epoch_{epoch+1}")
# Log trading statistics
logger.info(f" Train - Actions: BUY={train_buy_count}, SELL={train_sell_count}, HOLD={train_hold_count}")
logger.info(f" Valid - Actions: BUY={val_buy_count}, SELL={val_sell_count}, HOLD={val_hold_count}")
# Log epoch timing
epoch_time = time.time() - epoch_start
logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
# Save final model and performance metrics
logger.info("Saving final optimized model...")
model.save("NN/models/saved/optimized_short_term_model_extended")
# Save performance metrics to file
try:
import json
metrics_file = "NN/models/saved/training_metrics.json"
with open(metrics_file, 'w') as f:
json.dump(metrics_history, f, indent=2)
logger.info(f"Training metrics saved to {metrics_file}")
except Exception as e:
logger.error(f"Error saving metrics: {str(e)}")
# Generate performance plots
try:
model.plot_training_history()
except Exception as e:
logger.error(f"Error generating plots: {str(e)}")
# Calculate total training time
total_time = time.time() - start_time
hours, remainder = divmod(total_time, 3600)
minutes, seconds = divmod(remainder, 60)
logger.info(f"Extended training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")
logger.info(f"Best model performance - Epoch: {best_epoch}, PnL: {best_val_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
except Exception as e:
logger.error(f"Error during extended training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
if __name__ == "__main__":
run_extended_training()

View File

@ -1,227 +0,0 @@
#!/usr/bin/env python
import os
import logging
import torch
import argparse
import gc
import traceback
import shutil
from main import Agent, robust_save
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("test_model_save_load.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def create_test_directory():
"""Create a test directory for saving models"""
test_dir = "test_models"
os.makedirs(test_dir, exist_ok=True)
return test_dir
def test_save_load_cycle(state_size=64, action_size=4, hidden_size=384):
"""Test a full cycle of saving and loading models"""
test_dir = create_test_directory()
# Create a test agent
logger.info(f"Creating test agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}")
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
# Define paths for testing
save_path = os.path.join(test_dir, "test_agent.pt")
# Test saving
logger.info(f"Testing save to {save_path}")
save_success = agent.save(save_path)
if save_success:
logger.info(f"Save successful, model size: {os.path.getsize(save_path)} bytes")
else:
logger.error("Save failed!")
return False
# Memory cleanup
del agent
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Test loading
logger.info(f"Testing load from {save_path}")
try:
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
new_agent.load(save_path)
logger.info("Load successful")
# Verify model architecture
logger.info(f"Verifying model architecture")
assert new_agent.state_size == state_size, f"Expected state_size={state_size}, got {new_agent.state_size}"
assert new_agent.action_size == action_size, f"Expected action_size={action_size}, got {new_agent.action_size}"
assert new_agent.hidden_size == hidden_size, f"Expected hidden_size={hidden_size}, got {new_agent.hidden_size}"
logger.info("Model architecture verified correctly")
return True
except Exception as e:
logger.error(f"Error during load or verification: {e}")
logger.error(traceback.format_exc())
return False
def test_robust_save_methods(state_size=64, action_size=4, hidden_size=384):
"""Test all the robust save methods"""
test_dir = create_test_directory()
# Create a test agent
logger.info(f"Creating test agent for robust save testing")
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
# Test each robust save method
methods = [
("regular", os.path.join(test_dir, "regular_save.pt")),
("backup", os.path.join(test_dir, "backup_save.pt")),
("pickle2", os.path.join(test_dir, "pickle2_save.pt")),
("no_optimizer", os.path.join(test_dir, "no_optimizer_save.pt")),
("jit", os.path.join(test_dir, "jit_save.pt"))
]
results = {}
for method_name, save_path in methods:
logger.info(f"Testing {method_name} save method to {save_path}")
try:
if method_name == "regular":
# Use regular save
success = agent.save(save_path)
elif method_name == "backup":
# Use backup method directly
backup_path = f"{save_path}.backup"
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'optimizer': agent.optimizer.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, backup_path)
shutil.copy(backup_path, save_path)
success = os.path.exists(save_path)
elif method_name == "pickle2":
# Use pickle protocol 2
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'optimizer': agent.optimizer.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, save_path, pickle_protocol=2)
success = os.path.exists(save_path)
elif method_name == "no_optimizer":
# Save without optimizer
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, save_path)
success = os.path.exists(save_path)
elif method_name == "jit":
# Use JIT save
try:
scripted_policy = torch.jit.script(agent.policy_net)
torch.jit.save(scripted_policy, f"{save_path}.policy.jit")
scripted_target = torch.jit.script(agent.target_net)
torch.jit.save(scripted_target, f"{save_path}.target.jit")
# Save parameters
with open(f"{save_path}.params.json", "w") as f:
import json
params = {
'epsilon': float(agent.epsilon),
'state_size': int(agent.state_size),
'action_size': int(agent.action_size),
'hidden_size': int(agent.hidden_size)
}
json.dump(params, f)
success = (os.path.exists(f"{save_path}.policy.jit") and
os.path.exists(f"{save_path}.target.jit") and
os.path.exists(f"{save_path}.params.json"))
except Exception as e:
logger.error(f"JIT save failed: {e}")
success = False
if success:
if method_name != "jit":
file_size = os.path.getsize(save_path)
logger.info(f"{method_name} save successful, size: {file_size} bytes")
else:
logger.info(f"{method_name} save successful")
results[method_name] = True
else:
logger.error(f"{method_name} save failed")
results[method_name] = False
except Exception as e:
logger.error(f"Error during {method_name} save: {e}")
logger.error(traceback.format_exc())
results[method_name] = False
# Test loading each saved model
for method_name, save_path in methods:
if not results[method_name]:
logger.info(f"Skipping load test for {method_name} (save failed)")
continue
if method_name == "jit":
logger.info(f"Skipping load test for {method_name} (requires special loading)")
continue
logger.info(f"Testing load from {save_path}")
try:
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
new_agent.load(save_path)
logger.info(f"Load successful for {method_name} save")
except Exception as e:
logger.error(f"Error loading from {method_name} save: {e}")
logger.error(traceback.format_exc())
results[method_name] += " (load failed)"
# Return summary of results
return results
def main():
parser = argparse.ArgumentParser(description='Test model saving and loading')
parser.add_argument('--state_size', type=int, default=64, help='State size for test model')
parser.add_argument('--action_size', type=int, default=4, help='Action size for test model')
parser.add_argument('--hidden_size', type=int, default=384, help='Hidden size for test model')
parser.add_argument('--test_robust', action='store_true', help='Test all robust save methods')
args = parser.parse_args()
logger.info("Starting model save/load test")
if args.test_robust:
results = test_robust_save_methods(args.state_size, args.action_size, args.hidden_size)
logger.info(f"Robust save method results: {results}")
else:
success = test_save_load_cycle(args.state_size, args.action_size, args.hidden_size)
logger.info(f"Save/load cycle {'successful' if success else 'failed'}")
logger.info("Test completed")
if __name__ == "__main__":
main()

View File

@ -1,12 +0,0 @@
2025-03-17 23:32:41,968 - INFO - Testing regular save method...
2025-03-17 23:32:41,970 - INFO - Model saved to test_models\regular_save.pt
2025-03-17 23:32:41,970 - INFO - Regular save succeeded
2025-03-17 23:32:41,971 - INFO - Testing robust save method...
2025-03-17 23:32:41,971 - INFO - Saving model to test_models\robust_save.pt.backup (attempt 1)
2025-03-17 23:32:41,971 - INFO - Successfully saved to test_models\robust_save.pt.backup
2025-03-17 23:32:41,983 - INFO - Copied backup to test_models\robust_save.pt
2025-03-17 23:32:41,983 - INFO - Robust save succeeded!
2025-03-17 23:32:41,983 - INFO - Files created:
2025-03-17 23:32:41,985 - INFO - - regular_save.pt (17794 bytes)
2025-03-17 23:32:41,985 - INFO - - robust_save.pt (17826 bytes)
2025-03-17 23:32:41,985 - INFO - - robust_save.pt.backup (17826 bytes)

View File

@ -1,182 +0,0 @@
#!/usr/bin/env python
import torch
import torch.nn as nn
import os
import logging
import sys
import platform
# Fix for Windows asyncio issues with aiodns
if platform.system() == 'Windows':
try:
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
except Exception as e:
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("test_save.log"),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Define a simple model for testing
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 50)
self.fc2 = nn.Linear(50, 20)
self.fc3 = nn.Linear(20, 5)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
# Create a simple Agent class for testing
class TestAgent:
def __init__(self):
self.policy_net = SimpleModel()
self.target_net = SimpleModel()
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001)
self.epsilon = 0.1
def save(self, path):
"""Standard save method that might fail"""
checkpoint = {
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epsilon': self.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Model saved to {path}")
# Robust save function with multiple fallback approaches
def robust_save(model, path):
"""
Robust model saving with multiple fallback approaches
Args:
model: The Agent model to save
path: Path to save the model
Returns:
bool: True if successful, False otherwise
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If backup worked, copy to the actual path
if os.path.exists(backup_path):
import shutil
shutil.copy(backup_path, path)
logger.info(f"Copied backup to {path}")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {e}")
# Attempt 2: Try with pickle protocol 2 (more compatible)
try:
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {e}")
# Attempt 3: Try without optimizer state (which can be large and cause issues)
try:
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Successfully saved to {path} without optimizer state")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {e}")
# Attempt 4: Try with torch.jit.save instead
try:
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
# Save policy network using jit
scripted_policy = torch.jit.script(model.policy_net)
torch.jit.save(scripted_policy, f"{path}.policy.jit")
# Save target network using jit
scripted_target = torch.jit.script(model.target_net)
torch.jit.save(scripted_target, f"{path}.target.jit")
# Save epsilon value separately
with open(f"{path}.epsilon.txt", "w") as f:
f.write(str(model.epsilon))
logger.info(f"Successfully saved model components with jit.save")
return True
except Exception as e:
logger.error(f"All save attempts failed: {e}")
return False
def main():
# Create a test directory
save_dir = "test_models"
os.makedirs(save_dir, exist_ok=True)
# Create a test agent
agent = TestAgent()
# Test the regular save method (might fail)
try:
logger.info("Testing regular save method...")
save_path = os.path.join(save_dir, "regular_save.pt")
agent.save(save_path)
logger.info("Regular save succeeded")
except Exception as e:
logger.error(f"Regular save failed: {e}")
# Test our robust save method
logger.info("Testing robust save method...")
save_path = os.path.join(save_dir, "robust_save.pt")
success = robust_save(agent, save_path)
if success:
logger.info("Robust save succeeded!")
else:
logger.error("Robust save failed!")
# Check which files were created
logger.info("Files created:")
for file in os.listdir(save_dir):
file_path = os.path.join(save_dir, file)
file_size = os.path.getsize(file_path)
logger.info(f" - {file} ({file_size} bytes)")
if __name__ == "__main__":
main()

View File

@ -1,330 +0,0 @@
#!/usr/bin/env python
"""
Test script for the enhanced signal interpreter
"""
import os
import sys
import logging
import numpy as np
import time
import torch
from datetime import datetime
# Add the project root to path
sys.path.append(os.path.abspath('.'))
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('signal_interpreter_test')
# Import components
from NN.utils.signal_interpreter import SignalInterpreter
from NN.models.cnn_model_pytorch import CNNModelPyTorch
def test_signal_interpreter():
"""Run tests on the signal interpreter"""
logger.info("=== Testing Signal Interpreter for Short-Term High-Leverage Trading ===")
# Initialize signal interpreter with custom settings for testing
config = {
'buy_threshold': 0.6,
'sell_threshold': 0.6,
'hold_threshold': 0.7,
'confidence_multiplier': 1.2,
'trend_filter_enabled': True,
'volume_filter_enabled': True,
'oscillation_filter_enabled': True,
'min_price_movement': 0.001,
'hold_cooldown': 2,
'consecutive_signals_required': 1
}
signal_interpreter = SignalInterpreter(config)
logger.info("Signal interpreter initialized with test configuration")
# === Test 1: Basic Signal Processing ===
logger.info("\n=== Test 1: Basic Signal Processing ===")
# Simulate a series of model predictions with different confidence levels
test_signals = [
{'probs': [0.8, 0.1, 0.1], 'price_pred': -0.005, 'expected': 'SELL'}, # Strong SELL
{'probs': [0.2, 0.1, 0.7], 'price_pred': 0.004, 'expected': 'BUY'}, # Strong BUY
{'probs': [0.3, 0.6, 0.1], 'price_pred': 0.001, 'expected': 'HOLD'}, # Clear HOLD
{'probs': [0.45, 0.1, 0.45], 'price_pred': 0.002, 'expected': 'BUY'}, # Borderline case
{'probs': [0.5, 0.3, 0.2], 'price_pred': -0.001, 'expected': 'SELL'}, # Moderate SELL
{'probs': [0.1, 0.8, 0.1], 'price_pred': 0.0, 'expected': 'HOLD'}, # Strong HOLD
]
for i, test in enumerate(test_signals):
probs = np.array(test['probs'])
price_pred = test['price_pred']
expected = test['expected']
# Interpret signal
signal = signal_interpreter.interpret_signal(probs, price_pred)
# Log results
logger.info(f"Test 1.{i+1}: Probs={probs}, Price={price_pred:.4f}, Expected={expected}, Got={signal['action']}")
logger.info(f" Confidence: {signal['confidence']:.4f}")
# Check if signal matches expected outcome
if signal['action'] == expected:
logger.info(f" ✓ PASS: Signal matches expected outcome")
else:
logger.info(f" ✗ FAIL: Signal does not match expected outcome")
# === Test 2: Trend and Volume Filters ===
logger.info("\n=== Test 2: Trend and Volume Filters ===")
# Reset for next test
signal_interpreter.reset()
# Simulate signals with market data for filtering
test_cases = [
{
'probs': [0.8, 0.1, 0.1], # Strong SELL signal
'price_pred': -0.005,
'market_data': {'trend': 'uptrend', 'volume': {'is_low': False}},
'expected': 'HOLD' # Should be filtered by trend
},
{
'probs': [0.2, 0.1, 0.7], # Strong BUY signal
'price_pred': 0.004,
'market_data': {'trend': 'downtrend', 'volume': {'is_low': False}},
'expected': 'HOLD' # Should be filtered by trend
},
{
'probs': [0.8, 0.1, 0.1], # Strong SELL signal
'price_pred': -0.005,
'market_data': {'trend': 'downtrend', 'volume': {'is_low': True}},
'expected': 'HOLD' # Should be filtered by volume
},
{
'probs': [0.8, 0.1, 0.1], # Strong SELL signal
'price_pred': -0.005,
'market_data': {'trend': 'downtrend', 'volume': {'is_spike': True, 'direction': -1}},
'expected': 'SELL' # Volume spike confirms SELL signal
},
{
'probs': [0.2, 0.1, 0.7], # Strong BUY signal
'price_pred': 0.004,
'market_data': {'trend': 'uptrend', 'volume': {'is_spike': True, 'direction': 1}},
'expected': 'BUY' # Volume spike confirms BUY signal
}
]
for i, test in enumerate(test_cases):
probs = np.array(test['probs'])
price_pred = test['price_pred']
market_data = test['market_data']
expected = test['expected']
# Interpret signal with market data
signal = signal_interpreter.interpret_signal(probs, price_pred, market_data)
# Log results
logger.info(f"Test 2.{i+1}: Probs={probs}, Trend={market_data.get('trend', 'N/A')}, Volume={market_data.get('volume', {})}")
logger.info(f" Expected={expected}, Got={signal['action']}, Confidence={signal['confidence']:.4f}")
# Check if signal matches expected outcome
if signal['action'] == expected:
logger.info(f" ✓ PASS: Signal matches expected outcome")
else:
logger.info(f" ✗ FAIL: Signal does not match expected outcome")
# === Test 3: Oscillation Prevention ===
logger.info("\n=== Test 3: Oscillation Prevention ===")
# Reset for next test
signal_interpreter.reset()
# Create a sequence that would normally oscillate without the filter
oscillating_sequence = [
{'probs': [0.8, 0.1, 0.1], 'expected': 'SELL'}, # Strong SELL
{'probs': [0.2, 0.1, 0.7], 'expected': 'HOLD'}, # Strong BUY but would oscillate
{'probs': [0.8, 0.1, 0.1], 'expected': 'HOLD'}, # Strong SELL but would oscillate
{'probs': [0.2, 0.1, 0.7], 'expected': 'HOLD'}, # Strong BUY but would oscillate
{'probs': [0.1, 0.8, 0.1], 'expected': 'HOLD'}, # Strong HOLD
{'probs': [0.9, 0.0, 0.1], 'expected': 'SELL'}, # Very strong SELL after cooldown
]
# Process sequence
for i, test in enumerate(oscillating_sequence):
probs = np.array(test['probs'])
expected = test['expected']
# Interpret signal
signal = signal_interpreter.interpret_signal(probs)
# Log results
logger.info(f"Test 3.{i+1}: Probs={probs}, Expected={expected}, Got={signal['action']}")
# Check if signal matches expected outcome
if signal['action'] == expected:
logger.info(f" ✓ PASS: Signal matches expected outcome")
else:
logger.info(f" ✗ FAIL: Signal does not match expected outcome")
# === Test 4: Performance Tracking ===
logger.info("\n=== Test 4: Performance Tracking ===")
# Reset for next test
signal_interpreter.reset()
# Simulate a sequence of trades with market price data
initial_price = 50000.0
price_path = [
initial_price,
initial_price * 1.01, # +1% (profit for BUY)
initial_price * 0.99, # -1% (profit for SELL)
initial_price * 1.02, # +2% (profit for BUY)
initial_price * 0.98, # -2% (profit for SELL)
]
# Sequence of signals and corresponding market prices
trade_sequence = [
# BUY signal
{
'probs': [0.2, 0.1, 0.7],
'market_data': {'price': price_path[0]},
'expected_action': 'BUY'
},
# SELL signal to close BUY position with profit
{
'probs': [0.8, 0.1, 0.1],
'market_data': {'price': price_path[1]},
'expected_action': 'SELL'
},
# BUY signal to close SELL position with profit
{
'probs': [0.2, 0.1, 0.7],
'market_data': {'price': price_path[2]},
'expected_action': 'BUY'
},
# SELL signal to close BUY position with profit
{
'probs': [0.8, 0.1, 0.1],
'market_data': {'price': price_path[3]},
'expected_action': 'SELL'
},
# BUY signal to close SELL position with profit
{
'probs': [0.2, 0.1, 0.7],
'market_data': {'price': price_path[4]},
'expected_action': 'BUY'
}
]
# Process the trade sequence
for i, trade in enumerate(trade_sequence):
probs = np.array(trade['probs'])
market_data = trade['market_data']
expected_action = trade['expected_action']
# Introduce a small delay to simulate real-time trading
time.sleep(0.5)
# Interpret signal
signal = signal_interpreter.interpret_signal(probs, None, market_data)
# Log results
logger.info(f"Test 4.{i+1}: Probs={probs}, Price={market_data['price']:.2f}, Action={signal['action']}")
# Get performance stats
stats = signal_interpreter.get_performance_stats()
logger.info("\nFinal Performance Statistics:")
logger.info(f"Total Trades: {stats['total_trades']}")
logger.info(f"Profitable Trades: {stats['profitable_trades']}")
logger.info(f"Unprofitable Trades: {stats['unprofitable_trades']}")
logger.info(f"Win Rate: {stats['win_rate']:.2%}")
logger.info(f"Average Profit per Trade: {stats['avg_profit_per_trade']:.4%}")
# === Test 5: Integration with Model ===
logger.info("\n=== Test 5: Integration with CNN Model ===")
# Reset for next test
signal_interpreter.reset()
# Try to load the optimized model if available
model_loaded = False
try:
model_path = "NN/models/saved/optimized_short_term_model_best.pt"
model_file_exists = os.path.exists(model_path)
if not model_file_exists:
# Try alternate path format
alternate_path = model_path.replace(".pt", ".pt.pt")
model_file_exists = os.path.exists(alternate_path)
if model_file_exists:
model_path = alternate_path
if model_file_exists:
logger.info(f"Loading optimized model from {model_path}")
# Initialize a CNN model
model = CNNModelPyTorch(window_size=20, num_features=5, output_size=3)
model.load(model_path)
model_loaded = True
# Generate some synthetic test data (20 time steps, 5 features)
test_data = np.random.randn(1, 20, 5).astype(np.float32)
# Get model predictions
action_probs, price_pred = model.predict(test_data)
# Check if model returns torch tensors or numpy arrays and ensure correct format
if isinstance(action_probs, torch.Tensor):
action_probs = action_probs.detach().cpu().numpy()[0]
elif isinstance(action_probs, np.ndarray) and action_probs.ndim > 1:
action_probs = action_probs[0]
if isinstance(price_pred, torch.Tensor):
price_pred = price_pred.detach().cpu().numpy()[0][0] if price_pred.ndim > 1 else price_pred.detach().cpu().numpy()[0]
elif isinstance(price_pred, np.ndarray):
price_pred = price_pred[0][0] if price_pred.ndim > 1 else price_pred[0]
# Ensure action_probs has 3 values (SELL, HOLD, BUY)
if len(action_probs) != 3:
# If model output is wrong format, create dummy values for testing
logger.warning(f"Model output has incorrect format. Expected 3 action probabilities, got {len(action_probs)}")
action_probs = np.array([0.3, 0.4, 0.3]) # Dummy values
price_pred = 0.001 # Dummy value
# Process with signal interpreter
market_data = {'price': 50000.0}
signal = signal_interpreter.interpret_signal(action_probs, price_pred, market_data)
logger.info(f"Model predictions - Action Probs: {action_probs}, Price Prediction: {price_pred:.4f}")
logger.info(f"Interpreted Signal: {signal['action']} with confidence {signal['confidence']:.4f}")
else:
logger.warning(f"Model file not found: {model_path}")
# Run with synthetic data for testing
logger.info("Testing with synthetic data instead")
action_probs = np.array([0.3, 0.4, 0.3]) # Dummy values
price_pred = 0.001 # Dummy value
# Process with signal interpreter
market_data = {'price': 50000.0}
signal = signal_interpreter.interpret_signal(action_probs, price_pred, market_data)
logger.info(f"Synthetic predictions - Action Probs: {action_probs}, Price Prediction: {price_pred:.4f}")
logger.info(f"Interpreted Signal: {signal['action']} with confidence {signal['confidence']:.4f}")
model_loaded = True # Consider it loaded for reporting
except Exception as e:
logger.error(f"Error in model integration test: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Summary of all tests
logger.info("\n=== Signal Interpreter Test Summary ===")
logger.info("Basic signal processing: PASS")
logger.info("Trend and volume filters: PASS")
logger.info("Oscillation prevention: PASS")
logger.info("Performance tracking: PASS")
logger.info(f"Model integration: {'PASS' if model_loaded else 'NOT TESTED'}")
logger.info("\nSignal interpreter is ready for use in production environment.")
if __name__ == "__main__":
test_signal_interpreter()

View File

@ -1,82 +0,0 @@
#!/usr/bin/env python3
"""
Quick test script for CNN and RL training pipelines
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from core.config import setup_logging
from core.data_provider import DataProvider
from training.cnn_trainer import CNNTrainer
from training.rl_trainer import RLTrainer
def test_cnn_training():
"""Test CNN training with small dataset"""
print("\n=== Testing CNN Training ===")
# Setup
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m', '1h'])
trainer = CNNTrainer(data_provider)
# Configure for quick test
trainer.num_samples = 1000 # Small dataset
trainer.num_epochs = 5 # Few epochs
trainer.batch_size = 32
# Train
results = trainer.train(['ETH/USDT'], save_path='test_models/test_cnn.pt')
print(f"CNN Training completed!")
print(f" Best accuracy: {results['best_val_accuracy']:.4f}")
print(f" Training time: {results['total_time']:.2f}s")
return True
def test_rl_training():
"""Test RL training with small dataset"""
print("\n=== Testing RL Training ===")
# Setup
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m', '1h'])
trainer = RLTrainer(data_provider)
# Configure for quick test
trainer.num_episodes = 10
trainer.max_steps_per_episode = 100
trainer.evaluation_frequency = 5
# Train
results = trainer.train(save_path='test_models/test_rl.pt')
print(f"RL Training completed!")
print(f" Best reward: {results['best_reward']:.4f}")
print(f" Final balance: ${results['best_balance']:.2f}")
return True
def main():
setup_logging()
try:
# Test CNN
if test_cnn_training():
print("✅ CNN training test passed!")
# Test RL
if test_rl_training():
print("✅ RL training test passed!")
print("\n✅ All tests passed!")
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,19 +0,0 @@
import websockets
import asyncio
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
async def test_websocket():
url = "wss://stream.binance.com:9443/ws/btcusdt@trade"
try:
logger.info(f"Connecting to {url}")
async with websockets.connect(url) as ws:
logger.info("Connected successfully")
message = await ws.recv()
logger.info(f"Received message: {message[:200]}...")
except Exception as e:
logger.error(f"Connection failed: {str(e)}")
asyncio.run(test_websocket())

337
tests.py
View File

@ -1,337 +0,0 @@
"""
Unit tests for the trading bot.
This file contains tests for various components of the trading bot, including:
1. Periodic candle updates
2. Backtesting on historical data
3. Training on the last 7 days of data
"""
import unittest
import asyncio
import os
import sys
import logging
import datetime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()])
# Import functionality from main.py
import main
from main import (
CandleCache, BacktestCandles, initialize_exchange,
TradingEnvironment, Agent, train_with_backtesting,
fetch_multi_timeframe_data, train_agent
)
class TestPeriodicUpdates(unittest.TestCase):
"""Test that candle data is periodically updated during training."""
async def async_test_periodic_updates(self):
"""Test that candle data is periodically updated during training."""
logging.info("Testing periodic candle updates...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create candle cache
candle_cache = CandleCache()
# Initial fetch of candle data
candle_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
self.assertIsNotNone(candle_data, "Failed to fetch initial candle data")
self.assertIn('1m', candle_data, "1m candles not found in initial data")
# Check initial data timestamps
initial_1m_candles = candle_data['1m']
self.assertGreater(len(initial_1m_candles), 0, "No 1m candles found in initial data")
initial_timestamp = initial_1m_candles[-1][0]
# Wait for update interval to pass
logging.info("Waiting for update interval to pass (5 seconds for testing)...")
await asyncio.sleep(5) # Short wait for testing
# Force update by setting last_updated to None
candle_cache.last_updated['1m'] = None
# Fetch updated data
updated_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
self.assertIsNotNone(updated_data, "Failed to fetch updated candle data")
# Check if data was updated
updated_1m_candles = updated_data['1m']
self.assertGreater(len(updated_1m_candles), 0, "No 1m candles found in updated data")
updated_timestamp = updated_1m_candles[-1][0]
# In a live scenario, this check should pass with real-time updates
# For testing, we just ensure data was fetched
logging.info(f"Initial timestamp: {initial_timestamp}, Updated timestamp: {updated_timestamp}")
self.assertIsNotNone(updated_timestamp, "Updated timestamp is None")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Periodic update test completed")
def test_periodic_updates(self):
"""Run the async test."""
asyncio.run(self.async_test_periodic_updates())
class TestBacktesting(unittest.TestCase):
"""Test backtesting on historical data."""
async def async_test_backtesting(self):
"""Test backtesting on a specific time period."""
logging.info("Testing backtesting with historical data...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create a timestamp for 24 hours ago
now = datetime.datetime.now()
yesterday = now - datetime.timedelta(days=1)
since_timestamp = int(yesterday.timestamp() * 1000) # Convert to milliseconds
# Create a backtesting candle cache
backtest_cache = BacktestCandles(since_timestamp=since_timestamp)
backtest_cache.period_name = "1-day-ago"
# Fetch historical data
candle_data = await backtest_cache.fetch_all_timeframes(exchange, "ETH/USDT")
self.assertIsNotNone(candle_data, "Failed to fetch historical candle data")
self.assertIn('1m', candle_data, "1m candles not found in historical data")
# Check historical data timestamps
minute_candles = candle_data['1m']
self.assertGreater(len(minute_candles), 0, "No minute candles found in historical data")
# Check if timestamps are within the requested range
first_timestamp = minute_candles[0][0]
last_timestamp = minute_candles[-1][0]
logging.info(f"Requested since: {since_timestamp}")
logging.info(f"First timestamp in data: {first_timestamp}")
logging.info(f"Last timestamp in data: {last_timestamp}")
# In real tests, this check should compare timestamps precisely
# For this test, we just ensure data was fetched
self.assertLessEqual(first_timestamp, last_timestamp, "First timestamp should be before last timestamp")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Backtesting fetch test completed")
def test_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_backtesting())
class TestBacktestingLastSevenDays(unittest.TestCase):
"""Test backtesting on the last 7 days of data."""
async def async_test_seven_days_backtesting(self):
"""Test backtesting on the last 7 days."""
logging.info("Testing backtesting on the last 7 days...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create environment with small initial balance for testing
env = TradingEnvironment(
initial_balance=100, # Small balance for testing
leverage=10, # Lower leverage for testing
window_size=50, # Smaller window for faster testing
commission=0.0004 # Standard commission
)
# Create agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
# Initialize empty results dataframe
all_results = pd.DataFrame()
# Run backtesting for the last 7 days, one day at a time
now = datetime.datetime.now()
for day_offset in range(1, 8):
# Calculate time period
end_day = now - datetime.timedelta(days=day_offset-1)
start_day = end_day - datetime.timedelta(days=1)
# Convert to milliseconds
since_timestamp = int(start_day.timestamp() * 1000)
until_timestamp = int(end_day.timestamp() * 1000)
# Period name
period_name = f"Day-{day_offset}"
logging.info(f"Testing backtesting for period: {period_name}")
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
# Run backtesting with a small number of episodes for testing
stats = await train_with_backtesting(
agent=agent,
env=env,
symbol="ETH/USDT",
since_timestamp=since_timestamp,
until_timestamp=until_timestamp,
num_episodes=3, # Use a small number for testing
max_steps_per_episode=200, # Use a small number for testing
period_name=period_name
)
# Check if stats were returned
if stats is None:
logging.warning(f"No stats returned for period: {period_name}")
continue
# Create a dataframe from stats
if len(stats['episode_rewards']) > 0:
df = pd.DataFrame({
'Period': [period_name] * len(stats['episode_rewards']),
'Episode': list(range(1, len(stats['episode_rewards']) + 1)),
'Reward': stats['episode_rewards'],
'Balance': stats['balances'],
'PnL': stats['episode_pnls'],
'Fees': stats['fees'],
'Net_PnL': stats['net_pnl_after_fees']
})
# Append to all results
all_results = pd.concat([all_results, df], ignore_index=True)
logging.info(f"Completed backtesting for period: {period_name}")
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
else:
logging.warning(f"No episodes completed for period: {period_name}")
# Save all results
if not all_results.empty:
all_results.to_csv("all_backtest_results.csv", index=False)
logging.info("Saved all backtest results to all_backtest_results.csv")
# Create plot of results
plt.figure(figsize=(12, 8))
# Plot Net PnL by period
all_results.groupby('Period')['Net_PnL'].last().plot(kind='bar')
plt.title('Net PnL by Training Period (Last Episode)')
plt.ylabel('Net PnL ($)')
plt.tight_layout()
plt.savefig("backtest_results.png")
logging.info("Saved backtest results plot to backtest_results.png")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("7-day backtesting test completed")
def test_seven_days_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_seven_days_backtesting())
class TestSingleDayBacktesting(unittest.TestCase):
"""Test backtesting on a single day of historical data."""
async def async_test_single_day_backtesting(self):
"""Test backtesting on a single day."""
logging.info("Testing backtesting on a single day...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create environment with small initial balance for testing
env = TradingEnvironment(
initial_balance=100, # Small balance for testing
leverage=10, # Lower leverage for testing
window_size=50, # Smaller window for faster testing
commission=0.0004 # Standard commission
)
# Create agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
# Calculate time period for 1 day ago
now = datetime.datetime.now()
end_day = now
start_day = end_day - datetime.timedelta(days=1)
# Convert to milliseconds
since_timestamp = int(start_day.timestamp() * 1000)
until_timestamp = int(end_day.timestamp() * 1000)
# Period name
period_name = "Test-Day-1"
logging.info(f"Testing backtesting for period: {period_name}")
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
# Run backtesting with a small number of episodes for testing
stats = await train_with_backtesting(
agent=agent,
env=env,
symbol="ETH/USDT",
since_timestamp=since_timestamp,
until_timestamp=until_timestamp,
num_episodes=2, # Very small number for quick testing
max_steps_per_episode=100, # Very small number for quick testing
period_name=period_name
)
# Check if stats were returned
self.assertIsNotNone(stats, "No stats returned from backtesting")
# Check if episodes were completed
self.assertGreater(len(stats['episode_rewards']), 0, "No episodes completed")
# Log results
logging.info(f"Completed backtesting for period: {period_name}")
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Single day backtesting test completed")
def test_single_day_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_single_day_backtesting())
if __name__ == '__main__':
unittest.main()

115
tests/test_essential.py Normal file
View File

@ -0,0 +1,115 @@
#!/usr/bin/env python3
"""
Essential Test Suite - Core functionality tests
This file contains the most important tests to verify core functionality:
- Data loading and processing
- Basic model operations
- Trading signal generation
- Critical utility functions
"""
import sys
import os
import unittest
import logging
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
logger = logging.getLogger(__name__)
class TestEssentialFunctionality(unittest.TestCase):
"""Essential tests for core trading system functionality"""
def test_imports(self):
"""Test that all critical modules can be imported"""
try:
from core.config import get_config
from core.data_provider import DataProvider
from utils.model_utils import robust_save, robust_load
logger.info("✅ All critical imports successful")
except ImportError as e:
self.fail(f"Critical import failed: {e}")
def test_config_loading(self):
"""Test configuration loading"""
try:
from core.config import get_config
config = get_config()
self.assertIsNotNone(config, "Config should be loaded")
logger.info("✅ Configuration loading successful")
except Exception as e:
self.fail(f"Config loading failed: {e}")
def test_data_provider_initialization(self):
"""Test DataProvider can be initialized"""
try:
from core.data_provider import DataProvider
data_provider = DataProvider(['ETH/USDT'], ['1m'])
self.assertIsNotNone(data_provider, "DataProvider should initialize")
logger.info("✅ DataProvider initialization successful")
except Exception as e:
self.fail(f"DataProvider initialization failed: {e}")
def test_model_utils(self):
"""Test model utility functions"""
try:
from utils.model_utils import get_model_info
import tempfile
# Test with non-existent file
info = get_model_info("non_existent_file.pt")
self.assertFalse(info['exists'], "Should report file doesn't exist")
logger.info("✅ Model utils test successful")
except Exception as e:
self.fail(f"Model utils test failed: {e}")
def test_signal_generation_logic(self):
"""Test basic signal generation logic"""
import numpy as np
# Test signal distribution calculation
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) # SELL, HOLD, BUY
buy_count = np.sum(predictions == 2)
sell_count = np.sum(predictions == 0)
hold_count = np.sum(predictions == 1)
total = len(predictions)
distribution = {
"BUY": buy_count / total,
"SELL": sell_count / total,
"HOLD": hold_count / total
}
# Verify calculations
self.assertAlmostEqual(distribution["BUY"], 0.3, places=1)
self.assertAlmostEqual(distribution["SELL"], 0.3, places=1)
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=1)
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=1)
logger.info("✅ Signal generation logic test successful")
def run_essential_tests():
"""Run essential tests only"""
suite = unittest.TestLoader().loadTestsFromTestCase(TestEssentialFunctionality)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
return result.wasSuccessful()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logger.info("Running essential functionality tests...")
success = run_essential_tests()
if success:
logger.info("✅ All essential tests passed!")
sys.exit(0)
else:
logger.error("❌ Essential tests failed!")
sys.exit(1)

View File

@ -0,0 +1,402 @@
#!/usr/bin/env python3
"""
Comprehensive Indicators and Signals Test Suite
This module consolidates testing functionality for:
- Technical indicators (from test_indicators.py)
- Signal interpretation and processing (from test_signal_interpreter.py)
- Market data analysis
- Trading signal validation
"""
import sys
import os
import unittest
import logging
import numpy as np
import tempfile
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging
from core.data_provider import DataProvider
logger = logging.getLogger(__name__)
class TestTechnicalIndicators(unittest.TestCase):
"""Test suite for technical indicators functionality"""
def setUp(self):
"""Set up test fixtures"""
setup_logging()
self.data_provider = DataProvider(['ETH/USDT'], ['1h'])
def test_indicator_calculation(self):
"""Test that indicators are calculated correctly"""
logger.info("Testing technical indicators calculation...")
try:
# Fetch data with indicators
df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100)
self.assertIsNotNone(df, "Should fetch data successfully")
self.assertGreater(len(df), 0, "Should have data rows")
# Check basic OHLCV columns
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
for col in basic_cols:
self.assertIn(col, df.columns, f"Should have {col} column")
# Check that indicators are calculated
indicator_cols = [col for col in df.columns if col not in basic_cols]
self.assertGreater(len(indicator_cols), 0, "Should have technical indicators")
logger.info(f"✅ Successfully calculated {len(indicator_cols)} indicators")
except Exception as e:
logger.warning(f"Indicator test failed: {e}")
self.skipTest("Data or indicators not available")
def test_indicator_categorization(self):
"""Test categorization of different indicator types"""
logger.info("Testing indicator categorization...")
try:
df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100)
if df is not None:
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
indicator_cols = [col for col in df.columns if col not in basic_cols]
# Categorize indicators
trend_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['sma', 'ema', 'macd', 'adx', 'psar'])]
momentum_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['rsi', 'stoch', 'williams', 'cci'])]
volatility_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['bb_', 'atr', 'keltner'])]
volume_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['volume', 'obv', 'vpt', 'mfi', 'ad_line', 'vwap'])]
# Check we have indicators in each category
total_categorized = len(trend_indicators) + len(momentum_indicators) + len(volatility_indicators) + len(volume_indicators)
logger.info(f"Indicator categories:")
logger.info(f" Trend: {len(trend_indicators)}")
logger.info(f" Momentum: {len(momentum_indicators)}")
logger.info(f" Volatility: {len(volatility_indicators)}")
logger.info(f" Volume: {len(volume_indicators)}")
logger.info(f" Total categorized: {total_categorized}/{len(indicator_cols)}")
self.assertGreater(total_categorized, 0, "Should have categorized indicators")
else:
self.skipTest("Could not fetch data for categorization test")
except Exception as e:
logger.warning(f"Categorization test failed: {e}")
self.skipTest("Indicator categorization not available")
def test_feature_matrix_creation(self):
"""Test multi-timeframe feature matrix creation"""
logger.info("Testing feature matrix creation...")
try:
# Test feature matrix with multiple timeframes
feature_matrix = self.data_provider.get_feature_matrix('ETH/USDT', ['1h'], window_size=20)
if feature_matrix is not None:
self.assertEqual(len(feature_matrix.shape), 3, "Should be 3D matrix")
self.assertGreater(feature_matrix.shape[2], 0, "Should have features")
logger.info(f"✅ Feature matrix shape: {feature_matrix.shape}")
else:
self.skipTest("Could not create feature matrix")
except Exception as e:
logger.warning(f"Feature matrix test failed: {e}")
self.skipTest("Feature matrix creation not available")
class TestSignalProcessing(unittest.TestCase):
"""Test suite for signal interpretation and processing"""
def test_signal_distribution_calculation(self):
"""Test signal distribution calculation"""
logger.info("Testing signal distribution calculation...")
# Mock predictions (SELL=0, HOLD=1, BUY=2)
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0])
buy_count = np.sum(predictions == 2)
sell_count = np.sum(predictions == 0)
hold_count = np.sum(predictions == 1)
total = len(predictions)
distribution = {
"BUY": buy_count / total,
"SELL": sell_count / total,
"HOLD": hold_count / total
}
# Verify calculations
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
logger.info("✅ Signal distribution calculation test passed")
def test_basic_signal_interpretation(self):
"""Test basic signal interpretation logic"""
logger.info("Testing basic signal interpretation...")
# Test cases with different probability distributions
test_cases = [
{
'probs': [0.8, 0.1, 0.1], # Strong SELL
'expected_action': 'SELL',
'expected_confidence': 'high'
},
{
'probs': [0.1, 0.1, 0.8], # Strong BUY
'expected_action': 'BUY',
'expected_confidence': 'high'
},
{
'probs': [0.1, 0.8, 0.1], # Strong HOLD
'expected_action': 'HOLD',
'expected_confidence': 'high'
},
{
'probs': [0.4, 0.3, 0.3], # Uncertain - should prefer SELL (index 0)
'expected_action': 'SELL',
'expected_confidence': 'low'
},
{
'probs': [0.33, 0.33, 0.34], # Very uncertain - slight BUY preference
'expected_action': 'BUY',
'expected_confidence': 'low'
}
]
for i, test_case in enumerate(test_cases):
probs = np.array(test_case['probs'])
expected_action = test_case['expected_action']
# Simple signal interpretation (argmax)
predicted_action_idx = np.argmax(probs)
action_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
predicted_action = action_map[predicted_action_idx]
# Calculate confidence (max probability)
confidence = np.max(probs)
confidence_level = 'high' if confidence > 0.7 else 'medium' if confidence > 0.5 else 'low'
# Verify predictions
self.assertEqual(predicted_action, expected_action,
f"Test case {i+1}: Expected {expected_action}, got {predicted_action}")
logger.info(f"Test case {i+1}: {probs} -> {predicted_action} ({confidence_level} confidence)")
logger.info("✅ Basic signal interpretation test passed")
def test_signal_filtering_logic(self):
"""Test signal filtering and validation logic"""
logger.info("Testing signal filtering logic...")
# Test threshold-based filtering
buy_threshold = 0.6
sell_threshold = 0.6
hold_threshold = 0.7
test_signals = [
{
'probs': [0.8, 0.1, 0.1], # Strong SELL (above threshold)
'should_pass': True,
'expected': 'SELL'
},
{
'probs': [0.5, 0.3, 0.2], # Weak SELL (below threshold)
'should_pass': False,
'expected': 'HOLD'
},
{
'probs': [0.1, 0.2, 0.7], # Strong BUY (above threshold)
'should_pass': True,
'expected': 'BUY'
},
{
'probs': [0.2, 0.8, 0.0], # Strong HOLD (above threshold)
'should_pass': True,
'expected': 'HOLD'
}
]
for i, test in enumerate(test_signals):
probs = np.array(test['probs'])
sell_prob, hold_prob, buy_prob = probs
# Apply threshold filtering
if sell_prob >= sell_threshold:
filtered_action = 'SELL'
passed_filter = True
elif buy_prob >= buy_threshold:
filtered_action = 'BUY'
passed_filter = True
elif hold_prob >= hold_threshold:
filtered_action = 'HOLD'
passed_filter = True
else:
filtered_action = 'HOLD' # Default to HOLD if no threshold met
passed_filter = False
# Verify filtering
expected_pass = test['should_pass']
expected_action = test['expected']
self.assertEqual(passed_filter, expected_pass,
f"Test {i+1}: Filter pass expectation failed")
self.assertEqual(filtered_action, expected_action,
f"Test {i+1}: Expected {expected_action}, got {filtered_action}")
logger.info(f"Test {i+1}: {probs} -> {filtered_action} (passed: {passed_filter})")
logger.info("✅ Signal filtering logic test passed")
def test_signal_sequence_validation(self):
"""Test signal sequence validation and oscillation prevention"""
logger.info("Testing signal sequence validation...")
# Simulate a sequence of signals that might oscillate
signal_sequence = ['BUY', 'SELL', 'BUY', 'SELL', 'HOLD', 'BUY']
# Simple oscillation detection
oscillation_count = 0
for i in range(1, len(signal_sequence)):
if (signal_sequence[i-1] == 'BUY' and signal_sequence[i] == 'SELL') or \
(signal_sequence[i-1] == 'SELL' and signal_sequence[i] == 'BUY'):
oscillation_count += 1
# Count consecutive non-HOLD signals
consecutive_trades = 0
max_consecutive = 0
for signal in signal_sequence:
if signal != 'HOLD':
consecutive_trades += 1
max_consecutive = max(max_consecutive, consecutive_trades)
else:
consecutive_trades = 0
# Verify oscillation detection
self.assertGreater(oscillation_count, 0, "Should detect oscillations in test sequence")
self.assertGreater(max_consecutive, 1, "Should detect consecutive trades")
logger.info(f"Detected {oscillation_count} oscillations and max {max_consecutive} consecutive trades")
logger.info("✅ Signal sequence validation test passed")
class TestMarketDataAnalysis(unittest.TestCase):
"""Test suite for market data analysis functionality"""
def test_price_movement_calculation(self):
"""Test price movement and trend calculation"""
logger.info("Testing price movement calculation...")
# Mock price data
prices = np.array([100.0, 101.0, 102.5, 101.8, 103.2, 102.9, 104.1])
# Calculate price movements
price_changes = np.diff(prices)
percentage_changes = (price_changes / prices[:-1]) * 100
# Calculate simple trend
recent_trend = np.mean(percentage_changes[-3:]) # Last 3 changes
trend_direction = 'uptrend' if recent_trend > 0.1 else 'downtrend' if recent_trend < -0.1 else 'sideways'
# Verify calculations
self.assertEqual(len(price_changes), len(prices) - 1, "Should have n-1 price changes")
self.assertEqual(len(percentage_changes), len(prices) - 1, "Should have n-1 percentage changes")
# Verify trend detection makes sense
self.assertIn(trend_direction, ['uptrend', 'downtrend', 'sideways'], "Should detect valid trend")
logger.info(f"Price sequence: {prices}")
logger.info(f"Recent trend: {trend_direction} ({recent_trend:.2f}%)")
logger.info("✅ Price movement calculation test passed")
def test_volatility_measurement(self):
"""Test volatility measurement"""
logger.info("Testing volatility measurement...")
# Mock price data with different volatility
stable_prices = np.array([100.0, 100.1, 99.9, 100.2, 99.8, 100.0])
volatile_prices = np.array([100.0, 105.0, 95.0, 110.0, 90.0, 115.0])
# Calculate volatility (standard deviation of returns)
def calculate_volatility(prices):
returns = np.diff(prices) / prices[:-1]
return np.std(returns) * 100 # As percentage
stable_vol = calculate_volatility(stable_prices)
volatile_vol = calculate_volatility(volatile_prices)
# Verify volatility measurements
self.assertLess(stable_vol, volatile_vol, "Stable prices should have lower volatility")
self.assertGreater(volatile_vol, 5.0, "Volatile prices should have significant volatility")
logger.info(f"Stable volatility: {stable_vol:.2f}%")
logger.info(f"Volatile volatility: {volatile_vol:.2f}%")
logger.info("✅ Volatility measurement test passed")
def run_indicator_tests():
"""Run indicator tests only"""
suite = unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
return result.wasSuccessful()
def run_signal_tests():
"""Run signal processing tests only"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing),
unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
def run_all_tests():
"""Run all indicator and signal tests"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators),
unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing),
unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
if __name__ == "__main__":
setup_logging()
logger.info("Running indicators and signals test suite...")
if len(sys.argv) > 1:
test_type = sys.argv[1]
if test_type == "indicators":
success = run_indicator_tests()
elif test_type == "signals":
success = run_signal_tests()
else:
success = run_all_tests()
else:
success = run_all_tests()
if success:
logger.info("✅ All indicator and signal tests passed!")
sys.exit(0)
else:
logger.error("❌ Some tests failed!")
sys.exit(1)

View File

@ -0,0 +1,274 @@
#!/usr/bin/env python
"""
Comprehensive test suite for model persistence and training functionality
"""
import os
import sys
import unittest
import tempfile
import logging
import torch
import numpy as np
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from utils.model_utils import robust_save, robust_load, get_model_info, verify_save_load_cycle
# Configure logging for tests
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MockAgent:
"""Mock agent class for testing model persistence"""
def __init__(self, state_size=64, action_size=4, hidden_size=256):
self.state_size = state_size
self.action_size = action_size
self.hidden_size = hidden_size
self.epsilon = 0.1
# Create simple mock networks
self.policy_net = torch.nn.Sequential(
torch.nn.Linear(state_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, action_size)
)
self.target_net = torch.nn.Sequential(
torch.nn.Linear(state_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, action_size)
)
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001)
class TestModelPersistence(unittest.TestCase):
"""Test suite for model saving and loading functionality"""
def setUp(self):
"""Set up test fixtures"""
self.temp_dir = tempfile.mkdtemp()
self.test_agent = MockAgent()
def tearDown(self):
"""Clean up test fixtures"""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_robust_save_basic(self):
"""Test basic robust save functionality"""
save_path = os.path.join(self.temp_dir, "test_model.pt")
success = robust_save(self.test_agent, save_path)
self.assertTrue(success, "Robust save should succeed")
self.assertTrue(os.path.exists(save_path), "Model file should exist")
self.assertGreater(os.path.getsize(save_path), 0, "Model file should not be empty")
def test_robust_save_without_optimizer(self):
"""Test robust save without optimizer state"""
save_path = os.path.join(self.temp_dir, "test_model_no_opt.pt")
success = robust_save(self.test_agent, save_path, include_optimizer=False)
self.assertTrue(success, "Robust save without optimizer should succeed")
# Verify that optimizer state is not included
checkpoint = torch.load(save_path, map_location='cpu')
self.assertNotIn('optimizer', checkpoint, "Optimizer state should not be saved")
self.assertIn('policy_net', checkpoint, "Policy network should be saved")
def test_robust_load_basic(self):
"""Test basic robust load functionality"""
save_path = os.path.join(self.temp_dir, "test_model.pt")
# Save first
success = robust_save(self.test_agent, save_path)
self.assertTrue(success, "Save should succeed")
# Create new agent and load
new_agent = MockAgent()
success = robust_load(new_agent, save_path)
self.assertTrue(success, "Load should succeed")
# Verify epsilon was loaded
self.assertEqual(new_agent.epsilon, self.test_agent.epsilon, "Epsilon should match")
def test_get_model_info(self):
"""Test model info extraction"""
save_path = os.path.join(self.temp_dir, "test_model.pt")
# Test non-existent file
info = get_model_info(save_path)
self.assertFalse(info['exists'], "Non-existent file should return exists=False")
# Save model and test info
robust_save(self.test_agent, save_path)
info = get_model_info(save_path)
self.assertTrue(info['exists'], "Existing file should return exists=True")
self.assertGreater(info['size_bytes'], 0, "File size should be greater than 0")
self.assertTrue(info['has_optimizer'], "Should detect optimizer in checkpoint")
self.assertEqual(info['parameters']['state_size'], self.test_agent.state_size)
self.assertEqual(info['parameters']['action_size'], self.test_agent.action_size)
def test_save_load_cycle_verification(self):
"""Test save/load cycle verification"""
test_path = os.path.join(self.temp_dir, "cycle_test.pt")
success = verify_save_load_cycle(self.test_agent, test_path)
self.assertTrue(success, "Save/load cycle should succeed")
# File should be cleaned up after verification
self.assertFalse(os.path.exists(test_path), "Test file should be cleaned up")
def test_multiple_save_methods(self):
"""Test that different save methods all work"""
methods = ['regular', 'no_optimizer', 'pickle2']
for method in methods:
with self.subTest(method=method):
save_path = os.path.join(self.temp_dir, f"test_{method}.pt")
if method == 'regular':
success = robust_save(self.test_agent, save_path)
elif method == 'no_optimizer':
success = robust_save(self.test_agent, save_path, include_optimizer=False)
elif method == 'pickle2':
# This would be tested by the robust_save fallback mechanism
success = robust_save(self.test_agent, save_path)
self.assertTrue(success, f"{method} save should succeed")
self.assertTrue(os.path.exists(save_path), f"{method} save should create file")
class TestTrainingMetrics(unittest.TestCase):
"""Test suite for training metrics and monitoring functionality"""
def test_signal_distribution_calculation(self):
"""Test signal distribution calculation"""
# Mock predictions
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) # SELL, HOLD, BUY
buy_count = np.sum(predictions == 2)
sell_count = np.sum(predictions == 0)
hold_count = np.sum(predictions == 1)
total = len(predictions)
distribution = {
"BUY": buy_count / total,
"SELL": sell_count / total,
"HOLD": hold_count / total
}
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
def test_metrics_tracking_structure(self):
"""Test metrics history structure for training monitoring"""
metrics_history = {
"epoch": [],
"train_loss": [],
"val_loss": [],
"train_acc": [],
"val_acc": [],
"train_pnl": [],
"val_pnl": [],
"train_win_rate": [],
"val_win_rate": [],
"signal_distribution": []
}
# Simulate adding metrics for one epoch
metrics_history["epoch"].append(1)
metrics_history["train_loss"].append(0.5)
metrics_history["val_loss"].append(0.6)
metrics_history["train_acc"].append(0.7)
metrics_history["val_acc"].append(0.65)
metrics_history["train_pnl"].append(0.1)
metrics_history["val_pnl"].append(0.08)
metrics_history["train_win_rate"].append(0.6)
metrics_history["val_win_rate"].append(0.55)
metrics_history["signal_distribution"].append({"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4})
# Verify structure
self.assertEqual(len(metrics_history["epoch"]), 1)
self.assertEqual(metrics_history["epoch"][0], 1)
self.assertIsInstance(metrics_history["signal_distribution"][0], dict)
self.assertIn("BUY", metrics_history["signal_distribution"][0])
class TestModelArchitecture(unittest.TestCase):
"""Test suite for model architecture verification"""
def test_model_parameter_consistency(self):
"""Test that model parameters are consistent after save/load"""
agent = MockAgent(state_size=32, action_size=3, hidden_size=128)
with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, "consistency_test.pt")
# Save model
robust_save(agent, save_path)
# Load into new model with same architecture
new_agent = MockAgent(state_size=32, action_size=3, hidden_size=128)
robust_load(new_agent, save_path)
# Verify parameters match
self.assertEqual(new_agent.state_size, agent.state_size)
self.assertEqual(new_agent.action_size, agent.action_size)
self.assertEqual(new_agent.hidden_size, agent.hidden_size)
self.assertEqual(new_agent.epsilon, agent.epsilon)
def test_model_forward_pass(self):
"""Test that model can perform forward pass after load"""
agent = MockAgent()
with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, "forward_test.pt")
# Create test input
test_input = torch.randn(1, agent.state_size)
# Get original output
original_output = agent.policy_net(test_input)
# Save and load
robust_save(agent, save_path)
new_agent = MockAgent()
robust_load(new_agent, save_path)
# Test forward pass works
new_output = new_agent.policy_net(test_input)
self.assertEqual(new_output.shape, original_output.shape)
# Outputs should be identical since we loaded the same weights
torch.testing.assert_close(new_output, original_output)
def run_all_tests():
"""Run all test suites"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestModelPersistence),
unittest.TestLoader().loadTestsFromTestCase(TestTrainingMetrics),
unittest.TestLoader().loadTestsFromTestCase(TestModelArchitecture)
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
if __name__ == "__main__":
logger.info("Running comprehensive model persistence and training tests...")
success = run_all_tests()
if success:
logger.info("All tests passed!")
sys.exit(0)
else:
logger.error("Some tests failed!")
sys.exit(1)

View File

@ -0,0 +1,395 @@
#!/usr/bin/env python3
"""
Comprehensive Training Integration Tests
This module consolidates and improves test functionality from multiple test files:
- CNN training tests (from test_cnn_only.py, test_training.py)
- Model testing (from test_model.py)
- Chart data testing (from test_chart_data.py)
- Integration testing between components
"""
import sys
import os
import logging
import time
import unittest
import tempfile
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging, get_config
from core.data_provider import DataProvider
from training.cnn_trainer import CNNTrainer
from training.rl_trainer import RLTrainer
from dataprovider_realtime import RealTimeChart, TickStorage, BinanceHistoricalData
logger = logging.getLogger(__name__)
class TestDataProviders(unittest.TestCase):
"""Test suite for data provider functionality"""
def test_binance_historical_data(self):
"""Test Binance historical data fetching"""
logger.info("Testing Binance historical data fetch...")
try:
binance_data = BinanceHistoricalData()
df = binance_data.get_historical_candles("ETH/USDT", 60, 100)
self.assertIsNotNone(df, "Should fetch data successfully")
self.assertFalse(df.empty, "Data should not be empty")
self.assertGreater(len(df), 0, "Should have candles")
# Verify data structure
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
for col in required_columns:
self.assertIn(col, df.columns, f"Should have {col} column")
logger.info(f"✅ Successfully fetched {len(df)} candles")
return True
except Exception as e:
logger.warning(f"Binance API test failed: {e}")
self.skipTest("Binance API not available")
def test_tick_storage(self):
"""Test TickStorage functionality"""
logger.info("Testing TickStorage data loading...")
try:
tick_storage = TickStorage("ETH/USDT", ["1m", "5m", "1h"])
success = tick_storage.load_historical_data("ETH/USDT", limit=100)
if success:
# Check timeframes
for tf in ["1m", "5m", "1h"]:
candles = tick_storage.get_candles(tf)
logger.info(f" {tf}: {len(candles)} candles")
logger.info("✅ TickStorage working correctly")
return True
else:
self.skipTest("Could not load tick storage data")
except Exception as e:
logger.warning(f"TickStorage test failed: {e}")
self.skipTest("TickStorage not available")
def test_chart_initialization(self):
"""Test RealTimeChart initialization"""
logger.info("Testing RealTimeChart initialization...")
try:
chart = RealTimeChart(app=None, symbol="ETH/USDT", standalone=False)
# Test getting candles
candles_1m = chart.get_candles(60)
self.assertIsInstance(candles_1m, list, "Should return list of candles")
logger.info(f"✅ Chart initialized with {len(candles_1m)} 1m candles")
except Exception as e:
logger.warning(f"Chart initialization failed: {e}")
self.skipTest("Chart initialization not available")
class TestCNNTraining(unittest.TestCase):
"""Test suite for CNN training functionality"""
def setUp(self):
"""Set up test fixtures"""
self.temp_dir = tempfile.mkdtemp()
setup_logging()
def tearDown(self):
"""Clean up test fixtures"""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_cnn_quick_training(self):
"""Test quick CNN training with small dataset"""
logger.info("Testing CNN quick training...")
try:
config = get_config()
# Test configuration
symbols = ['ETH/USDT']
timeframes = ['1m', '5m']
num_samples = 100 # Very small for testing
epochs = 1
batch_size = 16
# Override config for quick test
config._config['timeframes'] = timeframes
trainer = CNNTrainer(config)
trainer.batch_size = batch_size
trainer.epochs = epochs
# Train model
save_path = os.path.join(self.temp_dir, 'test_cnn.pt')
results = trainer.train(symbols, save_path=save_path, num_samples=num_samples)
# Verify results
self.assertIsInstance(results, dict, "Should return results dict")
self.assertIn('best_val_accuracy', results, "Should have accuracy metric")
self.assertIn('total_epochs', results, "Should have epoch count")
self.assertIn('training_time', results, "Should have training time")
# Verify model was saved
self.assertTrue(os.path.exists(save_path), "Model should be saved")
logger.info(f"✅ CNN training completed successfully")
logger.info(f" Best accuracy: {results['best_val_accuracy']:.4f}")
logger.info(f" Training time: {results['training_time']:.2f}s")
except Exception as e:
logger.error(f"CNN training test failed: {e}")
raise
finally:
if hasattr(trainer, 'close_tensorboard'):
trainer.close_tensorboard()
class TestRLTraining(unittest.TestCase):
"""Test suite for RL training functionality"""
def setUp(self):
"""Set up test fixtures"""
self.temp_dir = tempfile.mkdtemp()
setup_logging()
def tearDown(self):
"""Clean up test fixtures"""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_rl_quick_training(self):
"""Test quick RL training with small dataset"""
logger.info("Testing RL quick training...")
try:
# Setup minimal configuration
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m'])
trainer = RLTrainer(data_provider)
# Configure for very quick test
trainer.num_episodes = 5
trainer.max_steps_per_episode = 50
trainer.evaluation_frequency = 3
trainer.save_frequency = 10 # Don't save during test
# Train
save_path = os.path.join(self.temp_dir, 'test_rl.pt')
results = trainer.train(save_path=save_path)
# Verify results
self.assertIsInstance(results, dict, "Should return results dict")
self.assertIn('total_episodes', results, "Should have episode count")
self.assertIn('best_reward', results, "Should have best reward")
self.assertIn('final_evaluation', results, "Should have final evaluation")
logger.info(f"✅ RL training completed successfully")
logger.info(f" Total episodes: {results['total_episodes']}")
logger.info(f" Best reward: {results['best_reward']:.4f}")
except Exception as e:
logger.error(f"RL training test failed: {e}")
raise
class TestExtendedTraining(unittest.TestCase):
"""Test suite for extended training functionality (from test_model.py)"""
def test_metrics_tracking(self):
"""Test comprehensive metrics tracking functionality"""
logger.info("Testing extended metrics tracking...")
# Test metrics history structure
metrics_history = {
"epoch": [],
"train_loss": [],
"val_loss": [],
"train_acc": [],
"val_acc": [],
"train_pnl": [],
"val_pnl": [],
"train_win_rate": [],
"val_win_rate": [],
"signal_distribution": []
}
# Simulate adding metrics
for epoch in range(3):
metrics_history["epoch"].append(epoch + 1)
metrics_history["train_loss"].append(0.5 - epoch * 0.1)
metrics_history["val_loss"].append(0.6 - epoch * 0.1)
metrics_history["train_acc"].append(0.6 + epoch * 0.05)
metrics_history["val_acc"].append(0.55 + epoch * 0.05)
metrics_history["train_pnl"].append(epoch * 0.1)
metrics_history["val_pnl"].append(epoch * 0.08)
metrics_history["train_win_rate"].append(0.5 + epoch * 0.1)
metrics_history["val_win_rate"].append(0.45 + epoch * 0.1)
metrics_history["signal_distribution"].append({
"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4
})
# Verify structure
self.assertEqual(len(metrics_history["epoch"]), 3)
self.assertEqual(len(metrics_history["train_loss"]), 3)
self.assertEqual(len(metrics_history["signal_distribution"]), 3)
# Verify improvement
self.assertLess(metrics_history["train_loss"][-1], metrics_history["train_loss"][0])
self.assertGreater(metrics_history["train_acc"][-1], metrics_history["train_acc"][0])
logger.info("✅ Metrics tracking test passed")
def test_signal_distribution_calculation(self):
"""Test signal distribution calculation"""
import numpy as np
# Mock predictions (SELL=0, HOLD=1, BUY=2)
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0])
buy_count = np.sum(predictions == 2)
sell_count = np.sum(predictions == 0)
hold_count = np.sum(predictions == 1)
total = len(predictions)
distribution = {
"BUY": buy_count / total,
"SELL": sell_count / total,
"HOLD": hold_count / total
}
# Verify calculations
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
logger.info("✅ Signal distribution calculation test passed")
class TestIntegration(unittest.TestCase):
"""Integration tests between components"""
def test_training_pipeline_integration(self):
"""Test that CNN and RL training can work together"""
logger.info("Testing training pipeline integration...")
with tempfile.TemporaryDirectory() as temp_dir:
try:
# Quick CNN training
config = get_config()
config._config['timeframes'] = ['1m']
cnn_trainer = CNNTrainer(config)
cnn_trainer.epochs = 1
cnn_trainer.batch_size = 8
cnn_path = os.path.join(temp_dir, 'test_cnn.pt')
cnn_results = cnn_trainer.train(['ETH/USDT'], save_path=cnn_path, num_samples=50)
# Quick RL training
data_provider = DataProvider(['ETH/USDT'], ['1m'])
rl_trainer = RLTrainer(data_provider)
rl_trainer.num_episodes = 3
rl_trainer.max_steps_per_episode = 25
rl_path = os.path.join(temp_dir, 'test_rl.pt')
rl_results = rl_trainer.train(save_path=rl_path)
# Verify both trained successfully
self.assertIsInstance(cnn_results, dict)
self.assertIsInstance(rl_results, dict)
self.assertTrue(os.path.exists(cnn_path))
self.assertTrue(os.path.exists(rl_path))
logger.info("✅ Training pipeline integration test passed")
except Exception as e:
logger.error(f"Integration test failed: {e}")
raise
finally:
if 'cnn_trainer' in locals():
cnn_trainer.close_tensorboard()
def run_quick_tests():
"""Run only the quickest tests for fast validation"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
def run_data_tests():
"""Run data provider tests"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
def run_training_tests():
"""Run training tests (slower)"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
def run_all_tests():
"""Run all test suites"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
unittest.TestLoader().loadTestsFromTestCase(TestIntegration),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
if __name__ == "__main__":
setup_logging()
logger.info("Running comprehensive training integration tests...")
if len(sys.argv) > 1:
test_type = sys.argv[1]
if test_type == "quick":
success = run_quick_tests()
elif test_type == "data":
success = run_data_tests()
elif test_type == "training":
success = run_training_tests()
else:
success = run_all_tests()
else:
success = run_all_tests()
if success:
logger.info("✅ All tests passed!")
sys.exit(0)
else:
logger.error("❌ Some tests failed!")
sys.exit(1)

View File

@ -1,128 +0,0 @@
import asyncio
import json
import logging
import unittest
from typing import Optional, Dict
import websockets
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class TestMEXCWebSocket(unittest.TestCase):
async def test_websocket_connection(self):
"""Test basic WebSocket connection and subscription"""
uri = "wss://stream.mexc.com/ws"
symbol = "ethusdt"
async with websockets.connect(uri) as ws:
# Test subscription to deals
sub_msg = {
"op": "sub",
"id": "test1",
"topic": f"spot.deals.{symbol}"
}
# Send subscription
await ws.send(json.dumps(sub_msg))
# Wait for subscription confirmation and first message
messages_received = 0
trades_received = 0
while messages_received < 5: # Get at least 5 messages
try:
response = await asyncio.wait_for(ws.recv(), timeout=10)
messages_received += 1
logger.info(f"Received message: {response[:200]}...")
data = json.loads(response)
# Check message structure
if isinstance(data, dict):
if 'channel' in data:
if data['channel'] == 'spot.deals':
trades = data.get('data', [])
if trades:
trades_received += 1
logger.info(f"Received trade data: {trades[0]}")
# Verify trade data structure
trade = trades[0]
self.assertIn('t', trade) # timestamp
self.assertIn('p', trade) # price
self.assertIn('v', trade) # volume
self.assertIn('S', trade) # side
except asyncio.TimeoutError:
self.fail("Timeout waiting for WebSocket messages")
# Verify we received some trades
self.assertGreater(trades_received, 0, "No trades received")
# Test unsubscribe
unsub_msg = {
"op": "unsub",
"id": "test1",
"topic": f"spot.deals.{symbol}"
}
await ws.send(json.dumps(unsub_msg))
async def test_kline_subscription(self):
"""Test subscription to kline (candlestick) data"""
uri = "wss://stream.mexc.com/ws"
symbol = "ethusdt"
async with websockets.connect(uri) as ws:
# Subscribe to 1m klines
sub_msg = {
"op": "sub",
"id": "test2",
"topic": f"spot.klines.{symbol}_1m"
}
await ws.send(json.dumps(sub_msg))
messages_received = 0
klines_received = 0
while messages_received < 5:
try:
response = await asyncio.wait_for(ws.recv(), timeout=10)
messages_received += 1
logger.info(f"Received kline message: {response[:200]}...")
data = json.loads(response)
if isinstance(data, dict):
if data.get('channel') == 'spot.klines':
kline_data = data.get('data', [])
if kline_data:
klines_received += 1
logger.info(f"Received kline data: {kline_data[0]}")
# Verify kline data structure (should be an array)
kline = kline_data[0]
self.assertEqual(len(kline), 6) # Should have 6 elements
except asyncio.TimeoutError:
self.fail("Timeout waiting for kline data")
self.assertGreater(klines_received, 0, "No kline data received")
def run_tests():
"""Run the WebSocket tests"""
async def run_async_tests():
# Create test suite
suite = unittest.TestSuite()
suite.addTest(TestMEXCWebSocket('test_websocket_connection'))
suite.addTest(TestMEXCWebSocket('test_kline_subscription'))
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)
# Run tests in asyncio loop
asyncio.run(run_async_tests())
if __name__ == "__main__":
run_tests()

File diff suppressed because it is too large Load Diff

View File

@ -1,415 +0,0 @@
#!/usr/bin/env python
"""
Extended overnight training session for CNN model with real-time data updates
This script runs continuous model training, refreshing market data at regular intervals
"""
import os
import sys
import logging
import numpy as np
import torch
import time
import json
from datetime import datetime, timedelta
import signal
import threading
# Add the project root to path
sys.path.append(os.path.abspath('.'))
# Configure logging with timestamp in filename
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"realtime_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
logger = logging.getLogger('realtime_training')
# Import the model and data interfaces
from NN.models.cnn_model_pytorch import CNNModelPyTorch
from dataprovider_realtime import MultiTimeframeDataInterface
from NN.utils.signal_interpreter import SignalInterpreter
# Global variables for graceful shutdown
running = True
training_stats = {
"epochs_completed": 0,
"best_val_pnl": -float('inf'),
"best_epoch": 0,
"best_win_rate": 0,
"training_started": datetime.now().isoformat(),
"last_update": datetime.now().isoformat(),
"epochs": []
}
def signal_handler(sig, frame):
"""Handle CTRL+C to gracefully exit training"""
global running
logger.info("Received interrupt signal. Finishing current epoch and saving model...")
running = False
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
def save_training_stats(stats, filepath="NN/models/saved/realtime_training_stats.json"):
"""Save training statistics to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, 'w') as f:
json.dump(stats, f, indent=2)
logger.info(f"Training statistics saved to {filepath}")
def run_overnight_training():
"""
Run a continuous training session with real-time data updates
"""
global running, training_stats
# Configuration parameters
symbol = "BTC/USDT"
timeframes = ["1m", "5m", "15m"] # Multiple timeframes for better signal quality
window_size = 24 # Larger window size for capturing more patterns
output_size = 3 # BUY/HOLD/SELL
batch_size = 64 # Batch size for training
# Real-time configuration
data_refresh_interval = 300 # Refresh data every 5 minutes
checkpoint_interval = 3600 # Save checkpoint every hour
max_training_time = 12 * 3600 # 12 hours max runtime
# Initialize training start time
start_time = time.time()
last_checkpoint_time = start_time
last_data_refresh_time = start_time
logger.info(f"Starting overnight training session for CNN model with {symbol} real-time data")
logger.info(f"Configuration: timeframes={timeframes}, window_size={window_size}, batch_size={batch_size}")
logger.info(f"Data will refresh every {data_refresh_interval} seconds")
logger.info(f"Checkpoints will be saved every {checkpoint_interval} seconds")
logger.info(f"Maximum training time: {max_training_time/3600} hours")
try:
# Initialize data interface
logger.info("Initializing MultiTimeframeDataInterface...")
data_interface = MultiTimeframeDataInterface(
symbol=symbol,
timeframes=timeframes
)
# Prepare initial training data
logger.info("Loading initial training data...")
X_train_dict, y_train, X_val_dict, y_val, train_prices, val_prices = data_interface.prepare_training_data(
window_size=window_size,
refresh=True
)
if X_train_dict is None or y_train is None:
logger.error("Failed to load training data")
return
# Get reference timeframe (lowest timeframe)
reference_tf = min(timeframes, key=lambda x: data_interface.timeframe_to_seconds.get(x, 3600))
logger.info(f"Using {reference_tf} as reference timeframe")
# Log data shape information
for tf, X in X_train_dict.items():
logger.info(f"Training data for {tf} - X shape: {X.shape}")
logger.info(f"Target labels shape: {y_train.shape}")
logger.info(f"Validation data for {reference_tf} - X shape: {X_val_dict[reference_tf].shape}, y shape: {y_val.shape}")
# Target distribution analysis
target_distribution = {
"SELL": np.sum(y_train == 0),
"HOLD": np.sum(y_train == 1),
"BUY": np.sum(y_train == 2)
}
logger.info(f"Target distribution: SELL: {target_distribution['SELL']} ({target_distribution['SELL']/len(y_train):.2%}), "
f"HOLD: {target_distribution['HOLD']} ({target_distribution['HOLD']/len(y_train):.2%}), "
f"BUY: {target_distribution['BUY']} ({target_distribution['BUY']/len(y_train):.2%})")
# Calculate future prices for profitability-focused loss function
logger.info("Calculating future price changes...")
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8)
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
# Initialize model
num_features = X_train_dict[reference_tf].shape[2] # Get feature count from the data
logger.info(f"Initializing model with {num_features} features")
# Use the same window size as the data
actual_window_size = X_train_dict[reference_tf].shape[1]
logger.info(f"Actual window size from data: {actual_window_size}")
# Try to load existing model if available
model_path = "NN/models/saved/optimized_short_term_model_best.pt"
model = CNNModelPyTorch(
window_size=actual_window_size,
num_features=num_features,
output_size=output_size,
timeframes=timeframes
)
# Try to load existing model for continued training
try:
if os.path.exists(model_path):
logger.info(f"Loading existing model from {model_path}")
model.load(model_path)
logger.info("Model loaded successfully")
else:
logger.info("No existing model found. Starting with a new model.")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
logger.info("Starting with a new model.")
# Initialize signal interpreter for testing predictions
signal_interpreter = SignalInterpreter(config={
'buy_threshold': 0.65,
'sell_threshold': 0.65,
'hold_threshold': 0.75,
'trend_filter_enabled': True,
'volume_filter_enabled': True
})
# Create checkpoint directory
checkpoint_dir = "NN/models/saved/realtime_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
# Track metrics
epoch = 0
best_val_pnl = -float('inf')
best_win_rate = 0
best_epoch = 0
# Training loop
while running and (time.time() - start_time < max_training_time):
epoch += 1
epoch_start = time.time()
logger.info(f"Epoch {epoch} - Starting at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# Check if we need to refresh data
if time.time() - last_data_refresh_time > data_refresh_interval:
logger.info("Refreshing training data...")
X_train_dict, y_train, X_val_dict, y_val, train_prices, val_prices = data_interface.prepare_training_data(
window_size=window_size,
refresh=True
)
if X_train_dict is None or y_train is None:
logger.warning("Failed to refresh training data. Using previous data.")
else:
logger.info(f"Refreshed training data for {reference_tf} - X shape: {X_train_dict[reference_tf].shape}, y shape: {y_train.shape}")
# Recalculate future prices
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8)
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
last_data_refresh_time = time.time()
# Convert multi-timeframe dict to the format expected by the model
# For now, we use only the reference timeframe, but in the future,
# the model should be updated to handle multi-timeframe inputs
X_train = X_train_dict[reference_tf]
X_val = X_val_dict[reference_tf]
# Train one epoch
train_action_loss, train_price_loss, train_acc = model.train_epoch(
X_train, y_train, train_future_prices, batch_size
)
# Evaluate
val_action_loss, val_price_loss, val_acc = model.evaluate(
X_val, y_val, val_future_prices
)
logger.info(f"Epoch {epoch} results:")
logger.info(f" Train - Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}")
logger.info(f" Valid - Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}")
# Get predictions for PnL calculation
train_action_probs, train_price_preds = model.predict(X_train)
val_action_probs, val_price_preds = model.predict(X_val)
# Convert probabilities to actions
train_preds = np.argmax(train_action_probs, axis=1)
val_preds = np.argmax(val_action_probs, axis=1)
# Track signal distribution
train_buy_count = np.sum(train_preds == 2)
train_sell_count = np.sum(train_preds == 0)
train_hold_count = np.sum(train_preds == 1)
val_buy_count = np.sum(val_preds == 2)
val_sell_count = np.sum(val_preds == 0)
val_hold_count = np.sum(val_preds == 1)
signal_dist = {
"train": {
"BUY": float(train_buy_count / len(train_preds)) if len(train_preds) > 0 else 0,
"SELL": float(train_sell_count / len(train_preds)) if len(train_preds) > 0 else 0,
"HOLD": float(train_hold_count / len(train_preds)) if len(train_preds) > 0 else 0
},
"val": {
"BUY": float(val_buy_count / len(val_preds)) if len(val_preds) > 0 else 0,
"SELL": float(val_sell_count / len(val_preds)) if len(val_preds) > 0 else 0,
"HOLD": float(val_hold_count / len(val_preds)) if len(val_preds) > 0 else 0
}
}
# Calculate PnL and win rates with different position sizes
position_sizes = [0.1, 0.25, 0.5, 1.0, 2.0] # Multiple position sizes for robustness
best_position_train_pnl = -float('inf')
best_position_val_pnl = -float('inf')
best_position_train_wr = 0
best_position_val_wr = 0
best_position_size = 1.0
for position_size in position_sizes:
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
train_preds, train_prices, position_size=position_size
)
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
val_preds, val_prices, position_size=position_size
)
logger.info(f" Position Size: {position_size}")
logger.info(f" Train - PnL: {train_pnl:.4f}, Win Rate: {train_win_rate:.4f}, Trades: {len(train_trades)}")
logger.info(f" Valid - PnL: {val_pnl:.4f}, Win Rate: {val_win_rate:.4f}, Trades: {len(val_trades)}")
# Track best position size for this epoch
if val_pnl > best_position_val_pnl:
best_position_val_pnl = val_pnl
best_position_val_wr = val_win_rate
best_position_size = position_size
if train_pnl > best_position_train_pnl:
best_position_train_pnl = train_pnl
best_position_train_wr = train_win_rate
# Track best model overall (using position size 1.0 as reference)
if val_pnl > best_val_pnl and position_size == 1.0:
best_val_pnl = val_pnl
best_win_rate = val_win_rate
best_epoch = epoch
logger.info(f" New best validation PnL: {best_val_pnl:.4f} at epoch {best_epoch}")
# Save the best model
model.save(f"NN/models/saved/optimized_short_term_model_realtime_best")
# Store epoch metrics
epoch_metrics = {
"epoch": epoch,
"train_loss": float(train_action_loss),
"val_loss": float(val_action_loss),
"train_acc": float(train_acc),
"val_acc": float(val_acc),
"train_pnl": float(best_position_train_pnl),
"val_pnl": float(best_position_val_pnl),
"train_win_rate": float(best_position_train_wr),
"val_win_rate": float(best_position_val_wr),
"best_position_size": float(best_position_size),
"signal_distribution": signal_dist,
"timestamp": datetime.now().isoformat(),
"data_age": int(time.time() - last_data_refresh_time)
}
# Update training stats
training_stats["epochs_completed"] = epoch
training_stats["best_val_pnl"] = float(best_val_pnl)
training_stats["best_epoch"] = best_epoch
training_stats["best_win_rate"] = float(best_win_rate)
training_stats["last_update"] = datetime.now().isoformat()
training_stats["epochs"].append(epoch_metrics)
# Check if we need to save checkpoint
if time.time() - last_checkpoint_time > checkpoint_interval:
logger.info(f"Saving checkpoint at epoch {epoch}")
# Save model checkpoint
model.save(f"{checkpoint_dir}/checkpoint_epoch_{epoch}")
# Save training statistics
save_training_stats(training_stats)
last_checkpoint_time = time.time()
# Test trade signal generation with a random sample
random_idx = np.random.randint(0, len(X_val))
sample_X = X_val[random_idx:random_idx+1]
sample_probs, sample_price_pred = model.predict(sample_X)
# Process with signal interpreter
signal = signal_interpreter.interpret_signal(
sample_probs[0],
float(sample_price_pred[0][0]) if hasattr(sample_price_pred, "__getitem__") else float(sample_price_pred[0]),
market_data={'price': float(val_prices[random_idx]) if random_idx < len(val_prices) else 50000.0}
)
logger.info(f" Sample trade signal: {signal['action']} with confidence {signal['confidence']:.4f}")
# Log trading statistics
logger.info(f" Train - Actions: BUY={train_buy_count}, SELL={train_sell_count}, HOLD={train_hold_count}")
logger.info(f" Valid - Actions: BUY={val_buy_count}, SELL={val_sell_count}, HOLD={val_hold_count}")
# Log epoch timing
epoch_time = time.time() - epoch_start
total_elapsed = time.time() - start_time
time_remaining = max_training_time - total_elapsed
logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
logger.info(f" Training time: {total_elapsed/3600:.2f} hours / {max_training_time/3600:.2f} hours")
logger.info(f" Estimated time remaining: {time_remaining/3600:.2f} hours")
# Save final model and performance metrics
logger.info("Saving final optimized model...")
model.save("NN/models/saved/optimized_short_term_model_realtime_final")
# Save performance metrics to file
save_training_stats(training_stats)
# Generate performance plots
try:
model.plot_training_history("NN/models/saved/realtime_training_stats.json")
logger.info("Performance plots generated successfully")
except Exception as e:
logger.error(f"Error generating plots: {str(e)}")
# Calculate total training time
total_time = time.time() - start_time
hours, remainder = divmod(total_time, 3600)
minutes, seconds = divmod(remainder, 60)
logger.info(f"Overnight training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")
logger.info(f"Best model performance - Epoch: {best_epoch}, PnL: {best_val_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
except Exception as e:
logger.error(f"Error during overnight training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Try to save the model and stats in case of error
try:
if 'model' in locals():
model.save("NN/models/saved/optimized_short_term_model_realtime_emergency")
logger.info("Emergency model save completed")
if 'training_stats' in locals():
save_training_stats(training_stats, "NN/models/saved/realtime_training_stats_emergency.json")
except Exception as e2:
logger.error(f"Failed to save emergency checkpoint: {str(e2)}")
if __name__ == "__main__":
# Print startup banner
print("=" * 80)
print("OVERNIGHT REALTIME TRAINING SESSION")
print("This script will continuously train the model using real-time market data")
print("Press Ctrl+C to safely stop training and save the model")
print("=" * 80)
run_overnight_training()

View File

@ -1,231 +0,0 @@
#!/usr/bin/env python
"""
Training Configuration for GOGO2 Trading System
This module provides a central configuration for all training scripts,
ensuring they use real market data and follow consistent practices.
Usage:
import train_config
config = train_config.get_config('supervised') # or 'reinforcement' or 'hybrid'
"""
import os
import logging
import json
from datetime import datetime
from pathlib import Path
# Ensure consistent logging across all training scripts
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
log_file = log_dir / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
logger = logging.getLogger('training')
# Define available training types
TRAINING_TYPES = {
'supervised': {
'description': 'Supervised learning using CNN model',
'script': 'train_with_realtime.py',
'model_class': 'CNNModelPyTorch',
'data_interface': 'MultiTimeframeDataInterface'
},
'reinforcement': {
'description': 'Reinforcement learning using DQN agent',
'script': 'train_rl_with_realtime.py',
'model_class': 'DQNAgent',
'data_interface': 'MultiTimeframeDataInterface'
},
'hybrid': {
'description': 'Combined supervised and reinforcement learning',
'script': 'train_hybrid.py', # To be implemented
'model_class': 'HybridModel', # To be implemented
'data_interface': 'MultiTimeframeDataInterface'
}
}
# Default configuration
DEFAULT_CONFIG = {
# Market data configuration
'market_data': {
'use_real_data_only': True, # IMPORTANT: Only use real market data, never synthetic
'symbol': 'BTC/USDT',
'timeframes': ['1m', '5m', '15m'],
'window_size': 24,
'data_refresh_interval': 300, # seconds
'use_indicators': True
},
# Training parameters
'training': {
'max_training_time': 12 * 3600, # seconds (12 hours)
'checkpoint_interval': 3600, # seconds (1 hour)
'batch_size': 64,
'learning_rate': 0.0001,
'optimizer': 'adam',
'loss_function': 'custom_pnl' # Focus on profitability
},
# Model paths
'paths': {
'models_dir': 'NN/models/saved',
'logs_dir': 'logs',
'tensorboard_dir': 'runs'
},
# GPU configuration
'hardware': {
'use_gpu': True,
'mixed_precision': True,
'device': 'cuda' if os.environ.get('CUDA_VISIBLE_DEVICES') is not None else 'cpu'
}
}
def get_config(training_type='supervised', custom_config=None):
"""
Get configuration for a specific training type
Args:
training_type (str): Type of training ('supervised', 'reinforcement', or 'hybrid')
custom_config (dict): Optional custom configuration to merge
Returns:
dict: Complete configuration
"""
if training_type not in TRAINING_TYPES:
raise ValueError(f"Invalid training type: {training_type}. Must be one of {list(TRAINING_TYPES.keys())}")
# Start with default configuration
config = DEFAULT_CONFIG.copy()
# Add training type-specific configuration
config['training_type'] = training_type
config['training_info'] = TRAINING_TYPES[training_type]
# Override with custom configuration if provided
if custom_config:
_deep_update(config, custom_config)
# Validate configuration
_validate_config(config)
return config
def save_config(config, filepath=None):
"""
Save configuration to a JSON file
Args:
config (dict): Configuration to save
filepath (str): Path to save to (default: based on training type and timestamp)
Returns:
str: Path where configuration was saved
"""
if filepath is None:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
training_type = config.get('training_type', 'unknown')
filepath = f"configs/training_{training_type}_{timestamp}.json"
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, 'w') as f:
json.dump(config, f, indent=2)
logger.info(f"Configuration saved to {filepath}")
return filepath
def load_config(filepath):
"""
Load configuration from a JSON file
Args:
filepath (str): Path to load from
Returns:
dict: Loaded configuration
"""
with open(filepath, 'r') as f:
config = json.load(f)
# Validate the loaded configuration
_validate_config(config)
logger.info(f"Configuration loaded from {filepath}")
return config
def _deep_update(target, source):
"""
Deep update a nested dictionary
Args:
target (dict): Target dictionary to update
source (dict): Source dictionary with updates
Returns:
dict: Updated target dictionary
"""
for key, value in source.items():
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
_deep_update(target[key], value)
else:
target[key] = value
return target
def _validate_config(config):
"""
Validate configuration to ensure it follows required guidelines
Args:
config (dict): Configuration to validate
Returns:
bool: True if valid, raises exception otherwise
"""
# Enforce real data policy
if config.get('use_real_data_only', True) is not True:
logger.error("POLICY VIOLATION: Real market data policy requires only using real data")
raise ValueError("Configuration violates policy: Must use only real market data, never synthetic")
# Add explicit check at the beginning of the validation function
if 'allow_synthetic_data' in config and config['allow_synthetic_data'] is True:
logger.error("POLICY VIOLATION: Synthetic data is not allowed under any circumstances")
raise ValueError("Configuration violates policy: Synthetic data is explicitly forbidden")
# Validate symbol
if not config['market_data']['symbol'] or '/' not in config['market_data']['symbol']:
raise ValueError(f"Invalid symbol format: {config['market_data']['symbol']}")
# Validate timeframes
if not config['market_data']['timeframes']:
raise ValueError("At least one timeframe must be specified")
# Ensure window size is reasonable
if config['market_data']['window_size'] < 10 or config['market_data']['window_size'] > 500:
raise ValueError(f"Window size out of reasonable range: {config['market_data']['window_size']}")
return True
if __name__ == "__main__":
# Show available training configurations
print("Available Training Configurations:")
print("=" * 40)
for training_type, info in TRAINING_TYPES.items():
print(f"{training_type.upper()}: {info['description']}")
# Example of getting and saving a configuration
config = get_config('supervised')
save_config(config)
print("\nDefault configuration generated and saved.")
print(f"Log file: {log_file}")

View File

@ -1,415 +0,0 @@
#!/usr/bin/env python
"""
DQN Training Session with Monitoring
This script sets up and runs a DQN agent training session with progress monitoring.
It tracks key metrics like rewards, losses, and prediction accuracy, and
visualizes the agent's learning progress.
"""
import os
import sys
import logging
import time
import argparse
import numpy as np
import torch
import matplotlib.pyplot as plt
from datetime import datetime
from pathlib import Path
import signal
from torch.utils.tensorboard import SummaryWriter
# Add project root to path if needed
project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.append(project_root)
# Import configurations
import train_config
# Import key components
from NN.models.dqn_agent import DQNAgent
from dataprovider_realtime import MultiTimeframeDataInterface
# Configure logging
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = log_dir / f"dqn_training_{timestamp}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
logger = logging.getLogger('dqn_training')
# Global variables for graceful shutdown
running = True
# Configure signal handler for graceful shutdown
def signal_handler(sig, frame):
global running
logger.info("Received interrupt signal. Finishing current episode and saving model...")
running = False
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
class DQNTrainingMonitor:
"""
Class to monitor DQN training progress and visualize results
"""
def __init__(self, config):
self.config = config
self.device = torch.device(config['hardware']['device'])
self.agent = None
self.data_interface = None
# Training stats
self.episode_rewards = []
self.avg_rewards = []
self.losses = []
self.epsilons = []
self.best_reward = -float('inf')
self.tensorboard_writer = None
# Paths
self.models_dir = Path(config['paths']['models_dir'])
self.models_dir.mkdir(exist_ok=True, parents=True)
# Metrics display intervals
self.plot_interval = config.get('visualization', {}).get('plot_interval', 5)
self.save_interval = config.get('training', {}).get('save_interval', 10)
def initialize(self):
"""Initialize the DQN agent and data interface"""
# Set up TensorBoard
tb_dir = Path(self.config['paths']['tensorboard_dir'])
tb_dir.mkdir(exist_ok=True, parents=True)
log_dir = tb_dir / f"dqn_{timestamp}"
self.tensorboard_writer = SummaryWriter(log_dir=str(log_dir))
logger.info(f"TensorBoard initialized at {log_dir}")
# Initialize data interface
symbol = self.config['market_data']['symbol']
timeframes = self.config['market_data']['timeframes']
window_size = self.config['market_data']['window_size']
logger.info(f"Initializing data interface for {symbol} with timeframes {timeframes}")
self.data_interface = MultiTimeframeDataInterface(
symbol=symbol,
timeframes=timeframes
)
# Get data for training
X_train_dict, _, _, _, _, _ = self.data_interface.prepare_training_data(
window_size=window_size,
refresh=True
)
if X_train_dict is None:
raise ValueError("Failed to load training data for DQN agent")
# Get feature count from the reference timeframe
reference_tf = min(
timeframes,
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
)
num_features = X_train_dict[reference_tf].shape[2]
logger.info(f"Using {num_features} features from timeframe {reference_tf}")
# Initialize DQN agent
state_size = num_features * window_size * len(timeframes)
action_size = 3 # Buy, Hold, Sell
logger.info(f"Initializing DQN agent with state size {state_size} and action size {action_size}")
self.agent = DQNAgent(
state_shape=(len(timeframes), window_size, num_features), # Multi-dimensional state shape
n_actions=action_size,
learning_rate=self.config['training']['learning_rate'],
batch_size=self.config['training']['batch_size'],
gamma=self.config.get('model', {}).get('gamma', 0.95),
epsilon=self.config.get('model', {}).get('epsilon_start', 1.0),
epsilon_min=self.config.get('model', {}).get('epsilon_min', 0.01),
epsilon_decay=self.config.get('model', {}).get('epsilon_decay', 0.995),
buffer_size=self.config.get('model', {}).get('memory_size', 10000),
device=self.device
)
# Load existing model if available
model_path = self.models_dir / "dqn_agent_best"
if os.path.exists(f"{model_path}_policy.pt") and not self.config.get('model', {}).get('new_model', False):
logger.info(f"Loading existing DQN model from {model_path}")
try:
self.agent.load(str(model_path))
logger.info("DQN model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
logger.info("Starting with a new model instead")
else:
logger.info("Starting with a new model")
return True
def train(self, num_episodes=100):
"""Train the DQN agent for a specified number of episodes"""
if self.agent is None:
raise ValueError("Agent not initialized. Call initialize() first.")
logger.info(f"Starting DQN training for {num_episodes} episodes")
# Get training data
window_size = self.config['market_data']['window_size']
X_train_dict, y_train, _, _, _, _ = self.data_interface.prepare_training_data(
window_size=window_size,
refresh=True
)
# Prepare data for training
reference_tf = min(
self.config['market_data']['timeframes'],
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
)
# Convert data to flat states for RL
states = []
actions = []
# Find the minimum length across all timeframes to ensure consistent indexing
min_length = min(len(X_train_dict[tf]) for tf in self.config['market_data']['timeframes'])
logger.info(f"Using {min_length} samples from each timeframe for training")
# Only use indices that exist in all timeframes
for i in range(min_length):
state = []
for tf in self.config['market_data']['timeframes']:
state.extend(X_train_dict[tf][i].flatten())
states.append(np.array(state))
actions.append(np.argmax(y_train[i]))
logger.info(f"Prepared {len(states)} state-action pairs for training")
# Training loop
global running
for episode in range(1, num_episodes + 1):
if not running:
logger.info("Training interrupted. Saving final model.")
self._save_model(final=True)
break
episode_reward = 0
total_loss = 0
correct_predictions = 0
# Randomly sample start position (to prevent overfitting on sequence)
start_idx = np.random.randint(0, len(states) - 1000) if len(states) > 1000 else 0
end_idx = min(start_idx + 1000, len(states))
logger.info(f"Episode {episode}/{num_episodes} - Training on sequence from {start_idx} to {end_idx}")
# Training on sequence
for i in range(start_idx, end_idx - 1):
state = states[i]
action = actions[i]
next_state = states[i + 1]
# Get reward based on price movement
# Price is typically the close price (4th column in OHLCV data)
try:
# Assuming the last feature in each timeframe is the closing price
price_current = X_train_dict[reference_tf][i][-1, -1] # Last row, last column of current state
price_next = X_train_dict[reference_tf][i+1][-1, -1] # Last row, last column of next state
price_diff = price_next - price_current
except IndexError:
# Fallback if we're at the edge of our data
price_diff = 0
if action == 0: # Buy
reward = price_diff * 100 # Scale reward for better learning
elif action == 2: # Sell
reward = -price_diff * 100
else: # Hold
reward = abs(price_diff) * 10 if abs(price_diff) < 0.0001 else -abs(price_diff) * 50
# Train the agent with this experience
predicted_action = self.agent.act(state)
# Store experience in memory
done = (i == end_idx - 2) # Mark as done if it's the last step
self.agent.remember(state, action, reward, next_state, done)
# Periodically replay from memory
if i % 10 == 0: # Replay every 10 steps
loss = self.agent.replay()
else:
loss = None
if predicted_action == action:
correct_predictions += 1
episode_reward += reward
if loss is not None:
total_loss += loss
# Calculate metrics
accuracy = correct_predictions / (end_idx - start_idx) * 100
avg_loss = total_loss / (end_idx - start_idx) if end_idx > start_idx else 0
# Update training history
self.episode_rewards.append(episode_reward)
self.avg_rewards.append(self.agent.avg_reward)
self.losses.append(avg_loss)
self.epsilons.append(self.agent.epsilon)
# Log metrics
logger.info(f"Episode {episode} - Reward: {episode_reward:.2f}, Avg Reward: {self.agent.avg_reward:.2f}, "
f"Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%, Epsilon: {self.agent.epsilon:.4f}")
# Log to TensorBoard
self._log_to_tensorboard(episode, episode_reward, avg_loss, accuracy)
# Save model if improved
improved = episode_reward > self.best_reward
if improved:
self.best_reward = episode_reward
logger.info(f"New best reward: {self.best_reward:.2f}")
# Periodically save model
if episode % self.save_interval == 0 or improved:
self._save_model(final=False)
# Plot progress
if episode % self.plot_interval == 0:
self._plot_training_progress()
# Save final model
logger.info("Training completed.")
self._save_model(final=True)
def _log_to_tensorboard(self, episode, reward, loss, accuracy):
"""Log training metrics to TensorBoard"""
if self.tensorboard_writer:
self.tensorboard_writer.add_scalar('Train/Reward', reward, episode)
self.tensorboard_writer.add_scalar('Train/AvgReward', self.agent.avg_reward, episode)
self.tensorboard_writer.add_scalar('Train/Loss', loss, episode)
self.tensorboard_writer.add_scalar('Train/Accuracy', accuracy, episode)
self.tensorboard_writer.add_scalar('Train/Epsilon', self.agent.epsilon, episode)
def _save_model(self, final=False):
"""Save the DQN model"""
if final:
save_path = self.models_dir / f"dqn_agent_final_{timestamp}"
else:
save_path = self.models_dir / "dqn_agent_best"
self.agent.save(str(save_path))
logger.info(f"Model saved to {save_path}")
def _plot_training_progress(self):
"""Plot training progress metrics"""
if not self.episode_rewards:
logger.warning("No training data available for plotting yet")
return
plt.figure(figsize=(15, 10))
# Plot rewards
plt.subplot(2, 2, 1)
plt.plot(self.episode_rewards, label='Episode Reward')
plt.plot(self.avg_rewards, label='Avg Reward', linestyle='--')
plt.title('Rewards')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.legend()
# Plot losses
plt.subplot(2, 2, 2)
plt.plot(self.losses)
plt.title('Loss')
plt.xlabel('Episode')
plt.ylabel('Loss')
# Plot epsilon
plt.subplot(2, 2, 3)
plt.plot(self.epsilons)
plt.title('Exploration Rate (Epsilon)')
plt.xlabel('Episode')
plt.ylabel('Epsilon')
# Save plot
plots_dir = Path("plots")
plots_dir.mkdir(exist_ok=True)
plt.tight_layout()
plt.savefig(plots_dir / f"dqn_training_progress_{timestamp}.png")
plt.close()
def parse_args():
parser = argparse.ArgumentParser(description='DQN Training Session with Monitoring')
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train')
parser.add_argument('--symbol', type=str, default='BTC/USDT', help='Trading symbol')
parser.add_argument('--timeframes', type=str, default='1m,5m,15m', help='Comma-separated timeframes')
parser.add_argument('--window', type=int, default=24, help='Window size for state construction')
parser.add_argument('--batch-size', type=int, default=64, help='Batch size for training')
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
parser.add_argument('--plot-interval', type=int, default=5, help='Interval for plotting progress')
parser.add_argument('--save-interval', type=int, default=10, help='Interval for saving model')
parser.add_argument('--new-model', action='store_true', help='Start with a new model instead of loading existing')
return parser.parse_args()
def main():
args = parse_args()
# Force CPU training to avoid device mismatch errors
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['DISABLE_MIXED_PRECISION'] = '1'
# Create custom config based on arguments
custom_config = {
'market_data': {
'symbol': args.symbol,
'timeframes': args.timeframes.split(','),
'window_size': args.window
},
'training': {
'batch_size': args.batch_size,
'learning_rate': args.lr,
'save_interval': args.save_interval
},
'visualization': {
'plot_interval': args.plot_interval
},
'model': {
'new_model': args.new_model
},
'hardware': {
'device': 'cpu',
'mixed_precision': False
}
}
# Get configuration
config = train_config.get_config('reinforcement', custom_config)
# Save configuration for reference
config_dir = Path("configs")
config_dir.mkdir(exist_ok=True)
config_path = config_dir / f"dqn_training_config_{timestamp}.json"
train_config.save_config(config, str(config_path))
# Initialize and train
monitor = DQNTrainingMonitor(config)
monitor.initialize()
monitor.train(num_episodes=args.episodes)
logger.info(f"Training completed. Results saved to logs and plots directories.")
logger.info(f"To visualize training in TensorBoard, run: tensorboard --logdir={config['paths']['tensorboard_dir']}")
if __name__ == "__main__":
main()

View File

@ -1,731 +0,0 @@
#!/usr/bin/env python
"""
Hybrid Training Script - Combining Supervised and Reinforcement Learning
This script provides a hybrid approach that:
1. Performs supervised learning on market data using CNN models
2. Uses reinforcement learning to optimize trading strategies
3. Only uses real market data (never synthetic)
The script enables both approaches to complement each other:
- CNN model learns patterns from historical data (supervised)
- RL agent optimizes actual trading decisions (reinforcement)
"""
import os
import sys
import logging
import argparse
import numpy as np
import torch
import time
import json
import asyncio
import signal
import threading
from datetime import datetime
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
# Add project root to path if needed
project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.append(project_root)
# Import configurations
import train_config
# Import key components
from NN.models.cnn_model_pytorch import CNNModelPyTorch
from NN.models.dqn_agent import DQNAgent
from dataprovider_realtime import MultiTimeframeDataInterface, RealTimeChart
from NN.utils.signal_interpreter import SignalInterpreter
# Global variables for graceful shutdown
running = True
training_stats = {
"supervised": {
"epochs_completed": 0,
"best_val_pnl": -float('inf'),
"best_epoch": 0,
"best_win_rate": 0
},
"reinforcement": {
"episodes_completed": 0,
"best_reward": -float('inf'),
"best_episode": 0,
"best_win_rate": 0
},
"hybrid": {
"iterations_completed": 0,
"best_combined_score": -float('inf'),
"training_started": datetime.now().isoformat(),
"last_update": datetime.now().isoformat()
}
}
# Configure signal handler for graceful shutdown
def signal_handler(sig, frame):
global running
logging.info("Received interrupt signal. Finishing current training cycle and saving models...")
running = False
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
class HybridModel:
"""
Hybrid model that combines supervised CNN learning with RL-based decision optimization
"""
def __init__(self, config):
self.config = config
self.device = torch.device(config['hardware']['device'])
self.supervised_model = None
self.rl_agent = None
self.data_interface = None
self.signal_interpreter = None
self.chart = None
# Training stats
self.tensorboard_writer = None
self.iter_count = 0
self.supervised_epochs = 0
self.rl_episodes = 0
# Initialize logging
self.logger = logging.getLogger('hybrid_model')
# Paths
self.models_dir = Path(config['paths']['models_dir'])
self.models_dir.mkdir(exist_ok=True, parents=True)
def initialize(self):
"""Initialize all components of the hybrid model"""
# Set up TensorBoard
log_dir = Path(self.config['paths']['tensorboard_dir']) / f"hybrid_{int(time.time())}"
self.tensorboard_writer = SummaryWriter(log_dir=str(log_dir))
self.logger.info(f"TensorBoard initialized at {log_dir}")
# Initialize data interface
symbol = self.config['market_data']['symbol']
timeframes = self.config['market_data']['timeframes']
window_size = self.config['market_data']['window_size']
self.logger.info(f"Initializing data interface for {symbol} with timeframes {timeframes}")
self.data_interface = MultiTimeframeDataInterface(
symbol=symbol,
timeframes=timeframes
)
# Initialize supervised model (CNN)
self._initialize_supervised_model(window_size)
# Initialize RL agent
self._initialize_rl_agent(window_size)
# Initialize signal interpreter
self.signal_interpreter = SignalInterpreter(config={
'buy_threshold': 0.65,
'sell_threshold': 0.65,
'hold_threshold': 0.75,
'trend_filter_enabled': True,
'volume_filter_enabled': True
})
# Initialize chart if visualization is enabled
if self.config.get('visualization', {}).get('enabled', False):
self._initialize_chart()
return True
def _initialize_supervised_model(self, window_size):
"""Initialize the supervised CNN model"""
try:
# Get data shape information
X_train_dict, y_train, X_val_dict, y_val, _, _ = self.data_interface.prepare_training_data(
window_size=window_size,
refresh=True
)
if X_train_dict is None or y_train is None:
raise ValueError("Failed to load training data")
# Get reference timeframe (lowest timeframe)
reference_tf = min(
self.config['market_data']['timeframes'],
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
)
# Get feature count from the data
num_features = X_train_dict[reference_tf].shape[2]
# Initialize model
self.logger.info(f"Initializing CNN model with {num_features} features")
self.supervised_model = CNNModelPyTorch(
window_size=window_size,
num_features=num_features,
output_size=3, # BUY/HOLD/SELL
timeframes=self.config['market_data']['timeframes']
)
# Load existing model if available
model_path = self.models_dir / "supervised_model_best.pt"
if model_path.exists():
self.logger.info(f"Loading existing CNN model from {model_path}")
self.supervised_model.load(str(model_path))
self.logger.info("CNN model loaded successfully")
else:
self.logger.info("No existing CNN model found. Starting with a new model.")
except Exception as e:
self.logger.error(f"Error initializing supervised model: {str(e)}")
import traceback
self.logger.error(traceback.format_exc())
raise
def _initialize_rl_agent(self, window_size):
"""Initialize the RL agent"""
try:
# Get data for RL training
X_train_dict, _, _, _, _, _ = self.data_interface.prepare_training_data(
window_size=window_size,
refresh=True
)
if X_train_dict is None:
raise ValueError("Failed to load training data for RL agent")
# Get reference timeframe features
reference_tf = min(
self.config['market_data']['timeframes'],
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
)
# Calculate state size - this is more complex for RL
# For simplicity, we'll use the CNN's feature representation + position info
state_size = window_size * X_train_dict[reference_tf].shape[2] + 3 # +3 for position, equity, unrealized_pnl
# Initialize RL agent
self.logger.info(f"Initializing RL agent with state size {state_size}")
self.rl_agent = DQNAgent(
state_size=state_size,
n_actions=3, # BUY/HOLD/SELL
epsilon=1.0,
epsilon_decay=0.995,
epsilon_min=0.01,
learning_rate=self.config['training']['learning_rate'],
gamma=0.99,
buffer_size=10000,
batch_size=self.config['training']['batch_size'],
device=self.device
)
# Load existing agent if available
agent_path = self.models_dir / "rl_agent_best.pth"
if agent_path.exists():
self.logger.info(f"Loading existing RL agent from {agent_path}")
self.rl_agent.load(str(agent_path))
self.logger.info("RL agent loaded successfully")
else:
self.logger.info("No existing RL agent found. Starting with a new agent.")
except Exception as e:
self.logger.error(f"Error initializing RL agent: {str(e)}")
import traceback
self.logger.error(traceback.format_exc())
raise
def _initialize_chart(self):
"""Initialize the RealTimeChart for visualization"""
try:
from dataprovider_realtime import RealTimeChart
symbol = self.config['market_data']['symbol']
self.logger.info(f"Initializing RealTimeChart for {symbol}")
self.chart = RealTimeChart(symbol=symbol)
# TODO: Start chart server in a background thread
except Exception as e:
self.logger.error(f"Error initializing chart: {str(e)}")
self.chart = None
async def train_hybrid(self, iterations=10, sv_epochs_per_iter=5, rl_episodes_per_iter=2):
"""
Main hybrid training loop
Args:
iterations: Number of hybrid iterations to run
sv_epochs_per_iter: Number of supervised epochs per iteration
rl_episodes_per_iter: Number of RL episodes per iteration
Returns:
dict: Training statistics
"""
self.logger.info(f"Starting hybrid training with {iterations} iterations")
self.logger.info(f"Each iteration includes {sv_epochs_per_iter} supervised epochs and {rl_episodes_per_iter} RL episodes")
# Training loop
for iteration in range(iterations):
if not running:
self.logger.info("Training stopped by user")
break
self.logger.info(f"Iteration {iteration+1}/{iterations}")
self.iter_count += 1
# 1. Supervised learning phase
self.logger.info("Starting supervised learning phase")
sv_stats = await self.train_supervised(epochs=sv_epochs_per_iter)
# 2. Reinforcement learning phase
self.logger.info("Starting reinforcement learning phase")
rl_stats = await self.train_reinforcement(episodes=rl_episodes_per_iter)
# 3. Update global training stats
self._update_training_stats(sv_stats, rl_stats)
# 4. Save models and stats
self._save_models_and_stats()
# 5. Log to TensorBoard
if self.tensorboard_writer:
self._log_to_tensorboard(iteration, sv_stats, rl_stats)
self.logger.info("Hybrid training completed")
return training_stats
async def train_supervised(self, epochs=5):
"""
Run supervised training for a specified number of epochs
Args:
epochs: Number of epochs to train
Returns:
dict: Training statistics
"""
# Get fresh data
window_size = self.config['market_data']['window_size']
X_train_dict, y_train, X_val_dict, y_val, train_prices, val_prices = self.data_interface.prepare_training_data(
window_size=window_size,
refresh=True
)
if X_train_dict is None or y_train is None:
self.logger.error("Failed to load training data")
return {}
# Get reference timeframe (lowest timeframe)
reference_tf = min(
self.config['market_data']['timeframes'],
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
)
# Calculate future prices for profitability-focused loss function
train_future_prices = self.data_interface.get_future_prices(train_prices, n_candles=8)
val_future_prices = self.data_interface.get_future_prices(val_prices, n_candles=8)
# For now, we use only the reference timeframe
X_train = X_train_dict[reference_tf]
X_val = X_val_dict[reference_tf]
# Training stats
stats = {
"train_losses": [],
"val_losses": [],
"train_accuracies": [],
"val_accuracies": [],
"train_pnls": [],
"val_pnls": [],
"best_val_pnl": -float('inf'),
"best_epoch": -1
}
batch_size = self.config['training']['batch_size']
# Training loop
for epoch in range(epochs):
if not running:
break
epoch_start = time.time()
# Train one epoch
train_action_loss, train_price_loss, train_acc = self.supervised_model.train_epoch(
X_train, y_train, train_future_prices, batch_size
)
# Evaluate
val_action_loss, val_price_loss, val_acc = self.supervised_model.evaluate(
X_val, y_val, val_future_prices
)
# Get predictions for PnL calculation
train_action_probs, _ = self.supervised_model.predict(X_train)
val_action_probs, _ = self.supervised_model.predict(X_val)
# Convert probabilities to actions
train_preds = np.argmax(train_action_probs, axis=1)
val_preds = np.argmax(val_action_probs, axis=1)
# Calculate PnL
train_pnl, train_win_rate, _ = self.data_interface.calculate_pnl(
train_preds, train_prices, position_size=1.0
)
val_pnl, val_win_rate, _ = self.data_interface.calculate_pnl(
val_preds, val_prices, position_size=1.0
)
# Update stats
stats["train_losses"].append(train_action_loss)
stats["val_losses"].append(val_action_loss)
stats["train_accuracies"].append(train_acc)
stats["val_accuracies"].append(val_acc)
stats["train_pnls"].append(train_pnl)
stats["val_pnls"].append(val_pnl)
# Check if this is the best model
if val_pnl > stats["best_val_pnl"]:
stats["best_val_pnl"] = val_pnl
stats["best_epoch"] = epoch
stats["best_win_rate"] = val_win_rate
# Save the best model
self.supervised_model.save(str(self.models_dir / "supervised_model_best.pt"))
# Log epoch results
self.logger.info(f"Supervised Epoch {epoch+1}/{epochs}")
self.logger.info(f" Train Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}, PnL: {train_pnl:.4f}")
self.logger.info(f" Val Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}, PnL: {val_pnl:.4f}")
# Log timing
epoch_time = time.time() - epoch_start
self.logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
# Update global epoch counter
self.supervised_epochs += 1
# Small delay to allow for interruption
await asyncio.sleep(0.1)
return stats
async def train_reinforcement(self, episodes=2):
"""
Run reinforcement learning for a specified number of episodes
Args:
episodes: Number of episodes to train
Returns:
dict: Training statistics
"""
from NN.train_rl import RLTradingEnvironment
# Get data for RL environment
window_size = self.config['market_data']['window_size']
# Get all timeframes data
data_dict = self.data_interface.get_multi_timeframe_data(refresh=True)
if not data_dict:
self.logger.error("Failed to fetch data for any timeframe")
return {}
# Extract key timeframes
timeframes = self.config['market_data']['timeframes']
# Extract features from dataframes
features = {}
for tf in timeframes:
if tf in data_dict:
df = data_dict[tf]
# Add indicators if not already added
if 'rsi' not in df.columns:
df = self.data_interface.add_indicators(df)
# Convert to numpy array with close price as the last column
features[tf] = np.hstack([
df.drop(['timestamp', 'close'], axis=1).values,
df['close'].values.reshape(-1, 1)
])
# Ensure we have all needed timeframes
required_tfs = ['1m', '5m', '15m'] # Most common timeframes used by RL
for tf in required_tfs:
if tf not in features and tf in timeframes:
self.logger.error(f"Missing features for timeframe {tf}")
return {}
# Create environment with our feature data
env = RLTradingEnvironment(
features_1m=features.get('1m'),
features_1h=features.get('1h', features.get('5m')), # Use 5m as fallback
features_1d=features.get('1d', features.get('15m')) # Use 15m as fallback
)
# Training stats
stats = {
"rewards": [],
"win_rates": [],
"trades": [],
"best_reward": -float('inf'),
"best_episode": -1
}
# RL training loop
for episode in range(episodes):
if not running:
break
episode_start = time.time()
self.logger.info(f"RL Episode {episode+1}/{episodes}")
# Reset environment
state = env.reset()
total_reward = 0
trades = 0
wins = 0
# Run one episode
done = False
max_steps = 1000
step = 0
while not done and step < max_steps:
# Use CNN model to enhance state representation if available
enhanced_state = self._enhance_state_with_cnn(state)
# Select action using the RL agent
action = self.rl_agent.act(enhanced_state)
# Take step in environment
next_state, reward, done, info = env.step(action)
# Store in replay buffer
self.rl_agent.remember(enhanced_state, action, reward,
self._enhance_state_with_cnn(next_state), done)
# Update episode statistics
total_reward += reward
state = next_state
step += 1
# Track trades and wins
if action != 2: # Not HOLD
trades += 1
if reward > 0:
wins += 1
# Train the agent on a batch of experiences
if len(self.rl_agent.memory) > self.config['training']['batch_size']:
self.rl_agent.replay(self.config['training']['batch_size'])
# Allow for interruption
if step % 100 == 0:
await asyncio.sleep(0.1)
if not running:
break
# Calculate win rate
win_rate = wins / max(1, trades)
# Update stats
stats["rewards"].append(total_reward)
stats["win_rates"].append(win_rate)
stats["trades"].append(trades)
# Check if this is the best agent
if total_reward > stats["best_reward"]:
stats["best_reward"] = total_reward
stats["best_episode"] = episode
# Save the best agent
self.rl_agent.save(str(self.models_dir / "rl_agent_best.pth"))
# Log episode results
self.logger.info(f" Reward: {total_reward:.4f}, Win Rate: {win_rate:.4f}, Trades: {trades}")
# Log timing
episode_time = time.time() - episode_start
self.logger.info(f" Episode completed in {episode_time:.2f} seconds")
# Update global episode counter
self.rl_episodes += 1
# Reduce exploration rate
self.rl_agent.adjust_epsilon()
# Small delay to allow for interruption
await asyncio.sleep(0.1)
return stats
def _enhance_state_with_cnn(self, state):
"""
Enhance the RL state with CNN feature extraction
Args:
state: The original state from the environment
Returns:
numpy.ndarray: Enhanced state representation
"""
# This is a placeholder - in a real implementation, you would:
# 1. Format the state for the CNN
# 2. Get the CNN's feature representation
# 3. Combine with the original state features
return state
def _update_training_stats(self, sv_stats, rl_stats):
"""Update global training statistics"""
global training_stats
# Update supervised stats
if sv_stats:
training_stats["supervised"]["epochs_completed"] = self.supervised_epochs
if "best_val_pnl" in sv_stats and sv_stats["best_val_pnl"] > training_stats["supervised"]["best_val_pnl"]:
training_stats["supervised"]["best_val_pnl"] = sv_stats["best_val_pnl"]
training_stats["supervised"]["best_epoch"] = sv_stats["best_epoch"] + training_stats["supervised"]["epochs_completed"] - len(sv_stats["train_losses"])
training_stats["supervised"]["best_win_rate"] = sv_stats.get("best_win_rate", 0)
# Update reinforcement stats
if rl_stats:
training_stats["reinforcement"]["episodes_completed"] = self.rl_episodes
if "best_reward" in rl_stats and rl_stats["best_reward"] > training_stats["reinforcement"]["best_reward"]:
training_stats["reinforcement"]["best_reward"] = rl_stats["best_reward"]
training_stats["reinforcement"]["best_episode"] = rl_stats["best_episode"] + training_stats["reinforcement"]["episodes_completed"] - len(rl_stats["rewards"])
# Update hybrid stats
training_stats["hybrid"]["iterations_completed"] = self.iter_count
training_stats["hybrid"]["last_update"] = datetime.now().isoformat()
# Calculate combined score (simple formula, can be adjusted)
sv_score = training_stats["supervised"]["best_val_pnl"]
rl_score = training_stats["reinforcement"]["best_reward"]
combined_score = sv_score * 0.7 + rl_score * 0.3 # Weight supervised more
if combined_score > training_stats["hybrid"]["best_combined_score"]:
training_stats["hybrid"]["best_combined_score"] = combined_score
def _save_models_and_stats(self):
"""Save models and training statistics"""
# Save training stats
try:
stats_file = self.models_dir / "hybrid_training_stats.json"
with open(stats_file, 'w') as f:
json.dump(training_stats, f, indent=2)
self.logger.info(f"Training statistics saved to {stats_file}")
except Exception as e:
self.logger.error(f"Error saving training stats: {str(e)}")
# Models are already saved in their respective training functions
def _log_to_tensorboard(self, iteration, sv_stats, rl_stats):
"""Log training metrics to TensorBoard"""
if not self.tensorboard_writer:
return
# Log supervised metrics
if sv_stats and "train_losses" in sv_stats:
for i, loss in enumerate(sv_stats["train_losses"]):
step = (iteration * len(sv_stats["train_losses"])) + i
self.tensorboard_writer.add_scalar('supervised/train_loss', loss, step)
self.tensorboard_writer.add_scalar('supervised/val_loss', sv_stats["val_losses"][i], step)
self.tensorboard_writer.add_scalar('supervised/train_accuracy', sv_stats["train_accuracies"][i], step)
self.tensorboard_writer.add_scalar('supervised/val_accuracy', sv_stats["val_accuracies"][i], step)
self.tensorboard_writer.add_scalar('supervised/train_pnl', sv_stats["train_pnls"][i], step)
self.tensorboard_writer.add_scalar('supervised/val_pnl', sv_stats["val_pnls"][i], step)
# Log reinforcement metrics
if rl_stats and "rewards" in rl_stats:
for i, reward in enumerate(rl_stats["rewards"]):
step = (iteration * len(rl_stats["rewards"])) + i
self.tensorboard_writer.add_scalar('reinforcement/reward', reward, step)
self.tensorboard_writer.add_scalar('reinforcement/win_rate', rl_stats["win_rates"][i], step)
self.tensorboard_writer.add_scalar('reinforcement/trades', rl_stats["trades"][i], step)
# Log hybrid metrics
self.tensorboard_writer.add_scalar('hybrid/iterations', self.iter_count, iteration)
self.tensorboard_writer.add_scalar('hybrid/combined_score', training_stats["hybrid"]["best_combined_score"], iteration)
# Flush to ensure data is written
self.tensorboard_writer.flush()
async def main():
"""Main entry point for the hybrid training script"""
parser = argparse.ArgumentParser(description='Hybrid Training Script')
parser.add_argument('--iterations', type=int, default=10, help='Number of hybrid iterations to run')
parser.add_argument('--sv-epochs', type=int, default=5, help='Supervised epochs per iteration')
parser.add_argument('--rl-episodes', type=int, default=2, help='RL episodes per iteration')
parser.add_argument('--symbol', type=str, default='BTC/USDT', help='Trading symbol')
parser.add_argument('--timeframes', type=str, nargs='+', default=['1m', '5m', '15m'], help='Timeframes to use')
parser.add_argument('--window-size', type=int, default=24, help='Window size for models')
parser.add_argument('--visualize', action='store_true', help='Enable visualization')
parser.add_argument('--config', type=str, help='Path to custom configuration file')
args = parser.parse_args()
# Load configuration
if args.config:
config = train_config.load_config(args.config)
else:
# Create custom config from command-line arguments
custom_config = {
'market_data': {
'symbol': args.symbol,
'timeframes': args.timeframes,
'window_size': args.window_size
},
'visualization': {
'enabled': args.visualize
}
}
config = train_config.get_config('hybrid', custom_config)
# Print startup banner
print("=" * 80)
print("HYBRID TRAINING SESSION")
print("Combining supervised learning (CNN) with reinforcement learning (RL)")
print(f"Symbol: {config['market_data']['symbol']}")
print(f"Timeframes: {config['market_data']['timeframes']}")
print(f"Iterations: {args.iterations} (SV epochs: {args.sv_epochs}, RL episodes: {args.rl_episodes})")
print("Press Ctrl+C to safely stop training and save the models")
print("=" * 80)
# Initialize the hybrid model
hybrid_model = HybridModel(config)
initialized = hybrid_model.initialize()
if not initialized:
print("Failed to initialize hybrid model. Exiting.")
return 1
try:
# Run training
await hybrid_model.train_hybrid(
iterations=args.iterations,
sv_epochs_per_iter=args.sv_epochs,
rl_episodes_per_iter=args.rl_episodes
)
print("Training completed successfully.")
return 0
except KeyboardInterrupt:
print("Training interrupted by user.")
return 0
except Exception as e:
print(f"Error during training: {str(e)}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
asyncio.run(main())

File diff suppressed because it is too large Load Diff

View File

@ -1,547 +0,0 @@
#!/usr/bin/env python
"""
Improved RL Trading with Enhanced Training and Monitoring
This script provides an improved version of the RL training process,
implementing better normalization, reward structure, and model training.
"""
import os
import sys
import logging
import argparse
import time
from datetime import datetime
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
# Add project directory to path if needed
project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.append(project_root)
# Import our custom modules
from NN.models.dqn_agent import DQNAgent
from NN.utils.trading_env import TradingEnvironment
from NN.utils.data_interface import DataInterface
from dataprovider_realtime import BinanceHistoricalData, RealTimeChart
# Configure logging
log_filename = f'improved_rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_filename),
logging.StreamHandler()
]
)
logger = logging.getLogger('improved_rl')
# Parse command line arguments
parser = argparse.ArgumentParser(description='Improved RL Trading with Enhanced Training')
parser.add_argument('--episodes', type=int, default=20, help='Number of episodes to train')
parser.add_argument('--visualize', action='store_true', help='Visualize trades during training')
parser.add_argument('--save-path', type=str, default='NN/models/saved/improved_dqn_agent', help='Path to save trained model')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
args = parser.parse_args()
def create_training_environment(symbol, window_size=20):
"""Create and prepare the training environment with data"""
logger.info(f"Setting up training environment for {symbol}")
# Fetch historical data from multiple timeframes
data_interface = DataInterface(symbol)
# Use Binance data provider for fetching data
historical_data = BinanceHistoricalData()
# Fetch data for each timeframe
df_1m = historical_data.get_historical_candles(symbol, interval_seconds=60, limit=1000)
df_5m = historical_data.get_historical_candles(symbol, interval_seconds=300, limit=1000)
df_15m = historical_data.get_historical_candles(symbol, interval_seconds=900, limit=500)
# Ensure all dataframes have index as timestamp type
if df_1m is not None and not df_1m.empty:
if 'timestamp' in df_1m.columns:
df_1m = df_1m.set_index('timestamp')
if df_5m is not None and not df_5m.empty:
if 'timestamp' in df_5m.columns:
df_5m = df_5m.set_index('timestamp')
if df_15m is not None and not df_15m.empty:
if 'timestamp' in df_15m.columns:
df_15m = df_15m.set_index('timestamp')
# Preprocess data (add technical indicators)
df_1m = preprocess_dataframe(df_1m)
df_5m = preprocess_dataframe(df_5m)
df_15m = preprocess_dataframe(df_15m)
# Create environment with all timeframes
env = create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size)
return env, (df_1m, df_5m, df_15m)
def preprocess_dataframe(df):
"""Add technical indicators and preprocess dataframe"""
if df is None or df.empty:
return None
# Drop any missing values
df = df.dropna()
# Ensure it has OHLCV columns
required_columns = ['open', 'high', 'low', 'close', 'volume']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
logger.warning(f"Missing required columns: {missing_columns}")
for col in missing_columns:
# Fill with close price for OHLC if missing
if col in ['open', 'high', 'low'] and 'close' in df.columns:
df[col] = df['close']
# Fill with zeros for volume if missing
elif col == 'volume':
df[col] = 0
# Add simple technical indicators
# 1. Simple Moving Averages
df['sma_5'] = df['close'].rolling(window=5).mean()
df['sma_10'] = df['close'].rolling(window=10).mean()
# 2. Relative Strength Index (RSI)
delta = df['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
df['rsi'] = 100 - (100 / (1 + rs))
# 3. Bollinger Bands
df['bb_middle'] = df['close'].rolling(window=20).mean()
df['bb_std'] = df['close'].rolling(window=20).std()
df['bb_upper'] = df['bb_middle'] + 2 * df['bb_std']
df['bb_lower'] = df['bb_middle'] - 2 * df['bb_std']
# 4. MACD
df['ema_12'] = df['close'].ewm(span=12, adjust=False).mean()
df['ema_26'] = df['close'].ewm(span=26, adjust=False).mean()
df['macd'] = df['ema_12'] - df['ema_26']
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
# 5. Price rate of change
df['roc'] = df['close'].pct_change(periods=10) * 100
# Fill any remaining NaN values with 0
df = df.fillna(0)
return df
def create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size=20):
"""Create a custom environment that handles multiple timeframes"""
# Ensure we have complete data for all timeframes
min_required_length = window_size + 100 # Add buffer for training
if (df_1m is None or len(df_1m) < min_required_length or
df_5m is None or len(df_5m) < min_required_length or
df_15m is None or len(df_15m) < min_required_length):
raise ValueError(f"Insufficient data for training. Need at least {min_required_length} candles per timeframe.")
# Ensure we only use the last N valid data points
df_1m = df_1m.iloc[-900:].copy() if len(df_1m) > 900 else df_1m.copy()
df_5m = df_5m.iloc[-180:].copy() if len(df_5m) > 180 else df_5m.copy()
df_15m = df_15m.iloc[-60:].copy() if len(df_15m) > 60 else df_15m.copy()
# Reset index to make sure we have continuous integers
df_1m = df_1m.reset_index(drop=True)
df_5m = df_5m.reset_index(drop=True)
df_15m = df_15m.reset_index(drop=True)
# For simplicity, we'll use the 1m data as the base environment
# The other timeframes will be incorporated through observation
env = TradingEnvironment(
data=df_1m,
initial_balance=100.0,
fee_rate=0.0005, # 0.05% fee (typical for crypto exchanges)
max_steps=len(df_1m) - window_size - 50, # Leave some room at the end
window_size=window_size,
risk_aversion=0.2, # Moderately risk-averse
price_scaling='zscore', # Use z-score normalization
reward_scaling=10.0, # Scale rewards for better learning
episode_penalty=0.2 # Penalty for holding positions at end of episode
)
return env
def initialize_agent(env, window_size=20, num_features=0, timeframes=None):
"""Initialize the DQN agent with appropriate parameters"""
if timeframes is None:
timeframes = ['1m', '5m', '15m']
# Calculate input dimensions
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
# If num_features wasn't provided, infer from environment
if num_features == 0:
# Calculate features per timeframe from state dimension and number of timeframes
# Accounting for the 3 additional features (position, equity, unrealized_pnl)
num_features = (state_dim - 3) // len(timeframes)
logger.info(f"Initializing DQN agent: state_dim={state_dim}, action_dim={action_dim}, features={num_features}")
agent = DQNAgent(
state_size=state_dim,
action_size=action_dim,
window_size=window_size,
num_features=num_features,
timeframes=timeframes,
learning_rate=0.0005, # Start with a moderate learning rate
gamma=0.97, # Slightly reduced discount factor for stable learning
epsilon=1.0, # Start with full exploration
epsilon_min=0.05, # Maintain some exploration even at the end
epsilon_decay=0.9975, # Slower decay for more exploration
memory_size=20000, # Larger replay buffer
batch_size=128, # Larger batch size for more stable gradients
target_update=5 # More frequent target network updates
)
return agent
def train_agent(env, agent, num_episodes=20, visualize=False, chart=None, save_path=None, save_freq=5):
"""
Train the DQN agent with improved training loop
Args:
env: The trading environment
agent: The DQN agent
num_episodes: Number of episodes to train
visualize: Whether to visualize trades during training
chart: The visualization chart (if visualize=True)
save_path: Path to save the model
save_freq: How often to save checkpoints (in episodes)
Returns:
tuple: (rewards, wins, losses, best_reward)
"""
logger.info(f"Starting training for {num_episodes} episodes")
# Initialize metrics tracking
rewards = []
win_rates = []
total_train_time = 0
best_reward = float('-inf')
best_model_path = None
# Create directory for checkpoints if needed
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
checkpoint_dir = os.path.join(os.path.dirname(save_path), 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# For tracking improvement
last_improved_episode = 0
patience = 10 # Episodes to wait for improvement before early stopping
for episode in range(num_episodes):
start_time = time.time()
# Reset environment and get initial state
state = env.reset()
done = False
episode_reward = 0
step = 0
# Action metrics for this episode
actions_taken = {0: 0, 1: 0, 2: 0} # Track BUY, SELL, HOLD actions
while not done:
# Select action
action = agent.act(state)
# Execute action
next_state, reward, done, info = env.step(action)
# Store experience in replay buffer
is_extrema = False # In a real implementation, detect extrema points
agent.remember(state, action, reward, next_state, done, is_extrema)
# Learn from experience
if len(agent.memory) >= agent.batch_size:
use_prioritized = episode > 1 # Start using prioritized replay after first episode
loss = agent.replay(use_prioritized=use_prioritized)
# Update state and metrics
state = next_state
episode_reward += reward
actions_taken[action] += 1
# Every 100 steps, log progress
if step % 100 == 0 or step < 10:
action_str = "BUY" if action == 0 else "SELL" if action == 1 else "HOLD"
current_price = info.get('current_price', 0)
pnl = info.get('pnl', 0)
balance = info.get('balance', 0)
logger.info(f"Episode {episode}, Step {step}: Action={action_str}, "
f"Reward={reward:.4f}, Balance=${balance:.2f}, PnL={pnl:.4f}")
# Add trade to visualization if enabled
if visualize and chart and action in [0, 1]: # BUY or SELL
chart.add_trade(
price=current_price,
timestamp=datetime.now(),
amount=0.1,
pnl=pnl,
action=action_str
)
step += 1
# Episode finished - calculate metrics
episode_time = time.time() - start_time
total_train_time += episode_time
# Get environment info
win_rate = env.winning_trades / max(1, env.total_trades)
trades = env.total_trades
balance = env.balance
gain = (balance - env.initial_balance) / env.initial_balance
max_drawdown = env.max_drawdown
# Record metrics
rewards.append(episode_reward)
win_rates.append(win_rate)
# Update agent's learning metrics
improved = agent.update_learning_metrics(episode_reward)
# If this is best performance, save the model
if episode_reward > best_reward:
best_reward = episode_reward
if save_path:
best_model_path = f"{save_path}_best"
agent.save(best_model_path)
logger.info(f"New best model saved to {best_model_path} (reward: {best_reward:.2f})")
last_improved_episode = episode
# Regular checkpoint saving
if save_path and episode % save_freq == 0:
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_episode_{episode}")
agent.save(checkpoint_path)
# Print episode summary
actions_summary = ", ".join([f"{k}:{v}" for k, v in actions_taken.items()])
logger.info(f"Episode {episode} completed in {episode_time:.2f}s")
logger.info(f" Total reward: {episode_reward:.4f}")
logger.info(f" Actions taken: {actions_summary}")
logger.info(f" Trades: {trades}, Win rate: {win_rate:.2%}")
logger.info(f" Balance: ${balance:.2f}, Gain: {gain:.2%}")
logger.info(f" Max Drawdown: {max_drawdown:.2%}")
# Early stopping check
if episode - last_improved_episode >= patience:
logger.info(f"No improvement for {patience} episodes. Early stopping.")
break
# Training complete
avg_time_per_episode = total_train_time / max(1, len(rewards))
logger.info(f"Training completed in {total_train_time:.2f}s ({avg_time_per_episode:.2f}s per episode)")
# Save final model
if save_path:
agent.save(f"{save_path}_final")
logger.info(f"Final model saved to {save_path}_final")
# Return training metrics
return rewards, win_rates, best_reward, best_model_path
def plot_training_results(rewards, win_rates, save_dir=None):
"""Plot training metrics and save the figure"""
plt.figure(figsize=(12, 8))
# Plot rewards
plt.subplot(2, 1, 1)
plt.plot(rewards, 'b-')
plt.title('Training Rewards per Episode')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
# Plot win rates
plt.subplot(2, 1, 2)
plt.plot(win_rates, 'g-')
plt.title('Win Rate per Episode')
plt.xlabel('Episode')
plt.ylabel('Win Rate')
plt.grid(True)
plt.tight_layout()
# Save figure if directory provided
if save_dir:
os.makedirs(save_dir, exist_ok=True)
plt.savefig(os.path.join(save_dir, f'training_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png'))
plt.close()
def evaluate_agent(env, agent, num_episodes=5, visualize=False, chart=None):
"""
Evaluate a trained agent on the environment
Args:
env: The trading environment
agent: The trained DQN agent
num_episodes: Number of evaluation episodes
visualize: Whether to visualize trades
chart: The visualization chart (if visualize=True)
Returns:
dict: Evaluation metrics
"""
logger.info(f"Evaluating agent over {num_episodes} episodes")
# Metrics to track
total_rewards = []
total_trades = []
win_rates = []
sharpe_ratios = []
sortino_ratios = []
max_drawdowns = []
final_balances = []
for episode in range(num_episodes):
# Reset environment
state = env.reset()
done = False
episode_reward = 0
# Run episode without exploration
while not done:
action = agent.act(state, explore=False) # No exploration during evaluation
next_state, reward, done, info = env.step(action)
episode_reward += reward
state = next_state
# Add trade to visualization if enabled
if visualize and chart and action in [0, 1]: # BUY or SELL
action_str = "BUY" if action == 0 else "SELL"
current_price = info.get('current_price', 0)
pnl = info.get('pnl', 0)
chart.add_trade(
price=current_price,
timestamp=datetime.now(),
amount=0.1,
pnl=pnl,
action=action_str
)
# Record metrics
total_rewards.append(episode_reward)
total_trades.append(env.total_trades)
win_rates.append(env.winning_trades / max(1, env.total_trades))
sharpe_ratios.append(info.get('sharpe_ratio', 0))
sortino_ratios.append(info.get('sortino_ratio', 0))
max_drawdowns.append(env.max_drawdown)
final_balances.append(env.balance)
logger.info(f"Evaluation episode {episode} - Reward: {episode_reward:.4f}, "
f"Balance: ${env.balance:.2f}, Win rate: {win_rates[-1]:.2%}")
# Calculate average metrics
avg_reward = np.mean(total_rewards)
avg_trades = np.mean(total_trades)
avg_win_rate = np.mean(win_rates)
avg_sharpe = np.mean(sharpe_ratios)
avg_sortino = np.mean(sortino_ratios)
avg_max_drawdown = np.mean(max_drawdowns)
avg_final_balance = np.mean(final_balances)
# Log evaluation summary
logger.info("Evaluation completed:")
logger.info(f" Average reward: {avg_reward:.4f}")
logger.info(f" Average trades per episode: {avg_trades:.2f}")
logger.info(f" Average win rate: {avg_win_rate:.2%}")
logger.info(f" Average Sharpe ratio: {avg_sharpe:.4f}")
logger.info(f" Average Sortino ratio: {avg_sortino:.4f}")
logger.info(f" Average max drawdown: {avg_max_drawdown:.2%}")
logger.info(f" Average final balance: ${avg_final_balance:.2f}")
# Return evaluation metrics
return {
'avg_reward': avg_reward,
'avg_trades': avg_trades,
'avg_win_rate': avg_win_rate,
'avg_sharpe': avg_sharpe,
'avg_sortino': avg_sortino,
'avg_max_drawdown': avg_max_drawdown,
'avg_final_balance': avg_final_balance
}
def main():
"""Main function to run the improved RL training"""
start_time = time.time()
logger.info(f"Starting improved RL training for {args.symbol}")
# Create environment
env, data_frames = create_training_environment(args.symbol)
# Initialize visualization if enabled
chart = None
if args.visualize:
logger.info("Initializing visualization chart")
chart = RealTimeChart(args.symbol)
time.sleep(2) # Give time for chart to initialize
# Initialize agent
agent = initialize_agent(env)
# Train agent
rewards, win_rates, best_reward, best_model_path = train_agent(
env=env,
agent=agent,
num_episodes=args.episodes,
visualize=args.visualize,
chart=chart,
save_path=args.save_path
)
# Plot training results
plot_dir = os.path.join(os.path.dirname(args.save_path), 'plots')
plot_training_results(rewards, win_rates, save_dir=plot_dir)
# Evaluate best model
logger.info("Evaluating best model")
# Load best model for evaluation
if best_model_path:
best_agent = initialize_agent(env)
best_agent.load(best_model_path)
# Evaluate the best model
eval_metrics = evaluate_agent(
env=env,
agent=best_agent,
visualize=args.visualize,
chart=chart
)
# Log evaluation results
logger.info("Best model evaluation complete:")
for metric, value in eval_metrics.items():
logger.info(f" {metric}: {value}")
# Total run time
total_time = time.time() - start_time
logger.info(f"Total run time: {total_time:.2f} seconds")
if __name__ == "__main__":
main()

View File

@ -1,476 +0,0 @@
#!/usr/bin/env python3
"""
Realtime RL Training with TensorBoard and Web UI Monitoring
This script runs RL training with:
- TensorBoard monitoring for training metrics
- Web UI for real-time trading visualization
- Real market data integration
- PnL tracking and performance analysis
"""
import asyncio
import threading
import time
import logging
import argparse
from datetime import datetime
import os
import sys
from pathlib import Path
# Add project path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging, get_config
from core.data_provider import DataProvider
from training.rl_trainer import RLTrainer
from torch.utils.tensorboard import SummaryWriter
logger = logging.getLogger(__name__)
class RealtimeRLTrainer:
"""Realtime RL Trainer with TensorBoard and Web UI"""
def __init__(self, symbol="ETH/USDT", initial_balance=1000.0):
self.symbol = symbol
self.initial_balance = initial_balance
# Initialize data provider
self.data_provider = DataProvider(
symbols=[symbol],
timeframes=['1s', '1m', '5m', '15m', '1h']
)
# Initialize RL trainer with TensorBoard
self.rl_trainer = RLTrainer(self.data_provider)
# Training state
self.current_episode = 0
self.session_trades = []
self.session_balance = initial_balance
self.session_pnl = 0.0
self.training_active = False
# Web dashboard
self.dashboard = None
self.dashboard_thread = None
logger.info(f"RealtimeRLTrainer initialized for {symbol}")
logger.info(f"TensorBoard logs: {self.rl_trainer.tensorboard_dir}")
def setup_web_dashboard(self, port=8051):
"""Setup web dashboard for monitoring"""
try:
import dash
from dash import dcc, html, Input, Output
import plotly.graph_objects as go
import plotly.express as px
# Create Dash app
app = dash.Dash(__name__)
# Layout
app.layout = html.Div([
html.H1(f"RL Training Monitor - {self.symbol}",
style={'textAlign': 'center', 'color': '#2c3e50'}),
# Refresh interval
dcc.Interval(
id='interval-component',
interval=2000, # Update every 2 seconds
n_intervals=0
),
# Status row
html.Div([
html.Div([
html.H3("Training Status", style={'color': '#34495e'}),
html.P(id='training-status', style={'fontSize': 18})
], className='three columns'),
html.Div([
html.H3("Current Episode", style={'color': '#34495e'}),
html.P(id='current-episode', style={'fontSize': 18})
], className='three columns'),
html.Div([
html.H3("Session Balance", style={'color': '#27ae60'}),
html.P(id='session-balance', style={'fontSize': 18})
], className='three columns'),
html.Div([
html.H3("Session PnL", style={'color': '#e74c3c'}),
html.P(id='session-pnl', style={'fontSize': 18})
], className='three columns'),
], className='row', style={'margin': '20px'}),
# Charts row
html.Div([
html.Div([
dcc.Graph(id='rewards-chart')
], className='six columns'),
html.Div([
dcc.Graph(id='balance-chart')
], className='six columns'),
], className='row'),
html.Div([
html.Div([
dcc.Graph(id='trades-chart')
], className='six columns'),
html.Div([
dcc.Graph(id='win-rate-chart')
], className='six columns'),
], className='row'),
# TensorBoard link
html.Div([
html.H3("TensorBoard Monitoring"),
html.A("Open TensorBoard",
href="http://localhost:6006",
target="_blank",
style={'fontSize': 16, 'color': '#3498db'})
], style={'textAlign': 'center', 'margin': '20px'})
])
# Callbacks
@app.callback(
[Output('training-status', 'children'),
Output('current-episode', 'children'),
Output('session-balance', 'children'),
Output('session-pnl', 'children'),
Output('rewards-chart', 'figure'),
Output('balance-chart', 'figure'),
Output('trades-chart', 'figure'),
Output('win-rate-chart', 'figure')],
[Input('interval-component', 'n_intervals')]
)
def update_dashboard(n):
# Status updates
status = "TRAINING" if self.training_active else "IDLE"
episode = f"{self.current_episode}"
balance = f"${self.session_balance:.2f}"
pnl = f"${self.session_pnl:.2f}"
# Create charts
rewards_fig = self._create_rewards_chart()
balance_fig = self._create_balance_chart()
trades_fig = self._create_trades_chart()
win_rate_fig = self._create_win_rate_chart()
return status, episode, balance, pnl, rewards_fig, balance_fig, trades_fig, win_rate_fig
self.dashboard = app
logger.info(f"Web dashboard created for port {port}")
except Exception as e:
logger.error(f"Error setting up web dashboard: {e}")
self.dashboard = None
def _create_rewards_chart(self):
"""Create rewards chart"""
import plotly.graph_objects as go
if not self.rl_trainer.episode_rewards:
fig = go.Figure()
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
else:
fig = go.Figure()
fig.add_trace(go.Scatter(
y=self.rl_trainer.episode_rewards,
mode='lines',
name='Episode Rewards',
line=dict(color='#3498db')
))
# Add moving average if enough data
if len(self.rl_trainer.avg_rewards) > 0:
fig.add_trace(go.Scatter(
y=self.rl_trainer.avg_rewards,
mode='lines',
name='Moving Average',
line=dict(color='#e74c3c', width=2)
))
fig.update_layout(title="Episode Rewards", xaxis_title="Episode", yaxis_title="Reward")
return fig
def _create_balance_chart(self):
"""Create balance chart"""
import plotly.graph_objects as go
if not self.rl_trainer.episode_balances:
fig = go.Figure()
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
else:
fig = go.Figure()
fig.add_trace(go.Scatter(
y=self.rl_trainer.episode_balances,
mode='lines',
name='Balance',
line=dict(color='#27ae60')
))
# Add initial balance line
fig.add_hline(y=self.initial_balance, line_dash="dash",
annotation_text="Initial Balance")
fig.update_layout(title="Portfolio Balance", xaxis_title="Episode", yaxis_title="Balance ($)")
return fig
def _create_trades_chart(self):
"""Create trades per episode chart"""
import plotly.graph_objects as go
if not self.rl_trainer.episode_trades:
fig = go.Figure()
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
else:
fig = go.Figure()
fig.add_trace(go.Bar(
y=self.rl_trainer.episode_trades,
name='Trades per Episode',
marker_color='#f39c12'
))
fig.update_layout(title="Trades per Episode", xaxis_title="Episode", yaxis_title="Number of Trades")
return fig
def _create_win_rate_chart(self):
"""Create win rate chart"""
import plotly.graph_objects as go
if not self.rl_trainer.win_rates:
fig = go.Figure()
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
else:
fig = go.Figure()
fig.add_trace(go.Scatter(
y=self.rl_trainer.win_rates,
mode='lines+markers',
name='Win Rate',
line=dict(color='#9b59b6')
))
# Add 50% line
fig.add_hline(y=0.5, line_dash="dash",
annotation_text="Break Even")
fig.update_layout(title="Win Rate", xaxis_title="Evaluation", yaxis_title="Win Rate")
return fig
def start_web_dashboard(self, port=8051):
"""Start web dashboard in background thread"""
if self.dashboard is None:
self.setup_web_dashboard(port)
if self.dashboard is not None:
def run_dashboard():
try:
# Use run instead of run_server for newer Dash versions
self.dashboard.run(port=port, debug=False, use_reloader=False)
except Exception as e:
logger.error(f"Error running dashboard: {e}")
self.dashboard_thread = threading.Thread(target=run_dashboard, daemon=True)
self.dashboard_thread.start()
logger.info(f"Web dashboard started on http://localhost:{port}")
else:
logger.warning("Dashboard not available")
async def train_realtime(self, episodes=100, evaluation_interval=10):
"""Run realtime training with monitoring"""
logger.info(f"Starting realtime RL training for {episodes} episodes")
logger.info(f"TensorBoard: http://localhost:6006")
logger.info(f"Web UI: http://localhost:8051")
self.training_active = True
# Setup environment and agent
environment, agent = self.rl_trainer.setup_environment_and_agent()
# Assign to trainer instance
self.rl_trainer.environment = environment
self.rl_trainer.agent = agent
# Training loop
for episode in range(episodes):
self.current_episode = episode
# Run episode
episode_start = time.time()
results = self.rl_trainer.run_episode(episode, training=True)
episode_time = time.time() - episode_start
# Update session tracking
self.session_balance = results.get('balance', self.initial_balance)
self.session_pnl = self.session_balance - self.initial_balance
# Log episode metrics to TensorBoard
self.rl_trainer.log_episode_metrics(episode, {
'total_reward': results['reward'],
'final_balance': results['balance'],
'total_return': results['pnl_percentage'],
'steps': results['steps'],
'total_trades': results['trades'],
'win_rate': 1.0 if results['pnl'] > 0 else 0.0,
'epsilon': agent.epsilon,
'memory_size': len(agent.memory) if hasattr(agent, 'memory') else 0
})
# Log progress
if episode % 10 == 0:
logger.info(
f"Episode {episode}/{episodes} - "
f"Reward: {results['reward']:.4f}, "
f"Balance: ${results['balance']:.2f}, "
f"PnL: {results['pnl_percentage']:.2f}%, "
f"Trades: {results['trades']}, "
f"Time: {episode_time:.2f}s"
)
# Evaluation
if episode % evaluation_interval == 0 and episode > 0:
eval_results = self.rl_trainer.evaluate_agent(num_episodes=3)
logger.info(
f"Evaluation - Avg Reward: {eval_results['avg_reward']:.4f}, "
f"Win Rate: {eval_results['win_rate']:.2%}"
)
# Small delay to allow UI updates
await asyncio.sleep(0.1)
self.training_active = False
logger.info("Training completed!")
# Save final model
save_path = f"models/rl/realtime_agent_{int(time.time())}.pt"
agent.save(save_path)
logger.info(f"Model saved: {save_path}")
return {
'episodes': episodes,
'final_balance': self.session_balance,
'final_pnl': self.session_pnl,
'model_path': save_path
}
async def main():
"""Main function"""
parser = argparse.ArgumentParser(description='Realtime RL Training with Monitoring')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol')
parser.add_argument('--episodes', type=int, default=50, help='Number of episodes')
parser.add_argument('--balance', type=float, default=1000.0, help='Initial balance')
parser.add_argument('--web-port', type=int, default=8051, help='Web dashboard port')
parser.add_argument('--keep-alive', type=int, default=300, help='Keep monitoring alive for N seconds after training')
args = parser.parse_args()
# Setup logging
setup_logging()
logger.info("=" * 60)
logger.info("REALTIME RL TRAINING WITH MONITORING")
logger.info(f"Symbol: {args.symbol}")
logger.info(f"Episodes: {args.episodes}")
logger.info(f"Initial Balance: ${args.balance:.2f}")
logger.info("=" * 60)
# Check if TensorBoard is accessible
try:
import requests
import time
import json
# Try to read port configuration
tensorboard_port = 6006 # default
try:
with open("monitoring_ports.json", "r") as f:
config = json.load(f)
tensorboard_port = config.get("tensorboard_port", 6006)
logger.info(f"📋 Using TensorBoard port {tensorboard_port} from config")
except FileNotFoundError:
logger.info("📋 No port config file found, using default ports")
logger.info("Checking TensorBoard accessibility...")
# Wait for TensorBoard to start
for i in range(10):
try:
response = requests.get(f"http://localhost:{tensorboard_port}", timeout=2)
logger.info(f"✅ TensorBoard is accessible at http://localhost:{tensorboard_port}")
break
except requests.exceptions.RequestException:
if i == 0:
logger.info("⏳ Waiting for TensorBoard to start...")
await asyncio.sleep(2)
else:
logger.warning(f"⚠️ TensorBoard may not be running on port {tensorboard_port}")
logger.warning(" Run: python start_monitoring.py")
except ImportError:
tensorboard_port = 6006
logger.warning("requests module not available for TensorBoard check")
try:
# Create trainer
trainer = RealtimeRLTrainer(
symbol=args.symbol,
initial_balance=args.balance
)
# Start web dashboard
logger.info("🚀 Starting web dashboard...")
trainer.start_web_dashboard(port=args.web_port)
# Wait for dashboard to start
await asyncio.sleep(3)
# Check if web dashboard is accessible
try:
import requests
response = requests.get(f"http://localhost:{args.web_port}", timeout=5)
logger.info(f"✅ Web Dashboard is accessible at http://localhost:{args.web_port}")
except:
logger.warning(f"⚠️ Web Dashboard may not be fully ready at http://localhost:{args.web_port}")
logger.info("MONITORING READY!")
logger.info(f"📊 TensorBoard: http://localhost:{tensorboard_port}")
logger.info(f"🌐 Web Dashboard: http://localhost:{args.web_port}")
logger.info("=" * 60)
# Run training
results = await trainer.train_realtime(
episodes=args.episodes,
evaluation_interval=10
)
logger.info("Training Results:")
logger.info(f" Final Balance: ${results['final_balance']:.2f}")
logger.info(f" Final PnL: ${results['final_pnl']:.2f}")
logger.info(f" Model Saved: {results['model_path']}")
# Keep monitoring alive for specified time
logger.info(f"🔄 Keeping monitoring alive for {args.keep_alive} seconds...")
logger.info(f"📊 TensorBoard: http://localhost:6006")
logger.info(f"🌐 Web Dashboard: http://localhost:{args.web_port}")
logger.info("Press Ctrl+C to exit monitoring.")
for remaining in range(args.keep_alive, 0, -10):
logger.info(f"⏰ Monitoring active - {remaining} seconds remaining")
await asyncio.sleep(10)
logger.info("✅ Monitoring session completed.")
except KeyboardInterrupt:
logger.info("Training stopped by user")
except Exception as e:
logger.error(f"Error in training: {e}")
import traceback
logger.error(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(main())

File diff suppressed because it is too large Load Diff

View File

@ -1,704 +0,0 @@
#!/usr/bin/env python
"""
Real-time training with tick data and multiple timeframes for context
This script uses streaming tick data for fast adaptation while maintaining higher timeframe context
"""
import os
import sys
import logging
import numpy as np
import torch
import time
import json
from datetime import datetime, timedelta
import signal
import threading
import asyncio
import websockets
from collections import deque
import pandas as pd
from typing import Dict, List, Optional
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import io
# Add the project root to path
sys.path.append(os.path.abspath('.'))
# Configure logging with timestamp in filename
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"realtime_ticks_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
logger = logging.getLogger('realtime_ticks_training')
# Import the model and data interfaces
from NN.models.cnn_model_pytorch import CNNModelPyTorch
from NN.utils.data_interface import DataInterface
from NN.utils.signal_interpreter import SignalInterpreter
# Global variables for graceful shutdown
running = True
training_stats = {
"epochs_completed": 0,
"best_val_pnl": -float('inf'),
"best_epoch": 0,
"best_win_rate": 0,
"training_started": datetime.now().isoformat(),
"last_update": datetime.now().isoformat(),
"epochs": [],
"cumulative_pnl": {
"train": 0.0,
"val": 0.0
},
"total_trades": {
"train": 0,
"val": 0
},
"total_wins": {
"train": 0,
"val": 0
}
}
class TickDataProcessor:
"""Process and store real-time tick data"""
def __init__(self, symbol: str, max_ticks: int = 10000):
self.symbol = symbol
self.ticks = deque(maxlen=max_ticks)
self.candle_cache = {
'1s': deque(maxlen=5000),
'1m': deque(maxlen=5000),
'5m': deque(maxlen=5000),
'15m': deque(maxlen=5000)
}
self.last_tick = None
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
self.ws = None
self.running = False
self.data_queue = asyncio.Queue()
async def start_websocket(self):
"""Start WebSocket connection and receive tick data"""
while self.running:
try:
async with websockets.connect(self.ws_url) as ws:
self.ws = ws
logger.info("WebSocket connected")
while self.running:
try:
message = await ws.recv()
data = json.loads(message)
if 'e' in data and data['e'] == 'trade':
tick = {
'timestamp': data['T'],
'price': float(data['p']),
'volume': float(data['q']),
'symbol': self.symbol
}
self.ticks.append(tick)
await self.data_queue.put(tick)
except websockets.exceptions.ConnectionClosed:
logger.warning("WebSocket connection closed")
break
except Exception as e:
logger.error(f"Error receiving WebSocket message: {str(e)}")
await asyncio.sleep(1)
except Exception as e:
logger.error(f"WebSocket connection error: {str(e)}")
await asyncio.sleep(5)
def process_tick(self, tick: Dict):
"""Process a single tick into candles for all timeframes"""
timestamp = tick['timestamp']
price = tick['price']
volume = tick['volume']
for timeframe in self.candle_cache.keys():
interval = self._get_interval_seconds(timeframe)
if interval is None:
continue
# Round timestamp to nearest candle interval
candle_ts = int(timestamp // (interval * 1000)) * (interval * 1000)
# Get or create candle for this timeframe
if not self.candle_cache[timeframe]:
# First candle for this timeframe
candle = {
'timestamp': candle_ts,
'open': price,
'high': price,
'low': price,
'close': price,
'volume': volume
}
self.candle_cache[timeframe].append(candle)
else:
# Update existing candle
last_candle = self.candle_cache[timeframe][-1]
if last_candle['timestamp'] == candle_ts:
# Update current candle
last_candle['high'] = max(last_candle['high'], price)
last_candle['low'] = min(last_candle['low'], price)
last_candle['close'] = price
last_candle['volume'] += volume
else:
# New candle
candle = {
'timestamp': candle_ts,
'open': price,
'high': price,
'low': price,
'close': price,
'volume': volume
}
self.candle_cache[timeframe].append(candle)
def _get_interval_seconds(self, timeframe: str) -> Optional[int]:
"""Convert timeframe string to seconds"""
try:
value = int(timeframe[:-1])
unit = timeframe[-1]
if unit == 's':
return value
elif unit == 'm':
return value * 60
elif unit == 'h':
return value * 3600
elif unit == 'd':
return value * 86400
return None
except:
return None
def get_candles(self, timeframe: str) -> pd.DataFrame:
"""Get candles for a specific timeframe"""
if timeframe not in self.candle_cache:
return pd.DataFrame()
candles = list(self.candle_cache[timeframe])
if not candles:
return pd.DataFrame()
df = pd.DataFrame(candles)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
return df
def signal_handler(sig, frame):
"""Handle CTRL+C to gracefully exit training"""
global running
logger.info("Received interrupt signal. Finishing current epoch and saving model...")
running = False
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
def save_training_stats(stats, filepath="NN/models/saved/realtime_ticks_training_stats.json"):
"""Save training statistics to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, 'w') as f:
json.dump(stats, f, indent=2)
logger.info(f"Training statistics saved to {filepath}")
def calculate_pnl_with_fees(predictions, prices, position_size=1.0, fee_rate=0.0002, initial_balance=100.0):
"""
Calculate PnL including trading fees and track USD balance
fee_rate: 0.02% per trade (both entry and exit)
initial_balance: Starting balance in USD (default: 100.0)
"""
trades = []
pnl = 0
win_count = 0
total_trades = 0
current_balance = initial_balance
balance_history = [initial_balance]
for i in range(len(predictions)):
if predictions[i] == 2: # BUY
entry_price = prices[i]
# Look ahead for exit
for j in range(i + 1, min(i + 8, len(prices))):
if predictions[j] == 0: # SELL
exit_price = prices[j]
# Calculate position size in USD
position_usd = current_balance * position_size
# Calculate raw PnL in USD
raw_pnl_usd = position_usd * ((exit_price - entry_price) / entry_price)
# Calculate fees in USD (both entry and exit)
entry_fee_usd = position_usd * fee_rate
exit_fee_usd = position_usd * fee_rate
total_fees_usd = entry_fee_usd + exit_fee_usd
# Calculate net PnL in USD after fees
net_pnl_usd = raw_pnl_usd - total_fees_usd
# Update balance
current_balance += net_pnl_usd
balance_history.append(current_balance)
trades.append({
'entry_idx': i,
'exit_idx': j,
'entry_price': entry_price,
'exit_price': exit_price,
'position_size_usd': position_usd,
'raw_pnl_usd': raw_pnl_usd,
'fees_usd': total_fees_usd,
'net_pnl_usd': net_pnl_usd,
'balance': current_balance
})
pnl += net_pnl_usd / initial_balance # Convert to percentage
if net_pnl_usd > 0:
win_count += 1
total_trades += 1
break
win_rate = win_count / total_trades if total_trades > 0 else 0
final_balance = current_balance
total_return = (final_balance - initial_balance) / initial_balance * 100 # Percentage return
return pnl, win_rate, trades, balance_history, total_return
def calculate_max_drawdown(balance_history):
"""Calculate maximum drawdown from balance history"""
if not balance_history:
return 0.0
peak = balance_history[0]
max_drawdown = 0.0
for balance in balance_history:
if balance > peak:
peak = balance
drawdown = (peak - balance) / peak * 100
max_drawdown = max(max_drawdown, drawdown)
return max_drawdown
async def run_realtime_training():
"""
Run continuous training with real-time tick data and multiple timeframes
"""
global running, training_stats
# Configuration parameters
symbol = "BTC/USDT"
timeframes = ["1s", "1m", "5m", "15m"] # Include 1s for tick-based training
window_size = 24 # Larger window size for capturing more patterns
output_size = 3 # BUY/HOLD/SELL
batch_size = 64 # Batch size for training
# Real-time configuration
data_refresh_interval = 60 # Refresh data every minute
checkpoint_interval = 3600 # Save checkpoint every hour
max_training_time = float('inf') # Run indefinitely
# Initialize TensorBoard writer
tensorboard_dir = "runs/realtime_ticks_training"
os.makedirs(tensorboard_dir, exist_ok=True)
writer = SummaryWriter(tensorboard_dir)
# Initialize training start time
start_time = time.time()
last_checkpoint_time = start_time
last_data_refresh_time = start_time
logger.info(f"Starting continuous real-time training with tick data for {symbol}")
logger.info(f"Configuration: timeframes={timeframes}, window_size={window_size}, batch_size={batch_size}")
logger.info(f"Data will refresh every {data_refresh_interval} seconds")
logger.info(f"Checkpoints will be saved every {checkpoint_interval} seconds")
logger.info(f"TensorBoard logs will be saved to {tensorboard_dir}")
try:
# Initialize tick data processor
tick_processor = TickDataProcessor(symbol)
tick_processor.running = True
# Start WebSocket connection in background
websocket_task = asyncio.create_task(tick_processor.start_websocket())
# Initialize data interface
logger.info("Initializing data interface...")
data_interface = DataInterface(
symbol=symbol,
timeframes=timeframes
)
# Initialize model
num_features = data_interface.get_feature_count()
logger.info(f"Initializing model with {num_features} features")
model = CNNModelPyTorch(
window_size=window_size,
timeframes=timeframes,
output_size=output_size,
num_pairs=1 # Single trading pair
)
# Try to load existing model
model_path = "NN/models/saved/optimized_short_term_model_best.pt"
try:
if os.path.exists(model_path):
logger.info(f"Loading existing model from {model_path}")
model.load(model_path)
logger.info("Model loaded successfully")
else:
logger.info("No existing model found. Starting with a new model.")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
logger.info("Starting with a new model.")
# Initialize signal interpreter
signal_interpreter = SignalInterpreter(config={
'buy_threshold': 0.55, # Lower threshold to catch more opportunities
'sell_threshold': 0.55, # Lower threshold to catch more opportunities
'hold_threshold': 0.65, # Lower threshold to reduce missed trades
'trend_filter_enabled': True,
'volume_filter_enabled': True,
'min_confidence': 0.45 # Minimum confidence to consider a trade
})
# Create checkpoint directory
checkpoint_dir = "NN/models/saved/realtime_ticks_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
# Track metrics
epoch = 0
best_val_pnl = -float('inf')
best_win_rate = 0
best_epoch = 0
consecutive_failures = 0
max_consecutive_failures = 5
# Training loop
while running:
try:
epoch += 1
epoch_start = time.time()
logger.info(f"Epoch {epoch} - Starting at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# Process any new ticks
while not tick_processor.data_queue.empty():
tick = await tick_processor.data_queue.get()
tick_processor.process_tick(tick)
# Check if we need to refresh data
if time.time() - last_data_refresh_time > data_refresh_interval:
logger.info("Refreshing training data...")
last_data_refresh_time = time.time()
# Get candles for all timeframes
candles_data = {}
for timeframe in timeframes:
df = tick_processor.get_candles(timeframe)
if not df.empty:
candles_data[timeframe] = df
# Prepare training data with multiple timeframes
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
refresh=True,
refresh_interval=data_refresh_interval
)
if X_train is None or y_train is None:
logger.warning("Failed to prepare training data. Using previous data.")
consecutive_failures += 1
if consecutive_failures >= max_consecutive_failures:
logger.error("Too many consecutive failures. Stopping training.")
break
await asyncio.sleep(5) # Wait before retrying
continue
consecutive_failures = 0 # Reset failure counter on success
logger.info(f"Training data prepared - X shape: {X_train.shape}, y shape: {y_train.shape}")
# Calculate future prices for profitability-focused loss function
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8)
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
# Train one epoch
train_action_loss, train_price_loss, train_acc = model.train_epoch(
X_train, y_train, train_future_prices, batch_size
)
# Evaluate
val_action_loss, val_price_loss, val_acc = model.evaluate(
X_val, y_val, val_future_prices
)
logger.info(f"Epoch {epoch} results:")
logger.info(f" Train - Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}")
logger.info(f" Valid - Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}")
# Get predictions for PnL calculation
train_action_probs, train_price_preds = model.predict(X_train)
val_action_probs, val_price_preds = model.predict(X_val)
# Convert probabilities to actions
train_preds = np.argmax(train_action_probs, axis=1)
val_preds = np.argmax(val_action_probs, axis=1)
# Track signal distribution
train_buy_count = np.sum(train_preds == 2)
train_sell_count = np.sum(train_preds == 0)
train_hold_count = np.sum(train_preds == 1)
val_buy_count = np.sum(val_preds == 2)
val_sell_count = np.sum(val_preds == 0)
val_hold_count = np.sum(val_preds == 1)
signal_dist = {
"train": {
"BUY": float(train_buy_count / len(train_preds)) if len(train_preds) > 0 else 0,
"SELL": float(train_sell_count / len(train_preds)) if len(train_preds) > 0 else 0,
"HOLD": float(train_hold_count / len(train_preds)) if len(train_preds) > 0 else 0
},
"val": {
"BUY": float(val_buy_count / len(val_preds)) if len(val_preds) > 0 else 0,
"SELL": float(val_sell_count / len(val_preds)) if len(val_preds) > 0 else 0,
"HOLD": float(val_hold_count / len(val_preds)) if len(val_preds) > 0 else 0
}
}
# Calculate PnL and win rates with different position sizes
position_sizes = [0.1, 0.25, 0.5, 1.0, 2.0]
best_position_train_pnl = -float('inf')
best_position_val_pnl = -float('inf')
best_position_train_wr = 0
best_position_val_wr = 0
best_position_size = 1.0
for position_size in position_sizes:
train_pnl, train_win_rate, train_trades, train_balance_history, train_total_return = calculate_pnl_with_fees(
train_preds, train_prices, position_size=position_size
)
val_pnl, val_win_rate, val_trades, val_balance_history, val_total_return = calculate_pnl_with_fees(
val_preds, val_prices, position_size=position_size
)
# Update cumulative PnL and trade statistics
training_stats["cumulative_pnl"]["train"] += train_pnl
training_stats["cumulative_pnl"]["val"] += val_pnl
training_stats["total_trades"]["train"] += len(train_trades)
training_stats["total_trades"]["val"] += len(val_trades)
training_stats["total_wins"]["train"] += sum(1 for t in train_trades if t['net_pnl_usd'] > 0)
training_stats["total_wins"]["val"] += sum(1 for t in val_trades if t['net_pnl_usd'] > 0)
# Calculate average fees per trade
avg_train_fees = np.mean([t['fees_usd'] for t in train_trades]) if train_trades else 0
avg_val_fees = np.mean([t['fees_usd'] for t in val_trades]) if val_trades else 0
# Calculate max drawdown
train_drawdown = calculate_max_drawdown(train_balance_history)
val_drawdown = calculate_max_drawdown(val_balance_history)
# Calculate overall win rate
overall_train_wr = training_stats["total_wins"]["train"] / training_stats["total_trades"]["train"] if training_stats["total_trades"]["train"] > 0 else 0
overall_val_wr = training_stats["total_wins"]["val"] / training_stats["total_trades"]["val"] if training_stats["total_trades"]["val"] > 0 else 0
logger.info(f" Position Size: {position_size}")
logger.info(f" Train - PnL: {train_pnl:.4f}, Win Rate: {train_win_rate:.4f}, Trades: {len(train_trades)}")
logger.info(f" Train - Total Return: {train_total_return:.2f}%, Max Drawdown: {train_drawdown:.2f}%")
logger.info(f" Train - Avg Fees: ${avg_train_fees:.2f} per trade")
logger.info(f" Train - Cumulative PnL: {training_stats['cumulative_pnl']['train']:.4f}, Overall WR: {overall_train_wr:.4f}")
logger.info(f" Valid - PnL: {val_pnl:.4f}, Win Rate: {val_win_rate:.4f}, Trades: {len(val_trades)}")
logger.info(f" Valid - Total Return: {val_total_return:.2f}%, Max Drawdown: {val_drawdown:.2f}%")
logger.info(f" Valid - Avg Fees: ${avg_val_fees:.2f} per trade")
logger.info(f" Valid - Cumulative PnL: {training_stats['cumulative_pnl']['val']:.4f}, Overall WR: {overall_val_wr:.4f}")
# Log to TensorBoard
writer.add_scalar(f'Balance/train/position_{position_size}', train_balance_history[-1], epoch)
writer.add_scalar(f'Balance/validation/position_{position_size}', val_balance_history[-1], epoch)
writer.add_scalar(f'Return/train/position_{position_size}', train_total_return, epoch)
writer.add_scalar(f'Return/validation/position_{position_size}', val_total_return, epoch)
writer.add_scalar(f'Drawdown/train/position_{position_size}', train_drawdown, epoch)
writer.add_scalar(f'Drawdown/validation/position_{position_size}', val_drawdown, epoch)
writer.add_scalar(f'CumulativePnL/train/position_{position_size}', training_stats["cumulative_pnl"]["train"], epoch)
writer.add_scalar(f'CumulativePnL/validation/position_{position_size}', training_stats["cumulative_pnl"]["val"], epoch)
writer.add_scalar(f'OverallWinRate/train/position_{position_size}', overall_train_wr, epoch)
writer.add_scalar(f'OverallWinRate/validation/position_{position_size}', overall_val_wr, epoch)
# Track best position size for this epoch
if val_pnl > best_position_val_pnl:
best_position_val_pnl = val_pnl
best_position_val_wr = val_win_rate
best_position_size = position_size
if train_pnl > best_position_train_pnl:
best_position_train_pnl = train_pnl
best_position_train_wr = train_win_rate
# Track best model overall (using position size 1.0 as reference)
if val_pnl > best_val_pnl and position_size == 1.0:
best_val_pnl = val_pnl
best_win_rate = val_win_rate
best_epoch = epoch
logger.info(f" New best validation PnL: {best_val_pnl:.4f} at epoch {best_epoch}")
# Save the best model
model.save(f"NN/models/saved/optimized_short_term_model_ticks_best")
# Store epoch metrics with cumulative statistics
epoch_metrics = {
"epoch": epoch,
"train_loss": float(train_action_loss),
"val_loss": float(val_action_loss),
"train_acc": float(train_acc),
"val_acc": float(val_acc),
"train_pnl": float(best_position_train_pnl),
"val_pnl": float(best_position_val_pnl),
"train_win_rate": float(best_position_train_wr),
"val_win_rate": float(best_position_val_wr),
"best_position_size": float(best_position_size),
"signal_distribution": signal_dist,
"timestamp": datetime.now().isoformat(),
"data_age": int(time.time() - last_data_refresh_time),
"cumulative_pnl": {
"train": float(training_stats["cumulative_pnl"]["train"]),
"val": float(training_stats["cumulative_pnl"]["val"])
},
"total_trades": {
"train": int(training_stats["total_trades"]["train"]),
"val": int(training_stats["total_trades"]["val"])
},
"overall_win_rate": {
"train": float(overall_train_wr),
"val": float(overall_val_wr)
}
}
# Update training stats
training_stats["epochs_completed"] = epoch
training_stats["best_val_pnl"] = float(best_val_pnl)
training_stats["best_epoch"] = best_epoch
training_stats["best_win_rate"] = float(best_win_rate)
training_stats["last_update"] = datetime.now().isoformat()
training_stats["epochs"].append(epoch_metrics)
# Check if we need to save checkpoint
if time.time() - last_checkpoint_time > checkpoint_interval:
logger.info(f"Saving checkpoint at epoch {epoch}")
# Save model checkpoint
model.save(f"{checkpoint_dir}/checkpoint_epoch_{epoch}")
# Save training statistics
save_training_stats(training_stats)
last_checkpoint_time = time.time()
# Test trade signal generation with a random sample
random_idx = np.random.randint(0, len(X_val))
sample_X = X_val[random_idx:random_idx+1]
sample_probs, sample_price_pred = model.predict(sample_X)
# Process with signal interpreter
signal = signal_interpreter.interpret_signal(
sample_probs[0],
float(sample_price_pred[0][0]) if hasattr(sample_price_pred, "__getitem__") else float(sample_price_pred[0]),
market_data={'price': float(val_prices[random_idx]) if random_idx < len(val_prices) else 50000.0}
)
logger.info(f" Sample trade signal: {signal['action']} with confidence {signal['confidence']:.4f}")
# Log trading statistics
logger.info(f" Train - Actions: BUY={train_buy_count}, SELL={train_sell_count}, HOLD={train_hold_count}")
logger.info(f" Valid - Actions: BUY={val_buy_count}, SELL={val_sell_count}, HOLD={val_hold_count}")
# Log epoch timing
epoch_time = time.time() - epoch_start
total_elapsed = time.time() - start_time
logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
logger.info(f" Total training time: {total_elapsed/3600:.2f} hours")
# Small delay to prevent CPU overload
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error during epoch {epoch}: {str(e)}")
import traceback
logger.error(traceback.format_exc())
consecutive_failures += 1
if consecutive_failures >= max_consecutive_failures:
logger.error("Too many consecutive failures. Stopping training.")
break
await asyncio.sleep(5) # Wait before retrying
continue
# Cleanup
tick_processor.running = False
websocket_task.cancel()
# Save final model and performance metrics
logger.info("Saving final optimized model...")
model.save("NN/models/saved/optimized_short_term_model_ticks_final")
# Save performance metrics to file
save_training_stats(training_stats)
# Generate performance plots
try:
model.plot_training_history("NN/models/saved/realtime_ticks_training_stats.json")
logger.info("Performance plots generated successfully")
except Exception as e:
logger.error(f"Error generating plots: {str(e)}")
# Calculate total training time
total_time = time.time() - start_time
hours, remainder = divmod(total_time, 3600)
minutes, seconds = divmod(remainder, 60)
logger.info(f"Continuous training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")
logger.info(f"Best model performance - Epoch: {best_epoch}, PnL: {best_val_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
# Close TensorBoard writer
writer.close()
except Exception as e:
logger.error(f"Error during real-time training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Try to save the model and stats in case of error
try:
if 'model' in locals():
model.save("NN/models/saved/optimized_short_term_model_ticks_emergency")
logger.info("Emergency model save completed")
if 'training_stats' in locals():
save_training_stats(training_stats, "NN/models/saved/realtime_ticks_training_stats_emergency.json")
if 'writer' in locals():
writer.close()
except Exception as e2:
logger.error(f"Failed to save emergency checkpoint: {str(e2)}")
if __name__ == "__main__":
# Print startup banner
print("=" * 80)
print("CONTINUOUS REALTIME TICKS TRAINING SESSION")
print("This script will continuously train the model using real-time tick data")
print("Press Ctrl+C to safely stop training and save the model")
print("TensorBoard logs will be saved to runs/realtime_ticks_training")
print("To view TensorBoard, run: tensorboard --logdir=runs/realtime_ticks_training")
print("=" * 80)
# Run the async training loop
asyncio.run(run_realtime_training())

View File

@ -8,26 +8,7 @@ Comprehensive training pipeline for scalping RL agents:
- Memory-efficient training loops - Memory-efficient training loops
""" """
import torch import torchimport numpy as npimport pandas as pdimport loggingfrom typing import Dict, List, Tuple, Optional, Anyimport timefrom pathlib import Pathimport matplotlib.pyplot as pltfrom collections import dequeimport randomfrom torch.utils.tensorboard import SummaryWriter# Add project importsimport sysimport ossys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))from core.config import get_configfrom core.data_provider import DataProviderfrom models.rl.scalping_agent import ScalpingEnvironment, ScalpingRLAgentfrom utils.model_utils import robust_save, robust_load
import numpy as np
import pandas as pd
import logging
from typing import Dict, List, Tuple, Optional, Any
import time
from pathlib import Path
import matplotlib.pyplot as plt
from collections import deque
import random
from torch.utils.tensorboard import SummaryWriter
# Add project imports
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.config import get_config
from core.data_provider import DataProvider
from models.rl.scalping_agent import ScalpingEnvironment, ScalpingRLAgent
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

File diff suppressed because it is too large Load Diff

241
utils/model_utils.py Normal file
View File

@ -0,0 +1,241 @@
#!/usr/bin/env python
"""
Model utilities for robust saving and loading of PyTorch models
"""
import os
import logging
import torch
import shutil
import gc
import json
from typing import Any, Dict, Optional, Union
logger = logging.getLogger(__name__)
def robust_save(model: Any, path: str, include_optimizer: bool = True) -> bool:
"""
Robust model saving with multiple fallback approaches
Args:
model: The model object to save (should have policy_net, target_net, optimizer, epsilon attributes)
path: Path to save the model
include_optimizer: Whether to include optimizer state in the save
Returns:
bool: True if successful, False otherwise
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Clean up GPU memory before saving
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Prepare checkpoint data
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': getattr(model, 'epsilon', 0.0),
'state_size': getattr(model, 'state_size', None),
'action_size': getattr(model, 'action_size', None),
'hidden_size': getattr(model, 'hidden_size', None),
}
# Add optimizer state if requested and available
if include_optimizer and hasattr(model, 'optimizer') and model.optimizer is not None:
checkpoint['optimizer'] = model.optimizer.state_dict()
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If backup worked, copy to the actual path
if os.path.exists(backup_path):
shutil.copy(backup_path, path)
logger.info(f"Copied backup to {path}")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {e}")
# Attempt 2: Try with pickle protocol 2 (more compatible)
try:
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
torch.save(checkpoint, path, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {e}")
# Attempt 3: Try without optimizer state (which can be large and cause issues)
try:
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
checkpoint_no_opt = {k: v for k, v in checkpoint.items() if k != 'optimizer'}
torch.save(checkpoint_no_opt, path)
logger.info(f"Successfully saved to {path} without optimizer state")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {e}")
# Attempt 4: Try with torch.jit.save instead
try:
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
# Save policy network using jit
scripted_policy = torch.jit.script(model.policy_net)
torch.jit.save(scripted_policy, f"{path}.policy.jit")
# Save target network using jit
scripted_target = torch.jit.script(model.target_net)
torch.jit.save(scripted_target, f"{path}.target.jit")
# Save parameters separately as JSON
params = {
'epsilon': float(getattr(model, 'epsilon', 0.0)),
'state_size': int(getattr(model, 'state_size', 0)),
'action_size': int(getattr(model, 'action_size', 0)),
'hidden_size': int(getattr(model, 'hidden_size', 0))
}
with open(f"{path}.params.json", "w") as f:
json.dump(params, f)
logger.info(f"Successfully saved model components with jit.save")
return True
except Exception as e:
logger.error(f"All save attempts failed: {e}")
return False
def robust_load(model: Any, path: str, device: Optional[torch.device] = None) -> bool:
"""
Robust model loading with fallback approaches
Args:
model: The model object to load into
path: Path to load the model from
device: Device to load the model on
Returns:
bool: True if successful, False otherwise
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Try regular PyTorch load first
try:
logger.info(f"Loading model from {path}")
if os.path.exists(path):
checkpoint = torch.load(path, map_location=device)
# Load network states
if 'policy_net' in checkpoint:
model.policy_net.load_state_dict(checkpoint['policy_net'])
if 'target_net' in checkpoint:
model.target_net.load_state_dict(checkpoint['target_net'])
# Load other attributes
if 'epsilon' in checkpoint:
model.epsilon = checkpoint['epsilon']
if 'optimizer' in checkpoint and hasattr(model, 'optimizer') and model.optimizer is not None:
try:
model.optimizer.load_state_dict(checkpoint['optimizer'])
except Exception as e:
logger.warning(f"Failed to load optimizer state: {e}")
logger.info("Successfully loaded model")
return True
except Exception as e:
logger.warning(f"Regular load failed: {e}")
# Try loading JIT saved components
try:
policy_path = f"{path}.policy.jit"
target_path = f"{path}.target.jit"
params_path = f"{path}.params.json"
if all(os.path.exists(p) for p in [policy_path, target_path, params_path]):
logger.info(f"Loading JIT model components")
# Load JIT models (this is more complex and may need model reconstruction)
# For now, just log that we found JIT files
logger.info("Found JIT model files, but loading them requires special handling")
with open(params_path, 'r') as f:
params = json.load(f)
logger.info(f"Model parameters: {params}")
# Note: Actually loading JIT models would require recreating the model architecture
# This is a placeholder for future implementation
return False
except Exception as e:
logger.error(f"JIT load failed: {e}")
logger.error(f"All load attempts failed for {path}")
return False
def get_model_info(path: str) -> Dict[str, Any]:
"""
Get information about a saved model
Args:
path: Path to the model file
Returns:
dict: Model information
"""
info = {
'exists': False,
'size_bytes': 0,
'has_optimizer': False,
'parameters': {}
}
try:
if os.path.exists(path):
info['exists'] = True
info['size_bytes'] = os.path.getsize(path)
# Try to load and inspect
checkpoint = torch.load(path, map_location='cpu')
info['has_optimizer'] = 'optimizer' in checkpoint
# Extract parameter info
for key in ['epsilon', 'state_size', 'action_size', 'hidden_size']:
if key in checkpoint:
info['parameters'][key] = checkpoint[key]
except Exception as e:
logger.warning(f"Failed to get model info for {path}: {e}")
return info
def verify_save_load_cycle(model: Any, test_path: str) -> bool:
"""
Test that a model can be saved and loaded correctly
Args:
model: Model to test
test_path: Path for test file
Returns:
bool: True if save/load cycle successful
"""
try:
# Save the model
if not robust_save(model, test_path):
return False
# Create a new model instance (this would need model creation logic)
# For now, just verify the file exists and has content
if os.path.exists(test_path) and os.path.getsize(test_path) > 0:
logger.info("Save/load cycle verification successful")
# Clean up test file
os.remove(test_path)
return True
else:
return False
except Exception as e:
logger.error(f"Save/load cycle verification failed: {e}")
return False