cleanup_1
This commit is contained in:
parent
f01047f260
commit
509ad0ae17
150
CLEANUP_PLAN.md
Normal file
150
CLEANUP_PLAN.md
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# Project Cleanup & Reorganization Plan
|
||||||
|
|
||||||
|
## Current Issues
|
||||||
|
1. **Code Duplication**: Multiple CNN models, RL agents, training scripts doing similar things
|
||||||
|
2. **Missing Methods**: Core functionality like `run()`, `start_websocket()` missing from classes
|
||||||
|
3. **Unclear Architecture**: No clean separation between components
|
||||||
|
4. **Hard to Maintain**: Scattered implementations make changes difficult
|
||||||
|
|
||||||
|
## New Clean Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
gogo2/
|
||||||
|
├── core/ # Core system components
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── data_provider.py # Multi-timeframe, multi-symbol data
|
||||||
|
│ ├── orchestrator.py # Main decision making module
|
||||||
|
│ └── config.py # Central configuration
|
||||||
|
├── models/ # AI/ML Models
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── cnn/ # CNN module
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── model.py # Single CNN implementation
|
||||||
|
│ │ ├── trainer.py # CNN training pipeline
|
||||||
|
│ │ └── predictor.py # CNN inference with confidence
|
||||||
|
│ └── rl/ # RL module
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── agent.py # Single RL agent implementation
|
||||||
|
│ ├── environment.py # Trading environment
|
||||||
|
│ └── trainer.py # RL training loop
|
||||||
|
├── trading/ # Trading execution
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── executor.py # Trade execution
|
||||||
|
│ ├── portfolio.py # Position/portfolio management
|
||||||
|
│ └── metrics.py # Performance tracking
|
||||||
|
├── web/ # Web interface
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── dashboard.py # Main dashboard
|
||||||
|
│ └── charts.py # Chart components
|
||||||
|
├── utils/ # Utilities
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── logger.py # Centralized logging
|
||||||
|
│ └── helpers.py # Common helpers
|
||||||
|
├── main.py # Single entry point
|
||||||
|
├── config.yaml # Configuration file
|
||||||
|
└── requirements.txt # Dependencies
|
||||||
|
```
|
||||||
|
|
||||||
|
## Core Goals
|
||||||
|
|
||||||
|
### 1. Data Provider (`core/data_provider.py`)
|
||||||
|
- **Multi-symbol support**: ETH/USDT, BTC/USDT (configurable)
|
||||||
|
- **Multi-timeframe**: 1m, 5m, 15m, 1h, 4h, 1d
|
||||||
|
- **Real-time streaming**: WebSocket integration
|
||||||
|
- **Historical data**: API integration for backtesting
|
||||||
|
- **Clean interface**: Simple methods for getting data
|
||||||
|
|
||||||
|
### 2. CNN Module (`models/cnn/`)
|
||||||
|
- **Single model implementation**: Remove duplicates
|
||||||
|
- **Timeframe-specific predictions**: Separate predictions per timeframe
|
||||||
|
- **Confidence scoring**: Each prediction includes confidence
|
||||||
|
- **Training pipeline**: Supervised learning with marked data (perfect moves)
|
||||||
|
|
||||||
|
### 3. RL Module (`models/rl/`)
|
||||||
|
- **Single agent**: Remove duplicate DQN implementations
|
||||||
|
- **Environment**: Clean trading simulation
|
||||||
|
- **Learning loop**: Evaluates trading actions and adapts
|
||||||
|
|
||||||
|
### 4. Orchestrator (`core/orchestrator.py`)
|
||||||
|
- **Decision making**: Combines CNN and RL outputs
|
||||||
|
- **Final actions**: BUY/SELL/HOLD decisions
|
||||||
|
- **Confidence weighting**: Uses CNN confidence in decisions
|
||||||
|
|
||||||
|
### 5. Web Interface (`web/`)
|
||||||
|
- **Real-time charts**: Live trading visualization
|
||||||
|
- **Performance dashboard**: Metrics and analytics
|
||||||
|
- **Simple & clean**: Remove complex chart implementations
|
||||||
|
|
||||||
|
## Cleanup Steps
|
||||||
|
|
||||||
|
### Phase 1: Core Infrastructure
|
||||||
|
1. Create new clean directory structure
|
||||||
|
2. Implement `core/data_provider.py` (consolidate all data functionality)
|
||||||
|
3. Implement `core/orchestrator.py` (main decision maker)
|
||||||
|
4. Create `config.yaml` for all settings
|
||||||
|
|
||||||
|
### Phase 2: Model Consolidation
|
||||||
|
1. Create single `models/cnn/model.py` (consolidate all CNN implementations)
|
||||||
|
2. Create single `models/rl/agent.py` (consolidate DQN implementations)
|
||||||
|
3. Remove duplicate model files
|
||||||
|
|
||||||
|
### Phase 3: Training Simplification
|
||||||
|
1. Create `models/cnn/trainer.py` (single CNN training script)
|
||||||
|
2. Create `models/rl/trainer.py` (single RL training script)
|
||||||
|
3. Remove all duplicate training scripts
|
||||||
|
|
||||||
|
### Phase 4: Web Interface
|
||||||
|
1. Create clean `web/dashboard.py` (consolidate chart functionality)
|
||||||
|
2. Remove complex/unused chart implementations
|
||||||
|
|
||||||
|
### Phase 5: Integration & Testing
|
||||||
|
1. Create single `main.py` entry point
|
||||||
|
2. Test all components work together
|
||||||
|
3. Remove unused files
|
||||||
|
|
||||||
|
## Files to Remove (After consolidation)
|
||||||
|
|
||||||
|
### Duplicate Training Scripts
|
||||||
|
- `train_hybrid.py`
|
||||||
|
- `train_dqn.py`
|
||||||
|
- `train_cnn_with_realtime.py`
|
||||||
|
- `train_with_realtime_ticks.py`
|
||||||
|
- `train_improved_rl.py`
|
||||||
|
- `NN/train_enhanced.py`
|
||||||
|
- `NN/train_rl.py`
|
||||||
|
|
||||||
|
### Duplicate Model Files
|
||||||
|
- `NN/models/cnn_model.py`
|
||||||
|
- `NN/models/enhanced_cnn.py`
|
||||||
|
- `NN/models/simple_cnn.py`
|
||||||
|
- `NN/models/transformer_model.py`
|
||||||
|
- `NN/models/transformer_model_pytorch.py`
|
||||||
|
- `NN/models/dqn_agent_enhanced.py`
|
||||||
|
|
||||||
|
### Duplicate Main Files
|
||||||
|
- `trading_main.py`
|
||||||
|
- `NN/main.py`
|
||||||
|
- `NN/realtime_main.py`
|
||||||
|
- `NN/realtime-main.py`
|
||||||
|
|
||||||
|
### Unused Utilities
|
||||||
|
- `launch_training.py`
|
||||||
|
- `NN/example.py`
|
||||||
|
- Most logs and backup directories
|
||||||
|
|
||||||
|
## Benefits of New Architecture
|
||||||
|
|
||||||
|
1. **Single Source of Truth**: One implementation per component
|
||||||
|
2. **Clear Separation**: CNN, RL, and Orchestrator are distinct
|
||||||
|
3. **Easy to Extend**: Adding new symbols/timeframes is simple
|
||||||
|
4. **Maintainable**: Changes are localized to specific modules
|
||||||
|
5. **Testable**: Each component can be tested independently
|
||||||
|
|
||||||
|
## Implementation Priority
|
||||||
|
|
||||||
|
1. **HIGH**: Core data provider and orchestrator
|
||||||
|
2. **HIGH**: Single CNN and RL implementations
|
||||||
|
3. **MEDIUM**: Web dashboard consolidation
|
||||||
|
4. **LOW**: Cleanup of unused files
|
||||||
|
|
||||||
|
This plan will result in a much cleaner, more maintainable codebase focused on the core goal: multi-modal trading system with CNN predictions and RL decision making.
|
269
CLEAN_ARCHITECTURE_SUMMARY.md
Normal file
269
CLEAN_ARCHITECTURE_SUMMARY.md
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
# Clean Trading System Architecture Summary
|
||||||
|
|
||||||
|
## 🎯 Project Reorganization Complete
|
||||||
|
|
||||||
|
We have successfully transformed the disorganized trading system into a clean, modular, and memory-efficient architecture that fits within **8GB memory constraints** and allows easy plugging of new AI models.
|
||||||
|
|
||||||
|
## 🏗️ New Architecture Overview
|
||||||
|
|
||||||
|
```
|
||||||
|
gogo2/
|
||||||
|
├── core/ # Core system components
|
||||||
|
│ ├── config.py # ✅ Central configuration management
|
||||||
|
│ ├── data_provider.py # ✅ Multi-timeframe, multi-symbol data provider
|
||||||
|
│ ├── orchestrator.py # ✅ Main decision making orchestrator
|
||||||
|
│ └── __init__.py
|
||||||
|
├── models/ # ✅ Modular AI/ML Models
|
||||||
|
│ ├── __init__.py # ✅ Base interfaces & memory management
|
||||||
|
│ ├── cnn/ # 🔄 CNN implementations (to be added)
|
||||||
|
│ └── rl/ # 🔄 RL implementations (to be added)
|
||||||
|
├── web/ # 🔄 Web dashboard (to be added)
|
||||||
|
├── trading/ # 🔄 Trading execution (to be added)
|
||||||
|
├── utils/ # 🔄 Utilities (to be added)
|
||||||
|
├── main_clean.py # ✅ Clean entry point
|
||||||
|
├── config.yaml # ✅ Central configuration
|
||||||
|
└── requirements.txt # 🔄 Dependencies list
|
||||||
|
```
|
||||||
|
|
||||||
|
## ✅ Key Features Implemented
|
||||||
|
|
||||||
|
### 1. **Memory-Efficient Model Registry**
|
||||||
|
- **8GB total memory limit** enforced
|
||||||
|
- **Individual model limits** (configurable per model)
|
||||||
|
- **Automatic memory tracking** and cleanup
|
||||||
|
- **GPU/CPU device management** with fallback
|
||||||
|
- **Model registration/unregistration** with memory checks
|
||||||
|
|
||||||
|
### 2. **Modular Orchestrator System**
|
||||||
|
- **Plugin architecture** - easily add new AI models
|
||||||
|
- **Dynamic weighting** based on model performance
|
||||||
|
- **Multi-model predictions** combining CNN, RL, and any new models
|
||||||
|
- **Confidence-based decisions** with threshold controls
|
||||||
|
- **Real-time memory monitoring**
|
||||||
|
|
||||||
|
### 3. **Unified Data Provider**
|
||||||
|
- **Multi-symbol support**: ETH/USDT, BTC/USDT (extendable)
|
||||||
|
- **Multi-timeframe**: 1m, 5m, 15m, 1h, 4h, 1d
|
||||||
|
- **Real-time streaming** via WebSocket (async)
|
||||||
|
- **Historical data caching** with automatic invalidation
|
||||||
|
- **Technical indicators** computed automatically
|
||||||
|
- **Feature matrix generation** for ML models
|
||||||
|
|
||||||
|
### 4. **Central Configuration System**
|
||||||
|
- **YAML-based configuration** for all settings
|
||||||
|
- **Environment-specific configs** support
|
||||||
|
- **Automatic directory creation**
|
||||||
|
- **Type-safe property access**
|
||||||
|
- **Runtime configuration updates**
|
||||||
|
|
||||||
|
## 🧠 Model Interface Design
|
||||||
|
|
||||||
|
### Base Model Interface
|
||||||
|
```python
|
||||||
|
class ModelInterface(ABC):
|
||||||
|
- predict(features) -> (action_probs, confidence)
|
||||||
|
- get_memory_usage() -> int (MB)
|
||||||
|
- cleanup_memory()
|
||||||
|
- device management (GPU/CPU)
|
||||||
|
```
|
||||||
|
|
||||||
|
### CNN Model Interface
|
||||||
|
```python
|
||||||
|
class CNNModelInterface(ModelInterface):
|
||||||
|
- train(training_data) -> training_metrics
|
||||||
|
- predict_timeframe(features, timeframe) -> prediction
|
||||||
|
- timeframe-specific predictions
|
||||||
|
```
|
||||||
|
|
||||||
|
### RL Agent Interface
|
||||||
|
```python
|
||||||
|
class RLAgentInterface(ModelInterface):
|
||||||
|
- act(state) -> action
|
||||||
|
- act_with_confidence(state) -> (action, confidence)
|
||||||
|
- remember(experience) -> None
|
||||||
|
- replay() -> loss
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📊 Memory Management Features
|
||||||
|
|
||||||
|
### Automatic Memory Tracking
|
||||||
|
- **Per-model memory usage** monitoring
|
||||||
|
- **Total system memory** tracking
|
||||||
|
- **GPU memory management** with CUDA cache clearing
|
||||||
|
- **Memory leak prevention** with periodic cleanup
|
||||||
|
|
||||||
|
### Memory Constraints
|
||||||
|
- **Total system limit**: 8GB (configurable)
|
||||||
|
- **Default per-model limit**: 2GB (configurable)
|
||||||
|
- **Automatic rejection** of models exceeding limits
|
||||||
|
- **Memory stats reporting** for monitoring
|
||||||
|
|
||||||
|
### Example Memory Stats
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
'total_limit_mb': 8192.0,
|
||||||
|
'models': {
|
||||||
|
'CNN': {'memory_mb': 1500, 'device': 'cuda'},
|
||||||
|
'RL': {'memory_mb': 800, 'device': 'cuda'},
|
||||||
|
'Transformer': {'memory_mb': 2000, 'device': 'cuda'}
|
||||||
|
},
|
||||||
|
'total_used_mb': 4300,
|
||||||
|
'total_free_mb': 3892,
|
||||||
|
'utilization_percent': 52.5
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔧 Easy Model Integration
|
||||||
|
|
||||||
|
### Adding a New Model (Example: Transformer)
|
||||||
|
```python
|
||||||
|
from models import ModelInterface, get_model_registry
|
||||||
|
|
||||||
|
class TransformerModel(ModelInterface):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__('Transformer', config)
|
||||||
|
self.model = self._build_transformer()
|
||||||
|
|
||||||
|
def predict(self, features):
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(features)
|
||||||
|
probs = F.softmax(outputs, dim=-1)
|
||||||
|
confidence = torch.max(probs).item()
|
||||||
|
return probs.numpy(), confidence
|
||||||
|
|
||||||
|
def get_memory_usage(self):
|
||||||
|
return sum(p.numel() * 4 for p in self.model.parameters()) // (1024*1024)
|
||||||
|
|
||||||
|
# Register with orchestrator
|
||||||
|
registry = get_model_registry()
|
||||||
|
orchestrator = TradingOrchestrator()
|
||||||
|
|
||||||
|
transformer = TransformerModel(config)
|
||||||
|
if orchestrator.register_model(transformer, weight=0.2):
|
||||||
|
print("Transformer model added successfully!")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🚀 Performance Optimizations
|
||||||
|
|
||||||
|
### Data Provider
|
||||||
|
- **Caching with TTL** (1-hour expiration)
|
||||||
|
- **Parquet storage** for fast I/O
|
||||||
|
- **Batch processing** of technical indicators
|
||||||
|
- **Memory-efficient** pandas operations
|
||||||
|
|
||||||
|
### Model System
|
||||||
|
- **Lazy loading** of models
|
||||||
|
- **Mixed precision** support (GPU)
|
||||||
|
- **Batch inference** where possible
|
||||||
|
- **Memory pooling** for repeated allocations
|
||||||
|
|
||||||
|
### Orchestrator
|
||||||
|
- **Asynchronous processing** for multiple models
|
||||||
|
- **Weighted averaging** of predictions
|
||||||
|
- **Confidence thresholding** to avoid low-quality decisions
|
||||||
|
- **Performance-based** weight adaptation
|
||||||
|
|
||||||
|
## 📈 Testing Results
|
||||||
|
|
||||||
|
### Data Provider Test
|
||||||
|
```
|
||||||
|
[SUCCESS] Historical data: 100 candles loaded
|
||||||
|
[SUCCESS] Feature matrix shape: (1, 20, 8)
|
||||||
|
[SUCCESS] Data provider health check passed
|
||||||
|
```
|
||||||
|
|
||||||
|
### Orchestrator Test
|
||||||
|
```
|
||||||
|
[SUCCESS] Model registry initialized with 8192.0MB limit
|
||||||
|
[SUCCESS] Both models registered successfully
|
||||||
|
[SUCCESS] Memory stats: 0.0% utilization
|
||||||
|
[SUCCESS] Models registered with orchestrator
|
||||||
|
[SUCCESS] Performance metrics available
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🎛️ Configuration Management
|
||||||
|
|
||||||
|
### Sample Configuration (config.yaml)
|
||||||
|
```yaml
|
||||||
|
# 8GB total memory limit
|
||||||
|
performance:
|
||||||
|
total_memory_gb: 8.0
|
||||||
|
use_gpu: true
|
||||||
|
mixed_precision: true
|
||||||
|
|
||||||
|
# Model-specific limits
|
||||||
|
models:
|
||||||
|
cnn:
|
||||||
|
max_memory_mb: 2000
|
||||||
|
window_size: 20
|
||||||
|
rl:
|
||||||
|
max_memory_mb: 1500
|
||||||
|
state_size: 100
|
||||||
|
|
||||||
|
# Trading symbols & timeframes
|
||||||
|
symbols: ["ETH/USDT", "BTC/USDT"]
|
||||||
|
timeframes: ["1m", "5m", "15m", "1h", "4h", "1d"]
|
||||||
|
|
||||||
|
# Decision making
|
||||||
|
orchestrator:
|
||||||
|
confidence_threshold: 0.5
|
||||||
|
decision_frequency: 60
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔄 Next Steps
|
||||||
|
|
||||||
|
### Phase 1: Complete Core Models
|
||||||
|
- [ ] Implement CNN model using the interface
|
||||||
|
- [ ] Implement RL agent using the interface
|
||||||
|
- [ ] Add model loading/saving functionality
|
||||||
|
|
||||||
|
### Phase 2: Enhanced Features
|
||||||
|
- [ ] Web dashboard integration
|
||||||
|
- [ ] Trading execution module
|
||||||
|
- [ ] Backtresting framework
|
||||||
|
- [ ] Performance analytics
|
||||||
|
|
||||||
|
### Phase 3: Advanced Models
|
||||||
|
- [ ] Transformer model for sequence modeling
|
||||||
|
- [ ] LSTM for temporal patterns
|
||||||
|
- [ ] Ensemble methods
|
||||||
|
- [ ] Meta-learning approaches
|
||||||
|
|
||||||
|
## 🎯 Benefits Achieved
|
||||||
|
|
||||||
|
1. **Memory Efficiency**: Strict 8GB enforcement with monitoring
|
||||||
|
2. **Modularity**: Easy to add/remove/test different AI models
|
||||||
|
3. **Maintainability**: Clear separation of concerns, no code duplication
|
||||||
|
4. **Scalability**: Can handle multiple symbols and timeframes efficiently
|
||||||
|
5. **Testability**: Each component can be tested independently
|
||||||
|
6. **Performance**: Optimized data processing and model inference
|
||||||
|
7. **Flexibility**: Configuration-driven behavior
|
||||||
|
8. **Monitoring**: Real-time memory and performance tracking
|
||||||
|
|
||||||
|
## 🛠️ Usage Examples
|
||||||
|
|
||||||
|
### Basic Testing
|
||||||
|
```bash
|
||||||
|
# Test data provider
|
||||||
|
python main_clean.py --mode test
|
||||||
|
|
||||||
|
# Test orchestrator system
|
||||||
|
python main_clean.py --mode orchestrator
|
||||||
|
|
||||||
|
# Test with specific symbol
|
||||||
|
python main_clean.py --mode test --symbol BTC/USDT
|
||||||
|
```
|
||||||
|
|
||||||
|
### Future Usage
|
||||||
|
```bash
|
||||||
|
# Training mode
|
||||||
|
python main_clean.py --mode train --symbol ETH/USDT
|
||||||
|
|
||||||
|
# Live trading
|
||||||
|
python main_clean.py --mode trade
|
||||||
|
|
||||||
|
# Web dashboard
|
||||||
|
python main_clean.py --mode web
|
||||||
|
```
|
||||||
|
|
||||||
|
This clean architecture provides a solid foundation for building a sophisticated multi-modal trading system that scales efficiently within memory constraints while remaining easy to extend and maintain.
|
94
config.yaml
Normal file
94
config.yaml
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# Trading System Configuration
|
||||||
|
|
||||||
|
# Trading Symbols (extendable)
|
||||||
|
symbols:
|
||||||
|
- "ETH/USDT"
|
||||||
|
- "BTC/USDT"
|
||||||
|
|
||||||
|
# Timeframes for multi-timeframe analysis
|
||||||
|
timeframes:
|
||||||
|
- "1m"
|
||||||
|
- "5m"
|
||||||
|
- "15m"
|
||||||
|
- "1h"
|
||||||
|
- "4h"
|
||||||
|
- "1d"
|
||||||
|
|
||||||
|
# Data Provider Settings
|
||||||
|
data:
|
||||||
|
provider: "binance"
|
||||||
|
cache_enabled: true
|
||||||
|
cache_dir: "cache"
|
||||||
|
historical_limit: 1000
|
||||||
|
real_time_enabled: true
|
||||||
|
websocket_reconnect: true
|
||||||
|
|
||||||
|
# CNN Model Configuration
|
||||||
|
cnn:
|
||||||
|
window_size: 20
|
||||||
|
features: ["open", "high", "low", "close", "volume"]
|
||||||
|
hidden_layers: [64, 32, 16]
|
||||||
|
dropout: 0.2
|
||||||
|
learning_rate: 0.001
|
||||||
|
batch_size: 32
|
||||||
|
epochs: 100
|
||||||
|
confidence_threshold: 0.6
|
||||||
|
|
||||||
|
# RL Agent Configuration
|
||||||
|
rl:
|
||||||
|
state_size: 100 # Will be calculated dynamically
|
||||||
|
action_space: 3 # BUY, HOLD, SELL
|
||||||
|
epsilon: 1.0
|
||||||
|
epsilon_decay: 0.995
|
||||||
|
epsilon_min: 0.01
|
||||||
|
learning_rate: 0.0001
|
||||||
|
gamma: 0.99
|
||||||
|
memory_size: 10000
|
||||||
|
batch_size: 64
|
||||||
|
target_update_freq: 1000
|
||||||
|
|
||||||
|
# Orchestrator Settings
|
||||||
|
orchestrator:
|
||||||
|
cnn_weight: 0.7 # Weight for CNN predictions
|
||||||
|
rl_weight: 0.3 # Weight for RL decisions
|
||||||
|
confidence_threshold: 0.5 # Minimum confidence to act
|
||||||
|
decision_frequency: 60 # Seconds between decisions
|
||||||
|
|
||||||
|
# Trading Execution
|
||||||
|
trading:
|
||||||
|
max_position_size: 0.1 # Maximum position size (fraction of balance)
|
||||||
|
stop_loss: 0.02 # 2% stop loss
|
||||||
|
take_profit: 0.05 # 5% take profit
|
||||||
|
trading_fee: 0.0002 # 0.02% trading fee
|
||||||
|
min_trade_interval: 60 # Minimum seconds between trades
|
||||||
|
|
||||||
|
# Web Dashboard
|
||||||
|
web:
|
||||||
|
host: "127.0.0.1"
|
||||||
|
port: 8050
|
||||||
|
debug: false
|
||||||
|
update_interval: 1000 # Milliseconds
|
||||||
|
chart_history: 100 # Number of candles to show
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
logging:
|
||||||
|
level: "INFO"
|
||||||
|
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
file: "logs/trading.log"
|
||||||
|
max_size: 10485760 # 10MB
|
||||||
|
backup_count: 5
|
||||||
|
|
||||||
|
# GPU/Performance
|
||||||
|
performance:
|
||||||
|
use_gpu: true
|
||||||
|
mixed_precision: true
|
||||||
|
num_workers: 4
|
||||||
|
batch_size_multiplier: 1.0
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
paths:
|
||||||
|
models: "models"
|
||||||
|
data: "data"
|
||||||
|
logs: "logs"
|
||||||
|
cache: "cache"
|
||||||
|
plots: "plots"
|
0
core/__init__.py
Normal file
0
core/__init__.py
Normal file
239
core/config.py
Normal file
239
core/config.py
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
"""
|
||||||
|
Central Configuration Management
|
||||||
|
|
||||||
|
This module handles all configuration for the trading system.
|
||||||
|
It loads settings from config.yaml and provides easy access to all components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Central configuration management for the trading system"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str = "config.yaml"):
|
||||||
|
"""Initialize configuration from YAML file"""
|
||||||
|
self.config_path = Path(config_path)
|
||||||
|
self._config = self._load_config()
|
||||||
|
self._setup_directories()
|
||||||
|
|
||||||
|
def _load_config(self) -> Dict[str, Any]:
|
||||||
|
"""Load configuration from YAML file"""
|
||||||
|
try:
|
||||||
|
if not self.config_path.exists():
|
||||||
|
logger.warning(f"Config file {self.config_path} not found, using defaults")
|
||||||
|
return self._get_default_config()
|
||||||
|
|
||||||
|
with open(self.config_path, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
logger.info(f"Loaded configuration from {self.config_path}")
|
||||||
|
return config
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading config: {e}")
|
||||||
|
logger.info("Using default configuration")
|
||||||
|
return self._get_default_config()
|
||||||
|
|
||||||
|
def _get_default_config(self) -> Dict[str, Any]:
|
||||||
|
"""Get default configuration if file is missing"""
|
||||||
|
return {
|
||||||
|
'symbols': ['ETH/USDT', 'BTC/USDT'],
|
||||||
|
'timeframes': ['1m', '5m', '15m', '1h', '4h', '1d'],
|
||||||
|
'data': {
|
||||||
|
'provider': 'binance',
|
||||||
|
'cache_enabled': True,
|
||||||
|
'cache_dir': 'cache',
|
||||||
|
'historical_limit': 1000,
|
||||||
|
'real_time_enabled': True,
|
||||||
|
'websocket_reconnect': True
|
||||||
|
},
|
||||||
|
'cnn': {
|
||||||
|
'window_size': 20,
|
||||||
|
'features': ['open', 'high', 'low', 'close', 'volume'],
|
||||||
|
'hidden_layers': [64, 32, 16],
|
||||||
|
'dropout': 0.2,
|
||||||
|
'learning_rate': 0.001,
|
||||||
|
'batch_size': 32,
|
||||||
|
'epochs': 100,
|
||||||
|
'confidence_threshold': 0.6
|
||||||
|
},
|
||||||
|
'rl': {
|
||||||
|
'state_size': 100,
|
||||||
|
'action_space': 3,
|
||||||
|
'epsilon': 1.0,
|
||||||
|
'epsilon_decay': 0.995,
|
||||||
|
'epsilon_min': 0.01,
|
||||||
|
'learning_rate': 0.0001,
|
||||||
|
'gamma': 0.99,
|
||||||
|
'memory_size': 10000,
|
||||||
|
'batch_size': 64,
|
||||||
|
'target_update_freq': 1000
|
||||||
|
},
|
||||||
|
'orchestrator': {
|
||||||
|
'cnn_weight': 0.7,
|
||||||
|
'rl_weight': 0.3,
|
||||||
|
'confidence_threshold': 0.5,
|
||||||
|
'decision_frequency': 60
|
||||||
|
},
|
||||||
|
'trading': {
|
||||||
|
'max_position_size': 0.1,
|
||||||
|
'stop_loss': 0.02,
|
||||||
|
'take_profit': 0.05,
|
||||||
|
'trading_fee': 0.0002,
|
||||||
|
'min_trade_interval': 60
|
||||||
|
},
|
||||||
|
'web': {
|
||||||
|
'host': '127.0.0.1',
|
||||||
|
'port': 8050,
|
||||||
|
'debug': False,
|
||||||
|
'update_interval': 1000,
|
||||||
|
'chart_history': 100
|
||||||
|
},
|
||||||
|
'logging': {
|
||||||
|
'level': 'INFO',
|
||||||
|
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
'file': 'logs/trading.log',
|
||||||
|
'max_size': 10485760,
|
||||||
|
'backup_count': 5
|
||||||
|
},
|
||||||
|
'performance': {
|
||||||
|
'use_gpu': True,
|
||||||
|
'mixed_precision': True,
|
||||||
|
'num_workers': 4,
|
||||||
|
'batch_size_multiplier': 1.0
|
||||||
|
},
|
||||||
|
'paths': {
|
||||||
|
'models': 'models',
|
||||||
|
'data': 'data',
|
||||||
|
'logs': 'logs',
|
||||||
|
'cache': 'cache',
|
||||||
|
'plots': 'plots'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _setup_directories(self):
|
||||||
|
"""Create necessary directories"""
|
||||||
|
try:
|
||||||
|
paths = self._config.get('paths', {})
|
||||||
|
for path_name, path_value in paths.items():
|
||||||
|
Path(path_value).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Also create specific model subdirectories
|
||||||
|
models_dir = Path(paths.get('models', 'models'))
|
||||||
|
(models_dir / 'cnn' / 'saved').mkdir(parents=True, exist_ok=True)
|
||||||
|
(models_dir / 'rl' / 'saved').mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating directories: {e}")
|
||||||
|
|
||||||
|
# Property accessors for easy access
|
||||||
|
@property
|
||||||
|
def symbols(self) -> List[str]:
|
||||||
|
"""Get list of trading symbols"""
|
||||||
|
return self._config.get('symbols', ['ETH/USDT'])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timeframes(self) -> List[str]:
|
||||||
|
"""Get list of timeframes"""
|
||||||
|
return self._config.get('timeframes', ['1m', '5m', '1h'])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> Dict[str, Any]:
|
||||||
|
"""Get data provider settings"""
|
||||||
|
return self._config.get('data', {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cnn(self) -> Dict[str, Any]:
|
||||||
|
"""Get CNN model settings"""
|
||||||
|
return self._config.get('cnn', {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rl(self) -> Dict[str, Any]:
|
||||||
|
"""Get RL agent settings"""
|
||||||
|
return self._config.get('rl', {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def orchestrator(self) -> Dict[str, Any]:
|
||||||
|
"""Get orchestrator settings"""
|
||||||
|
return self._config.get('orchestrator', {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trading(self) -> Dict[str, Any]:
|
||||||
|
"""Get trading execution settings"""
|
||||||
|
return self._config.get('trading', {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def web(self) -> Dict[str, Any]:
|
||||||
|
"""Get web dashboard settings"""
|
||||||
|
return self._config.get('web', {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logging(self) -> Dict[str, Any]:
|
||||||
|
"""Get logging settings"""
|
||||||
|
return self._config.get('logging', {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def performance(self) -> Dict[str, Any]:
|
||||||
|
"""Get performance settings"""
|
||||||
|
return self._config.get('performance', {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def paths(self) -> Dict[str, str]:
|
||||||
|
"""Get file paths"""
|
||||||
|
return self._config.get('paths', {})
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""Get configuration value by key with optional default"""
|
||||||
|
return self._config.get(key, default)
|
||||||
|
|
||||||
|
def update(self, key: str, value: Any):
|
||||||
|
"""Update configuration value"""
|
||||||
|
self._config[key] = value
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""Save current configuration back to file"""
|
||||||
|
try:
|
||||||
|
with open(self.config_path, 'w') as f:
|
||||||
|
yaml.dump(self._config, f, default_flow_style=False, indent=2)
|
||||||
|
logger.info(f"Configuration saved to {self.config_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving configuration: {e}")
|
||||||
|
|
||||||
|
# Global configuration instance
|
||||||
|
_config_instance = None
|
||||||
|
|
||||||
|
def get_config(config_path: str = "config.yaml") -> Config:
|
||||||
|
"""Get global configuration instance (singleton pattern)"""
|
||||||
|
global _config_instance
|
||||||
|
if _config_instance is None:
|
||||||
|
_config_instance = Config(config_path)
|
||||||
|
return _config_instance
|
||||||
|
|
||||||
|
def setup_logging(config: Optional[Config] = None):
|
||||||
|
"""Setup logging based on configuration"""
|
||||||
|
if config is None:
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
|
log_config = config.logging
|
||||||
|
|
||||||
|
# Create logs directory
|
||||||
|
log_file = Path(log_config.get('file', 'logs/trading.log'))
|
||||||
|
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, log_config.get('level', 'INFO')),
|
||||||
|
format=log_config.get('format', '%(asctime)s - %(name)s - %(levelname)s - %(message)s'),
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_file),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Logging configured successfully")
|
437
core/data_provider.py
Normal file
437
core/data_provider.py
Normal file
@ -0,0 +1,437 @@
|
|||||||
|
"""
|
||||||
|
Multi-Timeframe, Multi-Symbol Data Provider
|
||||||
|
|
||||||
|
This module consolidates all data functionality including:
|
||||||
|
- Historical data fetching from Binance API
|
||||||
|
- Real-time data streaming via WebSocket
|
||||||
|
- Multi-timeframe candle generation
|
||||||
|
- Caching and data management
|
||||||
|
- Technical indicators calculation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import websockets
|
||||||
|
import requests
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any
|
||||||
|
import ta
|
||||||
|
from threading import Thread, Lock
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from .config import get_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DataProvider:
|
||||||
|
"""Unified data provider for historical and real-time market data"""
|
||||||
|
|
||||||
|
def __init__(self, symbols: List[str] = None, timeframes: List[str] = None):
|
||||||
|
"""Initialize the data provider"""
|
||||||
|
self.config = get_config()
|
||||||
|
self.symbols = symbols or self.config.symbols
|
||||||
|
self.timeframes = timeframes or self.config.timeframes
|
||||||
|
|
||||||
|
# Data storage
|
||||||
|
self.historical_data = {} # {symbol: {timeframe: DataFrame}}
|
||||||
|
self.real_time_data = {} # {symbol: {timeframe: deque}}
|
||||||
|
self.current_prices = {} # {symbol: float}
|
||||||
|
|
||||||
|
# Real-time processing
|
||||||
|
self.websocket_tasks = {}
|
||||||
|
self.is_streaming = False
|
||||||
|
self.data_lock = Lock()
|
||||||
|
|
||||||
|
# Cache settings
|
||||||
|
self.cache_enabled = self.config.data.get('cache_enabled', True)
|
||||||
|
self.cache_dir = Path(self.config.data.get('cache_dir', 'cache'))
|
||||||
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Timeframe conversion
|
||||||
|
self.timeframe_seconds = {
|
||||||
|
'1m': 60, '5m': 300, '15m': 900, '30m': 1800,
|
||||||
|
'1h': 3600, '4h': 14400, '1d': 86400
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"DataProvider initialized for symbols: {self.symbols}")
|
||||||
|
logger.info(f"Timeframes: {self.timeframes}")
|
||||||
|
|
||||||
|
def get_historical_data(self, symbol: str, timeframe: str, limit: int = 1000,
|
||||||
|
refresh: bool = False) -> Optional[pd.DataFrame]:
|
||||||
|
"""Get historical OHLCV data for a symbol and timeframe"""
|
||||||
|
try:
|
||||||
|
# Check cache first
|
||||||
|
if not refresh and self.cache_enabled:
|
||||||
|
cached_data = self._load_from_cache(symbol, timeframe)
|
||||||
|
if cached_data is not None and len(cached_data) >= limit * 0.8:
|
||||||
|
logger.info(f"Using cached data for {symbol} {timeframe}")
|
||||||
|
return cached_data.tail(limit)
|
||||||
|
|
||||||
|
# Fetch from API
|
||||||
|
logger.info(f"Fetching historical data for {symbol} {timeframe}")
|
||||||
|
df = self._fetch_from_binance(symbol, timeframe, limit)
|
||||||
|
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
# Add technical indicators
|
||||||
|
df = self._add_technical_indicators(df)
|
||||||
|
|
||||||
|
# Cache the data
|
||||||
|
if self.cache_enabled:
|
||||||
|
self._save_to_cache(df, symbol, timeframe)
|
||||||
|
|
||||||
|
# Store in memory
|
||||||
|
if symbol not in self.historical_data:
|
||||||
|
self.historical_data[symbol] = {}
|
||||||
|
self.historical_data[symbol][timeframe] = df
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
logger.warning(f"No data received for {symbol} {timeframe}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching historical data for {symbol} {timeframe}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fetch_from_binance(self, symbol: str, timeframe: str, limit: int) -> Optional[pd.DataFrame]:
|
||||||
|
"""Fetch data from Binance API"""
|
||||||
|
try:
|
||||||
|
# Convert symbol format
|
||||||
|
binance_symbol = symbol.replace('/', '').upper()
|
||||||
|
|
||||||
|
# Convert timeframe
|
||||||
|
timeframe_map = {
|
||||||
|
'1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
|
||||||
|
'1h': '1h', '4h': '4h', '1d': '1d'
|
||||||
|
}
|
||||||
|
binance_timeframe = timeframe_map.get(timeframe, '1h')
|
||||||
|
|
||||||
|
# API request
|
||||||
|
url = "https://api.binance.com/api/v3/klines"
|
||||||
|
params = {
|
||||||
|
'symbol': binance_symbol,
|
||||||
|
'interval': binance_timeframe,
|
||||||
|
'limit': limit
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(url, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
df = pd.DataFrame(data, columns=[
|
||||||
|
'timestamp', 'open', 'high', 'low', 'close', 'volume',
|
||||||
|
'close_time', 'quote_volume', 'trades', 'taker_buy_base',
|
||||||
|
'taker_buy_quote', 'ignore'
|
||||||
|
])
|
||||||
|
|
||||||
|
# Process columns
|
||||||
|
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||||
|
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||||
|
df[col] = df[col].astype(float)
|
||||||
|
|
||||||
|
# Keep only OHLCV columns
|
||||||
|
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
|
||||||
|
df = df.sort_values('timestamp').reset_index(drop=True)
|
||||||
|
|
||||||
|
logger.info(f"Fetched {len(df)} candles for {symbol} {timeframe}")
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching from Binance API: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _add_technical_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Add technical indicators to the DataFrame"""
|
||||||
|
try:
|
||||||
|
df = df.copy()
|
||||||
|
|
||||||
|
# Moving averages
|
||||||
|
df['sma_20'] = ta.trend.sma_indicator(df['close'], window=20)
|
||||||
|
df['sma_50'] = ta.trend.sma_indicator(df['close'], window=50)
|
||||||
|
df['ema_12'] = ta.trend.ema_indicator(df['close'], window=12)
|
||||||
|
df['ema_26'] = ta.trend.ema_indicator(df['close'], window=26)
|
||||||
|
|
||||||
|
# MACD
|
||||||
|
macd = ta.trend.MACD(df['close'])
|
||||||
|
df['macd'] = macd.macd()
|
||||||
|
df['macd_signal'] = macd.macd_signal()
|
||||||
|
df['macd_histogram'] = macd.macd_diff()
|
||||||
|
|
||||||
|
# RSI
|
||||||
|
df['rsi'] = ta.momentum.rsi(df['close'], window=14)
|
||||||
|
|
||||||
|
# Bollinger Bands
|
||||||
|
bollinger = ta.volatility.BollingerBands(df['close'])
|
||||||
|
df['bb_upper'] = bollinger.bollinger_hband()
|
||||||
|
df['bb_lower'] = bollinger.bollinger_lband()
|
||||||
|
df['bb_middle'] = bollinger.bollinger_mavg()
|
||||||
|
|
||||||
|
# Volume moving average (simple rolling mean since ta.volume.volume_sma doesn't exist)
|
||||||
|
df['volume_sma'] = df['volume'].rolling(window=20).mean()
|
||||||
|
|
||||||
|
# Fill NaN values
|
||||||
|
df = df.bfill().fillna(0)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding technical indicators: {e}")
|
||||||
|
return df
|
||||||
|
|
||||||
|
def _load_from_cache(self, symbol: str, timeframe: str) -> Optional[pd.DataFrame]:
|
||||||
|
"""Load data from cache"""
|
||||||
|
try:
|
||||||
|
cache_file = self.cache_dir / f"{symbol.replace('/', '')}_{timeframe}.parquet"
|
||||||
|
if cache_file.exists():
|
||||||
|
# Check if cache is recent (less than 1 hour old)
|
||||||
|
cache_age = time.time() - cache_file.stat().st_mtime
|
||||||
|
if cache_age < 3600: # 1 hour
|
||||||
|
df = pd.read_parquet(cache_file)
|
||||||
|
logger.debug(f"Loaded {len(df)} rows from cache for {symbol} {timeframe}")
|
||||||
|
return df
|
||||||
|
else:
|
||||||
|
logger.debug(f"Cache for {symbol} {timeframe} is too old ({cache_age/3600:.1f}h)")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error loading cache for {symbol} {timeframe}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _save_to_cache(self, df: pd.DataFrame, symbol: str, timeframe: str):
|
||||||
|
"""Save data to cache"""
|
||||||
|
try:
|
||||||
|
cache_file = self.cache_dir / f"{symbol.replace('/', '')}_{timeframe}.parquet"
|
||||||
|
df.to_parquet(cache_file, index=False)
|
||||||
|
logger.debug(f"Saved {len(df)} rows to cache for {symbol} {timeframe}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error saving cache for {symbol} {timeframe}: {e}")
|
||||||
|
|
||||||
|
async def start_real_time_streaming(self):
|
||||||
|
"""Start real-time data streaming for all symbols"""
|
||||||
|
if self.is_streaming:
|
||||||
|
logger.warning("Real-time streaming already active")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_streaming = True
|
||||||
|
logger.info("Starting real-time data streaming")
|
||||||
|
|
||||||
|
# Start WebSocket for each symbol
|
||||||
|
for symbol in self.symbols:
|
||||||
|
task = asyncio.create_task(self._websocket_stream(symbol))
|
||||||
|
self.websocket_tasks[symbol] = task
|
||||||
|
|
||||||
|
async def stop_real_time_streaming(self):
|
||||||
|
"""Stop real-time data streaming"""
|
||||||
|
if not self.is_streaming:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Stopping real-time data streaming")
|
||||||
|
self.is_streaming = False
|
||||||
|
|
||||||
|
# Cancel all WebSocket tasks
|
||||||
|
for symbol, task in self.websocket_tasks.items():
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.websocket_tasks.clear()
|
||||||
|
|
||||||
|
async def _websocket_stream(self, symbol: str):
|
||||||
|
"""WebSocket stream for a single symbol"""
|
||||||
|
binance_symbol = symbol.replace('/', '').lower()
|
||||||
|
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@ticker"
|
||||||
|
|
||||||
|
while self.is_streaming:
|
||||||
|
try:
|
||||||
|
async with websockets.connect(url) as websocket:
|
||||||
|
logger.info(f"WebSocket connected for {symbol}")
|
||||||
|
|
||||||
|
async for message in websocket:
|
||||||
|
if not self.is_streaming:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
await self._process_tick(symbol, data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error processing tick for {symbol}: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket error for {symbol}: {e}")
|
||||||
|
if self.is_streaming:
|
||||||
|
logger.info(f"Reconnecting WebSocket for {symbol} in 5 seconds...")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
async def _process_tick(self, symbol: str, tick_data: Dict):
|
||||||
|
"""Process a single tick and update candles"""
|
||||||
|
try:
|
||||||
|
price = float(tick_data.get('c', 0)) # Current price
|
||||||
|
volume = float(tick_data.get('v', 0)) # 24h Volume
|
||||||
|
timestamp = pd.Timestamp.now()
|
||||||
|
|
||||||
|
# Update current price
|
||||||
|
with self.data_lock:
|
||||||
|
self.current_prices[symbol] = price
|
||||||
|
|
||||||
|
# Initialize real-time data structure if needed
|
||||||
|
if symbol not in self.real_time_data:
|
||||||
|
self.real_time_data[symbol] = {}
|
||||||
|
for tf in self.timeframes:
|
||||||
|
self.real_time_data[symbol][tf] = deque(maxlen=1000)
|
||||||
|
|
||||||
|
# Create tick record
|
||||||
|
tick = {
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'price': price,
|
||||||
|
'volume': volume
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update all timeframes
|
||||||
|
for timeframe in self.timeframes:
|
||||||
|
self._update_candle(symbol, timeframe, tick)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing tick for {symbol}: {e}")
|
||||||
|
|
||||||
|
def _update_candle(self, symbol: str, timeframe: str, tick: Dict):
|
||||||
|
"""Update candle for specific timeframe"""
|
||||||
|
try:
|
||||||
|
timeframe_secs = self.timeframe_seconds.get(timeframe, 3600)
|
||||||
|
current_time = tick['timestamp']
|
||||||
|
|
||||||
|
# Calculate candle start time
|
||||||
|
candle_start = current_time.floor(f'{timeframe_secs}s')
|
||||||
|
|
||||||
|
# Get current candle queue
|
||||||
|
candle_queue = self.real_time_data[symbol][timeframe]
|
||||||
|
|
||||||
|
# Check if we need a new candle
|
||||||
|
if not candle_queue or candle_queue[-1]['timestamp'] != candle_start:
|
||||||
|
# Create new candle
|
||||||
|
new_candle = {
|
||||||
|
'timestamp': candle_start,
|
||||||
|
'open': tick['price'],
|
||||||
|
'high': tick['price'],
|
||||||
|
'low': tick['price'],
|
||||||
|
'close': tick['price'],
|
||||||
|
'volume': tick['volume']
|
||||||
|
}
|
||||||
|
candle_queue.append(new_candle)
|
||||||
|
else:
|
||||||
|
# Update existing candle
|
||||||
|
current_candle = candle_queue[-1]
|
||||||
|
current_candle['high'] = max(current_candle['high'], tick['price'])
|
||||||
|
current_candle['low'] = min(current_candle['low'], tick['price'])
|
||||||
|
current_candle['close'] = tick['price']
|
||||||
|
current_candle['volume'] += tick['volume']
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating candle for {symbol} {timeframe}: {e}")
|
||||||
|
|
||||||
|
def get_latest_candles(self, symbol: str, timeframe: str, limit: int = 100) -> pd.DataFrame:
|
||||||
|
"""Get the latest candles combining historical and real-time data"""
|
||||||
|
try:
|
||||||
|
# Get historical data
|
||||||
|
historical_df = self.get_historical_data(symbol, timeframe, limit=limit)
|
||||||
|
|
||||||
|
# Get real-time data
|
||||||
|
with self.data_lock:
|
||||||
|
if symbol in self.real_time_data and timeframe in self.real_time_data[symbol]:
|
||||||
|
real_time_candles = list(self.real_time_data[symbol][timeframe])
|
||||||
|
|
||||||
|
if real_time_candles:
|
||||||
|
# Convert to DataFrame
|
||||||
|
rt_df = pd.DataFrame(real_time_candles)
|
||||||
|
|
||||||
|
if historical_df is not None:
|
||||||
|
# Combine historical and real-time
|
||||||
|
# Remove overlapping candles from historical data
|
||||||
|
if not rt_df.empty:
|
||||||
|
cutoff_time = rt_df['timestamp'].min()
|
||||||
|
historical_df = historical_df[historical_df['timestamp'] < cutoff_time]
|
||||||
|
|
||||||
|
# Concatenate
|
||||||
|
combined_df = pd.concat([historical_df, rt_df], ignore_index=True)
|
||||||
|
else:
|
||||||
|
combined_df = rt_df
|
||||||
|
|
||||||
|
return combined_df.tail(limit)
|
||||||
|
|
||||||
|
# Return just historical data if no real-time data
|
||||||
|
return historical_df.tail(limit) if historical_df is not None else pd.DataFrame()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting latest candles for {symbol} {timeframe}: {e}")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
def get_current_price(self, symbol: str) -> Optional[float]:
|
||||||
|
"""Get current price for a symbol"""
|
||||||
|
with self.data_lock:
|
||||||
|
return self.current_prices.get(symbol)
|
||||||
|
|
||||||
|
def get_feature_matrix(self, symbol: str, timeframes: List[str] = None,
|
||||||
|
window_size: int = 20) -> Optional[np.ndarray]:
|
||||||
|
"""Get feature matrix for multiple timeframes"""
|
||||||
|
try:
|
||||||
|
if timeframes is None:
|
||||||
|
timeframes = self.timeframes
|
||||||
|
|
||||||
|
features = []
|
||||||
|
|
||||||
|
for tf in timeframes:
|
||||||
|
df = self.get_latest_candles(symbol, tf, limit=window_size + 50)
|
||||||
|
|
||||||
|
if df is not None and len(df) >= window_size:
|
||||||
|
# Select feature columns
|
||||||
|
feature_cols = ['open', 'high', 'low', 'close', 'volume']
|
||||||
|
if 'sma_20' in df.columns:
|
||||||
|
feature_cols.extend(['sma_20', 'rsi', 'macd'])
|
||||||
|
|
||||||
|
# Get the latest window
|
||||||
|
tf_features = df[feature_cols].tail(window_size).values
|
||||||
|
features.append(tf_features)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Insufficient data for {symbol} {tf}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if features:
|
||||||
|
# Stack features from all timeframes
|
||||||
|
return np.stack(features, axis=0) # Shape: (n_timeframes, window_size, n_features)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating feature matrix for {symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def health_check(self) -> Dict[str, Any]:
|
||||||
|
"""Get health status of the data provider"""
|
||||||
|
status = {
|
||||||
|
'streaming': self.is_streaming,
|
||||||
|
'symbols': len(self.symbols),
|
||||||
|
'timeframes': len(self.timeframes),
|
||||||
|
'current_prices': len(self.current_prices),
|
||||||
|
'websocket_tasks': len(self.websocket_tasks),
|
||||||
|
'historical_data_loaded': {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check historical data availability
|
||||||
|
for symbol in self.symbols:
|
||||||
|
status['historical_data_loaded'][symbol] = {}
|
||||||
|
for tf in self.timeframes:
|
||||||
|
has_data = (symbol in self.historical_data and
|
||||||
|
tf in self.historical_data[symbol] and
|
||||||
|
not self.historical_data[symbol][tf].empty)
|
||||||
|
status['historical_data_loaded'][symbol][tf] = has_data
|
||||||
|
|
||||||
|
return status
|
516
core/orchestrator.py
Normal file
516
core/orchestrator.py
Normal file
@ -0,0 +1,516 @@
|
|||||||
|
"""
|
||||||
|
Trading Orchestrator - Main Decision Making Module
|
||||||
|
|
||||||
|
This is the core orchestrator that:
|
||||||
|
1. Coordinates CNN and RL modules via model registry
|
||||||
|
2. Combines their outputs with confidence weighting
|
||||||
|
3. Makes final trading decisions (BUY/SELL/HOLD)
|
||||||
|
4. Manages the learning loop between components
|
||||||
|
5. Ensures memory efficiency (8GB constraint)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .config import get_config
|
||||||
|
from .data_provider import DataProvider
|
||||||
|
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Prediction:
|
||||||
|
"""Represents a prediction from a model"""
|
||||||
|
action: str # 'BUY', 'SELL', 'HOLD'
|
||||||
|
confidence: float # 0.0 to 1.0
|
||||||
|
probabilities: Dict[str, float] # Probabilities for each action
|
||||||
|
timeframe: str # Timeframe this prediction is for
|
||||||
|
timestamp: datetime
|
||||||
|
model_name: str # Name of the model that made this prediction
|
||||||
|
metadata: Dict[str, Any] = None # Additional model-specific data
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TradingDecision:
|
||||||
|
"""Final trading decision from the orchestrator"""
|
||||||
|
action: str # 'BUY', 'SELL', 'HOLD'
|
||||||
|
confidence: float # Combined confidence
|
||||||
|
symbol: str
|
||||||
|
price: float
|
||||||
|
timestamp: datetime
|
||||||
|
reasoning: Dict[str, Any] # Why this decision was made
|
||||||
|
memory_usage: Dict[str, int] # Memory usage of models
|
||||||
|
|
||||||
|
class TradingOrchestrator:
|
||||||
|
"""
|
||||||
|
Main orchestrator that coordinates multiple AI models for trading decisions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_provider: DataProvider = None):
|
||||||
|
"""Initialize the orchestrator"""
|
||||||
|
self.config = get_config()
|
||||||
|
self.data_provider = data_provider or DataProvider()
|
||||||
|
self.model_registry = get_model_registry()
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.5)
|
||||||
|
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 60)
|
||||||
|
|
||||||
|
# Dynamic weights (will be adapted based on performance)
|
||||||
|
self.model_weights = {} # {model_name: weight}
|
||||||
|
self._initialize_default_weights()
|
||||||
|
|
||||||
|
# State tracking
|
||||||
|
self.last_decision_time = {} # {symbol: datetime}
|
||||||
|
self.recent_decisions = {} # {symbol: List[TradingDecision]}
|
||||||
|
self.model_performance = {} # {model_name: {'correct': int, 'total': int, 'accuracy': float}}
|
||||||
|
|
||||||
|
# Decision callbacks
|
||||||
|
self.decision_callbacks = []
|
||||||
|
|
||||||
|
logger.info("TradingOrchestrator initialized with modular model system")
|
||||||
|
logger.info(f"Confidence threshold: {self.confidence_threshold}")
|
||||||
|
logger.info(f"Decision frequency: {self.decision_frequency}s")
|
||||||
|
|
||||||
|
def _initialize_default_weights(self):
|
||||||
|
"""Initialize default model weights from config"""
|
||||||
|
self.model_weights = {
|
||||||
|
'CNN': self.config.orchestrator.get('cnn_weight', 0.7),
|
||||||
|
'RL': self.config.orchestrator.get('rl_weight', 0.3)
|
||||||
|
}
|
||||||
|
|
||||||
|
def register_model(self, model: ModelInterface, weight: float = None) -> bool:
|
||||||
|
"""Register a new model with the orchestrator"""
|
||||||
|
try:
|
||||||
|
# Register with model registry
|
||||||
|
if not self.model_registry.register_model(model):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Set weight
|
||||||
|
if weight is not None:
|
||||||
|
self.model_weights[model.name] = weight
|
||||||
|
elif model.name not in self.model_weights:
|
||||||
|
self.model_weights[model.name] = 0.1 # Default low weight for new models
|
||||||
|
|
||||||
|
# Initialize performance tracking
|
||||||
|
if model.name not in self.model_performance:
|
||||||
|
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||||||
|
|
||||||
|
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
|
||||||
|
self._normalize_weights()
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error registering model {model.name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def unregister_model(self, model_name: str) -> bool:
|
||||||
|
"""Unregister a model"""
|
||||||
|
try:
|
||||||
|
if self.model_registry.unregister_model(model_name):
|
||||||
|
if model_name in self.model_weights:
|
||||||
|
del self.model_weights[model_name]
|
||||||
|
if model_name in self.model_performance:
|
||||||
|
del self.model_performance[model_name]
|
||||||
|
|
||||||
|
self._normalize_weights()
|
||||||
|
logger.info(f"Unregistered {model_name} model")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error unregistering model {model_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _normalize_weights(self):
|
||||||
|
"""Normalize model weights to sum to 1.0"""
|
||||||
|
total_weight = sum(self.model_weights.values())
|
||||||
|
if total_weight > 0:
|
||||||
|
for model_name in self.model_weights:
|
||||||
|
self.model_weights[model_name] /= total_weight
|
||||||
|
|
||||||
|
def add_decision_callback(self, callback):
|
||||||
|
"""Add a callback function to be called when decisions are made"""
|
||||||
|
self.decision_callbacks.append(callback)
|
||||||
|
|
||||||
|
async def make_trading_decision(self, symbol: str) -> Optional[TradingDecision]:
|
||||||
|
"""
|
||||||
|
Make a trading decision for a symbol by combining all registered model outputs
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
current_time = datetime.now()
|
||||||
|
|
||||||
|
# Check if enough time has passed since last decision
|
||||||
|
if symbol in self.last_decision_time:
|
||||||
|
time_since_last = (current_time - self.last_decision_time[symbol]).total_seconds()
|
||||||
|
if time_since_last < self.decision_frequency:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get current market data
|
||||||
|
current_price = self.data_provider.get_current_price(symbol)
|
||||||
|
if current_price is None:
|
||||||
|
logger.warning(f"No current price available for {symbol}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get predictions from all registered models
|
||||||
|
predictions = await self._get_all_predictions(symbol)
|
||||||
|
|
||||||
|
if not predictions:
|
||||||
|
logger.warning(f"No predictions available for {symbol}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Combine predictions
|
||||||
|
decision = self._combine_predictions(
|
||||||
|
symbol=symbol,
|
||||||
|
price=current_price,
|
||||||
|
predictions=predictions,
|
||||||
|
timestamp=current_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update state
|
||||||
|
self.last_decision_time[symbol] = current_time
|
||||||
|
if symbol not in self.recent_decisions:
|
||||||
|
self.recent_decisions[symbol] = []
|
||||||
|
self.recent_decisions[symbol].append(decision)
|
||||||
|
|
||||||
|
# Keep only recent decisions (last 100)
|
||||||
|
if len(self.recent_decisions[symbol]) > 100:
|
||||||
|
self.recent_decisions[symbol] = self.recent_decisions[symbol][-100:]
|
||||||
|
|
||||||
|
# Call decision callbacks
|
||||||
|
for callback in self.decision_callbacks:
|
||||||
|
try:
|
||||||
|
await callback(decision)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in decision callback: {e}")
|
||||||
|
|
||||||
|
# Clean up memory periodically
|
||||||
|
if len(self.recent_decisions[symbol]) % 50 == 0:
|
||||||
|
self.model_registry.cleanup_all_models()
|
||||||
|
|
||||||
|
return decision
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error making trading decision for {symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||||||
|
"""Get predictions from all registered models"""
|
||||||
|
predictions = []
|
||||||
|
|
||||||
|
for model_name, model in self.model_registry.models.items():
|
||||||
|
try:
|
||||||
|
if isinstance(model, CNNModelInterface):
|
||||||
|
# Get CNN predictions for each timeframe
|
||||||
|
cnn_predictions = await self._get_cnn_predictions(model, symbol)
|
||||||
|
predictions.extend(cnn_predictions)
|
||||||
|
|
||||||
|
elif isinstance(model, RLAgentInterface):
|
||||||
|
# Get RL prediction
|
||||||
|
rl_prediction = await self._get_rl_prediction(model, symbol)
|
||||||
|
if rl_prediction:
|
||||||
|
predictions.append(rl_prediction)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Generic model interface
|
||||||
|
generic_prediction = await self._get_generic_prediction(model, symbol)
|
||||||
|
if generic_prediction:
|
||||||
|
predictions.append(generic_prediction)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
|
||||||
|
"""Get predictions from CNN model for all timeframes"""
|
||||||
|
predictions = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for timeframe in self.config.timeframes:
|
||||||
|
# Get feature matrix for this timeframe
|
||||||
|
feature_matrix = self.data_provider.get_feature_matrix(
|
||||||
|
symbol=symbol,
|
||||||
|
timeframes=[timeframe],
|
||||||
|
window_size=model.window_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if feature_matrix is not None:
|
||||||
|
# Get CNN prediction
|
||||||
|
try:
|
||||||
|
action_probs, confidence = model.predict_timeframe(feature_matrix, timeframe)
|
||||||
|
except AttributeError:
|
||||||
|
# Fallback to generic predict method
|
||||||
|
action_probs, confidence = model.predict(feature_matrix)
|
||||||
|
|
||||||
|
if action_probs is not None:
|
||||||
|
# Convert to prediction object
|
||||||
|
action_names = ['SELL', 'HOLD', 'BUY']
|
||||||
|
best_action_idx = np.argmax(action_probs)
|
||||||
|
best_action = action_names[best_action_idx]
|
||||||
|
|
||||||
|
prediction = Prediction(
|
||||||
|
action=best_action,
|
||||||
|
confidence=float(confidence) if confidence is not None else float(action_probs[best_action_idx]),
|
||||||
|
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
|
||||||
|
timeframe=timeframe,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
model_name=model.name,
|
||||||
|
metadata={'timeframe_specific': True}
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions.append(prediction)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting CNN predictions: {e}")
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
async def _get_rl_prediction(self, model: RLAgentInterface, symbol: str) -> Optional[Prediction]:
|
||||||
|
"""Get prediction from RL agent"""
|
||||||
|
try:
|
||||||
|
# Get current state for RL agent
|
||||||
|
state = self._get_rl_state(symbol)
|
||||||
|
if state is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get RL agent's action and confidence
|
||||||
|
action_idx, confidence = model.act_with_confidence(state)
|
||||||
|
|
||||||
|
action_names = ['SELL', 'HOLD', 'BUY']
|
||||||
|
action = action_names[action_idx]
|
||||||
|
|
||||||
|
# Create prediction object
|
||||||
|
prediction = Prediction(
|
||||||
|
action=action,
|
||||||
|
confidence=float(confidence),
|
||||||
|
probabilities={action: float(confidence), 'HOLD': 1.0 - float(confidence)},
|
||||||
|
timeframe='mixed', # RL uses mixed timeframes
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
model_name=model.name,
|
||||||
|
metadata={'state_size': len(state)}
|
||||||
|
)
|
||||||
|
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting RL prediction: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
|
||||||
|
"""Get prediction from generic model"""
|
||||||
|
try:
|
||||||
|
# Get feature matrix for the model
|
||||||
|
feature_matrix = self.data_provider.get_feature_matrix(
|
||||||
|
symbol=symbol,
|
||||||
|
timeframes=self.config.timeframes[:3], # Use first 3 timeframes
|
||||||
|
window_size=20
|
||||||
|
)
|
||||||
|
|
||||||
|
if feature_matrix is not None:
|
||||||
|
action_probs, confidence = model.predict(feature_matrix)
|
||||||
|
|
||||||
|
if action_probs is not None:
|
||||||
|
action_names = ['SELL', 'HOLD', 'BUY']
|
||||||
|
best_action_idx = np.argmax(action_probs)
|
||||||
|
best_action = action_names[best_action_idx]
|
||||||
|
|
||||||
|
prediction = Prediction(
|
||||||
|
action=best_action,
|
||||||
|
confidence=float(confidence),
|
||||||
|
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
|
||||||
|
timeframe='mixed',
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
model_name=model.name,
|
||||||
|
metadata={'generic_model': True}
|
||||||
|
)
|
||||||
|
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting generic prediction: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||||
|
"""Get current state for RL agent"""
|
||||||
|
try:
|
||||||
|
# Get feature matrix for all timeframes
|
||||||
|
feature_matrix = self.data_provider.get_feature_matrix(
|
||||||
|
symbol=symbol,
|
||||||
|
timeframes=self.config.timeframes,
|
||||||
|
window_size=self.config.rl.get('window_size', 20)
|
||||||
|
)
|
||||||
|
|
||||||
|
if feature_matrix is not None:
|
||||||
|
# Flatten the feature matrix for RL agent
|
||||||
|
# Shape: (n_timeframes, window_size, n_features) -> (n_timeframes * window_size * n_features,)
|
||||||
|
state = feature_matrix.flatten()
|
||||||
|
|
||||||
|
# Add additional state information (position, balance, etc.)
|
||||||
|
# This would come from a portfolio manager in a real implementation
|
||||||
|
additional_state = np.array([0.0, 1.0, 0.0]) # [position, balance, unrealized_pnl]
|
||||||
|
|
||||||
|
return np.concatenate([state, additional_state])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating RL state for {symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _combine_predictions(self, symbol: str, price: float,
|
||||||
|
predictions: List[Prediction],
|
||||||
|
timestamp: datetime) -> TradingDecision:
|
||||||
|
"""Combine all predictions into a final decision"""
|
||||||
|
try:
|
||||||
|
reasoning = {
|
||||||
|
'predictions': len(predictions),
|
||||||
|
'weights': self.model_weights.copy(),
|
||||||
|
'models_used': [pred.model_name for pred in predictions]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize action scores
|
||||||
|
action_scores = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
|
||||||
|
total_weight = 0.0
|
||||||
|
|
||||||
|
# Process all predictions
|
||||||
|
for pred in predictions:
|
||||||
|
# Get model weight
|
||||||
|
model_weight = self.model_weights.get(pred.model_name, 0.1)
|
||||||
|
|
||||||
|
# Weight by confidence and timeframe importance
|
||||||
|
timeframe_weight = self._get_timeframe_weight(pred.timeframe)
|
||||||
|
weighted_confidence = pred.confidence * timeframe_weight * model_weight
|
||||||
|
|
||||||
|
action_scores[pred.action] += weighted_confidence
|
||||||
|
total_weight += weighted_confidence
|
||||||
|
|
||||||
|
# Normalize scores
|
||||||
|
if total_weight > 0:
|
||||||
|
for action in action_scores:
|
||||||
|
action_scores[action] /= total_weight
|
||||||
|
|
||||||
|
# Choose best action
|
||||||
|
best_action = max(action_scores, key=action_scores.get)
|
||||||
|
best_confidence = action_scores[best_action]
|
||||||
|
|
||||||
|
# Apply confidence threshold
|
||||||
|
if best_confidence < self.confidence_threshold:
|
||||||
|
best_action = 'HOLD'
|
||||||
|
reasoning['threshold_applied'] = True
|
||||||
|
|
||||||
|
# Get memory usage stats
|
||||||
|
memory_usage = self.model_registry.get_memory_stats()
|
||||||
|
|
||||||
|
# Create final decision
|
||||||
|
decision = TradingDecision(
|
||||||
|
action=best_action,
|
||||||
|
confidence=best_confidence,
|
||||||
|
symbol=symbol,
|
||||||
|
price=price,
|
||||||
|
timestamp=timestamp,
|
||||||
|
reasoning=reasoning,
|
||||||
|
memory_usage=memory_usage['models']
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Decision for {symbol}: {best_action} (confidence: {best_confidence:.3f})")
|
||||||
|
logger.debug(f"Memory usage: {memory_usage['total_used_mb']:.1f}MB / {memory_usage['total_limit_mb']:.1f}MB")
|
||||||
|
|
||||||
|
return decision
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error combining predictions for {symbol}: {e}")
|
||||||
|
# Return safe default
|
||||||
|
return TradingDecision(
|
||||||
|
action='HOLD',
|
||||||
|
confidence=0.0,
|
||||||
|
symbol=symbol,
|
||||||
|
price=price,
|
||||||
|
timestamp=timestamp,
|
||||||
|
reasoning={'error': str(e)},
|
||||||
|
memory_usage={}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_timeframe_weight(self, timeframe: str) -> float:
|
||||||
|
"""Get importance weight for a timeframe"""
|
||||||
|
# Higher timeframes get more weight in decision making
|
||||||
|
weights = {
|
||||||
|
'1m': 0.1, '5m': 0.2, '15m': 0.3, '30m': 0.4,
|
||||||
|
'1h': 0.6, '4h': 0.8, '1d': 1.0
|
||||||
|
}
|
||||||
|
return weights.get(timeframe, 0.5)
|
||||||
|
|
||||||
|
def update_model_performance(self, model_name: str, was_correct: bool):
|
||||||
|
"""Update performance tracking for a model"""
|
||||||
|
if model_name in self.model_performance:
|
||||||
|
self.model_performance[model_name]['total'] += 1
|
||||||
|
if was_correct:
|
||||||
|
self.model_performance[model_name]['correct'] += 1
|
||||||
|
|
||||||
|
# Update accuracy
|
||||||
|
total = self.model_performance[model_name]['total']
|
||||||
|
correct = self.model_performance[model_name]['correct']
|
||||||
|
self.model_performance[model_name]['accuracy'] = correct / total if total > 0 else 0.0
|
||||||
|
|
||||||
|
def adapt_weights(self):
|
||||||
|
"""Dynamically adapt model weights based on performance"""
|
||||||
|
try:
|
||||||
|
for model_name, performance in self.model_performance.items():
|
||||||
|
if performance['total'] > 0:
|
||||||
|
# Adjust weight based on relative performance
|
||||||
|
accuracy = performance['correct'] / performance['total']
|
||||||
|
self.model_weights[model_name] = accuracy
|
||||||
|
|
||||||
|
logger.info(f"Adapted {model_name} weight: {self.model_weights[model_name]}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adapting weights: {e}")
|
||||||
|
|
||||||
|
def get_recent_decisions(self, symbol: str, limit: int = 10) -> List[TradingDecision]:
|
||||||
|
"""Get recent decisions for a symbol"""
|
||||||
|
if symbol in self.recent_decisions:
|
||||||
|
return self.recent_decisions[symbol][-limit:]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_performance_metrics(self) -> Dict[str, Any]:
|
||||||
|
"""Get performance metrics for the orchestrator"""
|
||||||
|
return {
|
||||||
|
'model_performance': self.model_performance.copy(),
|
||||||
|
'weights': self.model_weights.copy(),
|
||||||
|
'configuration': {
|
||||||
|
'confidence_threshold': self.confidence_threshold,
|
||||||
|
'decision_frequency': self.decision_frequency
|
||||||
|
},
|
||||||
|
'recent_activity': {
|
||||||
|
symbol: len(decisions) for symbol, decisions in self.recent_decisions.items()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def start_continuous_trading(self, symbols: List[str] = None):
|
||||||
|
"""Start continuous trading decisions for specified symbols"""
|
||||||
|
if symbols is None:
|
||||||
|
symbols = self.config.symbols
|
||||||
|
|
||||||
|
logger.info(f"Starting continuous trading for symbols: {symbols}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Make decisions for all symbols
|
||||||
|
for symbol in symbols:
|
||||||
|
decision = await self.make_trading_decision(symbol)
|
||||||
|
if decision and decision.action != 'HOLD':
|
||||||
|
logger.info(f"Trading decision: {decision.action} {symbol} at {decision.price}")
|
||||||
|
|
||||||
|
# Wait before next decision cycle
|
||||||
|
await asyncio.sleep(self.decision_frequency)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in continuous trading loop: {e}")
|
||||||
|
await asyncio.sleep(10) # Wait before retrying
|
221
main_clean.py
Normal file
221
main_clean.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Clean Trading System - Main Entry Point
|
||||||
|
|
||||||
|
This is the new clean entry point that demonstrates the consolidated architecture:
|
||||||
|
- Single configuration system
|
||||||
|
- Clean data provider
|
||||||
|
- Modular CNN and RL components
|
||||||
|
- Centralized orchestrator
|
||||||
|
- Simple web dashboard
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python main_clean.py --mode [train|trade|web] --symbol ETH/USDT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from core.config import get_config, setup_logging
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
from core.orchestrator import TradingOrchestrator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def run_data_test():
|
||||||
|
"""Test the data provider functionality"""
|
||||||
|
try:
|
||||||
|
config = get_config()
|
||||||
|
logger.info("Testing Data Provider...")
|
||||||
|
|
||||||
|
# Test data provider
|
||||||
|
data_provider = DataProvider(
|
||||||
|
symbols=['ETH/USDT'],
|
||||||
|
timeframes=['1h', '4h']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test historical data
|
||||||
|
logger.info("Testing historical data fetching...")
|
||||||
|
df = data_provider.get_historical_data('ETH/USDT', '1h', limit=100)
|
||||||
|
if df is not None:
|
||||||
|
logger.info(f"[SUCCESS] Historical data: {len(df)} candles loaded")
|
||||||
|
logger.info(f" Columns: {list(df.columns)}")
|
||||||
|
logger.info(f" Date range: {df['timestamp'].min()} to {df['timestamp'].max()}")
|
||||||
|
else:
|
||||||
|
logger.error("[FAILED] Failed to load historical data")
|
||||||
|
|
||||||
|
# Test feature matrix
|
||||||
|
logger.info("Testing feature matrix...")
|
||||||
|
feature_matrix = data_provider.get_feature_matrix('ETH/USDT', ['1h'], window_size=20)
|
||||||
|
if feature_matrix is not None:
|
||||||
|
logger.info(f"[SUCCESS] Feature matrix shape: {feature_matrix.shape}")
|
||||||
|
else:
|
||||||
|
logger.error("[FAILED] Failed to create feature matrix")
|
||||||
|
|
||||||
|
# Test health check
|
||||||
|
health = data_provider.health_check()
|
||||||
|
logger.info(f"[SUCCESS] Data provider health: {health}")
|
||||||
|
|
||||||
|
logger.info("Data provider test completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in data test: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
||||||
|
def run_orchestrator_test():
|
||||||
|
"""Test the modular orchestrator system"""
|
||||||
|
try:
|
||||||
|
from models import get_model_registry, ModelInterface
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger.info("Testing Modular Orchestrator System...")
|
||||||
|
|
||||||
|
# Test model registry
|
||||||
|
registry = get_model_registry()
|
||||||
|
logger.info(f"[SUCCESS] Model registry initialized with {registry.total_memory_limit_mb}MB limit")
|
||||||
|
|
||||||
|
# Create a mock model for testing
|
||||||
|
class MockCNNModel(ModelInterface):
|
||||||
|
def __init__(self):
|
||||||
|
config = {'max_memory_mb': 500} # 500MB limit
|
||||||
|
super().__init__('MockCNN', config)
|
||||||
|
self.model_params = torch.randn(1000, 100) # Small mock model
|
||||||
|
|
||||||
|
def predict(self, features):
|
||||||
|
# Mock prediction: random but consistent
|
||||||
|
np.random.seed(42)
|
||||||
|
action_probs = np.random.dirichlet([1, 1, 1]) # Random probabilities that sum to 1
|
||||||
|
confidence = np.random.uniform(0.5, 0.9)
|
||||||
|
return action_probs, confidence
|
||||||
|
|
||||||
|
def get_memory_usage(self):
|
||||||
|
# Estimate memory usage
|
||||||
|
if hasattr(self, 'model_params'):
|
||||||
|
return int(self.model_params.numel() * 4 / (1024*1024)) # 4 bytes per float, convert to MB
|
||||||
|
return 0
|
||||||
|
|
||||||
|
class MockRLAgent(ModelInterface):
|
||||||
|
def __init__(self):
|
||||||
|
config = {'max_memory_mb': 300} # 300MB limit
|
||||||
|
super().__init__('MockRL', config)
|
||||||
|
self.q_network = torch.randn(500, 50) # Smaller mock RL model
|
||||||
|
|
||||||
|
def predict(self, features):
|
||||||
|
# Mock RL prediction
|
||||||
|
np.random.seed(123)
|
||||||
|
action_probs = np.random.dirichlet([2, 1, 2]) # Favor BUY/SELL over HOLD
|
||||||
|
confidence = np.random.uniform(0.6, 0.8)
|
||||||
|
return action_probs, confidence
|
||||||
|
|
||||||
|
def act_with_confidence(self, state):
|
||||||
|
action_probs, confidence = self.predict(state)
|
||||||
|
action = np.argmax(action_probs)
|
||||||
|
return action, confidence
|
||||||
|
|
||||||
|
def get_memory_usage(self):
|
||||||
|
if hasattr(self, 'q_network'):
|
||||||
|
return int(self.q_network.numel() * 4 / (1024*1024))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def act(self, state):
|
||||||
|
return self.act_with_confidence(state)[0]
|
||||||
|
|
||||||
|
def remember(self, state, action, reward, next_state, done):
|
||||||
|
pass # Mock implementation
|
||||||
|
|
||||||
|
def replay(self):
|
||||||
|
return 0.0 # Mock implementation
|
||||||
|
|
||||||
|
# Test model registration
|
||||||
|
logger.info("Testing model registration...")
|
||||||
|
mock_cnn = MockCNNModel()
|
||||||
|
mock_rl = MockRLAgent()
|
||||||
|
|
||||||
|
success1 = registry.register_model(mock_cnn)
|
||||||
|
success2 = registry.register_model(mock_rl)
|
||||||
|
|
||||||
|
if success1 and success2:
|
||||||
|
logger.info("[SUCCESS] Both models registered successfully")
|
||||||
|
else:
|
||||||
|
logger.error(f"[FAILED] Model registration failed: CNN={success1}, RL={success2}")
|
||||||
|
|
||||||
|
# Test memory stats
|
||||||
|
memory_stats = registry.get_memory_stats()
|
||||||
|
logger.info(f"[SUCCESS] Memory stats: {memory_stats}")
|
||||||
|
|
||||||
|
# Test orchestrator
|
||||||
|
logger.info("Testing orchestrator integration...")
|
||||||
|
data_provider = DataProvider(symbols=['ETH/USDT'], timeframes=['1h'])
|
||||||
|
orchestrator = TradingOrchestrator(data_provider)
|
||||||
|
|
||||||
|
# Register models with orchestrator
|
||||||
|
success1 = orchestrator.register_model(mock_cnn, weight=0.7)
|
||||||
|
success2 = orchestrator.register_model(mock_rl, weight=0.3)
|
||||||
|
|
||||||
|
if success1 and success2:
|
||||||
|
logger.info("[SUCCESS] Models registered with orchestrator")
|
||||||
|
else:
|
||||||
|
logger.error(f"[FAILED] Orchestrator registration failed")
|
||||||
|
|
||||||
|
# Test orchestrator metrics
|
||||||
|
metrics = orchestrator.get_performance_metrics()
|
||||||
|
logger.info(f"[SUCCESS] Orchestrator metrics: {metrics}")
|
||||||
|
|
||||||
|
logger.info("Modular orchestrator test completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in orchestrator test: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main entry point"""
|
||||||
|
parser = argparse.ArgumentParser(description='Clean Trading System')
|
||||||
|
parser.add_argument('--mode', choices=['trade', 'train', 'web', 'test', 'orchestrator'],
|
||||||
|
default='test', help='Mode to run the system in')
|
||||||
|
parser.add_argument('--symbol', type=str, help='Override default symbol')
|
||||||
|
parser.add_argument('--config', type=str, default='config.yaml',
|
||||||
|
help='Configuration file path')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("CLEAN TRADING SYSTEM STARTING")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# Run appropriate mode
|
||||||
|
if args.mode == 'test':
|
||||||
|
run_data_test()
|
||||||
|
elif args.mode == 'orchestrator':
|
||||||
|
run_orchestrator_test()
|
||||||
|
else:
|
||||||
|
logger.info(f"Mode '{args.mode}' not yet implemented in clean architecture")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("System shutdown requested by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Fatal error: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return 1
|
||||||
|
|
||||||
|
logger.info("Clean Trading System finished")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(asyncio.run(main()))
|
0
trading/__init__.py
Normal file
0
trading/__init__.py
Normal file
@ -1,19 +0,0 @@
|
|||||||
"""
|
|
||||||
Utility functions for port management, launching services, and debug tools.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from utils.port_manager import (
|
|
||||||
is_port_in_use,
|
|
||||||
find_available_port,
|
|
||||||
kill_process_by_port,
|
|
||||||
kill_stale_debug_instances,
|
|
||||||
get_port_with_fallback
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'is_port_in_use',
|
|
||||||
'find_available_port',
|
|
||||||
'kill_process_by_port',
|
|
||||||
'kill_stale_debug_instances',
|
|
||||||
'get_port_with_fallback'
|
|
||||||
]
|
|
0
web/__init__.py
Normal file
0
web/__init__.py
Normal file
Loading…
x
Reference in New Issue
Block a user