refactoring
This commit is contained in:
@@ -1,472 +0,0 @@
|
||||
# CNN Model Training, Decision Making, and Dashboard Visualization Analysis
|
||||
|
||||
## Comprehensive Analysis: Enhanced RL Training Systems
|
||||
|
||||
### User Questions Addressed:
|
||||
1. **CNN Model Training Implementation** ✅
|
||||
2. **Decision-Making Model Training System** ✅
|
||||
3. **Model Predictions and Training Progress Visualization on Clean Dashboard** ✅
|
||||
4. **🔧 FIXED: Signal Generation and Model Loading Issues** ✅
|
||||
5. **🎯 FIXED: Manual Trading Execution and Chart Visualization** ✅
|
||||
6. **🚫 CRITICAL FIX: Removed ALL Simulated COB Data - Using REAL COB Only** ✅
|
||||
|
||||
---
|
||||
|
||||
## 🚫 **MAJOR SYSTEM CLEANUP: NO MORE SIMULATED DATA**
|
||||
|
||||
### **🔥 REMOVED ALL SIMULATION COMPONENTS**
|
||||
|
||||
**Problem Identified**: The system was using simulated COB data instead of the real COB integration that's already implemented and working.
|
||||
|
||||
**Root Cause**: Dashboard was creating separate simulated COB components instead of connecting to the existing Enhanced Orchestrator's real COB integration.
|
||||
|
||||
### **💥 SIMULATION COMPONENTS REMOVED:**
|
||||
|
||||
#### **1. Removed Simulated COB Data Generation**
|
||||
- ❌ `_generate_simulated_cob_data()` - **DELETED**
|
||||
- ❌ `_start_cob_simulation_thread()` - **DELETED**
|
||||
- ❌ `_update_cob_cache_from_price_data()` - **DELETED**
|
||||
- ❌ All `random.uniform()` COB data generation - **ELIMINATED**
|
||||
- ❌ Fake bid/ask level creation - **REMOVED**
|
||||
- ❌ Simulated liquidity calculations - **PURGED**
|
||||
|
||||
#### **2. Removed Separate RL COB Trader**
|
||||
- ❌ `RealtimeRLCOBTrader` initialization - **DELETED**
|
||||
- ❌ `cob_rl_trader` instance variables - **REMOVED**
|
||||
- ❌ `cob_predictions` deque caches - **ELIMINATED**
|
||||
- ❌ `cob_data_cache_1d` buffers - **PURGED**
|
||||
- ❌ `cob_raw_ticks` collections - **DELETED**
|
||||
- ❌ `_start_cob_data_subscription()` - **REMOVED**
|
||||
- ❌ `_on_cob_prediction()` callback - **DELETED**
|
||||
|
||||
#### **3. Updated COB Status System**
|
||||
- ✅ **Real COB Integration Detection**: Connects to `orchestrator.cob_integration`
|
||||
- ✅ **Actual COB Statistics**: Uses `cob_integration.get_statistics()`
|
||||
- ✅ **Live COB Snapshots**: Uses `cob_integration.get_cob_snapshot(symbol)`
|
||||
- ✅ **No Simulation Status**: Removed all "Simulated" status messages
|
||||
|
||||
### **🔗 REAL COB INTEGRATION CONNECTION**
|
||||
|
||||
#### **How Real COB Data Works:**
|
||||
1. **Enhanced Orchestrator** initializes with real COB integration
|
||||
2. **COB Integration** connects to live market data streams (Binance, OKX, etc.)
|
||||
3. **Dashboard** connects to orchestrator's COB integration via callbacks
|
||||
4. **Real-time Updates** flow: `Market → COB Provider → COB Integration → Dashboard`
|
||||
|
||||
#### **Real COB Data Path:**
|
||||
```
|
||||
Live Market Data (Multiple Exchanges)
|
||||
↓
|
||||
Multi-Exchange COB Provider
|
||||
↓
|
||||
COB Integration (Real Consolidated Order Book)
|
||||
↓
|
||||
Enhanced Trading Orchestrator
|
||||
↓
|
||||
Clean Trading Dashboard (Real COB Display)
|
||||
```
|
||||
|
||||
### **✅ VERIFICATION IMPLEMENTED**
|
||||
|
||||
#### **Enhanced COB Status Checking:**
|
||||
```python
|
||||
# Check for REAL COB integration from enhanced orchestrator
|
||||
if hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration:
|
||||
cob_integration = self.orchestrator.cob_integration
|
||||
|
||||
# Get real COB integration statistics
|
||||
cob_stats = cob_integration.get_statistics()
|
||||
if cob_stats:
|
||||
active_symbols = cob_stats.get('active_symbols', [])
|
||||
total_updates = cob_stats.get('total_updates', 0)
|
||||
provider_status = cob_stats.get('provider_status', 'Unknown')
|
||||
```
|
||||
|
||||
#### **Real COB Data Retrieval:**
|
||||
```python
|
||||
# Get from REAL COB integration via enhanced orchestrator
|
||||
snapshot = cob_integration.get_cob_snapshot(symbol)
|
||||
if snapshot:
|
||||
# Process REAL consolidated order book data
|
||||
return snapshot
|
||||
```
|
||||
|
||||
### **📊 STATUS MESSAGES UPDATED**
|
||||
|
||||
#### **Before (Simulation):**
|
||||
- ❌ `"COB-SIM BTC/USDT - Update #20, Mid: $107068.03, Spread: 7.1bps"`
|
||||
- ❌ `"Simulated (2 symbols)"`
|
||||
- ❌ `"COB simulation thread started"`
|
||||
|
||||
#### **After (Real Data Only):**
|
||||
- ✅ `"REAL COB Active (2 symbols)"`
|
||||
- ✅ `"No Enhanced Orchestrator COB Integration"` (when missing)
|
||||
- ✅ `"Retrieved REAL COB snapshot for ETH/USDT"`
|
||||
- ✅ `"REAL COB integration connected successfully"`
|
||||
|
||||
### **🚨 CRITICAL SYSTEM MESSAGES**
|
||||
|
||||
#### **If Enhanced Orchestrator Missing COB:**
|
||||
```
|
||||
CRITICAL: Enhanced orchestrator has NO COB integration!
|
||||
This means we're using basic orchestrator instead of enhanced one
|
||||
Dashboard will NOT have real COB data until this is fixed
|
||||
```
|
||||
|
||||
#### **Success Messages:**
|
||||
```
|
||||
REAL COB integration found: <class 'core.cob_integration.COBIntegration'>
|
||||
Registered dashboard callback with REAL COB integration
|
||||
NO SIMULATION - Using live market data only
|
||||
```
|
||||
|
||||
### **🔧 NEXT STEPS REQUIRED**
|
||||
|
||||
#### **1. Verify Enhanced Orchestrator Usage**
|
||||
- ✅ **main.py** correctly uses `EnhancedTradingOrchestrator`
|
||||
- ✅ **COB Integration** properly initialized in orchestrator
|
||||
- 🔍 **Need to verify**: Dashboard receives real COB callbacks
|
||||
|
||||
#### **2. Debug Connection Issues**
|
||||
- Dashboard shows connection attempts but no listening port
|
||||
- Enhanced orchestrator may need COB integration startup verification
|
||||
- Real COB data flow needs testing
|
||||
|
||||
#### **3. Test Real COB Data Display**
|
||||
- Verify COB snapshots contain real market data
|
||||
- Confirm bid/ask levels from actual exchanges
|
||||
- Validate liquidity and spread calculations
|
||||
|
||||
### **💡 VERIFICATION COMMANDS**
|
||||
|
||||
#### **Check COB Integration Status:**
|
||||
```python
|
||||
# In dashboard initialization:
|
||||
logger.info(f"Orchestrator type: {type(self.orchestrator)}")
|
||||
logger.info(f"Has COB integration: {hasattr(self.orchestrator, 'cob_integration')}")
|
||||
logger.info(f"COB integration active: {self.orchestrator.cob_integration is not None}")
|
||||
```
|
||||
|
||||
#### **Test Real COB Data:**
|
||||
```python
|
||||
# Test real COB snapshot retrieval:
|
||||
snapshot = self.orchestrator.cob_integration.get_cob_snapshot('ETH/USDT')
|
||||
logger.info(f"Real COB snapshot: {snapshot}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🚀 LATEST FIXES IMPLEMENTED (Manual Trading & Chart Visualization)
|
||||
|
||||
### 🔧 Manual Trading Buttons - FULLY FIXED ✅
|
||||
|
||||
**Problem**: Manual buy/sell buttons weren't executing trades properly
|
||||
|
||||
**Root Cause Analysis**:
|
||||
- Missing `execute_trade` method in `TradingExecutor`
|
||||
- Missing `get_closed_trades` and `get_current_position` methods
|
||||
- No proper trade record creation and tracking
|
||||
|
||||
**Solution Applied**:
|
||||
1. **Added missing methods to TradingExecutor**:
|
||||
- `execute_trade()` - Direct trade execution with proper error handling
|
||||
- `get_closed_trades()` - Returns trade history in dashboard format
|
||||
- `get_current_position()` - Returns current position information
|
||||
|
||||
2. **Enhanced manual trading execution**:
|
||||
- Proper error handling and trade recording
|
||||
- Real P&L tracking (+$0.05 demo profit for SELL orders)
|
||||
- Session metrics updates (trade count, total P&L, fees)
|
||||
- Visual confirmation of executed vs blocked trades
|
||||
|
||||
3. **Trade record structure**:
|
||||
```python
|
||||
trade_record = {
|
||||
'symbol': symbol,
|
||||
'side': action, # 'BUY' or 'SELL'
|
||||
'quantity': 0.01,
|
||||
'entry_price': current_price,
|
||||
'exit_price': current_price,
|
||||
'entry_time': datetime.now(),
|
||||
'exit_time': datetime.now(),
|
||||
'pnl': demo_pnl, # Real P&L calculation
|
||||
'fees': 0.0,
|
||||
'confidence': 1.0 # Manual trades = 100% confidence
|
||||
}
|
||||
```
|
||||
|
||||
### 📊 Chart Visualization - COMPLETELY SEPARATED ✅
|
||||
|
||||
**Problem**: All signals and trades were mixed together on charts
|
||||
|
||||
**Requirements**:
|
||||
- **1s mini chart**: Show ALL signals (executed + non-executed)
|
||||
- **1m main chart**: Show ONLY executed trades
|
||||
|
||||
**Solution Implemented**:
|
||||
|
||||
#### **1s Mini Chart (Row 2) - ALL SIGNALS:**
|
||||
- ✅ **Executed BUY signals**: Solid green triangles-up
|
||||
- ✅ **Executed SELL signals**: Solid red triangles-down
|
||||
- ✅ **Pending BUY signals**: Hollow green triangles-up
|
||||
- ✅ **Pending SELL signals**: Hollow red triangles-down
|
||||
- ✅ **Independent axis**: Can zoom/pan separately from main chart
|
||||
- ✅ **Real-time updates**: Shows all trading activity
|
||||
|
||||
#### **1m Main Chart (Row 1) - EXECUTED TRADES ONLY:**
|
||||
- ✅ **Executed BUY trades**: Large green circles with confidence hover
|
||||
- ✅ **Executed SELL trades**: Large red circles with confidence hover
|
||||
- ✅ **Professional display**: Clean execution-only view
|
||||
- ✅ **P&L information**: Hover shows actual profit/loss
|
||||
|
||||
#### **Chart Architecture:**
|
||||
```python
|
||||
# Main 1m chart - EXECUTED TRADES ONLY
|
||||
executed_signals = [signal for signal in self.recent_decisions if signal.get('executed', False)]
|
||||
|
||||
# 1s mini chart - ALL SIGNALS
|
||||
all_signals = self.recent_decisions[-50:] # Last 50 signals
|
||||
executed_buys = [s for s in buy_signals if s['executed']]
|
||||
pending_buys = [s for s in buy_signals if not s['executed']]
|
||||
```
|
||||
|
||||
### 🎯 Variable Scope Error - FIXED ✅
|
||||
|
||||
**Problem**: `cannot access local variable 'last_action' where it is not associated with a value`
|
||||
|
||||
**Root Cause**: Variables declared inside conditional blocks weren't accessible when conditions were False
|
||||
|
||||
**Solution Applied**:
|
||||
```python
|
||||
# BEFORE (caused error):
|
||||
if condition:
|
||||
last_action = 'BUY'
|
||||
last_confidence = 0.8
|
||||
# last_action accessed here would fail if condition was False
|
||||
|
||||
# AFTER (fixed):
|
||||
last_action = 'NONE'
|
||||
last_confidence = 0.0
|
||||
if condition:
|
||||
last_action = 'BUY'
|
||||
last_confidence = 0.8
|
||||
# Variables always defined
|
||||
```
|
||||
|
||||
### 🔇 Unicode Logging Errors - FIXED ✅
|
||||
|
||||
**Problem**: `UnicodeEncodeError: 'charmap' codec can't encode character '\U0001f4c8'`
|
||||
|
||||
**Root Cause**: Windows console (cp1252) can't handle Unicode emoji characters
|
||||
|
||||
**Solution Applied**: Removed ALL emoji icons from log messages:
|
||||
- `🚀 Starting...` → `Starting...`
|
||||
- `✅ Success` → `Success`
|
||||
- `📊 Data` → `Data`
|
||||
- `🔧 Fixed` → `Fixed`
|
||||
- `❌ Error` → `Error`
|
||||
|
||||
**Result**: Clean ASCII-only logging compatible with Windows console
|
||||
|
||||
---
|
||||
|
||||
## 🧠 CNN Model Training Implementation
|
||||
|
||||
### A. Williams Market Structure CNN Architecture
|
||||
|
||||
**Model Specifications:**
|
||||
- **Architecture**: Enhanced CNN with ResNet blocks, self-attention, and multi-task learning
|
||||
- **Parameters**: ~50M parameters (Williams) + 400M parameters (COB-RL optimized)
|
||||
- **Input Shape**: (900, 50) - 900 timesteps (1s bars), 50 features per timestep
|
||||
- **Output**: 10-class direction prediction + confidence scores
|
||||
|
||||
**Training Triggers:**
|
||||
1. **Real-time Pivot Detection**: Confirmed local extrema (tops/bottoms)
|
||||
2. **Perfect Move Identification**: >2% price moves within prediction window
|
||||
3. **Negative Case Training**: Failed predictions for intensive learning
|
||||
4. **Multi-timeframe Validation**: 1s, 1m, 1h, 1d consistency checks
|
||||
|
||||
### B. Feature Engineering Pipeline
|
||||
|
||||
**5 Timeseries Universal Format:**
|
||||
1. **ETH/USDT Ticks** (1s) - Primary trading pair real-time data
|
||||
2. **ETH/USDT 1m** - Short-term price action and patterns
|
||||
3. **ETH/USDT 1h** - Medium-term trends and momentum
|
||||
4. **ETH/USDT 1d** - Long-term market structure
|
||||
5. **BTC/USDT Ticks** (1s) - Reference asset for correlation analysis
|
||||
|
||||
**Feature Matrix Construction:**
|
||||
```python
|
||||
# Williams Market Structure Features (900x50 matrix)
|
||||
- OHLCV data (5 cols)
|
||||
- Technical indicators (15 cols)
|
||||
- Market microstructure (10 cols)
|
||||
- COB integration features (10 cols)
|
||||
- Cross-asset correlation (5 cols)
|
||||
- Temporal dynamics (5 cols)
|
||||
```
|
||||
|
||||
### C. Retrospective Training System
|
||||
|
||||
**Perfect Move Detection:**
|
||||
- **Threshold**: 2% price change within 15-minute window
|
||||
- **Context**: 200-candle history for enhanced pattern recognition
|
||||
- **Validation**: Multi-timeframe confirmation (1s→1m→1h consistency)
|
||||
- **Auto-labeling**: Optimal action determination for supervised learning
|
||||
|
||||
**Training Data Pipeline:**
|
||||
```
|
||||
Market Event → Extrema Detection → Perfect Move Validation → Feature Matrix → CNN Training
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Decision-Making Model Training System
|
||||
|
||||
### A. Neural Decision Fusion Architecture
|
||||
|
||||
**Model Integration Weights:**
|
||||
- **CNN Predictions**: 70% weight (Williams Market Structure)
|
||||
- **RL Agent Decisions**: 30% weight (DQN with sensitivity levels)
|
||||
- **COB RL Integration**: Dynamic weight based on market conditions
|
||||
|
||||
**Decision Fusion Process:**
|
||||
```python
|
||||
# Neural Decision Fusion combines all model predictions
|
||||
williams_pred = cnn_model.predict(market_state) # 70% weight
|
||||
dqn_action = rl_agent.act(state_vector) # 30% weight
|
||||
cob_signal = cob_rl.get_direction(order_book_state) # Variable weight
|
||||
|
||||
final_decision = neural_fusion.combine(williams_pred, dqn_action, cob_signal)
|
||||
```
|
||||
|
||||
### B. Enhanced Training Weight System
|
||||
|
||||
**Training Weight Multipliers:**
|
||||
- **Regular Predictions**: 1× base weight
|
||||
- **Signal Accumulation**: 1× weight (3+ confident predictions)
|
||||
- **🔥 Actual Trade Execution**: 10× weight multiplier**
|
||||
- **P&L-based Reward**: Enhanced feedback loop
|
||||
|
||||
**Trade Execution Enhanced Learning:**
|
||||
```python
|
||||
# 10× weight for actual trade outcomes
|
||||
if trade_executed:
|
||||
enhanced_reward = pnl_ratio * 10.0
|
||||
model.train_on_batch(state, action, enhanced_reward)
|
||||
|
||||
# Immediate training on last 3 signals that led to trade
|
||||
for signal in last_3_signals:
|
||||
model.retrain_signal(signal, actual_outcome)
|
||||
```
|
||||
|
||||
### C. Sensitivity Learning DQN
|
||||
|
||||
**5 Sensitivity Levels:**
|
||||
- **very_low** (0.1): Conservative, high-confidence only
|
||||
- **low** (0.3): Selective entry/exit
|
||||
- **medium** (0.5): Balanced approach
|
||||
- **high** (0.7): Aggressive trading
|
||||
- **very_high** (0.9): Maximum activity
|
||||
|
||||
**Adaptive Threshold System:**
|
||||
```python
|
||||
# Sensitivity affects confidence thresholds
|
||||
entry_threshold = base_threshold * sensitivity_multiplier
|
||||
exit_threshold = base_threshold * (1 - sensitivity_level)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📊 Dashboard Visualization and Model Monitoring
|
||||
|
||||
### A. Real-time Model Predictions Display
|
||||
|
||||
**Model Status Section:**
|
||||
- ✅ **Loaded Models**: DQN (5M params), CNN (50M params), COB-RL (400M params)
|
||||
- ✅ **Real-time Loss Tracking**: 5-MA loss for each model
|
||||
- ✅ **Prediction Counts**: Total predictions generated per model
|
||||
- ✅ **Last Prediction**: Timestamp, action, confidence for each model
|
||||
|
||||
**Training Metrics Visualization:**
|
||||
```python
|
||||
# Real-time model performance tracking
|
||||
{
|
||||
'dqn': {
|
||||
'active': True,
|
||||
'parameters': 5000000,
|
||||
'loss_5ma': 0.0234,
|
||||
'last_prediction': {'action': 'BUY', 'confidence': 0.67},
|
||||
'epsilon': 0.15 # Exploration rate
|
||||
},
|
||||
'cnn': {
|
||||
'active': True,
|
||||
'parameters': 50000000,
|
||||
'loss_5ma': 0.0198,
|
||||
'last_prediction': {'action': 'HOLD', 'confidence': 0.45}
|
||||
},
|
||||
'cob_rl': {
|
||||
'active': True,
|
||||
'parameters': 400000000,
|
||||
'loss_5ma': 0.012,
|
||||
'predictions_count': 1247
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### B. Training Progress Monitoring
|
||||
|
||||
**Loss Visualization:**
|
||||
- **Real-time Loss Charts**: 5-minute moving average for each model
|
||||
- **Training Status**: Active sessions, parameter counts, update frequencies
|
||||
- **Signal Generation**: ACTIVE/INACTIVE status with last update timestamps
|
||||
|
||||
**Performance Metrics Dashboard:**
|
||||
- **Session P&L**: Real-time profit/loss tracking
|
||||
- **Trade Accuracy**: Success rate of executed trades
|
||||
- **Model Confidence Trends**: Average confidence over time
|
||||
- **Training Iterations**: Progress tracking for continuous learning
|
||||
|
||||
### C. COB Integration Visualization
|
||||
|
||||
**Real-time COB Data Display:**
|
||||
- **Order Book Levels**: Bid/ask spreads and liquidity depth
|
||||
- **Exchange Breakdown**: Multi-exchange liquidity sources
|
||||
- **Market Microstructure**: Imbalance ratios and flow analysis
|
||||
- **COB Feature Status**: CNN features and RL state availability
|
||||
|
||||
**Training Pipeline Integration:**
|
||||
- **COB → CNN Features**: Real-time market microstructure patterns
|
||||
- **COB → RL States**: Enhanced state vectors for decision making
|
||||
- **Performance Tracking**: COB integration health monitoring
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Key System Capabilities
|
||||
|
||||
### Real-time Learning Pipeline
|
||||
1. **Market Data Ingestion**: 5 timeseries universal format
|
||||
2. **Feature Engineering**: Multi-timeframe analysis with COB integration
|
||||
3. **Model Predictions**: CNN, DQN, and COB-RL ensemble
|
||||
4. **Decision Fusion**: Neural network combines all predictions
|
||||
5. **Trade Execution**: 10× enhanced learning from actual trades
|
||||
6. **Retrospective Training**: Perfect move detection and model updates
|
||||
|
||||
### Enhanced Training Systems
|
||||
- **Continuous Learning**: Models update in real-time from market outcomes
|
||||
- **Multi-modal Integration**: CNN + RL + COB predictions combined intelligently
|
||||
- **Sensitivity Adaptation**: DQN adjusts risk appetite based on performance
|
||||
- **Perfect Move Detection**: Automatic identification of optimal trading opportunities
|
||||
- **Negative Case Training**: Intensive learning from failed predictions
|
||||
|
||||
### Dashboard Monitoring
|
||||
- **Real-time Model Status**: Active models, parameters, loss tracking
|
||||
- **Live Predictions**: Current model outputs with confidence scores
|
||||
- **Training Metrics**: Loss trends, accuracy rates, iteration counts
|
||||
- **COB Integration**: Real-time order book analysis and microstructure data
|
||||
- **Performance Tracking**: P&L, trade accuracy, model effectiveness
|
||||
|
||||
The system provides a comprehensive ML-driven trading environment with real-time learning, multi-modal decision making, and advanced market microstructure analysis through COB integration.
|
||||
|
||||
**Dashboard URL**: http://127.0.0.1:8051
|
||||
**Status**: ✅ FULLY OPERATIONAL
|
@@ -1,194 +0,0 @@
|
||||
# Enhanced Training Integration Report
|
||||
*Generated: 2024-12-19*
|
||||
|
||||
## 🎯 Integration Objective
|
||||
|
||||
Integrate the restored `EnhancedRealtimeTrainingSystem` into the orchestrator and audit the `EnhancedRLTrainingIntegrator` to determine if it can be used for comprehensive RL training.
|
||||
|
||||
## 📊 EnhancedRealtimeTrainingSystem Analysis
|
||||
|
||||
### **✅ Successfully Integrated**
|
||||
|
||||
The `EnhancedRealtimeTrainingSystem` has been successfully integrated into the orchestrator with the following capabilities:
|
||||
|
||||
#### **Core Features**
|
||||
- **Real-time Data Collection**: Multi-timeframe OHLCV, tick data, COB snapshots
|
||||
- **Enhanced DQN Training**: Prioritized experience replay with market-aware rewards
|
||||
- **CNN Training**: Real-time pattern recognition training
|
||||
- **Forward-looking Predictions**: Generates predictions for future validation
|
||||
- **Adaptive Learning**: Adjusts training frequency based on performance
|
||||
- **Comprehensive State Building**: 13,400+ feature states for RL training
|
||||
|
||||
#### **Integration Points in Orchestrator**
|
||||
```python
|
||||
# New orchestrator capabilities:
|
||||
self.enhanced_training_system: Optional[EnhancedRealtimeTrainingSystem] = None
|
||||
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
|
||||
|
||||
# Methods added:
|
||||
def _initialize_enhanced_training_system()
|
||||
def start_enhanced_training()
|
||||
def stop_enhanced_training()
|
||||
def get_enhanced_training_stats()
|
||||
def set_training_dashboard(dashboard)
|
||||
```
|
||||
|
||||
#### **Training Capabilities**
|
||||
1. **Real-time Data Streams**:
|
||||
- OHLCV data (1m, 5m intervals)
|
||||
- Tick-level market data
|
||||
- COB (Change of Bid) snapshots
|
||||
- Market event detection
|
||||
|
||||
2. **Enhanced Model Training**:
|
||||
- DQN with prioritized experience replay
|
||||
- CNN with multi-timeframe features
|
||||
- Comprehensive reward engineering
|
||||
- Performance-based adaptation
|
||||
|
||||
3. **Prediction Tracking**:
|
||||
- Forward-looking predictions with validation
|
||||
- Accuracy measurement and tracking
|
||||
- Model confidence scoring
|
||||
|
||||
## 🔍 EnhancedRLTrainingIntegrator Audit
|
||||
|
||||
### **Purpose & Scope**
|
||||
The `EnhancedRLTrainingIntegrator` is a comprehensive testing and validation system designed to:
|
||||
- Verify 13,400-feature comprehensive state building
|
||||
- Test enhanced pivot-based reward calculation
|
||||
- Validate Williams market structure integration
|
||||
- Demonstrate live comprehensive training
|
||||
|
||||
### **Audit Results**
|
||||
|
||||
#### **✅ Valuable Components**
|
||||
1. **Comprehensive State Verification**: Tests for exactly 13,400 features
|
||||
2. **Feature Distribution Analysis**: Analyzes non-zero vs zero features
|
||||
3. **Enhanced Reward Testing**: Validates pivot-based reward calculations
|
||||
4. **Williams Integration**: Tests market structure feature extraction
|
||||
5. **Live Training Demo**: Demonstrates coordinated decision making
|
||||
|
||||
#### **🔧 Integration Challenges**
|
||||
1. **Dependency Issues**: References `core.enhanced_orchestrator.EnhancedTradingOrchestrator` (not available)
|
||||
2. **Missing Methods**: Expects methods not present in current orchestrator:
|
||||
- `build_comprehensive_rl_state()`
|
||||
- `calculate_enhanced_pivot_reward()`
|
||||
- `make_coordinated_decisions()`
|
||||
3. **Williams Module**: Depends on `training.williams_market_structure` (needs verification)
|
||||
|
||||
#### **💡 Recommended Usage**
|
||||
The `EnhancedRLTrainingIntegrator` should be used as a **testing and validation tool** rather than direct integration:
|
||||
|
||||
```python
|
||||
# Use as standalone testing script
|
||||
python enhanced_rl_training_integration.py
|
||||
|
||||
# Or import specific testing functions
|
||||
from enhanced_rl_training_integration import EnhancedRLTrainingIntegrator
|
||||
integrator = EnhancedRLTrainingIntegrator()
|
||||
await integrator._verify_comprehensive_state_building()
|
||||
```
|
||||
|
||||
## 🚀 Implementation Strategy
|
||||
|
||||
### **Phase 1: EnhancedRealtimeTrainingSystem (✅ COMPLETE)**
|
||||
- [x] Integrated into orchestrator
|
||||
- [x] Added initialization methods
|
||||
- [x] Connected to data provider
|
||||
- [x] Dashboard integration support
|
||||
|
||||
### **Phase 2: Enhanced Methods (🔄 IN PROGRESS)**
|
||||
Add missing methods expected by the integrator:
|
||||
|
||||
```python
|
||||
# Add to orchestrator:
|
||||
def build_comprehensive_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Build comprehensive 13,400+ feature state for RL training"""
|
||||
|
||||
def calculate_enhanced_pivot_reward(self, trade_decision: Dict,
|
||||
market_data: Dict,
|
||||
trade_outcome: Dict) -> float:
|
||||
"""Calculate enhanced pivot-based rewards"""
|
||||
|
||||
async def make_coordinated_decisions(self) -> Dict[str, TradingDecision]:
|
||||
"""Make coordinated decisions across all symbols"""
|
||||
```
|
||||
|
||||
### **Phase 3: Validation Integration (📋 PLANNED)**
|
||||
Use `EnhancedRLTrainingIntegrator` as a validation tool:
|
||||
|
||||
```python
|
||||
# Integration validation workflow:
|
||||
1. Start enhanced training system
|
||||
2. Run comprehensive state building tests
|
||||
3. Validate reward calculation accuracy
|
||||
4. Test Williams market structure integration
|
||||
5. Monitor live training performance
|
||||
```
|
||||
|
||||
## 📈 Benefits of Integration
|
||||
|
||||
### **Real-time Learning**
|
||||
- Continuous model improvement during live trading
|
||||
- Adaptive learning based on market conditions
|
||||
- Forward-looking prediction validation
|
||||
|
||||
### **Comprehensive Features**
|
||||
- 13,400+ feature comprehensive states
|
||||
- Multi-timeframe market analysis
|
||||
- COB microstructure integration
|
||||
- Enhanced reward engineering
|
||||
|
||||
### **Performance Monitoring**
|
||||
- Real-time training statistics
|
||||
- Model accuracy tracking
|
||||
- Adaptive parameter adjustment
|
||||
- Comprehensive logging
|
||||
|
||||
## 🎯 Next Steps
|
||||
|
||||
### **Immediate Actions**
|
||||
1. **Complete Method Implementation**: Add missing orchestrator methods
|
||||
2. **Williams Module Verification**: Ensure market structure module is available
|
||||
3. **Testing Integration**: Use integrator for validation testing
|
||||
4. **Dashboard Connection**: Connect training system to dashboard
|
||||
|
||||
### **Future Enhancements**
|
||||
1. **Multi-Symbol Coordination**: Enhance coordinated decision making
|
||||
2. **Advanced Reward Engineering**: Implement sophisticated reward functions
|
||||
3. **Model Ensemble**: Combine multiple model predictions
|
||||
4. **Performance Optimization**: GPU acceleration for training
|
||||
|
||||
## 📊 Integration Status
|
||||
|
||||
| Component | Status | Notes |
|
||||
|-----------|--------|-------|
|
||||
| EnhancedRealtimeTrainingSystem | ✅ Integrated | Fully functional in orchestrator |
|
||||
| Real-time Data Collection | ✅ Available | Multi-timeframe data streams |
|
||||
| Enhanced DQN Training | ✅ Available | Prioritized experience replay |
|
||||
| CNN Training | ✅ Available | Pattern recognition training |
|
||||
| Forward Predictions | ✅ Available | Prediction validation system |
|
||||
| EnhancedRLTrainingIntegrator | 🔧 Partial | Use as validation tool |
|
||||
| Comprehensive State Building | 📋 Planned | Need to implement method |
|
||||
| Enhanced Reward Calculation | 📋 Planned | Need to implement method |
|
||||
| Williams Integration | ❓ Unknown | Need to verify module |
|
||||
|
||||
## 🏆 Conclusion
|
||||
|
||||
The `EnhancedRealtimeTrainingSystem` has been successfully integrated into the orchestrator, providing comprehensive real-time training capabilities. The `EnhancedRLTrainingIntegrator` serves as an excellent validation and testing tool, but requires additional method implementations in the orchestrator for full functionality.
|
||||
|
||||
**Key Achievements:**
|
||||
- ✅ Real-time training system fully integrated
|
||||
- ✅ Comprehensive feature extraction capabilities
|
||||
- ✅ Enhanced reward engineering framework
|
||||
- ✅ Forward-looking prediction validation
|
||||
- ✅ Performance monitoring and adaptation
|
||||
|
||||
**Recommended Actions:**
|
||||
1. Use the integrated training system for live model improvement
|
||||
2. Implement missing orchestrator methods for full integrator compatibility
|
||||
3. Use the integrator as a comprehensive testing and validation tool
|
||||
4. Monitor training performance and adapt parameters as needed
|
||||
|
||||
The integration provides a solid foundation for advanced ML-driven trading with continuous learning capabilities.
|
@@ -14,7 +14,7 @@ from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
import torch
|
||||
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata
|
||||
from NN.training.model_manager import create_model_manager, CheckpointMetadata
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
class CheckpointCleanup:
|
||||
def __init__(self):
|
||||
self.saved_models_dir = Path("NN/models/saved")
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
self.checkpoint_manager = create_model_manager()
|
||||
|
||||
def analyze_existing_checkpoints(self) -> Dict[str, Any]:
|
||||
logger.info("Analyzing existing checkpoint files...")
|
||||
|
@@ -35,7 +35,7 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, get_checkpoint_stats
|
||||
from NN.training.model_manager import create_model_manager
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
# Import training components
|
||||
@@ -55,7 +55,7 @@ class CheckpointIntegratedTrainingSystem:
|
||||
self.running = False
|
||||
|
||||
# Checkpoint management
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
self.checkpoint_manager = create_model_manager()
|
||||
self.training_integration = get_training_integration()
|
||||
|
||||
# Data provider
|
||||
|
@@ -1,5 +1,7 @@
|
||||
"""
|
||||
Enhanced Model Management System for Trading Dashboard
|
||||
Unified Model Management System for Trading Dashboard
|
||||
|
||||
CONSOLIDATED SYSTEM - All model management functionality in one place
|
||||
|
||||
This system provides:
|
||||
- Automatic cleanup of old model checkpoints
|
||||
@@ -7,6 +9,9 @@ This system provides:
|
||||
- Configurable retention policies
|
||||
- Startup model loading
|
||||
- Performance-based model selection
|
||||
- Robust model saving with multiple fallback strategies
|
||||
- Checkpoint management with W&B integration
|
||||
- Centralized storage using @checkpoints/ structure
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -15,17 +20,30 @@ import shutil
|
||||
import logging
|
||||
import torch
|
||||
import glob
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import hashlib
|
||||
import random
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Any, Optional, List, Tuple, Union
|
||||
from collections import defaultdict
|
||||
|
||||
# W&B import (optional)
|
||||
try:
|
||||
import wandb
|
||||
WANDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
WANDB_AVAILABLE = False
|
||||
wandb = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelMetrics:
|
||||
"""Performance metrics for model evaluation"""
|
||||
"""Enhanced performance metrics for model evaluation"""
|
||||
accuracy: float = 0.0
|
||||
profit_factor: float = 0.0
|
||||
win_rate: float = 0.0
|
||||
@@ -34,41 +52,66 @@ class ModelMetrics:
|
||||
total_trades: int = 0
|
||||
avg_trade_duration: float = 0.0
|
||||
confidence_score: float = 0.0
|
||||
|
||||
|
||||
# Additional metrics from checkpoint_manager
|
||||
loss: Optional[float] = None
|
||||
val_accuracy: Optional[float] = None
|
||||
val_loss: Optional[float] = None
|
||||
reward: Optional[float] = None
|
||||
pnl: Optional[float] = None
|
||||
epoch: Optional[int] = None
|
||||
training_time_hours: Optional[float] = None
|
||||
total_parameters: Optional[int] = None
|
||||
|
||||
def get_composite_score(self) -> float:
|
||||
"""Calculate composite performance score"""
|
||||
# Weighted composite score
|
||||
weights = {
|
||||
'profit_factor': 0.3,
|
||||
'sharpe_ratio': 0.25,
|
||||
'win_rate': 0.2,
|
||||
'profit_factor': 0.25,
|
||||
'sharpe_ratio': 0.2,
|
||||
'win_rate': 0.15,
|
||||
'accuracy': 0.15,
|
||||
'confidence_score': 0.1
|
||||
'confidence_score': 0.1,
|
||||
'loss_penalty': 0.1, # New: penalize high loss
|
||||
'val_penalty': 0.05 # New: penalize validation loss
|
||||
}
|
||||
|
||||
|
||||
# Normalize values to 0-1 range
|
||||
normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0
|
||||
normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1
|
||||
normalized_win_rate = self.win_rate
|
||||
normalized_accuracy = self.accuracy
|
||||
normalized_confidence = self.confidence_score
|
||||
|
||||
|
||||
# Loss penalty (lower loss = higher score)
|
||||
loss_penalty = 1.0
|
||||
if self.loss is not None and self.loss > 0:
|
||||
loss_penalty = max(0.1, 1 / (1 + self.loss)) # Better loss = higher penalty
|
||||
|
||||
# Validation penalty
|
||||
val_penalty = 1.0
|
||||
if self.val_loss is not None and self.val_loss > 0:
|
||||
val_penalty = max(0.1, 1 / (1 + self.val_loss))
|
||||
|
||||
# Apply penalties for poor performance
|
||||
drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown
|
||||
|
||||
|
||||
score = (
|
||||
weights['profit_factor'] * normalized_pf +
|
||||
weights['sharpe_ratio'] * normalized_sharpe +
|
||||
weights['win_rate'] * normalized_win_rate +
|
||||
weights['accuracy'] * normalized_accuracy +
|
||||
weights['confidence_score'] * normalized_confidence
|
||||
weights['confidence_score'] * normalized_confidence +
|
||||
weights['loss_penalty'] * loss_penalty +
|
||||
weights['val_penalty'] * val_penalty
|
||||
) * drawdown_penalty
|
||||
|
||||
|
||||
return min(max(score, 0), 1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Complete model information and metadata"""
|
||||
"""Model information tracking"""
|
||||
model_type: str # 'cnn', 'rl', 'transformer'
|
||||
model_name: str
|
||||
file_path: str
|
||||
@@ -78,14 +121,14 @@ class ModelInfo:
|
||||
metrics: ModelMetrics
|
||||
training_episodes: int = 0
|
||||
model_version: str = "1.0"
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
data = asdict(self)
|
||||
data['creation_time'] = self.creation_time.isoformat()
|
||||
data['last_updated'] = self.last_updated.isoformat()
|
||||
return data
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
|
||||
"""Create from dictionary"""
|
||||
@@ -94,465 +137,400 @@ class ModelInfo:
|
||||
data['metrics'] = ModelMetrics(**data['metrics'])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointMetadata:
|
||||
checkpoint_id: str
|
||||
model_name: str
|
||||
model_type: str
|
||||
file_path: str
|
||||
created_at: datetime
|
||||
file_size_mb: float
|
||||
performance_score: float
|
||||
accuracy: Optional[float] = None
|
||||
loss: Optional[float] = None
|
||||
val_accuracy: Optional[float] = None
|
||||
val_loss: Optional[float] = None
|
||||
reward: Optional[float] = None
|
||||
pnl: Optional[float] = None
|
||||
epoch: Optional[int] = None
|
||||
training_time_hours: Optional[float] = None
|
||||
total_parameters: Optional[int] = None
|
||||
wandb_run_id: Optional[str] = None
|
||||
wandb_artifact_name: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = asdict(self)
|
||||
data['created_at'] = self.created_at.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
|
||||
data['created_at'] = datetime.fromisoformat(data['created_at'])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Enhanced model management system"""
|
||||
|
||||
"""Unified model management system with @checkpoints/ structure"""
|
||||
|
||||
def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None):
|
||||
self.base_dir = Path(base_dir)
|
||||
self.config = config or self._get_default_config()
|
||||
|
||||
# Model directories
|
||||
self.models_dir = self.base_dir / "models"
|
||||
|
||||
# Updated directory structure using @checkpoints/
|
||||
self.checkpoints_dir = self.base_dir / "@checkpoints"
|
||||
self.models_dir = self.checkpoints_dir / "models"
|
||||
self.saved_dir = self.checkpoints_dir / "saved"
|
||||
self.best_models_dir = self.checkpoints_dir / "best_models"
|
||||
self.archive_dir = self.checkpoints_dir / "archive"
|
||||
|
||||
# Model type directories within @checkpoints/
|
||||
self.model_dirs = {
|
||||
'cnn': self.checkpoints_dir / "cnn",
|
||||
'dqn': self.checkpoints_dir / "dqn",
|
||||
'rl': self.checkpoints_dir / "rl",
|
||||
'transformer': self.checkpoints_dir / "transformer",
|
||||
'hybrid': self.checkpoints_dir / "hybrid"
|
||||
}
|
||||
|
||||
# Legacy directories for backward compatibility
|
||||
self.nn_models_dir = self.base_dir / "NN" / "models"
|
||||
self.registry_file = self.models_dir / "model_registry.json"
|
||||
self.best_models_dir = self.models_dir / "best_models"
|
||||
|
||||
# Create directories
|
||||
self.best_models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Model registry
|
||||
self.model_registry: Dict[str, ModelInfo] = {}
|
||||
self._load_registry()
|
||||
|
||||
logger.info(f"Model Manager initialized - Base: {self.base_dir}")
|
||||
logger.info(f"Retention policy: Keep {self.config['max_models_per_type']} best models per type")
|
||||
|
||||
self.legacy_models_dir = self.base_dir / "models"
|
||||
|
||||
# Metadata and checkpoint management
|
||||
self.metadata_file = self.checkpoints_dir / "model_metadata.json"
|
||||
self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json"
|
||||
|
||||
# Initialize storage
|
||||
self._initialize_directories()
|
||||
self.metadata = self._load_metadata()
|
||||
self.checkpoint_metadata = self._load_checkpoint_metadata()
|
||||
|
||||
logger.info(f"ModelManager initialized with @checkpoints/ structure at {self.checkpoints_dir}")
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration"""
|
||||
return {
|
||||
'max_models_per_type': 3, # Keep top 3 models per type
|
||||
'max_total_models': 10, # Maximum total models to keep
|
||||
'cleanup_frequency_hours': 24, # Cleanup every 24 hours
|
||||
'min_performance_threshold': 0.3, # Minimum composite score
|
||||
'max_checkpoint_age_days': 7, # Delete checkpoints older than 7 days
|
||||
'auto_cleanup_enabled': True,
|
||||
'backup_before_cleanup': True,
|
||||
'model_size_limit_mb': 100, # Individual model size limit
|
||||
'total_storage_limit_gb': 5.0 # Total storage limit
|
||||
'max_checkpoints_per_model': 5,
|
||||
'cleanup_old_models': True,
|
||||
'auto_archive': True,
|
||||
'wandb_enabled': WANDB_AVAILABLE,
|
||||
'checkpoint_retention_days': 30
|
||||
}
|
||||
|
||||
def _load_registry(self):
|
||||
"""Load model registry from file"""
|
||||
try:
|
||||
if self.registry_file.exists():
|
||||
with open(self.registry_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
self.model_registry = {
|
||||
k: ModelInfo.from_dict(v) for k, v in data.items()
|
||||
}
|
||||
logger.info(f"Loaded {len(self.model_registry)} models from registry")
|
||||
else:
|
||||
logger.info("No existing model registry found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model registry: {e}")
|
||||
self.model_registry = {}
|
||||
|
||||
def _save_registry(self):
|
||||
"""Save model registry to file"""
|
||||
try:
|
||||
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.registry_file, 'w') as f:
|
||||
data = {k: v.to_dict() for k, v in self.model_registry.items()}
|
||||
json.dump(data, f, indent=2, default=str)
|
||||
logger.info(f"Saved registry with {len(self.model_registry)} models")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model registry: {e}")
|
||||
|
||||
def cleanup_all_existing_models(self, confirm: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Clean up all existing model files and prepare for 2-action system training
|
||||
|
||||
Args:
|
||||
confirm: If True, perform the cleanup. If False, return what would be cleaned
|
||||
|
||||
Returns:
|
||||
Dict with cleanup statistics
|
||||
"""
|
||||
cleanup_stats = {
|
||||
'files_found': 0,
|
||||
'files_deleted': 0,
|
||||
'directories_cleaned': 0,
|
||||
'space_freed_mb': 0.0,
|
||||
'errors': []
|
||||
}
|
||||
|
||||
# Model file patterns for both 2-action and legacy 3-action systems
|
||||
model_patterns = [
|
||||
"**/*.pt", "**/*.pth", "**/*.h5", "**/*.pkl", "**/*.joblib", "**/*.model",
|
||||
"**/checkpoint_*", "**/model_*", "**/cnn_*", "**/dqn_*", "**/rl_*"
|
||||
]
|
||||
|
||||
# Directories to clean
|
||||
model_directories = [
|
||||
"models/saved",
|
||||
"NN/models/saved",
|
||||
"NN/models/saved/checkpoints",
|
||||
"NN/models/saved/realtime_checkpoints",
|
||||
"NN/models/saved/realtime_ticks_checkpoints",
|
||||
"model_backups"
|
||||
]
|
||||
|
||||
try:
|
||||
# Scan for files to be cleaned
|
||||
for directory in model_directories:
|
||||
dir_path = Path(self.base_dir) / directory
|
||||
if dir_path.exists():
|
||||
for pattern in model_patterns:
|
||||
for file_path in dir_path.glob(pattern):
|
||||
if file_path.is_file():
|
||||
cleanup_stats['files_found'] += 1
|
||||
file_size = file_path.stat().st_size / (1024 * 1024) # MB
|
||||
cleanup_stats['space_freed_mb'] += file_size
|
||||
|
||||
if confirm:
|
||||
try:
|
||||
file_path.unlink()
|
||||
cleanup_stats['files_deleted'] += 1
|
||||
logger.info(f"Deleted model file: {file_path}")
|
||||
except Exception as e:
|
||||
cleanup_stats['errors'].append(f"Failed to delete {file_path}: {e}")
|
||||
|
||||
# Clean up empty checkpoint directories
|
||||
for directory in model_directories:
|
||||
dir_path = Path(self.base_dir) / directory
|
||||
if dir_path.exists():
|
||||
for subdir in dir_path.rglob("*"):
|
||||
if subdir.is_dir() and not any(subdir.iterdir()):
|
||||
if confirm:
|
||||
try:
|
||||
subdir.rmdir()
|
||||
cleanup_stats['directories_cleaned'] += 1
|
||||
logger.info(f"Removed empty directory: {subdir}")
|
||||
except Exception as e:
|
||||
cleanup_stats['errors'].append(f"Failed to remove directory {subdir}: {e}")
|
||||
|
||||
if confirm:
|
||||
# Clear the registry for fresh start with 2-action system
|
||||
self.model_registry = {
|
||||
'models': {},
|
||||
'metadata': {
|
||||
'last_updated': datetime.now().isoformat(),
|
||||
'total_models': 0,
|
||||
'system_type': '2_action', # Mark as 2-action system
|
||||
'action_space': ['SELL', 'BUY'],
|
||||
'version': '2.0'
|
||||
}
|
||||
}
|
||||
self._save_registry()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("MODEL CLEANUP COMPLETED - 2-ACTION SYSTEM READY")
|
||||
logger.info(f"Files deleted: {cleanup_stats['files_deleted']}")
|
||||
logger.info(f"Space freed: {cleanup_stats['space_freed_mb']:.2f} MB")
|
||||
logger.info(f"Directories cleaned: {cleanup_stats['directories_cleaned']}")
|
||||
logger.info("Registry reset for 2-action system (BUY/SELL)")
|
||||
logger.info("Ready for fresh training with intelligent position management")
|
||||
logger.info("=" * 60)
|
||||
else:
|
||||
logger.info("=" * 60)
|
||||
logger.info("MODEL CLEANUP PREVIEW - 2-ACTION SYSTEM MIGRATION")
|
||||
logger.info(f"Files to delete: {cleanup_stats['files_found']}")
|
||||
logger.info(f"Space to free: {cleanup_stats['space_freed_mb']:.2f} MB")
|
||||
logger.info("Run with confirm=True to perform cleanup")
|
||||
logger.info("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
cleanup_stats['errors'].append(f"Cleanup error: {e}")
|
||||
logger.error(f"Error during model cleanup: {e}")
|
||||
|
||||
return cleanup_stats
|
||||
|
||||
def register_model(self, model_path: str, model_type: str, metrics: Optional[ModelMetrics] = None) -> str:
|
||||
"""
|
||||
Register a new model in the 2-action system
|
||||
|
||||
Args:
|
||||
model_path: Path to the model file
|
||||
model_type: Type of model ('cnn', 'rl', 'transformer')
|
||||
metrics: Performance metrics
|
||||
|
||||
Returns:
|
||||
str: Unique model name/ID
|
||||
"""
|
||||
if not Path(model_path).exists():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Generate unique model name
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_name = f"{model_type}_2action_{timestamp}"
|
||||
|
||||
# Get file info
|
||||
file_path = Path(model_path)
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
# Default metrics for 2-action system
|
||||
if metrics is None:
|
||||
metrics = ModelMetrics(
|
||||
accuracy=0.0,
|
||||
profit_factor=1.0,
|
||||
win_rate=0.5,
|
||||
sharpe_ratio=0.0,
|
||||
max_drawdown=0.0,
|
||||
confidence_score=0.5
|
||||
)
|
||||
|
||||
# Create model info
|
||||
model_info = ModelInfo(
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
file_path=str(file_path.absolute()),
|
||||
creation_time=datetime.now(),
|
||||
last_updated=datetime.now(),
|
||||
file_size_mb=file_size_mb,
|
||||
metrics=metrics,
|
||||
model_version="2.0" # 2-action system version
|
||||
)
|
||||
|
||||
# Add to registry
|
||||
self.model_registry['models'][model_name] = model_info.to_dict()
|
||||
self.model_registry['metadata']['total_models'] = len(self.model_registry['models'])
|
||||
self.model_registry['metadata']['last_updated'] = datetime.now().isoformat()
|
||||
self.model_registry['metadata']['system_type'] = '2_action'
|
||||
self.model_registry['metadata']['action_space'] = ['SELL', 'BUY']
|
||||
|
||||
self._save_registry()
|
||||
|
||||
# Cleanup old models if necessary
|
||||
self._cleanup_models_by_type(model_type)
|
||||
|
||||
logger.info(f"Registered 2-action model: {model_name}")
|
||||
logger.info(f"Model type: {model_type}, Size: {file_size_mb:.2f} MB")
|
||||
logger.info(f"Performance score: {metrics.get_composite_score():.4f}")
|
||||
|
||||
return model_name
|
||||
|
||||
def _should_keep_model(self, model_info: ModelInfo) -> bool:
|
||||
"""Determine if model should be kept based on performance"""
|
||||
score = model_info.metrics.get_composite_score()
|
||||
|
||||
# Check minimum threshold
|
||||
if score < self.config['min_performance_threshold']:
|
||||
return False
|
||||
|
||||
# Check size limit
|
||||
if model_info.file_size_mb > self.config['model_size_limit_mb']:
|
||||
logger.warning(f"Model too large: {model_info.file_size_mb:.1f}MB > {self.config['model_size_limit_mb']}MB")
|
||||
return False
|
||||
|
||||
# Check if better than existing models of same type
|
||||
existing_models = self.get_models_by_type(model_info.model_type)
|
||||
if len(existing_models) >= self.config['max_models_per_type']:
|
||||
# Find worst performing model
|
||||
worst_model = min(existing_models.values(), key=lambda m: m.metrics.get_composite_score())
|
||||
if score <= worst_model.metrics.get_composite_score():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _cleanup_models_by_type(self, model_type: str):
|
||||
"""Cleanup old models of specific type, keeping only the best ones"""
|
||||
models_of_type = self.get_models_by_type(model_type)
|
||||
max_keep = self.config['max_models_per_type']
|
||||
|
||||
if len(models_of_type) <= max_keep:
|
||||
return
|
||||
|
||||
# Sort by performance score
|
||||
sorted_models = sorted(
|
||||
models_of_type.items(),
|
||||
key=lambda x: x[1].metrics.get_composite_score(),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Keep only the best models
|
||||
models_to_keep = sorted_models[:max_keep]
|
||||
models_to_remove = sorted_models[max_keep:]
|
||||
|
||||
for model_name, model_info in models_to_remove:
|
||||
|
||||
def _initialize_directories(self):
|
||||
"""Initialize directory structure"""
|
||||
directories = [
|
||||
self.checkpoints_dir,
|
||||
self.models_dir,
|
||||
self.saved_dir,
|
||||
self.best_models_dir,
|
||||
self.archive_dir
|
||||
] + list(self.model_dirs.values())
|
||||
|
||||
for directory in directories:
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load_metadata(self) -> Dict[str, Any]:
|
||||
"""Load model metadata"""
|
||||
if self.metadata_file.exists():
|
||||
try:
|
||||
# Remove file
|
||||
model_path = Path(model_info.file_path)
|
||||
if model_path.exists():
|
||||
model_path.unlink()
|
||||
|
||||
# Remove from registry
|
||||
del self.model_registry[model_name]
|
||||
|
||||
logger.info(f"Removed old model: {model_name} (Score: {model_info.metrics.get_composite_score():.3f})")
|
||||
|
||||
with open(self.metadata_file, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing model {model_name}: {e}")
|
||||
|
||||
def get_models_by_type(self, model_type: str) -> Dict[str, ModelInfo]:
|
||||
"""Get all models of a specific type"""
|
||||
return {
|
||||
name: info for name, info in self.model_registry.items()
|
||||
if info.model_type == model_type
|
||||
}
|
||||
|
||||
def get_best_model(self, model_type: str) -> Optional[ModelInfo]:
|
||||
"""Get the best performing model of a specific type"""
|
||||
models_of_type = self.get_models_by_type(model_type)
|
||||
|
||||
if not models_of_type:
|
||||
return None
|
||||
|
||||
return max(models_of_type.values(), key=lambda m: m.metrics.get_composite_score())
|
||||
|
||||
def load_best_models(self) -> Dict[str, Any]:
|
||||
"""Load the best models for each type"""
|
||||
loaded_models = {}
|
||||
|
||||
for model_type in ['cnn', 'rl', 'transformer']:
|
||||
best_model = self.get_best_model(model_type)
|
||||
|
||||
if best_model:
|
||||
try:
|
||||
model_path = Path(best_model.file_path)
|
||||
if model_path.exists():
|
||||
# Load the model
|
||||
model_data = torch.load(model_path, map_location='cpu')
|
||||
loaded_models[model_type] = {
|
||||
'model': model_data,
|
||||
'info': best_model,
|
||||
'path': str(model_path)
|
||||
}
|
||||
logger.info(f"Loaded best {model_type} model: {best_model.model_name} "
|
||||
f"(Score: {best_model.metrics.get_composite_score():.3f})")
|
||||
else:
|
||||
logger.warning(f"Best {model_type} model file not found: {model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {model_type} model: {e}")
|
||||
else:
|
||||
logger.info(f"No {model_type} model available")
|
||||
|
||||
return loaded_models
|
||||
|
||||
def update_model_performance(self, model_name: str, metrics: ModelMetrics):
|
||||
"""Update performance metrics for a model"""
|
||||
if model_name in self.model_registry:
|
||||
self.model_registry[model_name].metrics = metrics
|
||||
self.model_registry[model_name].last_updated = datetime.now()
|
||||
self._save_registry()
|
||||
|
||||
logger.info(f"Updated metrics for {model_name}: Score {metrics.get_composite_score():.3f}")
|
||||
else:
|
||||
logger.warning(f"Model {model_name} not found in registry")
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""Get storage usage statistics"""
|
||||
total_size_mb = 0
|
||||
model_count = 0
|
||||
|
||||
for model_info in self.model_registry.values():
|
||||
total_size_mb += model_info.file_size_mb
|
||||
model_count += 1
|
||||
|
||||
# Check actual storage usage
|
||||
actual_size_mb = 0
|
||||
if self.best_models_dir.exists():
|
||||
actual_size_mb = sum(
|
||||
f.stat().st_size for f in self.best_models_dir.rglob('*') if f.is_file()
|
||||
) / 1024 / 1024
|
||||
|
||||
return {
|
||||
'total_models': model_count,
|
||||
'registered_size_mb': total_size_mb,
|
||||
'actual_size_mb': actual_size_mb,
|
||||
'storage_limit_gb': self.config['total_storage_limit_gb'],
|
||||
'utilization_percent': (actual_size_mb / 1024) / self.config['total_storage_limit_gb'] * 100,
|
||||
'models_by_type': {
|
||||
model_type: len(self.get_models_by_type(model_type))
|
||||
for model_type in ['cnn', 'rl', 'transformer']
|
||||
logger.error(f"Error loading metadata: {e}")
|
||||
return {'models': {}, 'last_updated': datetime.now().isoformat()}
|
||||
|
||||
def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Load checkpoint metadata"""
|
||||
if self.checkpoint_metadata_file.exists():
|
||||
try:
|
||||
with open(self.checkpoint_metadata_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
# Convert dict values back to CheckpointMetadata objects
|
||||
result = {}
|
||||
for key, checkpoints in data.items():
|
||||
result[key] = [CheckpointMetadata.from_dict(cp) for cp in checkpoints]
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint metadata: {e}")
|
||||
return defaultdict(list)
|
||||
|
||||
def save_checkpoint(self, model, model_name: str, model_type: str,
|
||||
performance_metrics: Dict[str, float],
|
||||
training_metadata: Optional[Dict[str, Any]] = None,
|
||||
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||
"""Save a model checkpoint with enhanced error handling and validation"""
|
||||
try:
|
||||
performance_score = self._calculate_performance_score(performance_metrics)
|
||||
|
||||
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
|
||||
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
|
||||
return None
|
||||
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = self.model_dirs.get(model_type, self.saved_dir) / "checkpoints"
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate checkpoint filename
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
checkpoint_id = f"{model_name}_{timestamp}"
|
||||
filename = f"{checkpoint_id}.pt"
|
||||
filepath = checkpoint_dir / filename
|
||||
|
||||
# Save model
|
||||
save_dict = {
|
||||
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
|
||||
'model_class': model.__class__.__name__,
|
||||
'checkpoint_id': checkpoint_id,
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'performance_score': performance_score,
|
||||
'performance_metrics': performance_metrics,
|
||||
'training_metadata': training_metadata or {},
|
||||
'created_at': datetime.now().isoformat(),
|
||||
'version': '2.0'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
torch.save(save_dict, filepath)
|
||||
|
||||
# Create checkpoint metadata
|
||||
file_size_mb = filepath.stat().st_size / (1024 * 1024)
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
file_path=str(filepath),
|
||||
created_at=datetime.now(),
|
||||
file_size_mb=file_size_mb,
|
||||
performance_score=performance_score,
|
||||
accuracy=performance_metrics.get('accuracy'),
|
||||
loss=performance_metrics.get('loss'),
|
||||
val_accuracy=performance_metrics.get('val_accuracy'),
|
||||
val_loss=performance_metrics.get('val_loss'),
|
||||
reward=performance_metrics.get('reward'),
|
||||
pnl=performance_metrics.get('pnl'),
|
||||
epoch=performance_metrics.get('epoch'),
|
||||
training_time_hours=performance_metrics.get('training_time_hours'),
|
||||
total_parameters=performance_metrics.get('total_parameters')
|
||||
)
|
||||
|
||||
# Store metadata
|
||||
self.checkpoint_metadata[model_name].append(metadata)
|
||||
self._save_checkpoint_metadata()
|
||||
|
||||
# Rotate checkpoints if needed
|
||||
self._rotate_checkpoints(model_name)
|
||||
|
||||
# Upload to W&B if enabled
|
||||
if self.config.get('wandb_enabled'):
|
||||
self._upload_to_wandb(metadata)
|
||||
|
||||
logger.info(f"Checkpoint saved: {checkpoint_id} (score: {performance_score:.4f})")
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
|
||||
"""Calculate performance score from metrics"""
|
||||
# Simple weighted score - can be enhanced
|
||||
weights = {'accuracy': 0.4, 'profit_factor': 0.3, 'win_rate': 0.2, 'sharpe_ratio': 0.1}
|
||||
score = 0.0
|
||||
for metric, weight in weights.items():
|
||||
if metric in metrics:
|
||||
score += metrics[metric] * weight
|
||||
return score
|
||||
|
||||
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
|
||||
"""Determine if checkpoint should be saved"""
|
||||
existing_checkpoints = self.checkpoint_metadata.get(model_name, [])
|
||||
if not existing_checkpoints:
|
||||
return True
|
||||
|
||||
# Keep if better than worst checkpoint or if we have fewer than max
|
||||
max_checkpoints = self.config.get('max_checkpoints_per_model', 5)
|
||||
if len(existing_checkpoints) < max_checkpoints:
|
||||
return True
|
||||
|
||||
worst_score = min(cp.performance_score for cp in existing_checkpoints)
|
||||
return performance_score > worst_score
|
||||
|
||||
def _rotate_checkpoints(self, model_name: str):
|
||||
"""Rotate checkpoints to maintain max count"""
|
||||
checkpoints = self.checkpoint_metadata.get(model_name, [])
|
||||
max_checkpoints = self.config.get('max_checkpoints_per_model', 5)
|
||||
|
||||
if len(checkpoints) <= max_checkpoints:
|
||||
return
|
||||
|
||||
# Sort by performance score (descending)
|
||||
checkpoints.sort(key=lambda x: x.performance_score, reverse=True)
|
||||
|
||||
# Remove excess checkpoints
|
||||
to_remove = checkpoints[max_checkpoints:]
|
||||
for checkpoint in to_remove:
|
||||
try:
|
||||
Path(checkpoint.file_path).unlink(missing_ok=True)
|
||||
logger.debug(f"Removed old checkpoint: {checkpoint.checkpoint_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing checkpoint {checkpoint.checkpoint_id}: {e}")
|
||||
|
||||
# Update metadata
|
||||
self.checkpoint_metadata[model_name] = checkpoints[:max_checkpoints]
|
||||
self._save_checkpoint_metadata()
|
||||
|
||||
def _save_checkpoint_metadata(self):
|
||||
"""Save checkpoint metadata to file"""
|
||||
try:
|
||||
data = {}
|
||||
for model_name, checkpoints in self.checkpoint_metadata.items():
|
||||
data[model_name] = [cp.to_dict() for cp in checkpoints]
|
||||
|
||||
with open(self.checkpoint_metadata_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint metadata: {e}")
|
||||
|
||||
def _upload_to_wandb(self, metadata: CheckpointMetadata) -> Optional[str]:
|
||||
"""Upload checkpoint to W&B"""
|
||||
if not WANDB_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
# This would be implemented based on your W&B workflow
|
||||
logger.debug(f"W&B upload not implemented yet for {metadata.checkpoint_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading to W&B: {e}")
|
||||
return None
|
||||
|
||||
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
"""Load the best checkpoint for a model"""
|
||||
try:
|
||||
# First, try the unified registry
|
||||
model_info = self.metadata['models'].get(model_name)
|
||||
if model_info and Path(model_info['latest_path']).exists():
|
||||
# Load from unified registry
|
||||
load_dict = torch.load(model_info['latest_path'], map_location='cpu')
|
||||
return model_info['latest_path'], None
|
||||
|
||||
# Fallback to checkpoint metadata
|
||||
checkpoints = self.checkpoint_metadata.get(model_name, [])
|
||||
if not checkpoints:
|
||||
logger.warning(f"No checkpoints found for {model_name}")
|
||||
return None
|
||||
|
||||
# Get best checkpoint
|
||||
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
|
||||
|
||||
if not Path(best_checkpoint.file_path).exists():
|
||||
logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
|
||||
return None
|
||||
|
||||
return best_checkpoint.file_path, best_checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""Get storage statistics"""
|
||||
try:
|
||||
total_size = 0
|
||||
file_count = 0
|
||||
|
||||
for directory in [self.checkpoints_dir, self.models_dir, self.saved_dir]:
|
||||
if directory.exists():
|
||||
for file_path in directory.rglob('*'):
|
||||
if file_path.is_file():
|
||||
total_size += file_path.stat().st_size
|
||||
file_count += 1
|
||||
|
||||
return {
|
||||
'total_size_mb': total_size / (1024 * 1024),
|
||||
'file_count': file_count,
|
||||
'directories': len(list(self.checkpoints_dir.iterdir())) if self.checkpoints_dir.exists() else 0
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage stats: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
|
||||
"""Get model performance leaderboard"""
|
||||
leaderboard = []
|
||||
|
||||
for model_name, model_info in self.model_registry.items():
|
||||
leaderboard.append({
|
||||
'name': model_name,
|
||||
'type': model_info.model_type,
|
||||
'score': model_info.metrics.get_composite_score(),
|
||||
'profit_factor': model_info.metrics.profit_factor,
|
||||
'win_rate': model_info.metrics.win_rate,
|
||||
'sharpe_ratio': model_info.metrics.sharpe_ratio,
|
||||
'size_mb': model_info.file_size_mb,
|
||||
'age_days': (datetime.now() - model_info.creation_time).days,
|
||||
'last_updated': model_info.last_updated.strftime('%Y-%m-%d %H:%M')
|
||||
})
|
||||
|
||||
# Sort by score
|
||||
leaderboard.sort(key=lambda x: x['score'], reverse=True)
|
||||
|
||||
return leaderboard
|
||||
|
||||
def cleanup_checkpoints(self) -> Dict[str, Any]:
|
||||
"""Clean up old checkpoint files"""
|
||||
cleanup_summary = {
|
||||
'deleted_files': 0,
|
||||
'freed_space_mb': 0,
|
||||
'errors': []
|
||||
}
|
||||
|
||||
cutoff_date = datetime.now() - timedelta(days=self.config['max_checkpoint_age_days'])
|
||||
|
||||
# Search for checkpoint files
|
||||
checkpoint_patterns = [
|
||||
"**/checkpoint_*.pt",
|
||||
"**/model_*.pt",
|
||||
"**/*checkpoint*",
|
||||
"**/epoch_*.pt"
|
||||
]
|
||||
|
||||
for pattern in checkpoint_patterns:
|
||||
for file_path in self.base_dir.rglob(pattern):
|
||||
if "best_models" not in str(file_path) and file_path.is_file():
|
||||
try:
|
||||
file_time = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
if file_time < cutoff_date:
|
||||
size_mb = file_path.stat().st_size / 1024 / 1024
|
||||
file_path.unlink()
|
||||
cleanup_summary['deleted_files'] += 1
|
||||
cleanup_summary['freed_space_mb'] += size_mb
|
||||
except Exception as e:
|
||||
error_msg = f"Error deleting checkpoint {file_path}: {e}"
|
||||
logger.error(error_msg)
|
||||
cleanup_summary['errors'].append(error_msg)
|
||||
|
||||
if cleanup_summary['deleted_files'] > 0:
|
||||
logger.info(f"Checkpoint cleanup: Deleted {cleanup_summary['deleted_files']} files, "
|
||||
f"freed {cleanup_summary['freed_space_mb']:.1f}MB")
|
||||
|
||||
return cleanup_summary
|
||||
try:
|
||||
leaderboard = []
|
||||
|
||||
for model_name, model_info in self.metadata['models'].items():
|
||||
if 'metrics' in model_info:
|
||||
metrics = ModelMetrics(**model_info['metrics'])
|
||||
leaderboard.append({
|
||||
'model_name': model_name,
|
||||
'model_type': model_info.get('model_type', 'unknown'),
|
||||
'composite_score': metrics.get_composite_score(),
|
||||
'accuracy': metrics.accuracy,
|
||||
'profit_factor': metrics.profit_factor,
|
||||
'win_rate': metrics.win_rate,
|
||||
'last_updated': model_info.get('last_saved', 'unknown')
|
||||
})
|
||||
|
||||
# Sort by composite score
|
||||
leaderboard.sort(key=lambda x: x['composite_score'], reverse=True)
|
||||
return leaderboard
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting leaderboard: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# ===== LEGACY COMPATIBILITY FUNCTIONS =====
|
||||
|
||||
def create_model_manager() -> ModelManager:
|
||||
"""Create and initialize the global model manager"""
|
||||
"""Create and return a ModelManager instance"""
|
||||
return ModelManager()
|
||||
|
||||
# Example usage
|
||||
|
||||
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""Legacy compatibility function to save a model"""
|
||||
manager = create_model_manager()
|
||||
return manager.save_model(model, model_name, model_type, metadata)
|
||||
|
||||
|
||||
def load_model(model_name: str, model_type: str = 'cnn',
|
||||
model_class: Optional[Any] = None) -> Optional[Any]:
|
||||
"""Legacy compatibility function to load a model"""
|
||||
manager = create_model_manager()
|
||||
return manager.load_model(model_name, model_type, model_class)
|
||||
|
||||
|
||||
def save_checkpoint(model, model_name: str, model_type: str,
|
||||
performance_metrics: Dict[str, float],
|
||||
training_metadata: Optional[Dict[str, Any]] = None,
|
||||
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||
"""Legacy compatibility function to save a checkpoint"""
|
||||
manager = create_model_manager()
|
||||
return manager.save_checkpoint(model, model_name, model_type,
|
||||
performance_metrics, training_metadata, force_save)
|
||||
|
||||
|
||||
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
"""Legacy compatibility function to load the best checkpoint"""
|
||||
manager = create_model_manager()
|
||||
return manager.load_best_checkpoint(model_name)
|
||||
|
||||
|
||||
# ===== EXAMPLE USAGE =====
|
||||
if __name__ == "__main__":
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Create model manager
|
||||
manager = ModelManager()
|
||||
|
||||
# Clean up all existing models (with confirmation)
|
||||
print("WARNING: This will delete ALL existing models!")
|
||||
print("Type 'CONFIRM' to proceed:")
|
||||
user_input = input().strip()
|
||||
|
||||
if user_input == "CONFIRM":
|
||||
cleanup_result = manager.cleanup_all_existing_models(confirm=True)
|
||||
print(f"\nCleanup complete:")
|
||||
print(f"- Deleted {cleanup_result['files_deleted']} files")
|
||||
print(f"- Freed {cleanup_result['space_freed_mb']:.1f}MB of space")
|
||||
print(f"- Cleaned {cleanup_result['directories_cleaned']} directories")
|
||||
|
||||
if cleanup_result['errors']:
|
||||
print(f"- {len(cleanup_result['errors'])} errors occurred")
|
||||
else:
|
||||
print("Cleanup cancelled")
|
||||
# Example usage of the unified model manager
|
||||
manager = create_model_manager()
|
||||
print(f"ModelManager initialized at: {manager.checkpoints_dir}")
|
||||
|
||||
# Get storage stats
|
||||
stats = manager.get_storage_stats()
|
||||
print(f"Storage stats: {stats}")
|
||||
|
||||
# Get leaderboard
|
||||
leaderboard = manager.get_model_leaderboard()
|
||||
print(f"Models in leaderboard: {len(leaderboard)}")
|
Reference in New Issue
Block a user