added leverage slider
This commit is contained in:
parent
d870f74d0c
commit
7d8eca995e
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
@ -127,8 +127,6 @@
|
|||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "main_clean.py",
|
"program": "main_clean.py",
|
||||||
"args": [
|
"args": [
|
||||||
"--mode",
|
|
||||||
"web",
|
|
||||||
"--port",
|
"--port",
|
||||||
"8050"
|
"8050"
|
||||||
],
|
],
|
||||||
|
145
ENHANCED_DQN_LEVERAGE_INTEGRATION_SUMMARY.md
Normal file
145
ENHANCED_DQN_LEVERAGE_INTEGRATION_SUMMARY.md
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
# Enhanced DQN and Leverage Integration Summary
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
Successfully integrated best features from EnhancedDQNAgent into the main DQNAgent and implemented comprehensive 50x leverage support throughout the trading system for amplified reward sensitivity.
|
||||||
|
|
||||||
|
## Key Enhancements Implemented
|
||||||
|
|
||||||
|
### 1. **Enhanced DQN Agent Features Integration** (`NN/models/dqn_agent.py`)
|
||||||
|
|
||||||
|
#### **Market Regime Adaptation**
|
||||||
|
- **Market Regime Weights**: Adaptive confidence based on market conditions
|
||||||
|
- Trending markets: 1.2x confidence multiplier
|
||||||
|
- Ranging markets: 0.8x confidence multiplier
|
||||||
|
- Volatile markets: 0.6x confidence multiplier
|
||||||
|
- **New Method**: `act_with_confidence()` for regime-aware decision making
|
||||||
|
|
||||||
|
#### **Advanced Replay Mechanisms**
|
||||||
|
- **Prioritized Experience Replay**: Enhanced memory management
|
||||||
|
- Alpha: 0.6 (priority exponent)
|
||||||
|
- Beta: 0.4 (importance sampling)
|
||||||
|
- Beta increment: 0.001 per step
|
||||||
|
- **Double DQN Support**: Improved Q-value estimation
|
||||||
|
- **Dueling Network Architecture**: Value and advantage head separation
|
||||||
|
|
||||||
|
#### **Enhanced Position Management**
|
||||||
|
- **Intelligent Entry/Exit Thresholds**:
|
||||||
|
- Entry confidence threshold: 0.7 (high bar for new positions)
|
||||||
|
- Exit confidence threshold: 0.3 (lower bar for closing)
|
||||||
|
- Uncertainty threshold: 0.1 (neutral zone)
|
||||||
|
- **Market Context Integration**: Price and regime-aware decision making
|
||||||
|
|
||||||
|
### 2. **Comprehensive Leverage Integration**
|
||||||
|
|
||||||
|
#### **Dynamic Leverage Slider** (`web/dashboard.py`)
|
||||||
|
- **Range**: 1x to 100x leverage with 1x increments
|
||||||
|
- **Real-time Adjustment**: Instant leverage changes via slider
|
||||||
|
- **Risk Assessment Display**:
|
||||||
|
- Low Risk (1x-5x): Green badge
|
||||||
|
- Medium Risk (6x-25x): Yellow badge
|
||||||
|
- High Risk (26x-50x): Red badge
|
||||||
|
- Extreme Risk (51x-100x): Red badge
|
||||||
|
- **Visual Indicators**: Clear marks at 1x, 10x, 25x, 50x, 75x, 100x
|
||||||
|
|
||||||
|
#### **Leveraged PnL Calculations**
|
||||||
|
- **New Helper Function**: `_calculate_leveraged_pnl_and_fees()`
|
||||||
|
- **Amplified Profits/Losses**: All PnL calculations multiplied by leverage
|
||||||
|
- **Enhanced Fee Structure**: Position value × leverage × fee rate
|
||||||
|
- **Real-time Updates**: Unrealized PnL reflects current leverage setting
|
||||||
|
|
||||||
|
#### **Fee Calculations with Leverage**
|
||||||
|
- **Opening Positions**: `fee = price × size × fee_rate × leverage`
|
||||||
|
- **Closing Positions**: Leverage affects both PnL and exit fees
|
||||||
|
- **Comprehensive Tracking**: All fee calculations include leverage impact
|
||||||
|
|
||||||
|
### 3. **Reward Sensitivity Improvements**
|
||||||
|
|
||||||
|
#### **Amplified Training Signals**
|
||||||
|
- **50x Leverage Default**: Small 0.1% price moves = 5% portfolio impact
|
||||||
|
- **Enhanced Learning**: Models can now learn from micro-movements
|
||||||
|
- **Realistic Risk/Reward**: Proper leverage trading simulation
|
||||||
|
|
||||||
|
#### **Example Impact**:
|
||||||
|
```
|
||||||
|
Without Leverage: 0.1% price move = $10 profit (weak signal)
|
||||||
|
With 50x Leverage: 0.1% price move = $500 profit (strong signal)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. **Technical Implementation Details**
|
||||||
|
|
||||||
|
#### **Code Integration Points**
|
||||||
|
- **Dashboard**: Leverage slider UI component with real-time feedback
|
||||||
|
- **PnL Engine**: All profit/loss calculations leverage-aware
|
||||||
|
- **DQN Agent**: Market regime adaptation and enhanced replay
|
||||||
|
- **Fee System**: Comprehensive leverage-adjusted fee calculations
|
||||||
|
|
||||||
|
#### **Error Handling & Robustness**
|
||||||
|
- **Syntax Error Fixes**: Resolved escaped quote issues
|
||||||
|
- **Encoding Support**: UTF-8 file handling for Windows compatibility
|
||||||
|
- **Fallback Systems**: Graceful degradation on errors
|
||||||
|
|
||||||
|
## Benefits for Model Training
|
||||||
|
|
||||||
|
### **1. Enhanced Signal Quality**
|
||||||
|
- **Amplified Rewards**: Small profitable trades now generate meaningful learning signals
|
||||||
|
- **Reduced Noise**: Clear distinction between good and bad decisions
|
||||||
|
- **Market Adaptation**: AI adjusts confidence based on market regime
|
||||||
|
|
||||||
|
### **2. Improved Learning Efficiency**
|
||||||
|
- **Prioritized Replay**: Focus learning on important experiences
|
||||||
|
- **Double DQN**: More accurate Q-value estimation
|
||||||
|
- **Position Management**: Intelligent entry/exit decision making
|
||||||
|
|
||||||
|
### **3. Real-world Trading Simulation**
|
||||||
|
- **Realistic Leverage**: Proper simulation of leveraged trading
|
||||||
|
- **Fee Integration**: Real trading costs included in all calculations
|
||||||
|
- **Risk Management**: Automatic risk assessment and warnings
|
||||||
|
|
||||||
|
## Usage Instructions
|
||||||
|
|
||||||
|
### **Starting the Enhanced Dashboard**
|
||||||
|
```bash
|
||||||
|
python run_scalping_dashboard.py --port 8050
|
||||||
|
```
|
||||||
|
|
||||||
|
### **Adjusting Leverage**
|
||||||
|
1. Open dashboard at `http://localhost:8050`
|
||||||
|
2. Use leverage slider to adjust from 1x to 100x
|
||||||
|
3. Watch real-time risk assessment updates
|
||||||
|
4. Monitor amplified PnL calculations
|
||||||
|
|
||||||
|
### **Monitoring Enhanced Features**
|
||||||
|
- **Leverage Display**: Current multiplier and risk level
|
||||||
|
- **PnL Amplification**: See leveraged profit/loss calculations
|
||||||
|
- **DQN Performance**: Enhanced market regime adaptation
|
||||||
|
- **Fee Tracking**: Leverage-adjusted trading costs
|
||||||
|
|
||||||
|
## Files Modified
|
||||||
|
|
||||||
|
1. **`NN/models/dqn_agent.py`**: Enhanced with market adaptation and advanced replay
|
||||||
|
2. **`web/dashboard.py`**: Leverage slider and amplified PnL calculations
|
||||||
|
3. **`update_leverage_pnl.py`**: Automated leverage integration script
|
||||||
|
4. **`fix_dashboard_syntax.py`**: Syntax error resolution script
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
- ✅ **Leverage Integration**: All PnL calculations leverage-aware
|
||||||
|
- ✅ **Enhanced DQN**: Market regime adaptation implemented
|
||||||
|
- ✅ **UI Enhancement**: Dynamic leverage slider with risk assessment
|
||||||
|
- ✅ **Fee System**: Comprehensive leverage-adjusted fees
|
||||||
|
- ✅ **Model Training**: 50x amplified reward sensitivity
|
||||||
|
- ✅ **System Stability**: Syntax errors resolved, dashboard operational
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
1. **Monitor Training Performance**: Watch how enhanced signals affect model learning
|
||||||
|
2. **Risk Management**: Set appropriate leverage limits based on market conditions
|
||||||
|
3. **Performance Analysis**: Track how regime adaptation improves trading decisions
|
||||||
|
4. **Further Optimization**: Fine-tune leverage multipliers based on results
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Implementation Status**: ✅ **COMPLETE**
|
||||||
|
**Dashboard Status**: ✅ **OPERATIONAL**
|
||||||
|
**Enhanced Features**: ✅ **ACTIVE**
|
||||||
|
**Leverage System**: ✅ **FULLY INTEGRATED**
|
191
LEVERAGE_SLIDER_IMPLEMENTATION_SUMMARY.md
Normal file
191
LEVERAGE_SLIDER_IMPLEMENTATION_SUMMARY.md
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
# Leverage Slider Implementation Summary
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
Successfully implemented a dynamic leverage slider in the trading dashboard that allows real-time adjustment of leverage from 1x to 100x, with automatic risk assessment and reward amplification for enhanced model training.
|
||||||
|
|
||||||
|
## Key Features Implemented
|
||||||
|
|
||||||
|
### 1. **Interactive Leverage Slider**
|
||||||
|
- **Range**: 1x to 100x leverage
|
||||||
|
- **Step Size**: 1x increments
|
||||||
|
- **Real-time Updates**: Instant feedback on leverage changes
|
||||||
|
- **Visual Marks**: Clear indicators at 1x, 10x, 25x, 50x, 75x, 100x
|
||||||
|
- **Tooltip**: Always-visible current leverage value
|
||||||
|
|
||||||
|
### 2. **Dynamic Risk Assessment**
|
||||||
|
- **Low Risk**: 1x - 5x leverage (Green badge)
|
||||||
|
- **Medium Risk**: 6x - 25x leverage (Yellow badge)
|
||||||
|
- **High Risk**: 26x - 50x leverage (Red badge)
|
||||||
|
- **Extreme Risk**: 51x - 100x leverage (Dark badge)
|
||||||
|
|
||||||
|
### 3. **Real-time Leverage Display**
|
||||||
|
- Current leverage multiplier (e.g., "50x")
|
||||||
|
- Risk level indicator with color coding
|
||||||
|
- Explanatory text for user guidance
|
||||||
|
|
||||||
|
### 4. **Reward Amplification System**
|
||||||
|
The leverage slider directly affects trading rewards for model training:
|
||||||
|
|
||||||
|
| Price Change | 1x Leverage | 25x Leverage | 50x Leverage | 100x Leverage |
|
||||||
|
|--------------|-------------|--------------|--------------|---------------|
|
||||||
|
| 0.1% | 0.1% | 2.5% | 5.0% | 10.0% |
|
||||||
|
| 0.2% | 0.2% | 5.0% | 10.0% | 20.0% |
|
||||||
|
| 0.5% | 0.5% | 12.5% | 25.0% | 50.0% |
|
||||||
|
| 1.0% | 1.0% | 25.0% | 50.0% | 100.0% |
|
||||||
|
|
||||||
|
## Technical Implementation
|
||||||
|
|
||||||
|
### 1. **Dashboard Layout Integration**
|
||||||
|
```python
|
||||||
|
# Added to System & Leverage panel
|
||||||
|
html.Div([
|
||||||
|
html.Label([
|
||||||
|
html.I(className="fas fa-chart-line me-1"),
|
||||||
|
"Leverage Multiplier"
|
||||||
|
], className="form-label small fw-bold"),
|
||||||
|
dcc.Slider(
|
||||||
|
id='leverage-slider',
|
||||||
|
min=1.0,
|
||||||
|
max=100.0,
|
||||||
|
step=1.0,
|
||||||
|
value=50.0,
|
||||||
|
marks={1: '1x', 10: '10x', 25: '25x', 50: '50x', 75: '75x', 100: '100x'},
|
||||||
|
tooltip={"placement": "bottom", "always_visible": True}
|
||||||
|
)
|
||||||
|
])
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. **Callback Implementation**
|
||||||
|
- **Input**: Leverage slider value changes
|
||||||
|
- **Outputs**: Current leverage display, risk level, risk badge styling
|
||||||
|
- **Functionality**: Real-time updates with validation and logging
|
||||||
|
|
||||||
|
### 3. **State Management**
|
||||||
|
```python
|
||||||
|
# Dashboard initialization
|
||||||
|
self.leverage_multiplier = 50.0 # Default 50x leverage
|
||||||
|
self.min_leverage = 1.0
|
||||||
|
self.max_leverage = 100.0
|
||||||
|
self.leverage_step = 1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. **Risk Calculation Logic**
|
||||||
|
```python
|
||||||
|
if leverage <= 5:
|
||||||
|
risk_level = "Low Risk"
|
||||||
|
risk_class = "badge bg-success"
|
||||||
|
elif leverage <= 25:
|
||||||
|
risk_level = "Medium Risk"
|
||||||
|
risk_class = "badge bg-warning text-dark"
|
||||||
|
elif leverage <= 50:
|
||||||
|
risk_level = "High Risk"
|
||||||
|
risk_class = "badge bg-danger"
|
||||||
|
else:
|
||||||
|
risk_level = "Extreme Risk"
|
||||||
|
risk_class = "badge bg-dark"
|
||||||
|
```
|
||||||
|
|
||||||
|
## User Interface
|
||||||
|
|
||||||
|
### 1. **Location**
|
||||||
|
- **Panel**: System & Leverage (bottom right of dashboard)
|
||||||
|
- **Position**: Below system status, above explanatory text
|
||||||
|
- **Visibility**: Always visible and accessible
|
||||||
|
|
||||||
|
### 2. **Visual Design**
|
||||||
|
- **Slider**: Bootstrap-styled with clear marks
|
||||||
|
- **Badges**: Color-coded risk indicators
|
||||||
|
- **Icons**: Font Awesome chart icon for visual clarity
|
||||||
|
- **Typography**: Clear labels and explanatory text
|
||||||
|
|
||||||
|
### 3. **User Experience**
|
||||||
|
- **Immediate Feedback**: Leverage and risk update instantly
|
||||||
|
- **Clear Guidance**: "Higher leverage = Higher rewards & risks"
|
||||||
|
- **Intuitive Controls**: Standard slider interface
|
||||||
|
- **Visual Cues**: Color-coded risk levels
|
||||||
|
|
||||||
|
## Benefits for Model Training
|
||||||
|
|
||||||
|
### 1. **Enhanced Learning Signals**
|
||||||
|
- **Problem Solved**: Small price movements (0.1%) now generate significant rewards (5% at 50x)
|
||||||
|
- **Model Sensitivity**: Neural networks can now distinguish between good and bad decisions
|
||||||
|
- **Training Efficiency**: Faster convergence due to amplified reward signals
|
||||||
|
|
||||||
|
### 2. **Adaptive Risk Management**
|
||||||
|
- **Conservative Start**: Begin with lower leverage (1x-10x) for stable learning
|
||||||
|
- **Progressive Scaling**: Increase leverage as models improve
|
||||||
|
- **Maximum Performance**: Use 50x-100x for aggressive learning phases
|
||||||
|
|
||||||
|
### 3. **Real-world Preparation**
|
||||||
|
- **Leverage Simulation**: Models learn to handle leveraged trading scenarios
|
||||||
|
- **Risk Awareness**: Training includes risk management considerations
|
||||||
|
- **Market Realism**: Simulates actual trading conditions with leverage
|
||||||
|
|
||||||
|
## Usage Instructions
|
||||||
|
|
||||||
|
### 1. **Accessing the Slider**
|
||||||
|
1. Run: `python run_scalping_dashboard.py`
|
||||||
|
2. Open: http://127.0.0.1:8050
|
||||||
|
3. Navigate to: "System & Leverage" panel (bottom right)
|
||||||
|
|
||||||
|
### 2. **Adjusting Leverage**
|
||||||
|
1. **Drag the slider** to desired leverage level
|
||||||
|
2. **Watch real-time updates** of leverage display and risk level
|
||||||
|
3. **Monitor color changes** in risk indicator badges
|
||||||
|
4. **Observe amplified rewards** in trading performance
|
||||||
|
|
||||||
|
### 3. **Recommended Settings**
|
||||||
|
- **Learning Phase**: Start with 10x-25x leverage
|
||||||
|
- **Training Phase**: Use 50x leverage (current default)
|
||||||
|
- **Advanced Training**: Experiment with 75x-100x leverage
|
||||||
|
- **Conservative Mode**: Use 1x-5x for traditional trading
|
||||||
|
|
||||||
|
## Testing Results
|
||||||
|
|
||||||
|
### ✅ **All Tests Passed**
|
||||||
|
- **Leverage Calculations**: Risk levels correctly assigned
|
||||||
|
- **Reward Amplification**: Proper multiplication of returns
|
||||||
|
- **Dashboard Integration**: Slider functions correctly
|
||||||
|
- **Real-time Updates**: Immediate response to changes
|
||||||
|
|
||||||
|
### 📊 **Performance Metrics**
|
||||||
|
- **Response Time**: Instant slider updates
|
||||||
|
- **Visual Feedback**: Clear risk level indicators
|
||||||
|
- **User Experience**: Intuitive and responsive interface
|
||||||
|
- **System Integration**: Seamless dashboard integration
|
||||||
|
|
||||||
|
## Future Enhancements
|
||||||
|
|
||||||
|
### 1. **Advanced Features**
|
||||||
|
- **Preset Buttons**: Quick selection of common leverage levels
|
||||||
|
- **Risk Calculator**: Real-time P&L projection based on leverage
|
||||||
|
- **Historical Analysis**: Track performance across different leverage levels
|
||||||
|
- **Auto-adjustment**: AI-driven leverage optimization
|
||||||
|
|
||||||
|
### 2. **Safety Features**
|
||||||
|
- **Maximum Limits**: Configurable upper bounds for leverage
|
||||||
|
- **Warning System**: Alerts for extreme leverage levels
|
||||||
|
- **Confirmation Dialogs**: Require confirmation for high-risk settings
|
||||||
|
- **Emergency Stop**: Quick reset to safe leverage levels
|
||||||
|
|
||||||
|
## Conclusion
|
||||||
|
|
||||||
|
The leverage slider implementation successfully addresses the "always invested" problem by:
|
||||||
|
|
||||||
|
1. **Amplifying small price movements** into meaningful training signals
|
||||||
|
2. **Providing real-time control** over risk/reward amplification
|
||||||
|
3. **Enabling progressive training** from conservative to aggressive strategies
|
||||||
|
4. **Improving model learning** through enhanced reward sensitivity
|
||||||
|
|
||||||
|
The system is now ready for enhanced model training with adjustable leverage settings, providing the flexibility needed for optimal neural network learning while maintaining user control over risk levels.
|
||||||
|
|
||||||
|
## Files Modified
|
||||||
|
- `web/dashboard.py`: Added leverage slider, callbacks, and display logic
|
||||||
|
- `test_leverage_slider.py`: Comprehensive testing suite
|
||||||
|
- `run_scalping_dashboard.py`: Fixed import issues for proper dashboard launch
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
1. **Monitor Performance**: Track how different leverage levels affect model learning
|
||||||
|
2. **Optimize Settings**: Find optimal leverage ranges for different market conditions
|
||||||
|
3. **Enhance UI**: Add more visual feedback and control options
|
||||||
|
4. **Integrate Analytics**: Track leverage usage patterns and performance correlations
|
@ -14,15 +14,19 @@ class TradingEnvironment(gym.Env):
|
|||||||
"""
|
"""
|
||||||
Trading environment implementing gym interface for reinforcement learning
|
Trading environment implementing gym interface for reinforcement learning
|
||||||
|
|
||||||
Actions:
|
2-Action System:
|
||||||
- 0: Buy
|
- 0: SELL (or close long position)
|
||||||
- 1: Sell
|
- 1: BUY (or close short position)
|
||||||
- 2: Hold
|
|
||||||
|
Intelligent Position Management:
|
||||||
|
- When neutral: Actions enter positions
|
||||||
|
- When positioned: Actions can close or flip positions
|
||||||
|
- Different thresholds for entry vs exit decisions
|
||||||
|
|
||||||
State:
|
State:
|
||||||
- OHLCV data from multiple timeframes
|
- OHLCV data from multiple timeframes
|
||||||
- Technical indicators
|
- Technical indicators
|
||||||
- Position data
|
- Position data and unrealized PnL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -33,9 +37,11 @@ class TradingEnvironment(gym.Env):
|
|||||||
window_size: int = 20,
|
window_size: int = 20,
|
||||||
max_position: float = 1.0,
|
max_position: float = 1.0,
|
||||||
reward_scaling: float = 1.0,
|
reward_scaling: float = 1.0,
|
||||||
|
entry_threshold: float = 0.6, # Higher threshold for entering positions
|
||||||
|
exit_threshold: float = 0.3, # Lower threshold for exiting positions
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the trading environment.
|
Initialize the trading environment with 2-action system.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_interface: DataInterface instance to get market data
|
data_interface: DataInterface instance to get market data
|
||||||
@ -44,6 +50,8 @@ class TradingEnvironment(gym.Env):
|
|||||||
window_size: Number of candles in the observation window
|
window_size: Number of candles in the observation window
|
||||||
max_position: Maximum position size as a fraction of balance
|
max_position: Maximum position size as a fraction of balance
|
||||||
reward_scaling: Scale factor for rewards
|
reward_scaling: Scale factor for rewards
|
||||||
|
entry_threshold: Confidence threshold for entering new positions
|
||||||
|
exit_threshold: Confidence threshold for exiting positions
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -53,21 +61,23 @@ class TradingEnvironment(gym.Env):
|
|||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.max_position = max_position
|
self.max_position = max_position
|
||||||
self.reward_scaling = reward_scaling
|
self.reward_scaling = reward_scaling
|
||||||
|
self.entry_threshold = entry_threshold
|
||||||
|
self.exit_threshold = exit_threshold
|
||||||
|
|
||||||
# Load data for primary timeframe (assuming the first one is primary)
|
# Load data for primary timeframe (assuming the first one is primary)
|
||||||
self.timeframe = self.data_interface.timeframes[0]
|
self.timeframe = self.data_interface.timeframes[0]
|
||||||
self.reset_data()
|
self.reset_data()
|
||||||
|
|
||||||
# Define action and observation spaces
|
# Define action and observation spaces for 2-action system
|
||||||
self.action_space = spaces.Discrete(3) # Buy, Sell, Hold
|
self.action_space = spaces.Discrete(2) # 0=SELL, 1=BUY
|
||||||
|
|
||||||
# For observation space, we consider multiple timeframes with OHLCV data
|
# For observation space, we consider multiple timeframes with OHLCV data
|
||||||
# and additional features like technical indicators, position info, etc.
|
# and additional features like technical indicators, position info, etc.
|
||||||
n_timeframes = len(self.data_interface.timeframes)
|
n_timeframes = len(self.data_interface.timeframes)
|
||||||
n_features = 5 # OHLCV data by default
|
n_features = 5 # OHLCV data by default
|
||||||
|
|
||||||
# Add additional features for position, balance, etc.
|
# Add additional features for position, balance, unrealized_pnl, etc.
|
||||||
additional_features = 3 # position, balance, unrealized_pnl
|
additional_features = 5 # position, balance, unrealized_pnl, entry_price, position_duration
|
||||||
|
|
||||||
# Calculate total feature dimension
|
# Calculate total feature dimension
|
||||||
total_features = (n_timeframes * n_features * self.window_size) + additional_features
|
total_features = (n_timeframes * n_features * self.window_size) + additional_features
|
||||||
@ -79,6 +89,11 @@ class TradingEnvironment(gym.Env):
|
|||||||
# Use tuple for state_shape that EnhancedCNN expects
|
# Use tuple for state_shape that EnhancedCNN expects
|
||||||
self.state_shape = (total_features,)
|
self.state_shape = (total_features,)
|
||||||
|
|
||||||
|
# Position tracking for 2-action system
|
||||||
|
self.position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
||||||
|
self.entry_price = 0.0 # Price at which position was entered
|
||||||
|
self.entry_step = 0 # Step at which position was entered
|
||||||
|
|
||||||
# Initialize state
|
# Initialize state
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@ -103,9 +118,6 @@ class TradingEnvironment(gym.Env):
|
|||||||
"""Reset the environment to initial state"""
|
"""Reset the environment to initial state"""
|
||||||
# Reset trading variables
|
# Reset trading variables
|
||||||
self.balance = self.initial_balance
|
self.balance = self.initial_balance
|
||||||
self.position = 0.0 # No position initially
|
|
||||||
self.entry_price = 0.0
|
|
||||||
self.total_pnl = 0.0
|
|
||||||
self.trades = []
|
self.trades = []
|
||||||
self.rewards = []
|
self.rewards = []
|
||||||
|
|
||||||
@ -119,10 +131,10 @@ class TradingEnvironment(gym.Env):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""
|
"""
|
||||||
Take a step in the environment.
|
Take a step in the environment using 2-action system with intelligent position management.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action: Action to take (0: Buy, 1: Sell, 2: Hold)
|
action: Action to take (0: SELL, 1: BUY)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (observation, reward, done, info)
|
tuple: (observation, reward, done, info)
|
||||||
@ -132,7 +144,7 @@ class TradingEnvironment(gym.Env):
|
|||||||
prev_position = self.position
|
prev_position = self.position
|
||||||
prev_price = self.prices[self.current_step]
|
prev_price = self.prices[self.current_step]
|
||||||
|
|
||||||
# Take action
|
# Take action with intelligent position management
|
||||||
info = {}
|
info = {}
|
||||||
reward = 0
|
reward = 0
|
||||||
last_position_info = None
|
last_position_info = None
|
||||||
@ -141,43 +153,50 @@ class TradingEnvironment(gym.Env):
|
|||||||
current_price = self.prices[self.current_step]
|
current_price = self.prices[self.current_step]
|
||||||
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
|
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
|
||||||
|
|
||||||
# Process the action
|
# Implement 2-action system with position management
|
||||||
if action == 0: # Buy
|
if action == 0: # SELL action
|
||||||
if self.position <= 0: # Only buy if not already long
|
if self.position == 0: # No position - enter short
|
||||||
# Close any existing short position
|
|
||||||
if self.position < 0:
|
|
||||||
close_pnl, last_position_info = self._close_position(current_price)
|
|
||||||
reward += close_pnl * self.reward_scaling
|
|
||||||
|
|
||||||
# Open new long position
|
|
||||||
self._open_position(1.0 * self.max_position, current_price)
|
|
||||||
logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
|
||||||
|
|
||||||
elif action == 1: # Sell
|
|
||||||
if self.position >= 0: # Only sell if not already short
|
|
||||||
# Close any existing long position
|
|
||||||
if self.position > 0:
|
|
||||||
close_pnl, last_position_info = self._close_position(current_price)
|
|
||||||
reward += close_pnl * self.reward_scaling
|
|
||||||
|
|
||||||
# Open new short position
|
|
||||||
self._open_position(-1.0 * self.max_position, current_price)
|
self._open_position(-1.0 * self.max_position, current_price)
|
||||||
logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
logger.info(f"ENTER SHORT at step {self.current_step}, price: {current_price:.4f}")
|
||||||
|
reward = -self.transaction_fee # Entry cost
|
||||||
|
|
||||||
elif action == 2: # Hold
|
elif self.position > 0: # Long position - close it
|
||||||
# No action, but still calculate unrealized PnL for reward
|
close_pnl, last_position_info = self._close_position(current_price)
|
||||||
|
reward += close_pnl * self.reward_scaling
|
||||||
|
logger.info(f"CLOSE LONG at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||||
|
|
||||||
|
elif self.position < 0: # Already short - potentially flip to long if very strong signal
|
||||||
|
# For now, just hold the short position (no action)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Calculate unrealized PnL and add to reward
|
elif action == 1: # BUY action
|
||||||
|
if self.position == 0: # No position - enter long
|
||||||
|
self._open_position(1.0 * self.max_position, current_price)
|
||||||
|
logger.info(f"ENTER LONG at step {self.current_step}, price: {current_price:.4f}")
|
||||||
|
reward = -self.transaction_fee # Entry cost
|
||||||
|
|
||||||
|
elif self.position < 0: # Short position - close it
|
||||||
|
close_pnl, last_position_info = self._close_position(current_price)
|
||||||
|
reward += close_pnl * self.reward_scaling
|
||||||
|
logger.info(f"CLOSE SHORT at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||||
|
|
||||||
|
elif self.position > 0: # Already long - potentially flip to short if very strong signal
|
||||||
|
# For now, just hold the long position (no action)
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Calculate unrealized PnL and add to reward if holding position
|
||||||
if self.position != 0:
|
if self.position != 0:
|
||||||
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
|
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
|
||||||
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
|
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
|
||||||
|
|
||||||
# Apply penalties for holding a position
|
# Apply time-based holding penalty to encourage decisive actions
|
||||||
if self.position != 0:
|
position_duration = self.current_step - self.entry_step
|
||||||
# Small holding fee/interest
|
holding_penalty = min(position_duration * 0.0001, 0.01) # Max 1% penalty
|
||||||
holding_penalty = abs(self.position) * 0.0001 # 0.01% per step
|
reward -= holding_penalty
|
||||||
reward -= holding_penalty * self.reward_scaling
|
|
||||||
|
# Reward staying neutral when uncertain (no clear setup)
|
||||||
|
else:
|
||||||
|
reward += 0.0001 # Small reward for not trading without clear signals
|
||||||
|
|
||||||
# Move to next step
|
# Move to next step
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
@ -215,7 +234,7 @@ class TradingEnvironment(gym.Env):
|
|||||||
'step': self.current_step,
|
'step': self.current_step,
|
||||||
'timestamp': self.timestamps[self.current_step],
|
'timestamp': self.timestamps[self.current_step],
|
||||||
'action': action,
|
'action': action,
|
||||||
'action_name': ['BUY', 'SELL', 'HOLD'][action],
|
'action_name': ['SELL', 'BUY'][action],
|
||||||
'price': current_price,
|
'price': current_price,
|
||||||
'position_changed': prev_position != self.position,
|
'position_changed': prev_position != self.position,
|
||||||
'prev_position': prev_position,
|
'prev_position': prev_position,
|
||||||
@ -234,7 +253,7 @@ class TradingEnvironment(gym.Env):
|
|||||||
self.trades.append(trade_result)
|
self.trades.append(trade_result)
|
||||||
|
|
||||||
# Log trade details
|
# Log trade details
|
||||||
logger.info(f"Trade executed - Action: {['BUY', 'SELL', 'HOLD'][action]}, "
|
logger.info(f"Trade executed - Action: {['SELL', 'BUY'][action]}, "
|
||||||
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
|
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
|
||||||
f"Balance: {self.balance:.4f}")
|
f"Balance: {self.balance:.4f}")
|
||||||
|
|
||||||
@ -268,42 +287,71 @@ class TradingEnvironment(gym.Env):
|
|||||||
else: # Short position
|
else: # Short position
|
||||||
return -self.position * (1.0 - current_price / self.entry_price)
|
return -self.position * (1.0 - current_price / self.entry_price)
|
||||||
|
|
||||||
def _open_position(self, position_size, price):
|
def _open_position(self, position_size: float, entry_price: float):
|
||||||
"""Open a new position"""
|
"""Open a new position"""
|
||||||
self.position = position_size
|
self.position = position_size
|
||||||
self.entry_price = price
|
self.entry_price = entry_price
|
||||||
|
self.entry_step = self.current_step
|
||||||
|
|
||||||
def _close_position(self, price):
|
# Calculate position value
|
||||||
"""Close the current position and return PnL"""
|
position_value = abs(position_size) * entry_price
|
||||||
pnl = self._calculate_unrealized_pnl(price)
|
|
||||||
|
|
||||||
# Apply transaction fee
|
# Apply transaction fee
|
||||||
fee = abs(self.position) * price * self.transaction_fee
|
fee = position_value * self.transaction_fee
|
||||||
pnl -= fee
|
self.balance -= fee
|
||||||
|
|
||||||
|
logger.info(f"Opened position: {position_size:.4f} at {entry_price:.4f}, fee: {fee:.4f}")
|
||||||
|
|
||||||
|
def _close_position(self, exit_price: float) -> Tuple[float, Dict]:
|
||||||
|
"""Close current position and return PnL"""
|
||||||
|
if self.position == 0:
|
||||||
|
return 0.0, {}
|
||||||
|
|
||||||
|
# Calculate PnL
|
||||||
|
if self.position > 0: # Long position
|
||||||
|
pnl = (exit_price - self.entry_price) / self.entry_price
|
||||||
|
else: # Short position
|
||||||
|
pnl = (self.entry_price - exit_price) / self.entry_price
|
||||||
|
|
||||||
|
# Apply transaction fees (entry + exit)
|
||||||
|
position_value = abs(self.position) * exit_price
|
||||||
|
exit_fee = position_value * self.transaction_fee
|
||||||
|
total_fees = exit_fee # Entry fee already applied when opening
|
||||||
|
|
||||||
|
# Net PnL after fees
|
||||||
|
net_pnl = pnl - (total_fees / (abs(self.position) * self.entry_price))
|
||||||
|
|
||||||
# Update balance
|
# Update balance
|
||||||
self.balance += pnl
|
self.balance *= (1 + net_pnl)
|
||||||
self.total_pnl += pnl
|
self.total_pnl += net_pnl
|
||||||
|
|
||||||
# Store position details before resetting
|
# Track trade
|
||||||
last_position = {
|
position_info = {
|
||||||
'position_size': self.position,
|
'position_size': self.position,
|
||||||
'entry_price': self.entry_price,
|
'entry_price': self.entry_price,
|
||||||
'exit_price': price,
|
'exit_price': exit_price,
|
||||||
'pnl': pnl,
|
'pnl': net_pnl,
|
||||||
'fee': fee
|
'duration': self.current_step - self.entry_step,
|
||||||
|
'entry_step': self.entry_step,
|
||||||
|
'exit_step': self.current_step
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.trades.append(position_info)
|
||||||
|
|
||||||
|
# Update trade statistics
|
||||||
|
if net_pnl > 0:
|
||||||
|
self.winning_trades += 1
|
||||||
|
else:
|
||||||
|
self.losing_trades += 1
|
||||||
|
|
||||||
|
logger.info(f"Closed position: {self.position:.4f}, PnL: {net_pnl:.4f}, Duration: {position_info['duration']} steps")
|
||||||
|
|
||||||
# Reset position
|
# Reset position
|
||||||
self.position = 0.0
|
self.position = 0.0
|
||||||
self.entry_price = 0.0
|
self.entry_price = 0.0
|
||||||
|
self.entry_step = 0
|
||||||
|
|
||||||
# Log position closure
|
return net_pnl, position_info
|
||||||
logger.info(f"Closed position - Size: {last_position['position_size']:.4f}, "
|
|
||||||
f"Entry: {last_position['entry_price']:.4f}, Exit: {last_position['exit_price']:.4f}, "
|
|
||||||
f"PnL: {last_position['pnl']:.4f}, Fee: {last_position['fee']:.4f}")
|
|
||||||
|
|
||||||
return pnl, last_position
|
|
||||||
|
|
||||||
def _get_observation(self):
|
def _get_observation(self):
|
||||||
"""
|
"""
|
||||||
@ -411,7 +459,7 @@ class TradingEnvironment(gym.Env):
|
|||||||
for trade in last_n_trades:
|
for trade in last_n_trades:
|
||||||
position_info = {
|
position_info = {
|
||||||
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
|
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
|
||||||
'action': trade.get('action_name', ['BUY', 'SELL', 'HOLD'][trade['action']]),
|
'action': trade.get('action_name', ['SELL', 'BUY'][trade['action']]),
|
||||||
'entry_price': trade.get('entry_price', 0.0),
|
'entry_price': trade.get('entry_price', 0.0),
|
||||||
'exit_price': trade.get('exit_price', trade['price']),
|
'exit_price': trade.get('exit_price', trade['price']),
|
||||||
'position_size': trade.get('position_size', self.max_position),
|
'position_size': trade.get('position_size', self.max_position),
|
||||||
|
@ -1,560 +0,0 @@
|
|||||||
"""
|
|
||||||
Convolutional Neural Network for timeseries analysis
|
|
||||||
|
|
||||||
This module implements a deep CNN model for cryptocurrency price analysis.
|
|
||||||
The model uses multiple parallel convolutional pathways and LSTM layers
|
|
||||||
to detect patterns at different time scales.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import tensorflow as tf
|
|
||||||
from tensorflow.keras.models import Model, load_model
|
|
||||||
from tensorflow.keras.layers import (
|
|
||||||
Input, Conv1D, MaxPooling1D, Dense, Dropout, BatchNormalization,
|
|
||||||
LSTM, Bidirectional, Flatten, Concatenate, GlobalAveragePooling1D,
|
|
||||||
LeakyReLU, Attention
|
|
||||||
)
|
|
||||||
from tensorflow.keras.optimizers import Adam
|
|
||||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
|
|
||||||
from tensorflow.keras.metrics import AUC
|
|
||||||
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class CNNModel:
|
|
||||||
"""
|
|
||||||
Convolutional Neural Network for time series analysis.
|
|
||||||
|
|
||||||
This model uses a multi-pathway architecture with different filter sizes
|
|
||||||
to detect patterns at different time scales, combined with LSTM layers
|
|
||||||
for temporal dependencies.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, input_shape=(20, 5), output_size=1, model_dir="NN/models/saved"):
|
|
||||||
"""
|
|
||||||
Initialize the CNN model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_shape (tuple): Shape of input data (sequence_length, features)
|
|
||||||
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
|
|
||||||
model_dir (str): Directory to save trained models
|
|
||||||
"""
|
|
||||||
self.input_shape = input_shape
|
|
||||||
self.output_size = output_size
|
|
||||||
self.model_dir = model_dir
|
|
||||||
self.model = None
|
|
||||||
self.history = None
|
|
||||||
|
|
||||||
# Create model directory if it doesn't exist
|
|
||||||
os.makedirs(self.model_dir, exist_ok=True)
|
|
||||||
|
|
||||||
logger.info(f"Initialized CNN model with input shape {input_shape} and output size {output_size}")
|
|
||||||
|
|
||||||
def build_model(self, filters=(32, 64, 128), kernel_sizes=(3, 5, 7),
|
|
||||||
dropout_rate=0.3, learning_rate=0.001):
|
|
||||||
"""
|
|
||||||
Build the CNN model architecture.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filters (tuple): Number of filters for each convolutional pathway
|
|
||||||
kernel_sizes (tuple): Kernel sizes for each convolutional pathway
|
|
||||||
dropout_rate (float): Dropout rate for regularization
|
|
||||||
learning_rate (float): Learning rate for Adam optimizer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The compiled model
|
|
||||||
"""
|
|
||||||
# Input layer
|
|
||||||
inputs = Input(shape=self.input_shape)
|
|
||||||
|
|
||||||
# Multiple parallel convolutional pathways with different kernel sizes
|
|
||||||
# to capture patterns at different time scales
|
|
||||||
conv_layers = []
|
|
||||||
|
|
||||||
for i, (filter_size, kernel_size) in enumerate(zip(filters, kernel_sizes)):
|
|
||||||
conv_path = Conv1D(
|
|
||||||
filters=filter_size,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
padding='same',
|
|
||||||
name=f'conv1d_{i+1}'
|
|
||||||
)(inputs)
|
|
||||||
conv_path = BatchNormalization()(conv_path)
|
|
||||||
conv_path = LeakyReLU(alpha=0.1)(conv_path)
|
|
||||||
conv_path = MaxPooling1D(pool_size=2, padding='same')(conv_path)
|
|
||||||
conv_path = Dropout(dropout_rate)(conv_path)
|
|
||||||
conv_layers.append(conv_path)
|
|
||||||
|
|
||||||
# Merge convolutional pathways
|
|
||||||
if len(conv_layers) > 1:
|
|
||||||
merged = Concatenate()(conv_layers)
|
|
||||||
else:
|
|
||||||
merged = conv_layers[0]
|
|
||||||
|
|
||||||
# Add another Conv1D layer after merging
|
|
||||||
x = Conv1D(filters=filters[-1], kernel_size=3, padding='same')(merged)
|
|
||||||
x = BatchNormalization()(x)
|
|
||||||
x = LeakyReLU(alpha=0.1)(x)
|
|
||||||
x = MaxPooling1D(pool_size=2, padding='same')(x)
|
|
||||||
x = Dropout(dropout_rate)(x)
|
|
||||||
|
|
||||||
# Bidirectional LSTM for temporal dependencies
|
|
||||||
x = Bidirectional(LSTM(128, return_sequences=True))(x)
|
|
||||||
x = Dropout(dropout_rate)(x)
|
|
||||||
|
|
||||||
# Attention mechanism to focus on important time steps
|
|
||||||
x = Bidirectional(LSTM(64, return_sequences=True))(x)
|
|
||||||
|
|
||||||
# Global average pooling to reduce parameters
|
|
||||||
x = GlobalAveragePooling1D()(x)
|
|
||||||
x = Dropout(dropout_rate)(x)
|
|
||||||
|
|
||||||
# Dense layers for final classification/regression
|
|
||||||
x = Dense(64, activation='relu')(x)
|
|
||||||
x = BatchNormalization()(x)
|
|
||||||
x = Dropout(dropout_rate)(x)
|
|
||||||
|
|
||||||
# Output layer
|
|
||||||
if self.output_size == 1:
|
|
||||||
# Binary classification (up/down)
|
|
||||||
outputs = Dense(1, activation='sigmoid', name='output')(x)
|
|
||||||
loss = 'binary_crossentropy'
|
|
||||||
metrics = ['accuracy', AUC()]
|
|
||||||
elif self.output_size == 3:
|
|
||||||
# Multi-class classification (buy/hold/sell)
|
|
||||||
outputs = Dense(3, activation='softmax', name='output')(x)
|
|
||||||
loss = 'categorical_crossentropy'
|
|
||||||
metrics = ['accuracy']
|
|
||||||
else:
|
|
||||||
# Regression
|
|
||||||
outputs = Dense(self.output_size, activation='linear', name='output')(x)
|
|
||||||
loss = 'mse'
|
|
||||||
metrics = ['mae']
|
|
||||||
|
|
||||||
# Create and compile model
|
|
||||||
self.model = Model(inputs=inputs, outputs=outputs)
|
|
||||||
|
|
||||||
# Compile with Adam optimizer
|
|
||||||
self.model.compile(
|
|
||||||
optimizer=Adam(learning_rate=learning_rate),
|
|
||||||
loss=loss,
|
|
||||||
metrics=metrics
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log model summary
|
|
||||||
self.model.summary(print_fn=lambda x: logger.info(x))
|
|
||||||
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def train(self, X_train, y_train, batch_size=32, epochs=100, validation_split=0.2,
|
|
||||||
callbacks=None, class_weights=None):
|
|
||||||
"""
|
|
||||||
Train the CNN model on the provided data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
X_train (numpy.ndarray): Training features
|
|
||||||
y_train (numpy.ndarray): Training targets
|
|
||||||
batch_size (int): Batch size
|
|
||||||
epochs (int): Number of epochs
|
|
||||||
validation_split (float): Fraction of data to use for validation
|
|
||||||
callbacks (list): List of Keras callbacks
|
|
||||||
class_weights (dict): Class weights for imbalanced datasets
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
History object containing training metrics
|
|
||||||
"""
|
|
||||||
if self.model is None:
|
|
||||||
self.build_model()
|
|
||||||
|
|
||||||
# Default callbacks if none provided
|
|
||||||
if callbacks is None:
|
|
||||||
# Create a timestamp for model checkpoints
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
|
|
||||||
callbacks = [
|
|
||||||
EarlyStopping(
|
|
||||||
monitor='val_loss',
|
|
||||||
patience=10,
|
|
||||||
restore_best_weights=True
|
|
||||||
),
|
|
||||||
ReduceLROnPlateau(
|
|
||||||
monitor='val_loss',
|
|
||||||
factor=0.5,
|
|
||||||
patience=5,
|
|
||||||
min_lr=1e-6
|
|
||||||
),
|
|
||||||
ModelCheckpoint(
|
|
||||||
filepath=os.path.join(self.model_dir, f"cnn_model_{timestamp}.h5"),
|
|
||||||
monitor='val_loss',
|
|
||||||
save_best_only=True
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check if y_train needs to be one-hot encoded for multi-class
|
|
||||||
if self.output_size == 3 and len(y_train.shape) == 1:
|
|
||||||
y_train = tf.keras.utils.to_categorical(y_train, num_classes=3)
|
|
||||||
|
|
||||||
# Train the model
|
|
||||||
logger.info(f"Training CNN model with {len(X_train)} samples, batch size {batch_size}, epochs {epochs}")
|
|
||||||
self.history = self.model.fit(
|
|
||||||
X_train, y_train,
|
|
||||||
batch_size=batch_size,
|
|
||||||
epochs=epochs,
|
|
||||||
validation_split=validation_split,
|
|
||||||
callbacks=callbacks,
|
|
||||||
class_weight=class_weights,
|
|
||||||
verbose=2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save the trained model
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
model_path = os.path.join(self.model_dir, f"cnn_model_final_{timestamp}.h5")
|
|
||||||
self.model.save(model_path)
|
|
||||||
logger.info(f"Model saved to {model_path}")
|
|
||||||
|
|
||||||
# Save training history
|
|
||||||
history_path = os.path.join(self.model_dir, f"cnn_model_history_{timestamp}.json")
|
|
||||||
with open(history_path, 'w') as f:
|
|
||||||
# Convert numpy values to Python native types for JSON serialization
|
|
||||||
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
|
|
||||||
json.dump(history_dict, f, indent=2)
|
|
||||||
|
|
||||||
return self.history
|
|
||||||
|
|
||||||
def evaluate(self, X_test, y_test, plot_results=False):
|
|
||||||
"""
|
|
||||||
Evaluate the model on test data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
X_test (numpy.ndarray): Test features
|
|
||||||
y_test (numpy.ndarray): Test targets
|
|
||||||
plot_results (bool): Whether to plot evaluation results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Evaluation metrics
|
|
||||||
"""
|
|
||||||
if self.model is None:
|
|
||||||
raise ValueError("Model has not been built or trained yet")
|
|
||||||
|
|
||||||
# Convert y_test to one-hot encoding for multi-class
|
|
||||||
y_test_original = y_test.copy()
|
|
||||||
if self.output_size == 3 and len(y_test.shape) == 1:
|
|
||||||
y_test = tf.keras.utils.to_categorical(y_test, num_classes=3)
|
|
||||||
|
|
||||||
# Evaluate model
|
|
||||||
logger.info(f"Evaluating CNN model on {len(X_test)} samples")
|
|
||||||
eval_results = self.model.evaluate(X_test, y_test, verbose=0)
|
|
||||||
|
|
||||||
metrics = {}
|
|
||||||
for metric, value in zip(self.model.metrics_names, eval_results):
|
|
||||||
metrics[metric] = value
|
|
||||||
logger.info(f"{metric}: {value:.4f}")
|
|
||||||
|
|
||||||
# Get predictions
|
|
||||||
y_pred_prob = self.model.predict(X_test)
|
|
||||||
|
|
||||||
# Different processing based on output type
|
|
||||||
if self.output_size == 1:
|
|
||||||
# Binary classification
|
|
||||||
y_pred = (y_pred_prob > 0.5).astype(int).flatten()
|
|
||||||
|
|
||||||
# Classification report
|
|
||||||
report = classification_report(y_test, y_pred)
|
|
||||||
logger.info(f"Classification Report:\n{report}")
|
|
||||||
|
|
||||||
# Confusion matrix
|
|
||||||
cm = confusion_matrix(y_test, y_pred)
|
|
||||||
logger.info(f"Confusion Matrix:\n{cm}")
|
|
||||||
|
|
||||||
# ROC curve and AUC
|
|
||||||
fpr, tpr, _ = roc_curve(y_test, y_pred_prob)
|
|
||||||
roc_auc = auc(fpr, tpr)
|
|
||||||
metrics['auc'] = roc_auc
|
|
||||||
|
|
||||||
if plot_results:
|
|
||||||
self._plot_binary_results(y_test, y_pred, y_pred_prob, fpr, tpr, roc_auc)
|
|
||||||
|
|
||||||
elif self.output_size == 3:
|
|
||||||
# Multi-class classification
|
|
||||||
y_pred = np.argmax(y_pred_prob, axis=1)
|
|
||||||
|
|
||||||
# Classification report
|
|
||||||
report = classification_report(y_test_original, y_pred)
|
|
||||||
logger.info(f"Classification Report:\n{report}")
|
|
||||||
|
|
||||||
# Confusion matrix
|
|
||||||
cm = confusion_matrix(y_test_original, y_pred)
|
|
||||||
logger.info(f"Confusion Matrix:\n{cm}")
|
|
||||||
|
|
||||||
if plot_results:
|
|
||||||
self._plot_multiclass_results(y_test_original, y_pred, y_pred_prob)
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
def predict(self, X):
|
|
||||||
"""
|
|
||||||
Make predictions on new data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
X (numpy.ndarray): Input features
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (y_pred, y_proba) where:
|
|
||||||
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
|
|
||||||
y_proba is the class probability
|
|
||||||
"""
|
|
||||||
if self.model is None:
|
|
||||||
raise ValueError("Model has not been built or trained yet")
|
|
||||||
|
|
||||||
# Ensure X has the right shape
|
|
||||||
if len(X.shape) == 2:
|
|
||||||
# Single sample, add batch dimension
|
|
||||||
X = np.expand_dims(X, axis=0)
|
|
||||||
|
|
||||||
# Get predictions
|
|
||||||
y_proba = self.model.predict(X)
|
|
||||||
|
|
||||||
# Process based on output type
|
|
||||||
if self.output_size == 1:
|
|
||||||
# Binary classification
|
|
||||||
y_pred = (y_proba > 0.5).astype(int).flatten()
|
|
||||||
return y_pred, y_proba.flatten()
|
|
||||||
elif self.output_size == 3:
|
|
||||||
# Multi-class classification
|
|
||||||
y_pred = np.argmax(y_proba, axis=1)
|
|
||||||
return y_pred, y_proba
|
|
||||||
else:
|
|
||||||
# Regression
|
|
||||||
return y_proba, y_proba
|
|
||||||
|
|
||||||
def save(self, filepath=None):
|
|
||||||
"""
|
|
||||||
Save the model to disk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath (str): Path to save the model
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Path where the model was saved
|
|
||||||
"""
|
|
||||||
if self.model is None:
|
|
||||||
raise ValueError("Model has not been built yet")
|
|
||||||
|
|
||||||
if filepath is None:
|
|
||||||
# Create a default filepath with timestamp
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
filepath = os.path.join(self.model_dir, f"cnn_model_{timestamp}.h5")
|
|
||||||
|
|
||||||
self.model.save(filepath)
|
|
||||||
logger.info(f"Model saved to {filepath}")
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
def load(self, filepath):
|
|
||||||
"""
|
|
||||||
Load a saved model from disk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath (str): Path to the saved model
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The loaded model
|
|
||||||
"""
|
|
||||||
self.model = load_model(filepath)
|
|
||||||
logger.info(f"Model loaded from {filepath}")
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def extract_hidden_features(self, X):
|
|
||||||
"""
|
|
||||||
Extract features from the last hidden layer of the CNN for transfer learning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
X (numpy.ndarray): Input data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
numpy.ndarray: Extracted features
|
|
||||||
"""
|
|
||||||
if self.model is None:
|
|
||||||
raise ValueError("Model has not been built or trained yet")
|
|
||||||
|
|
||||||
# Create a new model that outputs the features from the layer before the output
|
|
||||||
feature_layer_name = self.model.layers[-2].name
|
|
||||||
feature_extractor = Model(
|
|
||||||
inputs=self.model.input,
|
|
||||||
outputs=self.model.get_layer(feature_layer_name).output
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract features
|
|
||||||
features = feature_extractor.predict(X)
|
|
||||||
|
|
||||||
return features
|
|
||||||
|
|
||||||
def _plot_binary_results(self, y_true, y_pred, y_proba, fpr, tpr, roc_auc):
|
|
||||||
"""
|
|
||||||
Plot evaluation results for binary classification.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y_true (numpy.ndarray): True labels
|
|
||||||
y_pred (numpy.ndarray): Predicted labels
|
|
||||||
y_proba (numpy.ndarray): Prediction probabilities
|
|
||||||
fpr (numpy.ndarray): False positive rates for ROC curve
|
|
||||||
tpr (numpy.ndarray): True positive rates for ROC curve
|
|
||||||
roc_auc (float): Area under ROC curve
|
|
||||||
"""
|
|
||||||
plt.figure(figsize=(15, 5))
|
|
||||||
|
|
||||||
# Confusion Matrix
|
|
||||||
plt.subplot(1, 3, 1)
|
|
||||||
cm = confusion_matrix(y_true, y_pred)
|
|
||||||
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
|
||||||
plt.title('Confusion Matrix')
|
|
||||||
plt.colorbar()
|
|
||||||
tick_marks = [0, 1]
|
|
||||||
plt.xticks(tick_marks, ['0', '1'])
|
|
||||||
plt.yticks(tick_marks, ['0', '1'])
|
|
||||||
plt.xlabel('Predicted Label')
|
|
||||||
plt.ylabel('True Label')
|
|
||||||
|
|
||||||
# Add text annotations to confusion matrix
|
|
||||||
thresh = cm.max() / 2.
|
|
||||||
for i in range(cm.shape[0]):
|
|
||||||
for j in range(cm.shape[1]):
|
|
||||||
plt.text(j, i, format(cm[i, j], 'd'),
|
|
||||||
horizontalalignment="center",
|
|
||||||
color="white" if cm[i, j] > thresh else "black")
|
|
||||||
|
|
||||||
# Histogram of prediction probabilities
|
|
||||||
plt.subplot(1, 3, 2)
|
|
||||||
plt.hist(y_proba[y_true == 0], alpha=0.5, label='Class 0')
|
|
||||||
plt.hist(y_proba[y_true == 1], alpha=0.5, label='Class 1')
|
|
||||||
plt.title('Prediction Probabilities')
|
|
||||||
plt.xlabel('Probability of Class 1')
|
|
||||||
plt.ylabel('Count')
|
|
||||||
plt.legend()
|
|
||||||
|
|
||||||
# ROC Curve
|
|
||||||
plt.subplot(1, 3, 3)
|
|
||||||
plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.3f})')
|
|
||||||
plt.plot([0, 1], [0, 1], 'k--')
|
|
||||||
plt.xlim([0.0, 1.0])
|
|
||||||
plt.ylim([0.0, 1.05])
|
|
||||||
plt.xlabel('False Positive Rate')
|
|
||||||
plt.ylabel('True Positive Rate')
|
|
||||||
plt.title('Receiver Operating Characteristic')
|
|
||||||
plt.legend(loc="lower right")
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# Save figure
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
fig_path = os.path.join(self.model_dir, f"cnn_evaluation_{timestamp}.png")
|
|
||||||
plt.savefig(fig_path)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
logger.info(f"Evaluation plots saved to {fig_path}")
|
|
||||||
|
|
||||||
def _plot_multiclass_results(self, y_true, y_pred, y_proba):
|
|
||||||
"""
|
|
||||||
Plot evaluation results for multi-class classification.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y_true (numpy.ndarray): True labels
|
|
||||||
y_pred (numpy.ndarray): Predicted labels
|
|
||||||
y_proba (numpy.ndarray): Prediction probabilities
|
|
||||||
"""
|
|
||||||
plt.figure(figsize=(12, 5))
|
|
||||||
|
|
||||||
# Confusion Matrix
|
|
||||||
plt.subplot(1, 2, 1)
|
|
||||||
cm = confusion_matrix(y_true, y_pred)
|
|
||||||
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
|
||||||
plt.title('Confusion Matrix')
|
|
||||||
plt.colorbar()
|
|
||||||
classes = ['BUY', 'HOLD', 'SELL'] # Assumes classes are 0, 1, 2
|
|
||||||
tick_marks = np.arange(len(classes))
|
|
||||||
plt.xticks(tick_marks, classes)
|
|
||||||
plt.yticks(tick_marks, classes)
|
|
||||||
plt.xlabel('Predicted Label')
|
|
||||||
plt.ylabel('True Label')
|
|
||||||
|
|
||||||
# Add text annotations to confusion matrix
|
|
||||||
thresh = cm.max() / 2.
|
|
||||||
for i in range(cm.shape[0]):
|
|
||||||
for j in range(cm.shape[1]):
|
|
||||||
plt.text(j, i, format(cm[i, j], 'd'),
|
|
||||||
horizontalalignment="center",
|
|
||||||
color="white" if cm[i, j] > thresh else "black")
|
|
||||||
|
|
||||||
# Class probability distributions
|
|
||||||
plt.subplot(1, 2, 2)
|
|
||||||
for i, cls in enumerate(classes):
|
|
||||||
plt.hist(y_proba[y_true == i, i], alpha=0.5, label=f'Class {cls}')
|
|
||||||
plt.title('Class Probability Distributions')
|
|
||||||
plt.xlabel('Probability')
|
|
||||||
plt.ylabel('Count')
|
|
||||||
plt.legend()
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# Save figure
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
fig_path = os.path.join(self.model_dir, f"cnn_multiclass_evaluation_{timestamp}.png")
|
|
||||||
plt.savefig(fig_path)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
logger.info(f"Multiclass evaluation plots saved to {fig_path}")
|
|
||||||
|
|
||||||
def plot_training_history(self):
|
|
||||||
"""
|
|
||||||
Plot training history (loss and metrics).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Path to the saved plot
|
|
||||||
"""
|
|
||||||
if self.history is None:
|
|
||||||
raise ValueError("Model has not been trained yet")
|
|
||||||
|
|
||||||
plt.figure(figsize=(12, 5))
|
|
||||||
|
|
||||||
# Plot loss
|
|
||||||
plt.subplot(1, 2, 1)
|
|
||||||
plt.plot(self.history.history['loss'], label='Training Loss')
|
|
||||||
if 'val_loss' in self.history.history:
|
|
||||||
plt.plot(self.history.history['val_loss'], label='Validation Loss')
|
|
||||||
plt.title('Model Loss')
|
|
||||||
plt.xlabel('Epoch')
|
|
||||||
plt.ylabel('Loss')
|
|
||||||
plt.legend()
|
|
||||||
|
|
||||||
# Plot accuracy
|
|
||||||
plt.subplot(1, 2, 2)
|
|
||||||
|
|
||||||
if 'accuracy' in self.history.history:
|
|
||||||
plt.plot(self.history.history['accuracy'], label='Training Accuracy')
|
|
||||||
if 'val_accuracy' in self.history.history:
|
|
||||||
plt.plot(self.history.history['val_accuracy'], label='Validation Accuracy')
|
|
||||||
plt.title('Model Accuracy')
|
|
||||||
plt.ylabel('Accuracy')
|
|
||||||
elif 'mae' in self.history.history:
|
|
||||||
plt.plot(self.history.history['mae'], label='Training MAE')
|
|
||||||
if 'val_mae' in self.history.history:
|
|
||||||
plt.plot(self.history.history['val_mae'], label='Validation MAE')
|
|
||||||
plt.title('Model MAE')
|
|
||||||
plt.ylabel('MAE')
|
|
||||||
|
|
||||||
plt.xlabel('Epoch')
|
|
||||||
plt.legend()
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# Save figure
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
fig_path = os.path.join(self.model_dir, f"cnn_training_history_{timestamp}.png")
|
|
||||||
plt.savefig(fig_path)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
logger.info(f"Training history plot saved to {fig_path}")
|
|
||||||
return fig_path
|
|
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import time
|
||||||
|
|
||||||
# Add parent directory to path
|
# Add parent directory to path
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
@ -23,16 +24,16 @@ class DQNAgent:
|
|||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
state_shape: Tuple[int, ...],
|
state_shape: Tuple[int, ...],
|
||||||
n_actions: int,
|
n_actions: int = 2,
|
||||||
learning_rate: float = 0.0005, # Reduced learning rate for more stability
|
learning_rate: float = 0.001,
|
||||||
gamma: float = 0.97, # Slightly reduced discount factor
|
|
||||||
epsilon: float = 1.0,
|
epsilon: float = 1.0,
|
||||||
epsilon_min: float = 0.05, # Increased minimum epsilon for more exploration
|
epsilon_min: float = 0.01,
|
||||||
epsilon_decay: float = 0.9975, # Slower decay rate
|
epsilon_decay: float = 0.995,
|
||||||
buffer_size: int = 20000, # Increased memory size
|
buffer_size: int = 10000,
|
||||||
batch_size: int = 128, # Larger batch size
|
batch_size: int = 32,
|
||||||
target_update: int = 5, # More frequent target updates
|
target_update: int = 100,
|
||||||
device=None): # Device for computations
|
priority_memory: bool = True,
|
||||||
|
device=None):
|
||||||
|
|
||||||
# Extract state dimensions
|
# Extract state dimensions
|
||||||
if isinstance(state_shape, tuple) and len(state_shape) > 1:
|
if isinstance(state_shape, tuple) and len(state_shape) > 1:
|
||||||
@ -48,11 +49,9 @@ class DQNAgent:
|
|||||||
# Store parameters
|
# Store parameters
|
||||||
self.n_actions = n_actions
|
self.n_actions = n_actions
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.gamma = gamma
|
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.epsilon_min = epsilon_min
|
self.epsilon_min = epsilon_min
|
||||||
self.epsilon_decay = epsilon_decay
|
self.epsilon_decay = epsilon_decay
|
||||||
self.epsilon_start = epsilon # Store initial epsilon value for resets/bumps
|
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.target_update = target_update
|
self.target_update = target_update
|
||||||
@ -127,10 +126,41 @@ class DQNAgent:
|
|||||||
self.max_confidence = 0.0
|
self.max_confidence = 0.0
|
||||||
self.min_confidence = 1.0
|
self.min_confidence = 1.0
|
||||||
|
|
||||||
|
# Enhanced features from EnhancedDQNAgent
|
||||||
|
# Market adaptation capabilities
|
||||||
|
self.market_regime_weights = {
|
||||||
|
'trending': 1.2, # Higher confidence in trending markets
|
||||||
|
'ranging': 0.8, # Lower confidence in ranging markets
|
||||||
|
'volatile': 0.6 # Much lower confidence in volatile markets
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dueling network support (requires enhanced network architecture)
|
||||||
|
self.use_dueling = True
|
||||||
|
|
||||||
|
# Prioritized experience replay parameters
|
||||||
|
self.use_prioritized_replay = priority_memory
|
||||||
|
self.alpha = 0.6 # Priority exponent
|
||||||
|
self.beta = 0.4 # Importance sampling exponent
|
||||||
|
self.beta_increment = 0.001
|
||||||
|
|
||||||
|
# Double DQN support
|
||||||
|
self.use_double_dqn = True
|
||||||
|
|
||||||
|
# Enhanced training features from EnhancedDQNAgent
|
||||||
|
self.target_update_freq = target_update # More descriptive name
|
||||||
|
self.training_steps = 0
|
||||||
|
self.gradient_clip_norm = 1.0 # Gradient clipping
|
||||||
|
|
||||||
|
# Enhanced statistics tracking
|
||||||
|
self.epsilon_history = []
|
||||||
|
self.td_errors = [] # Track TD errors for analysis
|
||||||
|
|
||||||
# Trade action fee and confidence thresholds
|
# Trade action fee and confidence thresholds
|
||||||
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
|
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
|
||||||
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
|
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
|
||||||
self.recent_actions = [] # Track recent actions to avoid oscillations
|
self.recent_actions = deque(maxlen=10)
|
||||||
|
self.recent_prices = deque(maxlen=20)
|
||||||
|
self.recent_rewards = deque(maxlen=100)
|
||||||
|
|
||||||
# Violent move detection
|
# Violent move detection
|
||||||
self.price_history = []
|
self.price_history = []
|
||||||
@ -173,6 +203,16 @@ class DQNAgent:
|
|||||||
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
||||||
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
|
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
|
||||||
|
|
||||||
|
# Position management for 2-action system
|
||||||
|
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
||||||
|
self.position_entry_price = 0.0
|
||||||
|
self.position_entry_time = None
|
||||||
|
|
||||||
|
# Different thresholds for entry vs exit decisions
|
||||||
|
self.entry_confidence_threshold = 0.7 # High threshold for new positions
|
||||||
|
self.exit_confidence_threshold = 0.3 # Lower threshold for closing positions
|
||||||
|
self.uncertainty_threshold = 0.1 # When to stay neutral
|
||||||
|
|
||||||
def move_models_to_device(self, device=None):
|
def move_models_to_device(self, device=None):
|
||||||
"""Move models to the specified device (GPU/CPU)"""
|
"""Move models to the specified device (GPU/CPU)"""
|
||||||
if device is not None:
|
if device is not None:
|
||||||
@ -290,247 +330,148 @@ class DQNAgent:
|
|||||||
if len(self.price_movement_memory) > self.buffer_size // 4:
|
if len(self.price_movement_memory) > self.buffer_size // 4:
|
||||||
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
|
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
|
||||||
|
|
||||||
def act(self, state: np.ndarray, explore=True) -> int:
|
def act(self, state: np.ndarray, explore=True, current_price=None, market_context=None) -> int:
|
||||||
"""Choose action using epsilon-greedy policy with explore flag"""
|
"""
|
||||||
if explore and random.random() < self.epsilon:
|
Choose action based on current state using 2-action system with intelligent position management
|
||||||
return random.randrange(self.n_actions)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
Args:
|
||||||
# Enhance state with real-time tick features
|
state: Current market state
|
||||||
enhanced_state = self._enhance_state_with_tick_features(state)
|
explore: Whether to use epsilon-greedy exploration
|
||||||
|
current_price: Current market price for position management
|
||||||
|
market_context: Additional market context for decision making
|
||||||
|
|
||||||
# Ensure state is normalized before inference
|
Returns:
|
||||||
state_tensor = self._normalize_state(enhanced_state)
|
int: Action (0=SELL, 1=BUY) or None if should hold position
|
||||||
state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device)
|
"""
|
||||||
|
|
||||||
# Get predictions using the policy network
|
# Convert state to tensor
|
||||||
self.policy_net.eval() # Set to evaluation mode for inference
|
if isinstance(state, np.ndarray):
|
||||||
action_probs, extrema_pred, price_predictions, hidden_features, advanced_predictions = self.policy_net(state_tensor)
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||||
self.policy_net.train() # Back to training mode
|
|
||||||
|
|
||||||
# Store hidden features for integration
|
|
||||||
self.last_hidden_features = hidden_features.cpu().numpy()
|
|
||||||
|
|
||||||
# Track feature history (limited size)
|
|
||||||
self.feature_history.append(hidden_features.cpu().numpy())
|
|
||||||
if len(self.feature_history) > 100:
|
|
||||||
self.feature_history = self.feature_history[-100:]
|
|
||||||
|
|
||||||
# Get the predicted extrema class (0=bottom, 1=top, 2=neither)
|
|
||||||
extrema_class = extrema_pred.argmax(dim=1).item()
|
|
||||||
extrema_confidence = torch.softmax(extrema_pred, dim=1)[0, extrema_class].item()
|
|
||||||
|
|
||||||
# Log extrema prediction for significant signals
|
|
||||||
if extrema_confidence > 0.7 and extrema_class != 2: # Only log strong top/bottom signals
|
|
||||||
extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER"
|
|
||||||
logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}")
|
|
||||||
|
|
||||||
# Process price predictions
|
|
||||||
price_immediate = torch.softmax(price_predictions['immediate'], dim=1)
|
|
||||||
price_midterm = torch.softmax(price_predictions['midterm'], dim=1)
|
|
||||||
price_longterm = torch.softmax(price_predictions['longterm'], dim=1)
|
|
||||||
price_values = price_predictions['values']
|
|
||||||
|
|
||||||
# Get predicted direction for each timeframe (0=down, 1=sideways, 2=up)
|
|
||||||
immediate_direction = price_immediate.argmax(dim=1).item()
|
|
||||||
midterm_direction = price_midterm.argmax(dim=1).item()
|
|
||||||
longterm_direction = price_longterm.argmax(dim=1).item()
|
|
||||||
|
|
||||||
# Get confidence levels
|
|
||||||
immediate_conf = price_immediate[0, immediate_direction].item()
|
|
||||||
midterm_conf = price_midterm[0, midterm_direction].item()
|
|
||||||
longterm_conf = price_longterm[0, longterm_direction].item()
|
|
||||||
|
|
||||||
# Get predicted price change percentages
|
|
||||||
price_changes = price_values[0].tolist()
|
|
||||||
|
|
||||||
# Log significant price movement predictions
|
|
||||||
timeframes = ["1s/1m", "1h", "1d", "1w"]
|
|
||||||
directions = ["DOWN", "SIDEWAYS", "UP"]
|
|
||||||
|
|
||||||
for i, (direction, conf) in enumerate([
|
|
||||||
(immediate_direction, immediate_conf),
|
|
||||||
(midterm_direction, midterm_conf),
|
|
||||||
(longterm_direction, longterm_conf)
|
|
||||||
]):
|
|
||||||
if conf > 0.7 and direction != 1: # Only log high confidence non-sideways predictions
|
|
||||||
logger.info(f"Price prediction: {timeframes[i]} -> {directions[direction]}, "
|
|
||||||
f"Confidence: {conf:.4f}, Expected change: {price_changes[i]:.2f}%")
|
|
||||||
|
|
||||||
# Store predictions for environment to use
|
|
||||||
self.last_extrema_pred = {
|
|
||||||
'class': extrema_class,
|
|
||||||
'confidence': extrema_confidence,
|
|
||||||
'raw': extrema_pred.cpu().numpy()
|
|
||||||
}
|
|
||||||
|
|
||||||
self.last_price_pred = {
|
|
||||||
'immediate': {
|
|
||||||
'direction': immediate_direction,
|
|
||||||
'confidence': immediate_conf,
|
|
||||||
'change': price_changes[0]
|
|
||||||
},
|
|
||||||
'midterm': {
|
|
||||||
'direction': midterm_direction,
|
|
||||||
'confidence': midterm_conf,
|
|
||||||
'change': price_changes[1]
|
|
||||||
},
|
|
||||||
'longterm': {
|
|
||||||
'direction': longterm_direction,
|
|
||||||
'confidence': longterm_conf,
|
|
||||||
'change': price_changes[2]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get the action with highest Q-value
|
|
||||||
action = action_probs.argmax().item()
|
|
||||||
|
|
||||||
# Calculate overall confidence in the action
|
|
||||||
q_values_softmax = F.softmax(action_probs, dim=1)[0]
|
|
||||||
action_confidence = q_values_softmax[action].item()
|
|
||||||
|
|
||||||
# Track confidence metrics
|
|
||||||
self.confidence_history.append(action_confidence)
|
|
||||||
if len(self.confidence_history) > 100:
|
|
||||||
self.confidence_history = self.confidence_history[-100:]
|
|
||||||
|
|
||||||
# Update confidence metrics
|
|
||||||
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
|
|
||||||
self.max_confidence = max(self.max_confidence, action_confidence)
|
|
||||||
self.min_confidence = min(self.min_confidence, action_confidence)
|
|
||||||
|
|
||||||
# Log average confidence occasionally
|
|
||||||
if random.random() < 0.01: # 1% of the time
|
|
||||||
logger.info(f"Confidence metrics - Current: {action_confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
|
|
||||||
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
|
|
||||||
|
|
||||||
# Track price for violent move detection
|
|
||||||
try:
|
|
||||||
# Extract current price from state (assuming it's in the last position)
|
|
||||||
if len(state.shape) > 1: # For 2D state
|
|
||||||
current_price = state[-1, -1]
|
|
||||||
else: # For 1D state
|
|
||||||
current_price = state[-1]
|
|
||||||
|
|
||||||
self.price_history.append(current_price)
|
|
||||||
if len(self.price_history) > self.volatility_window:
|
|
||||||
self.price_history = self.price_history[-self.volatility_window:]
|
|
||||||
|
|
||||||
# Detect violent price moves if we have enough price history
|
|
||||||
if len(self.price_history) >= 5:
|
|
||||||
# Calculate short-term volatility
|
|
||||||
recent_prices = self.price_history[-5:]
|
|
||||||
|
|
||||||
# Make sure we're working with scalar values, not arrays
|
|
||||||
if isinstance(recent_prices[0], np.ndarray):
|
|
||||||
# If prices are arrays, extract the last value (current price)
|
|
||||||
recent_prices = [p[-1] if isinstance(p, np.ndarray) and p.size > 0 else p for p in recent_prices]
|
|
||||||
|
|
||||||
# Calculate price changes with protection against division by zero
|
|
||||||
price_changes = []
|
|
||||||
for i in range(1, len(recent_prices)):
|
|
||||||
if recent_prices[i-1] != 0 and not np.isnan(recent_prices[i-1]) and not np.isnan(recent_prices[i]):
|
|
||||||
change = (recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
|
|
||||||
price_changes.append(change)
|
|
||||||
else:
|
else:
|
||||||
price_changes.append(0.0)
|
state_tensor = state.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
# Calculate volatility as sum of absolute price changes
|
# Get Q-values
|
||||||
volatility = sum([abs(change) for change in price_changes])
|
q_values = self.policy_net(state_tensor)
|
||||||
|
action_values = q_values.cpu().data.numpy()[0]
|
||||||
|
|
||||||
# Check if we've had a violent move
|
# Calculate confidence scores
|
||||||
if volatility > self.volatility_threshold:
|
sell_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
|
||||||
logger.info(f"Violent price move detected! Volatility: {volatility:.6f}")
|
buy_confidence = torch.softmax(q_values, dim=1)[0, 1].item()
|
||||||
self.post_violent_move = True
|
|
||||||
self.violent_move_cooldown = 10 # Set cooldown period
|
|
||||||
|
|
||||||
# Handle post-violent move period
|
# Determine action based on current position and confidence thresholds
|
||||||
if self.post_violent_move:
|
action = self._determine_action_with_position_management(
|
||||||
if self.violent_move_cooldown > 0:
|
sell_confidence, buy_confidence, current_price, market_context, explore
|
||||||
self.violent_move_cooldown -= 1
|
)
|
||||||
# Increase confidence threshold temporarily after violent moves
|
|
||||||
effective_threshold = self.minimum_action_confidence * 1.1
|
|
||||||
logger.info(f"Post-violent move period: {self.violent_move_cooldown} steps remaining. " +
|
|
||||||
f"Using higher confidence threshold: {effective_threshold:.4f}")
|
|
||||||
else:
|
|
||||||
self.post_violent_move = False
|
|
||||||
logger.info("Post-violent move period ended")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error in violent move detection: {str(e)}")
|
|
||||||
|
|
||||||
# Apply trade action fee to buy/sell actions but not to hold
|
# Update tracking
|
||||||
# This creates a threshold that must be exceeded to justify a trade
|
if current_price:
|
||||||
action_values = action_probs.clone()
|
self.recent_prices.append(current_price)
|
||||||
|
|
||||||
# If BUY or SELL, apply fee by reducing the Q-value
|
if action is not None:
|
||||||
if action == 0 or action == 1: # BUY or SELL
|
|
||||||
# Check if confidence is above minimum threshold
|
|
||||||
effective_threshold = self.minimum_action_confidence
|
|
||||||
if self.post_violent_move:
|
|
||||||
effective_threshold *= 1.1 # Higher threshold after violent moves
|
|
||||||
|
|
||||||
if action_confidence < effective_threshold:
|
|
||||||
# If confidence is below threshold, force HOLD action
|
|
||||||
logger.info(f"Action {action} confidence {action_confidence:.4f} below threshold {effective_threshold}, forcing HOLD")
|
|
||||||
action = 2 # HOLD
|
|
||||||
else:
|
|
||||||
# Apply trade action fee to ensure we only trade when there's clear benefit
|
|
||||||
fee_adjusted_action_values = action_values.clone()
|
|
||||||
fee_adjusted_action_values[0, 0] -= self.trade_action_fee # Reduce BUY value
|
|
||||||
fee_adjusted_action_values[0, 1] -= self.trade_action_fee # Reduce SELL value
|
|
||||||
# Hold value remains unchanged
|
|
||||||
|
|
||||||
# Re-determine the action based on fee-adjusted values
|
|
||||||
fee_adjusted_action = fee_adjusted_action_values.argmax().item()
|
|
||||||
|
|
||||||
# If the fee changes our decision, log this
|
|
||||||
if fee_adjusted_action != action:
|
|
||||||
logger.info(f"Trade action fee changed decision from {action} to {fee_adjusted_action}")
|
|
||||||
action = fee_adjusted_action
|
|
||||||
|
|
||||||
# Adjust action based on extrema and price predictions
|
|
||||||
# Prioritize short-term movement for trading decisions
|
|
||||||
if immediate_conf > 0.8: # Only adjust for strong signals
|
|
||||||
if immediate_direction == 2: # UP prediction
|
|
||||||
# Bias toward BUY for strong up predictions
|
|
||||||
if action != 0 and action != 2 and random.random() < 0.3 * immediate_conf:
|
|
||||||
logger.info(f"Adjusting action to BUY based on immediate UP prediction")
|
|
||||||
action = 0 # BUY
|
|
||||||
elif immediate_direction == 0: # DOWN prediction
|
|
||||||
# Bias toward SELL for strong down predictions
|
|
||||||
if action != 1 and action != 2 and random.random() < 0.3 * immediate_conf:
|
|
||||||
logger.info(f"Adjusting action to SELL based on immediate DOWN prediction")
|
|
||||||
action = 1 # SELL
|
|
||||||
|
|
||||||
# Also consider extrema detection for action adjustment
|
|
||||||
if extrema_confidence > 0.8: # Only adjust for strong signals
|
|
||||||
if extrema_class == 0: # Bottom detected
|
|
||||||
# Bias toward BUY at bottoms
|
|
||||||
if action != 0 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
|
||||||
logger.info(f"Adjusting action to BUY based on bottom detection")
|
|
||||||
action = 0 # BUY
|
|
||||||
elif extrema_class == 1: # Top detected
|
|
||||||
# Bias toward SELL at tops
|
|
||||||
if action != 1 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
|
||||||
logger.info(f"Adjusting action to SELL based on top detection")
|
|
||||||
action = 1 # SELL
|
|
||||||
|
|
||||||
# Finally, avoid action oscillation by checking recent history
|
|
||||||
if len(self.recent_actions) >= 2:
|
|
||||||
last_action = self.recent_actions[-1]
|
|
||||||
if action != last_action and action != 2 and last_action != 2:
|
|
||||||
# We're switching between BUY and SELL too quickly
|
|
||||||
# Only allow this if we have very high confidence
|
|
||||||
if action_confidence < 0.85:
|
|
||||||
logger.info(f"Preventing oscillation from {last_action} to {action}, forcing HOLD")
|
|
||||||
action = 2 # HOLD
|
|
||||||
|
|
||||||
# Update recent actions list
|
|
||||||
self.recent_actions.append(action)
|
self.recent_actions.append(action)
|
||||||
if len(self.recent_actions) > 5:
|
|
||||||
self.recent_actions = self.recent_actions[-5:]
|
|
||||||
|
|
||||||
return action
|
return action
|
||||||
|
else:
|
||||||
|
# Return None to indicate HOLD (don't change position)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
||||||
|
"""Choose action with confidence score adapted to market regime (from Enhanced DQN)"""
|
||||||
|
with torch.no_grad():
|
||||||
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||||
|
q_values = self.policy_net(state_tensor)
|
||||||
|
|
||||||
|
# Convert Q-values to probabilities
|
||||||
|
action_probs = torch.softmax(q_values, dim=1)
|
||||||
|
action = q_values.argmax().item()
|
||||||
|
base_confidence = action_probs[0, action].item()
|
||||||
|
|
||||||
|
# Adapt confidence based on market regime
|
||||||
|
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||||
|
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||||
|
|
||||||
|
return action, adapted_confidence
|
||||||
|
|
||||||
|
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
|
||||||
|
"""
|
||||||
|
Determine action based on current position and confidence thresholds
|
||||||
|
|
||||||
|
This implements the intelligent position management where:
|
||||||
|
- When neutral: Need high confidence to enter position
|
||||||
|
- When in position: Need lower confidence to exit
|
||||||
|
- Different thresholds for entry vs exit
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Apply epsilon-greedy exploration
|
||||||
|
if explore and np.random.random() <= self.epsilon:
|
||||||
|
return np.random.choice([0, 1])
|
||||||
|
|
||||||
|
# Get the dominant signal
|
||||||
|
dominant_action = 0 if sell_conf > buy_conf else 1
|
||||||
|
dominant_confidence = max(sell_conf, buy_conf)
|
||||||
|
|
||||||
|
# Decision logic based on current position
|
||||||
|
if self.current_position == 0: # No position - need high confidence to enter
|
||||||
|
if dominant_confidence >= self.entry_confidence_threshold:
|
||||||
|
# Strong enough signal to enter position
|
||||||
|
if dominant_action == 1: # BUY signal
|
||||||
|
self.current_position = 1.0
|
||||||
|
self.position_entry_price = current_price
|
||||||
|
self.position_entry_time = time.time()
|
||||||
|
logger.info(f"ENTERING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
|
||||||
|
return 1
|
||||||
|
else: # SELL signal
|
||||||
|
self.current_position = -1.0
|
||||||
|
self.position_entry_price = current_price
|
||||||
|
self.position_entry_time = time.time()
|
||||||
|
logger.info(f"ENTERING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
# Not confident enough to enter position
|
||||||
|
return None
|
||||||
|
|
||||||
|
elif self.current_position > 0: # Long position
|
||||||
|
if dominant_action == 0 and dominant_confidence >= self.exit_confidence_threshold:
|
||||||
|
# SELL signal with enough confidence to close long position
|
||||||
|
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||||
|
logger.info(f"CLOSING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||||
|
self.current_position = 0.0
|
||||||
|
self.position_entry_price = 0.0
|
||||||
|
self.position_entry_time = None
|
||||||
|
return 0
|
||||||
|
elif dominant_action == 0 and dominant_confidence >= self.entry_confidence_threshold:
|
||||||
|
# Very strong SELL signal - close long and enter short
|
||||||
|
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||||
|
logger.info(f"FLIPPING from LONG to SHORT at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||||
|
self.current_position = -1.0
|
||||||
|
self.position_entry_price = current_price
|
||||||
|
self.position_entry_time = time.time()
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
# Hold the long position
|
||||||
|
return None
|
||||||
|
|
||||||
|
elif self.current_position < 0: # Short position
|
||||||
|
if dominant_action == 1 and dominant_confidence >= self.exit_confidence_threshold:
|
||||||
|
# BUY signal with enough confidence to close short position
|
||||||
|
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||||
|
logger.info(f"CLOSING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||||
|
self.current_position = 0.0
|
||||||
|
self.position_entry_price = 0.0
|
||||||
|
self.position_entry_time = None
|
||||||
|
return 1
|
||||||
|
elif dominant_action == 1 and dominant_confidence >= self.entry_confidence_threshold:
|
||||||
|
# Very strong BUY signal - close short and enter long
|
||||||
|
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||||
|
logger.info(f"FLIPPING from SHORT to LONG at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||||
|
self.current_position = 1.0
|
||||||
|
self.position_entry_price = current_price
|
||||||
|
self.position_entry_time = time.time()
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
# Hold the short position
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def replay(self, experiences=None):
|
def replay(self, experiences=None):
|
||||||
"""Train the model using experiences from memory"""
|
"""Train the model using experiences from memory"""
|
||||||
@ -658,8 +599,16 @@ class DQNAgent:
|
|||||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
||||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
# Get next Q values with target network
|
# Enhanced Double DQN implementation
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
if self.use_double_dqn:
|
||||||
|
# Double DQN: Use policy network to select actions, target network to evaluate
|
||||||
|
policy_q_values, _, _, _, _ = self.policy_net(next_states)
|
||||||
|
next_actions = policy_q_values.argmax(1)
|
||||||
|
target_q_values_all, _, _, _, _ = self.target_net(next_states)
|
||||||
|
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||||
|
else:
|
||||||
|
# Standard DQN: Use target network for both selection and evaluation
|
||||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
||||||
next_q_values = next_q_values.max(1)[0]
|
next_q_values = next_q_values.max(1)[0]
|
||||||
|
|
||||||
@ -699,16 +648,25 @@ class DQNAgent:
|
|||||||
# Backward pass
|
# Backward pass
|
||||||
total_loss.backward()
|
total_loss.backward()
|
||||||
|
|
||||||
# Clip gradients to avoid exploding gradients
|
# Enhanced gradient clipping with configurable norm
|
||||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.gradient_clip_norm)
|
||||||
|
|
||||||
# Update weights
|
# Update weights
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
# Update target network if needed
|
# Enhanced target network update tracking
|
||||||
self.update_count += 1
|
self.training_steps += 1
|
||||||
if self.update_count % self.target_update == 0:
|
if self.training_steps % self.target_update_freq == 0:
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||||
|
logger.debug(f"Target network updated at step {self.training_steps}")
|
||||||
|
|
||||||
|
# Enhanced statistics tracking
|
||||||
|
self.epsilon_history.append(self.epsilon)
|
||||||
|
|
||||||
|
# Calculate and store TD error for analysis
|
||||||
|
with torch.no_grad():
|
||||||
|
td_error = torch.abs(current_q_values - target_q_values).mean().item()
|
||||||
|
self.td_errors.append(td_error)
|
||||||
|
|
||||||
# Return loss
|
# Return loss
|
||||||
return total_loss.item()
|
return total_loss.item()
|
||||||
@ -1169,3 +1127,39 @@ class DQNAgent:
|
|||||||
logger.info(f"Agent state loaded from {path}_agent_state.pt")
|
logger.info(f"Agent state loaded from {path}_agent_state.pt")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
||||||
|
|
||||||
|
def get_position_info(self):
|
||||||
|
"""Get current position information"""
|
||||||
|
return {
|
||||||
|
'position': self.current_position,
|
||||||
|
'entry_price': self.position_entry_price,
|
||||||
|
'entry_time': self.position_entry_time,
|
||||||
|
'entry_threshold': self.entry_confidence_threshold,
|
||||||
|
'exit_threshold': self.exit_confidence_threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_enhanced_training_stats(self):
|
||||||
|
"""Get enhanced RL training statistics with detailed metrics (from EnhancedDQNAgent)"""
|
||||||
|
return {
|
||||||
|
'buffer_size': len(self.memory),
|
||||||
|
'epsilon': self.epsilon,
|
||||||
|
'avg_reward': self.avg_reward,
|
||||||
|
'best_reward': self.best_reward,
|
||||||
|
'recent_rewards': list(self.recent_rewards) if hasattr(self, 'recent_rewards') else [],
|
||||||
|
'no_improvement_count': self.no_improvement_count,
|
||||||
|
# Enhanced statistics from EnhancedDQNAgent
|
||||||
|
'training_steps': self.training_steps,
|
||||||
|
'avg_td_error': np.mean(self.td_errors[-100:]) if self.td_errors else 0.0,
|
||||||
|
'recent_losses': self.losses[-10:] if self.losses else [],
|
||||||
|
'epsilon_trend': self.epsilon_history[-20:] if self.epsilon_history else [],
|
||||||
|
'specialized_buffers': {
|
||||||
|
'extrema_memory': len(self.extrema_memory),
|
||||||
|
'positive_memory': len(self.positive_memory),
|
||||||
|
'price_movement_memory': len(self.price_movement_memory)
|
||||||
|
},
|
||||||
|
'market_regime_weights': self.market_regime_weights,
|
||||||
|
'use_double_dqn': self.use_double_dqn,
|
||||||
|
'use_prioritized_replay': self.use_prioritized_replay,
|
||||||
|
'gradient_clip_norm': self.gradient_clip_norm,
|
||||||
|
'target_update_frequency': self.target_update_freq
|
||||||
|
}
|
@ -1,329 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import numpy as np
|
|
||||||
from collections import deque
|
|
||||||
import random
|
|
||||||
from typing import Tuple, List
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
# Add parent directory to path
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
||||||
|
|
||||||
# Import the EnhancedCNN model
|
|
||||||
from NN.models.enhanced_cnn import EnhancedCNN, ExampleSiftingDataset
|
|
||||||
|
|
||||||
# Configure logger
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class EnhancedDQNAgent:
|
|
||||||
"""
|
|
||||||
Enhanced Deep Q-Network agent for trading
|
|
||||||
Uses the improved EnhancedCNN model with residual connections and attention mechanisms
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
state_shape: Tuple[int, ...],
|
|
||||||
n_actions: int,
|
|
||||||
learning_rate: float = 0.0003, # Slightly reduced learning rate for stability
|
|
||||||
gamma: float = 0.95, # Discount factor
|
|
||||||
epsilon: float = 1.0,
|
|
||||||
epsilon_min: float = 0.05,
|
|
||||||
epsilon_decay: float = 0.995, # Slower decay for more exploration
|
|
||||||
buffer_size: int = 50000, # Larger memory buffer
|
|
||||||
batch_size: int = 128, # Larger batch size
|
|
||||||
target_update: int = 10, # More frequent target updates
|
|
||||||
confidence_threshold: float = 0.4, # Lower confidence threshold
|
|
||||||
device=None):
|
|
||||||
|
|
||||||
# Extract state dimensions
|
|
||||||
if isinstance(state_shape, tuple) and len(state_shape) > 1:
|
|
||||||
# Multi-dimensional state (like image or sequence)
|
|
||||||
self.state_dim = state_shape
|
|
||||||
else:
|
|
||||||
# 1D state
|
|
||||||
if isinstance(state_shape, tuple):
|
|
||||||
self.state_dim = state_shape[0]
|
|
||||||
else:
|
|
||||||
self.state_dim = state_shape
|
|
||||||
|
|
||||||
# Store parameters
|
|
||||||
self.n_actions = n_actions
|
|
||||||
self.learning_rate = learning_rate
|
|
||||||
self.gamma = gamma
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self.epsilon_min = epsilon_min
|
|
||||||
self.epsilon_decay = epsilon_decay
|
|
||||||
self.buffer_size = buffer_size
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.target_update = target_update
|
|
||||||
self.confidence_threshold = confidence_threshold
|
|
||||||
|
|
||||||
# Set device for computation
|
|
||||||
if device is None:
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
else:
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
# Initialize models with the enhanced CNN
|
|
||||||
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
|
|
||||||
self.target_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
|
|
||||||
|
|
||||||
# Initialize the target network with the same weights as the policy network
|
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
||||||
|
|
||||||
# Set models to eval mode (important for batch norm, dropout)
|
|
||||||
self.target_net.eval()
|
|
||||||
|
|
||||||
# Optimization components
|
|
||||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
|
|
||||||
self.criterion = nn.MSELoss()
|
|
||||||
|
|
||||||
# Experience replay memory with example sifting
|
|
||||||
self.memory = ExampleSiftingDataset(max_examples=buffer_size)
|
|
||||||
self.update_count = 0
|
|
||||||
|
|
||||||
# Confidence tracking
|
|
||||||
self.confidence_history = []
|
|
||||||
self.avg_confidence = 0.0
|
|
||||||
self.max_confidence = 0.0
|
|
||||||
self.min_confidence = 1.0
|
|
||||||
|
|
||||||
# Performance tracking
|
|
||||||
self.losses = []
|
|
||||||
self.rewards = []
|
|
||||||
self.avg_reward = 0.0
|
|
||||||
|
|
||||||
# Check if mixed precision training should be used
|
|
||||||
self.use_mixed_precision = False
|
|
||||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
|
||||||
self.use_mixed_precision = True
|
|
||||||
self.scaler = torch.cuda.amp.GradScaler()
|
|
||||||
logger.info("Mixed precision training enabled")
|
|
||||||
else:
|
|
||||||
logger.info("Mixed precision training disabled")
|
|
||||||
|
|
||||||
# For compatibility with old code
|
|
||||||
self.action_size = n_actions
|
|
||||||
|
|
||||||
logger.info(f"Enhanced DQN Agent using device: {self.device}")
|
|
||||||
logger.info(f"Confidence threshold set to {self.confidence_threshold}")
|
|
||||||
|
|
||||||
def move_models_to_device(self, device=None):
|
|
||||||
"""Move models to the specified device (GPU/CPU)"""
|
|
||||||
if device is not None:
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.policy_net = self.policy_net.to(self.device)
|
|
||||||
self.target_net = self.target_net.to(self.device)
|
|
||||||
logger.info(f"Moved models to {self.device}")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to move models to {self.device}: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _normalize_state(self, state):
|
|
||||||
"""Normalize state for better training stability"""
|
|
||||||
try:
|
|
||||||
# Convert to numpy array if needed
|
|
||||||
if isinstance(state, list):
|
|
||||||
state = np.array(state, dtype=np.float32)
|
|
||||||
|
|
||||||
# Apply normalization based on state shape
|
|
||||||
if len(state.shape) > 1:
|
|
||||||
# Multi-dimensional state - normalize each feature dimension separately
|
|
||||||
for i in range(state.shape[0]):
|
|
||||||
# Skip if all zeros (to avoid division by zero)
|
|
||||||
if np.sum(np.abs(state[i])) > 0:
|
|
||||||
# Standardize each feature dimension
|
|
||||||
mean = np.mean(state[i])
|
|
||||||
std = np.std(state[i])
|
|
||||||
if std > 0:
|
|
||||||
state[i] = (state[i] - mean) / std
|
|
||||||
else:
|
|
||||||
# 1D state vector
|
|
||||||
# Skip if all zeros
|
|
||||||
if np.sum(np.abs(state)) > 0:
|
|
||||||
mean = np.mean(state)
|
|
||||||
std = np.std(state)
|
|
||||||
if std > 0:
|
|
||||||
state = (state - mean) / std
|
|
||||||
|
|
||||||
return state
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error normalizing state: {str(e)}")
|
|
||||||
return state
|
|
||||||
|
|
||||||
def remember(self, state, action, reward, next_state, done):
|
|
||||||
"""Store experience in memory with example sifting"""
|
|
||||||
self.memory.add_example(state, action, reward, next_state, done)
|
|
||||||
|
|
||||||
# Also track rewards for monitoring
|
|
||||||
self.rewards.append(reward)
|
|
||||||
if len(self.rewards) > 100:
|
|
||||||
self.rewards = self.rewards[-100:]
|
|
||||||
self.avg_reward = np.mean(self.rewards)
|
|
||||||
|
|
||||||
def act(self, state, explore=True):
|
|
||||||
"""Choose action using epsilon-greedy policy with built-in confidence thresholding"""
|
|
||||||
if explore and random.random() < self.epsilon:
|
|
||||||
return random.randrange(self.n_actions), 0.0 # Return action and zero confidence
|
|
||||||
|
|
||||||
# Normalize state before inference
|
|
||||||
normalized_state = self._normalize_state(state)
|
|
||||||
|
|
||||||
# Use the EnhancedCNN's act method which includes confidence thresholding
|
|
||||||
action, confidence = self.policy_net.act(normalized_state, explore=explore)
|
|
||||||
|
|
||||||
# Track confidence metrics
|
|
||||||
self.confidence_history.append(confidence)
|
|
||||||
if len(self.confidence_history) > 100:
|
|
||||||
self.confidence_history = self.confidence_history[-100:]
|
|
||||||
|
|
||||||
# Update confidence metrics
|
|
||||||
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
|
|
||||||
self.max_confidence = max(self.max_confidence, confidence)
|
|
||||||
self.min_confidence = min(self.min_confidence, confidence)
|
|
||||||
|
|
||||||
# Log average confidence occasionally
|
|
||||||
if random.random() < 0.01: # 1% of the time
|
|
||||||
logger.info(f"Confidence metrics - Current: {confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
|
|
||||||
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
|
|
||||||
|
|
||||||
return action, confidence
|
|
||||||
|
|
||||||
def replay(self):
|
|
||||||
"""Train the model using experience replay with high-quality examples"""
|
|
||||||
# Check if enough samples in memory
|
|
||||||
if len(self.memory) < self.batch_size:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
# Get batch of experiences
|
|
||||||
batch = self.memory.get_batch(self.batch_size)
|
|
||||||
if batch is None:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
states = torch.FloatTensor(batch['states']).to(self.device)
|
|
||||||
actions = torch.LongTensor(batch['actions']).to(self.device)
|
|
||||||
rewards = torch.FloatTensor(batch['rewards']).to(self.device)
|
|
||||||
next_states = torch.FloatTensor(batch['next_states']).to(self.device)
|
|
||||||
dones = torch.FloatTensor(batch['dones']).to(self.device)
|
|
||||||
|
|
||||||
# Compute Q values
|
|
||||||
self.policy_net.train() # Set to training mode
|
|
||||||
|
|
||||||
# Get current Q values
|
|
||||||
if self.use_mixed_precision:
|
|
||||||
with torch.cuda.amp.autocast():
|
|
||||||
# Get current Q values
|
|
||||||
q_values, _, _, _ = self.policy_net(states)
|
|
||||||
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
||||||
|
|
||||||
# Compute target Q values
|
|
||||||
with torch.no_grad():
|
|
||||||
self.target_net.eval()
|
|
||||||
next_q_values, _, _, _ = self.target_net(next_states)
|
|
||||||
next_q = next_q_values.max(1)[0]
|
|
||||||
target_q = rewards + (1 - dones) * self.gamma * next_q
|
|
||||||
|
|
||||||
# Compute loss
|
|
||||||
loss = self.criterion(current_q, target_q)
|
|
||||||
|
|
||||||
# Perform backpropagation with mixed precision
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
self.scaler.scale(loss).backward()
|
|
||||||
self.scaler.unscale_(self.optimizer)
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
|
||||||
self.scaler.step(self.optimizer)
|
|
||||||
self.scaler.update()
|
|
||||||
else:
|
|
||||||
# Standard precision training
|
|
||||||
# Get current Q values
|
|
||||||
q_values, _, _, _ = self.policy_net(states)
|
|
||||||
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
||||||
|
|
||||||
# Compute target Q values
|
|
||||||
with torch.no_grad():
|
|
||||||
self.target_net.eval()
|
|
||||||
next_q_values, _, _, _ = self.target_net(next_states)
|
|
||||||
next_q = next_q_values.max(1)[0]
|
|
||||||
target_q = rewards + (1 - dones) * self.gamma * next_q
|
|
||||||
|
|
||||||
# Compute loss
|
|
||||||
loss = self.criterion(current_q, target_q)
|
|
||||||
|
|
||||||
# Perform backpropagation
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
# Track loss
|
|
||||||
loss_value = loss.item()
|
|
||||||
self.losses.append(loss_value)
|
|
||||||
if len(self.losses) > 100:
|
|
||||||
self.losses = self.losses[-100:]
|
|
||||||
|
|
||||||
# Update target network
|
|
||||||
self.update_count += 1
|
|
||||||
if self.update_count % self.target_update == 0:
|
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
||||||
logger.info(f"Updated target network (step {self.update_count})")
|
|
||||||
|
|
||||||
# Decay epsilon
|
|
||||||
if self.epsilon > self.epsilon_min:
|
|
||||||
self.epsilon *= self.epsilon_decay
|
|
||||||
|
|
||||||
return loss_value
|
|
||||||
|
|
||||||
def save(self, path):
|
|
||||||
"""Save agent state and models"""
|
|
||||||
self.policy_net.save(f"{path}_policy")
|
|
||||||
self.target_net.save(f"{path}_target")
|
|
||||||
|
|
||||||
# Save agent state
|
|
||||||
torch.save({
|
|
||||||
'epsilon': self.epsilon,
|
|
||||||
'confidence_threshold': self.confidence_threshold,
|
|
||||||
'losses': self.losses,
|
|
||||||
'rewards': self.rewards,
|
|
||||||
'avg_reward': self.avg_reward,
|
|
||||||
'confidence_history': self.confidence_history,
|
|
||||||
'avg_confidence': self.avg_confidence,
|
|
||||||
'max_confidence': self.max_confidence,
|
|
||||||
'min_confidence': self.min_confidence,
|
|
||||||
'update_count': self.update_count
|
|
||||||
}, f"{path}_agent_state.pt")
|
|
||||||
|
|
||||||
logger.info(f"Agent state saved to {path}_agent_state.pt")
|
|
||||||
|
|
||||||
def load(self, path):
|
|
||||||
"""Load agent state and models"""
|
|
||||||
policy_loaded = self.policy_net.load(f"{path}_policy")
|
|
||||||
target_loaded = self.target_net.load(f"{path}_target")
|
|
||||||
|
|
||||||
# Load agent state if available
|
|
||||||
agent_state_path = f"{path}_agent_state.pt"
|
|
||||||
if os.path.exists(agent_state_path):
|
|
||||||
try:
|
|
||||||
state = torch.load(agent_state_path)
|
|
||||||
self.epsilon = state.get('epsilon', self.epsilon)
|
|
||||||
self.confidence_threshold = state.get('confidence_threshold', self.confidence_threshold)
|
|
||||||
self.policy_net.confidence_threshold = self.confidence_threshold
|
|
||||||
self.target_net.confidence_threshold = self.confidence_threshold
|
|
||||||
self.losses = state.get('losses', [])
|
|
||||||
self.rewards = state.get('rewards', [])
|
|
||||||
self.avg_reward = state.get('avg_reward', 0.0)
|
|
||||||
self.confidence_history = state.get('confidence_history', [])
|
|
||||||
self.avg_confidence = state.get('avg_confidence', 0.0)
|
|
||||||
self.max_confidence = state.get('max_confidence', 0.0)
|
|
||||||
self.min_confidence = state.get('min_confidence', 1.0)
|
|
||||||
self.update_count = state.get('update_count', 0)
|
|
||||||
logger.info(f"Agent state loaded from {agent_state_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading agent state: {str(e)}")
|
|
||||||
|
|
||||||
return policy_loaded and target_loaded
|
|
@ -110,96 +110,119 @@ class EnhancedCNN(nn.Module):
|
|||||||
logger.info(f"EnhancedCNN initialized with input shape: {input_shape}, actions: {n_actions}")
|
logger.info(f"EnhancedCNN initialized with input shape: {input_shape}, actions: {n_actions}")
|
||||||
|
|
||||||
def _build_network(self):
|
def _build_network(self):
|
||||||
"""Build the MASSIVELY enhanced neural network for 4GB VRAM budget"""
|
"""Build the ULTRA MASSIVE enhanced neural network for maximum learning capacity"""
|
||||||
|
|
||||||
# MASSIVELY SCALED ARCHITECTURE for 4GB VRAM (up to ~50M parameters)
|
# ULTRA MASSIVE SCALED ARCHITECTURE for maximum learning (up to ~100M parameters)
|
||||||
if self.channels > 1:
|
if self.channels > 1:
|
||||||
# Massive convolutional backbone with deeper residual blocks
|
# Ultra massive convolutional backbone with much deeper residual blocks
|
||||||
self.conv_layers = nn.Sequential(
|
self.conv_layers = nn.Sequential(
|
||||||
# Initial large conv block
|
# Initial ultra large conv block
|
||||||
nn.Conv1d(self.channels, 256, kernel_size=7, padding=3), # Much wider initial layer
|
nn.Conv1d(self.channels, 512, kernel_size=7, padding=3), # Ultra wide initial layer
|
||||||
nn.BatchNorm1d(256),
|
nn.BatchNorm1d(512),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.1),
|
nn.Dropout(0.1),
|
||||||
|
|
||||||
# First residual stage - 256 channels
|
# First residual stage - 512 channels
|
||||||
ResidualBlock(256, 512),
|
ResidualBlock(512, 768),
|
||||||
ResidualBlock(512, 512),
|
ResidualBlock(768, 768),
|
||||||
ResidualBlock(512, 512),
|
ResidualBlock(768, 768),
|
||||||
|
ResidualBlock(768, 768), # Additional layer
|
||||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||||
nn.Dropout(0.2),
|
nn.Dropout(0.2),
|
||||||
|
|
||||||
# Second residual stage - 512 channels
|
# Second residual stage - 768 to 1024 channels
|
||||||
ResidualBlock(512, 1024),
|
ResidualBlock(768, 1024),
|
||||||
ResidualBlock(1024, 1024),
|
ResidualBlock(1024, 1024),
|
||||||
ResidualBlock(1024, 1024),
|
ResidualBlock(1024, 1024),
|
||||||
|
ResidualBlock(1024, 1024), # Additional layer
|
||||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||||
nn.Dropout(0.25),
|
nn.Dropout(0.25),
|
||||||
|
|
||||||
# Third residual stage - 1024 channels
|
# Third residual stage - 1024 to 1536 channels
|
||||||
ResidualBlock(1024, 1536),
|
ResidualBlock(1024, 1536),
|
||||||
ResidualBlock(1536, 1536),
|
ResidualBlock(1536, 1536),
|
||||||
ResidualBlock(1536, 1536),
|
ResidualBlock(1536, 1536),
|
||||||
|
ResidualBlock(1536, 1536), # Additional layer
|
||||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
|
|
||||||
# Fourth residual stage - 1536 channels (MASSIVE)
|
# Fourth residual stage - 1536 to 2048 channels
|
||||||
ResidualBlock(1536, 2048),
|
ResidualBlock(1536, 2048),
|
||||||
ResidualBlock(2048, 2048),
|
ResidualBlock(2048, 2048),
|
||||||
ResidualBlock(2048, 2048),
|
ResidualBlock(2048, 2048),
|
||||||
|
ResidualBlock(2048, 2048), # Additional layer
|
||||||
|
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
|
||||||
|
# Fifth residual stage - ULTRA MASSIVE 2048 to 3072 channels
|
||||||
|
ResidualBlock(2048, 3072),
|
||||||
|
ResidualBlock(3072, 3072),
|
||||||
|
ResidualBlock(3072, 3072),
|
||||||
|
ResidualBlock(3072, 3072),
|
||||||
nn.AdaptiveAvgPool1d(1) # Global average pooling
|
nn.AdaptiveAvgPool1d(1) # Global average pooling
|
||||||
)
|
)
|
||||||
# Massive feature dimension after conv layers
|
# Ultra massive feature dimension after conv layers
|
||||||
self.conv_features = 2048
|
self.conv_features = 3072
|
||||||
else:
|
else:
|
||||||
# For 1D vectors, use massive dense preprocessing
|
# For 1D vectors, use ultra massive dense preprocessing
|
||||||
self.conv_layers = None
|
self.conv_layers = None
|
||||||
self.conv_features = 0
|
self.conv_features = 0
|
||||||
|
|
||||||
# MASSIVE fully connected feature extraction layers
|
# ULTRA MASSIVE fully connected feature extraction layers
|
||||||
if self.conv_layers is None:
|
if self.conv_layers is None:
|
||||||
# For 1D inputs - massive feature extraction
|
# For 1D inputs - ultra massive feature extraction
|
||||||
self.fc1 = nn.Linear(self.feature_dim, 2048)
|
self.fc1 = nn.Linear(self.feature_dim, 3072)
|
||||||
self.features_dim = 2048
|
self.features_dim = 3072
|
||||||
else:
|
else:
|
||||||
# For data processed by massive conv layers
|
# For data processed by ultra massive conv layers
|
||||||
self.fc1 = nn.Linear(self.conv_features, 2048)
|
self.fc1 = nn.Linear(self.conv_features, 3072)
|
||||||
self.features_dim = 2048
|
self.features_dim = 3072
|
||||||
|
|
||||||
# MASSIVE common feature extraction with multiple attention layers
|
# ULTRA MASSIVE common feature extraction with multiple deep layers
|
||||||
self.fc_layers = nn.Sequential(
|
self.fc_layers = nn.Sequential(
|
||||||
self.fc1,
|
self.fc1,
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(2048, 2048), # Keep massive width
|
nn.Linear(3072, 3072), # Keep ultra massive width
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(2048, 1536), # Still very wide
|
nn.Linear(3072, 2560), # Ultra wide hidden layer
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(1536, 1024), # Large hidden layer
|
nn.Linear(2560, 2048), # Still very wide
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(1024, 768), # Final feature representation
|
nn.Linear(2048, 1536), # Large hidden layer
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(1536, 1024), # Final feature representation
|
||||||
nn.ReLU()
|
nn.ReLU()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Multiple attention mechanisms for different aspects
|
# Multiple attention mechanisms for different aspects (larger capacity)
|
||||||
self.price_attention = SelfAttention(768)
|
self.price_attention = SelfAttention(1024) # Increased from 768
|
||||||
self.volume_attention = SelfAttention(768)
|
self.volume_attention = SelfAttention(1024)
|
||||||
self.trend_attention = SelfAttention(768)
|
self.trend_attention = SelfAttention(1024)
|
||||||
self.volatility_attention = SelfAttention(768)
|
self.volatility_attention = SelfAttention(1024)
|
||||||
|
self.momentum_attention = SelfAttention(1024) # Additional attention
|
||||||
|
self.microstructure_attention = SelfAttention(1024) # Additional attention
|
||||||
|
|
||||||
# Attention fusion layer
|
# Ultra massive attention fusion layer
|
||||||
self.attention_fusion = nn.Sequential(
|
self.attention_fusion = nn.Sequential(
|
||||||
nn.Linear(768 * 4, 1024), # Combine all attention outputs
|
nn.Linear(1024 * 6, 2048), # Combine all 6 attention outputs
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(1024, 768)
|
nn.Linear(2048, 1536),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(1536, 1024)
|
||||||
)
|
)
|
||||||
|
|
||||||
# MASSIVE dueling architecture with deeper networks
|
# ULTRA MASSIVE dueling architecture with much deeper networks
|
||||||
self.advantage_stream = nn.Sequential(
|
self.advantage_stream = nn.Sequential(
|
||||||
|
nn.Linear(1024, 768),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
nn.Linear(768, 512),
|
nn.Linear(768, 512),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
@ -212,6 +235,9 @@ class EnhancedCNN(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.value_stream = nn.Sequential(
|
self.value_stream = nn.Sequential(
|
||||||
|
nn.Linear(1024, 768),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
nn.Linear(768, 512),
|
nn.Linear(768, 512),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
@ -223,8 +249,11 @@ class EnhancedCNN(nn.Module):
|
|||||||
nn.Linear(128, 1)
|
nn.Linear(128, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# MASSIVE extrema detection head with ensemble predictions
|
# ULTRA MASSIVE extrema detection head with deeper ensemble predictions
|
||||||
self.extrema_head = nn.Sequential(
|
self.extrema_head = nn.Sequential(
|
||||||
|
nn.Linear(1024, 768),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
nn.Linear(768, 512),
|
nn.Linear(768, 512),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
@ -236,9 +265,12 @@ class EnhancedCNN(nn.Module):
|
|||||||
nn.Linear(128, 3) # 0=bottom, 1=top, 2=neither
|
nn.Linear(128, 3) # 0=bottom, 1=top, 2=neither
|
||||||
)
|
)
|
||||||
|
|
||||||
# MASSIVE multi-timeframe price prediction heads
|
# ULTRA MASSIVE multi-timeframe price prediction heads
|
||||||
self.price_pred_immediate = nn.Sequential(
|
self.price_pred_immediate = nn.Sequential(
|
||||||
nn.Linear(768, 256),
|
nn.Linear(1024, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(512, 256),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(256, 128),
|
nn.Linear(256, 128),
|
||||||
@ -247,7 +279,10 @@ class EnhancedCNN(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.price_pred_midterm = nn.Sequential(
|
self.price_pred_midterm = nn.Sequential(
|
||||||
nn.Linear(768, 256),
|
nn.Linear(1024, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(512, 256),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(256, 128),
|
nn.Linear(256, 128),
|
||||||
@ -256,7 +291,10 @@ class EnhancedCNN(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.price_pred_longterm = nn.Sequential(
|
self.price_pred_longterm = nn.Sequential(
|
||||||
nn.Linear(768, 256),
|
nn.Linear(1024, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(512, 256),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(256, 128),
|
nn.Linear(256, 128),
|
||||||
@ -264,8 +302,11 @@ class EnhancedCNN(nn.Module):
|
|||||||
nn.Linear(128, 3) # Up, Down, Sideways
|
nn.Linear(128, 3) # Up, Down, Sideways
|
||||||
)
|
)
|
||||||
|
|
||||||
# MASSIVE value prediction with ensemble approaches
|
# ULTRA MASSIVE value prediction with ensemble approaches
|
||||||
self.price_pred_value = nn.Sequential(
|
self.price_pred_value = nn.Sequential(
|
||||||
|
nn.Linear(1024, 768),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
nn.Linear(768, 512),
|
nn.Linear(768, 512),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
@ -280,7 +321,10 @@ class EnhancedCNN(nn.Module):
|
|||||||
# Additional specialized prediction heads for better accuracy
|
# Additional specialized prediction heads for better accuracy
|
||||||
# Volatility prediction head
|
# Volatility prediction head
|
||||||
self.volatility_head = nn.Sequential(
|
self.volatility_head = nn.Sequential(
|
||||||
nn.Linear(768, 256),
|
nn.Linear(1024, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(512, 256),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(256, 128),
|
nn.Linear(256, 128),
|
||||||
@ -290,7 +334,10 @@ class EnhancedCNN(nn.Module):
|
|||||||
|
|
||||||
# Support/Resistance level detection head
|
# Support/Resistance level detection head
|
||||||
self.support_resistance_head = nn.Sequential(
|
self.support_resistance_head = nn.Sequential(
|
||||||
nn.Linear(768, 256),
|
nn.Linear(1024, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(512, 256),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(256, 128),
|
nn.Linear(256, 128),
|
||||||
@ -300,7 +347,10 @@ class EnhancedCNN(nn.Module):
|
|||||||
|
|
||||||
# Market regime classification head
|
# Market regime classification head
|
||||||
self.market_regime_head = nn.Sequential(
|
self.market_regime_head = nn.Sequential(
|
||||||
nn.Linear(768, 256),
|
nn.Linear(1024, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(512, 256),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(256, 128),
|
nn.Linear(256, 128),
|
||||||
@ -310,7 +360,10 @@ class EnhancedCNN(nn.Module):
|
|||||||
|
|
||||||
# Risk assessment head
|
# Risk assessment head
|
||||||
self.risk_head = nn.Sequential(
|
self.risk_head = nn.Sequential(
|
||||||
nn.Linear(768, 256),
|
nn.Linear(1024, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(512, 256),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(256, 128),
|
nn.Linear(256, 128),
|
||||||
@ -330,7 +383,7 @@ class EnhancedCNN(nn.Module):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Forward pass through the MASSIVE network"""
|
"""Forward pass through the ULTRA MASSIVE network"""
|
||||||
batch_size = x.size(0)
|
batch_size = x.size(0)
|
||||||
|
|
||||||
# Process different input shapes
|
# Process different input shapes
|
||||||
@ -349,7 +402,7 @@ class EnhancedCNN(nn.Module):
|
|||||||
total_features = x_reshaped.size(1) * x_reshaped.size(2)
|
total_features = x_reshaped.size(1) * x_reshaped.size(2)
|
||||||
self._check_rebuild_network(total_features)
|
self._check_rebuild_network(total_features)
|
||||||
|
|
||||||
# Apply massive convolutions
|
# Apply ultra massive convolutions
|
||||||
x_conv = self.conv_layers(x_reshaped)
|
x_conv = self.conv_layers(x_reshaped)
|
||||||
# Flatten: [batch, channels, 1] -> [batch, channels]
|
# Flatten: [batch, channels, 1] -> [batch, channels]
|
||||||
x_flat = x_conv.view(batch_size, -1)
|
x_flat = x_conv.view(batch_size, -1)
|
||||||
@ -364,33 +417,40 @@ class EnhancedCNN(nn.Module):
|
|||||||
if x_flat.size(1) != self.feature_dim:
|
if x_flat.size(1) != self.feature_dim:
|
||||||
self._check_rebuild_network(x_flat.size(1))
|
self._check_rebuild_network(x_flat.size(1))
|
||||||
|
|
||||||
# Apply MASSIVE FC layers to get base features
|
# Apply ULTRA MASSIVE FC layers to get base features
|
||||||
features = self.fc_layers(x_flat) # [batch, 768]
|
features = self.fc_layers(x_flat) # [batch, 1024]
|
||||||
|
|
||||||
# Apply multiple specialized attention mechanisms
|
# Apply multiple specialized attention mechanisms
|
||||||
features_3d = features.unsqueeze(1) # [batch, 1, 768]
|
features_3d = features.unsqueeze(1) # [batch, 1, 1024]
|
||||||
|
|
||||||
# Get attention-refined features for different aspects
|
# Get attention-refined features for different aspects
|
||||||
price_features, _ = self.price_attention(features_3d)
|
price_features, _ = self.price_attention(features_3d)
|
||||||
price_features = price_features.squeeze(1) # [batch, 768]
|
price_features = price_features.squeeze(1) # [batch, 1024]
|
||||||
|
|
||||||
volume_features, _ = self.volume_attention(features_3d)
|
volume_features, _ = self.volume_attention(features_3d)
|
||||||
volume_features = volume_features.squeeze(1) # [batch, 768]
|
volume_features = volume_features.squeeze(1) # [batch, 1024]
|
||||||
|
|
||||||
trend_features, _ = self.trend_attention(features_3d)
|
trend_features, _ = self.trend_attention(features_3d)
|
||||||
trend_features = trend_features.squeeze(1) # [batch, 768]
|
trend_features = trend_features.squeeze(1) # [batch, 1024]
|
||||||
|
|
||||||
volatility_features, _ = self.volatility_attention(features_3d)
|
volatility_features, _ = self.volatility_attention(features_3d)
|
||||||
volatility_features = volatility_features.squeeze(1) # [batch, 768]
|
volatility_features = volatility_features.squeeze(1) # [batch, 1024]
|
||||||
|
|
||||||
|
momentum_features, _ = self.momentum_attention(features_3d)
|
||||||
|
momentum_features = momentum_features.squeeze(1) # [batch, 1024]
|
||||||
|
|
||||||
|
microstructure_features, _ = self.microstructure_attention(features_3d)
|
||||||
|
microstructure_features = microstructure_features.squeeze(1) # [batch, 1024]
|
||||||
|
|
||||||
# Fuse all attention outputs
|
# Fuse all attention outputs
|
||||||
combined_attention = torch.cat([
|
combined_attention = torch.cat([
|
||||||
price_features, volume_features,
|
price_features, volume_features,
|
||||||
trend_features, volatility_features
|
trend_features, volatility_features,
|
||||||
], dim=1) # [batch, 768*4]
|
momentum_features, microstructure_features
|
||||||
|
], dim=1) # [batch, 1024*6]
|
||||||
|
|
||||||
# Apply attention fusion to get final refined features
|
# Apply attention fusion to get final refined features
|
||||||
features_refined = self.attention_fusion(combined_attention) # [batch, 768]
|
features_refined = self.attention_fusion(combined_attention) # [batch, 1024]
|
||||||
|
|
||||||
# Calculate advantage and value (Dueling DQN architecture)
|
# Calculate advantage and value (Dueling DQN architecture)
|
||||||
advantage = self.advantage_stream(features_refined)
|
advantage = self.advantage_stream(features_refined)
|
||||||
@ -399,7 +459,7 @@ class EnhancedCNN(nn.Module):
|
|||||||
# Combine for Q-values (Dueling architecture)
|
# Combine for Q-values (Dueling architecture)
|
||||||
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
# Get massive ensemble of predictions
|
# Get ultra massive ensemble of predictions
|
||||||
|
|
||||||
# Extrema predictions (bottom/top/neither detection)
|
# Extrema predictions (bottom/top/neither detection)
|
||||||
extrema_pred = self.extrema_head(features_refined)
|
extrema_pred = self.extrema_head(features_refined)
|
||||||
@ -435,7 +495,7 @@ class EnhancedCNN(nn.Module):
|
|||||||
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
|
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
|
||||||
|
|
||||||
def act(self, state, explore=True):
|
def act(self, state, explore=True):
|
||||||
"""Enhanced action selection with massive model predictions"""
|
"""Enhanced action selection with ultra massive model predictions"""
|
||||||
if explore and np.random.random() < 0.1: # 10% random exploration
|
if explore and np.random.random() < 0.1: # 10% random exploration
|
||||||
return np.random.choice(self.n_actions)
|
return np.random.choice(self.n_actions)
|
||||||
|
|
||||||
@ -471,7 +531,7 @@ class EnhancedCNN(nn.Module):
|
|||||||
risk_class = torch.argmax(risk, dim=1).item()
|
risk_class = torch.argmax(risk, dim=1).item()
|
||||||
risk_labels = ['Low Risk', 'Medium Risk', 'High Risk', 'Extreme Risk']
|
risk_labels = ['Low Risk', 'Medium Risk', 'High Risk', 'Extreme Risk']
|
||||||
|
|
||||||
logger.info(f"MASSIVE Model Predictions:")
|
logger.info(f"ULTRA MASSIVE Model Predictions:")
|
||||||
logger.info(f" Volatility: {volatility_labels[volatility_class]} ({volatility[0, volatility_class]:.3f})")
|
logger.info(f" Volatility: {volatility_labels[volatility_class]} ({volatility[0, volatility_class]:.3f})")
|
||||||
logger.info(f" Support/Resistance: {sr_labels[sr_class]} ({sr[0, sr_class]:.3f})")
|
logger.info(f" Support/Resistance: {sr_labels[sr_class]} ({sr[0, sr_class]:.3f})")
|
||||||
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[0, regime_class]:.3f})")
|
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[0, regime_class]:.3f})")
|
||||||
|
98
_dev/cleanup_models_now.py
Normal file
98
_dev/cleanup_models_now.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Immediate Model Cleanup Script
|
||||||
|
|
||||||
|
This script will clean up all existing model files and prepare the system
|
||||||
|
for fresh training with the new model management system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from model_manager import ModelManager
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run the model cleanup"""
|
||||||
|
|
||||||
|
# Configure logging for better output
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("GOGO2 MODEL CLEANUP SYSTEM")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
print("This script will:")
|
||||||
|
print("1. Delete ALL existing model files (.pt, .pth)")
|
||||||
|
print("2. Remove ALL checkpoint directories")
|
||||||
|
print("3. Clear model backup directories")
|
||||||
|
print("4. Reset the model registry")
|
||||||
|
print("5. Create clean directory structure")
|
||||||
|
print()
|
||||||
|
print("WARNING: This action cannot be undone!")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Calculate current space usage first
|
||||||
|
try:
|
||||||
|
manager = ModelManager()
|
||||||
|
storage_stats = manager.get_storage_stats()
|
||||||
|
print(f"Current storage usage:")
|
||||||
|
print(f"- Models: {storage_stats['total_models']}")
|
||||||
|
print(f"- Size: {storage_stats['actual_size_mb']:.1f}MB")
|
||||||
|
print()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error checking current storage: {e}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Ask for confirmation
|
||||||
|
print("Type 'CLEANUP' to proceed with the cleanup:")
|
||||||
|
user_input = input("> ").strip()
|
||||||
|
|
||||||
|
if user_input != "CLEANUP":
|
||||||
|
print("Cleanup cancelled. No changes made.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Starting cleanup...")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create manager and run cleanup
|
||||||
|
manager = ModelManager()
|
||||||
|
cleanup_result = manager.cleanup_all_existing_models(confirm=True)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 60)
|
||||||
|
print("CLEANUP COMPLETE")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Files deleted: {cleanup_result['deleted_files']}")
|
||||||
|
print(f"Space freed: {cleanup_result['freed_space_mb']:.1f} MB")
|
||||||
|
print(f"Directories cleaned: {len(cleanup_result['deleted_directories'])}")
|
||||||
|
|
||||||
|
if cleanup_result['errors']:
|
||||||
|
print(f"Errors encountered: {len(cleanup_result['errors'])}")
|
||||||
|
print("Errors:")
|
||||||
|
for error in cleanup_result['errors'][:5]: # Show first 5 errors
|
||||||
|
print(f" - {error}")
|
||||||
|
if len(cleanup_result['errors']) > 5:
|
||||||
|
print(f" ... and {len(cleanup_result['errors']) - 5} more")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("System is now ready for fresh model training!")
|
||||||
|
print("The following directories have been created:")
|
||||||
|
print("- models/best_models/")
|
||||||
|
print("- models/cnn/")
|
||||||
|
print("- models/rl/")
|
||||||
|
print("- models/checkpoints/")
|
||||||
|
print("- NN/models/saved/")
|
||||||
|
print()
|
||||||
|
print("New models will be automatically managed by the ModelManager.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during cleanup: {e}")
|
||||||
|
logging.exception("Cleanup failed")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
15
config.yaml
15
config.yaml
@ -6,11 +6,12 @@ system:
|
|||||||
log_level: "INFO" # DEBUG, INFO, WARNING, ERROR
|
log_level: "INFO" # DEBUG, INFO, WARNING, ERROR
|
||||||
session_timeout: 3600 # Session timeout in seconds
|
session_timeout: 3600 # Session timeout in seconds
|
||||||
|
|
||||||
# Trading Symbols (extendable/configurable)
|
# Trading Symbols Configuration
|
||||||
|
# Primary trading pair: ETH/USDT (main signals generation)
|
||||||
|
# Reference pair: BTC/USDT (correlation analysis only, no trading signals)
|
||||||
symbols:
|
symbols:
|
||||||
- "ETH/USDC" # MEXC supports ETHUSDC for API trading
|
- "ETH/USDT" # MAIN TRADING PAIR - Generate signals and execute trades
|
||||||
- "BTC/USDT"
|
- "BTC/USDT" # REFERENCE ONLY - For correlation analysis, no direct trading
|
||||||
- "MX/USDT"
|
|
||||||
|
|
||||||
# Timeframes for ultra-fast scalping (500x leverage)
|
# Timeframes for ultra-fast scalping (500x leverage)
|
||||||
timeframes:
|
timeframes:
|
||||||
@ -179,11 +180,9 @@ mexc_trading:
|
|||||||
require_confirmation: false # No manual confirmation for live trading
|
require_confirmation: false # No manual confirmation for live trading
|
||||||
emergency_stop: false # Emergency stop all trading
|
emergency_stop: false # Emergency stop all trading
|
||||||
|
|
||||||
# Supported symbols for live trading
|
# Supported symbols for live trading (ONLY ETH)
|
||||||
allowed_symbols:
|
allowed_symbols:
|
||||||
- "ETH/USDC" # MEXC supports ETHUSDC for API trading
|
- "ETH/USDT" # MAIN TRADING PAIR - Only this pair is actively traded
|
||||||
- "BTC/USDT"
|
|
||||||
- "MX/USDT"
|
|
||||||
|
|
||||||
# Trading hours (UTC)
|
# Trading hours (UTC)
|
||||||
trading_hours:
|
trading_hours:
|
||||||
|
@ -54,16 +54,23 @@ run cnn training fron the dashboard as well - on each pivot point we inference a
|
|||||||
|
|
||||||
well, we have sell signals. don't we sell at the exact moment when we have long position and execute a sell signal? I see now we're totaly invested. change the model outputs too include cash signal (or learn to make decision to not enter position when we're not certain about where the market will go. this way we will only enter when the price move is clearly visible and most probable) learn to not be so certain when we made a bad trade (replay both entering and exiting position) we can do that by storing the models input data when we make a decision and then train with the known output. This is why we wanted to have a central data probider class which will be preparing the data for all the models er inference and train.
|
well, we have sell signals. don't we sell at the exact moment when we have long position and execute a sell signal? I see now we're totaly invested. change the model outputs too include cash signal (or learn to make decision to not enter position when we're not certain about where the market will go. this way we will only enter when the price move is clearly visible and most probable) learn to not be so certain when we made a bad trade (replay both entering and exiting position) we can do that by storing the models input data when we make a decision and then train with the known output. This is why we wanted to have a central data probider class which will be preparing the data for all the models er inference and train.
|
||||||
|
|
||||||
I see we're always invested. adjust the training, reward functions and possibly model outputs to include CASH signal where we sell our positions but we keep off the market. or use the orchestrator to learn to make that decison when gets uncertain signals from the expert models.mods hould learn to effectively spot setups in the market which are with high risk/reward level and act on theese
|
I see we're always invested. adjust the training, reward functions use the orchestrator to learn to make that decison when gets uncertain signals from the expert models.mods hould learn to effectively spot setups in the market which are with high risk/reward level and act on theese
|
||||||
|
|
||||||
|
|
||||||
also, implement risk management (stop loss)
|
|
||||||
make all dashboard processes run on the server without need of dashboard page to be open in a browser. add Start/Stop toggle on the dash to control it, but all processes should hapen on the server and the dash is just a way to display and contrl them. auto start when we start the web server.
|
|
||||||
|
|
||||||
|
|
||||||
|
I see we're always invested. adjust the training, reward functions use the orchestrator to learn to make that decison when gets uncertain signals from the expert models.mods hould learn to effectively spot setups in the market which are with high risk/reward level and act on theese
|
||||||
if that does not work I think we can make it simpler and easier to train if we have just 2 model actions buy/sell. we don't need hold signal, as until we have action we hold. And when we are long and we get a sell signal - we close. and enter short on consequtive sell signal. also, we will have different thresholds for entering and exiting. learning to enter when we are more certain
|
if that does not work I think we can make it simpler and easier to train if we have just 2 model actions buy/sell. we don't need hold signal, as until we have action we hold. And when we are long and we get a sell signal - we close. and enter short on consequtive sell signal. also, we will have different thresholds for entering and exiting. learning to enter when we are more certain
|
||||||
this will also help us simplify the training and our codebase to keep it easy to develop.
|
this will also help us simplify the training and our codebase to keep it easy to develop.
|
||||||
as our models are chained, it does not make sense anymore to train them separately. so remove all modes from main_clean and all referenced code. we use only web mode wherehe flow is: we collect data, calculate indicators and pivot points -> CNN -> RL => orchestrator -> broker/web
|
as our models are chained, it does not make sense anymore to train them separately. so remove all modes from main_clean and all referenced code. we use only web mode wherehe flow is: we collect data, calculate indicators and pivot points -> CNN -> RL => orchestrator -> broker/web
|
||||||
|
|
||||||
|
orchestrator model also should be an appropriate MoE model that will be able to learn to make decisions based on the signals from the expert models. it should be able to include more models in the future.
|
||||||
|
|
||||||
|
|
||||||
# DASH
|
# DASH
|
||||||
|
also, implement risk management (stop loss)
|
||||||
|
make all dashboard processes run on the server without need of dashboard page to be open in a browser. add Start/Stop toggle on the dash to control it, but all processes should hapen on the server and the dash is just a way to display and contrl them. auto start when we start the web server.
|
||||||
|
|
||||||
|
all models/training/inference should be run on the server. dashboard should be used only for displaying the data and controlling the processes. let's add a start/stop button to the dashboard to control the processes. also add slider to adjust the buy/sell thresholds for the orchestrator model and therefore bias the agressiveness of the model actions.
|
||||||
|
|
||||||
add a row with small charts showing all the data we feed to the models: the 1m 1h 1d and reference (btc) ohlcv on the dashboard
|
add a row with small charts showing all the data we feed to the models: the 1m 1h 1d and reference (btc) ohlcv on the dashboard
|
117
main_clean.py
117
main_clean.py
@ -1,15 +1,14 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Clean Trading System - Streamlined Entry Point
|
Streamlined Trading System - Web Dashboard Only
|
||||||
|
|
||||||
Simplified entry point with only essential modes:
|
Simplified entry point with only the web dashboard mode:
|
||||||
- test: Test data provider and core components
|
- Streamlined Flow: Data -> Indicators/Pivots -> CNN -> RL -> Orchestrator -> Execution
|
||||||
- web: Live trading dashboard with integrated training pipeline
|
- 2-Action System: BUY/SELL with intelligent position management
|
||||||
|
- Always invested approach with smart risk/reward setup detection
|
||||||
Streamlined Flow: Data -> Indicators/Pivots -> CNN -> RL -> Orchestrator -> Execution
|
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python main_clean.py --mode [test|web] --symbol ETH/USDT
|
python main_clean.py [--symbol ETH/USDT] [--port 8050]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -29,87 +28,12 @@ from core.data_provider import DataProvider
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def run_data_test():
|
|
||||||
"""Test the enhanced data provider and core components"""
|
|
||||||
try:
|
|
||||||
config = get_config()
|
|
||||||
logger.info("Testing Enhanced Data Provider and Core Components...")
|
|
||||||
|
|
||||||
# Test data provider with multiple timeframes
|
|
||||||
data_provider = DataProvider(
|
|
||||||
symbols=['ETH/USDT'],
|
|
||||||
timeframes=['1s', '1m', '1h', '4h']
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test historical data
|
|
||||||
logger.info("Testing historical data fetching...")
|
|
||||||
df = data_provider.get_historical_data('ETH/USDT', '1h', limit=100)
|
|
||||||
if df is not None:
|
|
||||||
logger.info(f"[SUCCESS] Historical data: {len(df)} candles loaded")
|
|
||||||
logger.info(f" Columns: {len(df.columns)} total")
|
|
||||||
logger.info(f" Date range: {df['timestamp'].min()} to {df['timestamp'].max()}")
|
|
||||||
|
|
||||||
# Show indicator breakdown
|
|
||||||
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
|
||||||
indicators = [col for col in df.columns if col not in basic_cols]
|
|
||||||
logger.info(f" Technical indicators: {len(indicators)}")
|
|
||||||
else:
|
|
||||||
logger.error("[FAILED] Failed to load historical data")
|
|
||||||
|
|
||||||
# Test multi-timeframe feature matrix
|
|
||||||
logger.info("Testing multi-timeframe feature matrix...")
|
|
||||||
feature_matrix = data_provider.get_feature_matrix('ETH/USDT', ['1h', '4h'], window_size=20)
|
|
||||||
if feature_matrix is not None:
|
|
||||||
logger.info(f"[SUCCESS] Feature matrix shape: {feature_matrix.shape}")
|
|
||||||
logger.info(f" Timeframes: {feature_matrix.shape[0]}")
|
|
||||||
logger.info(f" Window size: {feature_matrix.shape[1]}")
|
|
||||||
logger.info(f" Features: {feature_matrix.shape[2]}")
|
|
||||||
else:
|
|
||||||
logger.error("[FAILED] Failed to create feature matrix")
|
|
||||||
|
|
||||||
# Test CNN model availability
|
|
||||||
try:
|
|
||||||
from NN.models.cnn_model import CNNModel
|
|
||||||
cnn = CNNModel(n_actions=2) # 2-action system
|
|
||||||
logger.info("[SUCCESS] CNN model initialized with 2 actions (BUY/SELL)")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[WARNING] CNN model not available: {e}")
|
|
||||||
|
|
||||||
# Test RL agent availability
|
|
||||||
try:
|
|
||||||
from NN.models.dqn_agent import DQNAgent
|
|
||||||
agent = DQNAgent(state_shape=(50,), n_actions=2) # 2-action system
|
|
||||||
logger.info("[SUCCESS] RL Agent initialized with 2 actions (BUY/SELL)")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[WARNING] RL Agent not available: {e}")
|
|
||||||
|
|
||||||
# Test orchestrator
|
|
||||||
try:
|
|
||||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
||||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
|
||||||
logger.info("[SUCCESS] Enhanced Trading Orchestrator initialized")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[WARNING] Enhanced Orchestrator not available: {e}")
|
|
||||||
|
|
||||||
# Test health check
|
|
||||||
health = data_provider.health_check()
|
|
||||||
logger.info(f"[SUCCESS] Data provider health check completed")
|
|
||||||
|
|
||||||
logger.info("[SUCCESS] Core system test completed successfully!")
|
|
||||||
logger.info("2-Action System: BUY/SELL only (no HOLD)")
|
|
||||||
logger.info("Streamlined Flow: Data -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in system test: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
raise
|
|
||||||
|
|
||||||
def run_web_dashboard():
|
def run_web_dashboard():
|
||||||
"""Run the streamlined web dashboard with integrated training pipeline"""
|
"""Run the streamlined web dashboard with 2-action system and always-invested approach"""
|
||||||
try:
|
try:
|
||||||
logger.info("Starting Streamlined Trading Dashboard...")
|
logger.info("Starting Streamlined Trading Dashboard...")
|
||||||
logger.info("2-Action System: BUY/SELL with intelligent position management")
|
logger.info("2-Action System: BUY/SELL with intelligent position management")
|
||||||
|
logger.info("Always Invested Approach: Smart risk/reward setup detection")
|
||||||
logger.info("Integrated Training Pipeline: Live data -> Models -> Trading")
|
logger.info("Integrated Training Pipeline: Live data -> Models -> Trading")
|
||||||
|
|
||||||
# Get configuration
|
# Get configuration
|
||||||
@ -143,7 +67,7 @@ def run_web_dashboard():
|
|||||||
model_registry = {}
|
model_registry = {}
|
||||||
logger.warning("Model registry not available, using empty registry")
|
logger.warning("Model registry not available, using empty registry")
|
||||||
|
|
||||||
# Create streamlined orchestrator with 2-action system
|
# Create streamlined orchestrator with 2-action system and always-invested approach
|
||||||
orchestrator = EnhancedTradingOrchestrator(
|
orchestrator = EnhancedTradingOrchestrator(
|
||||||
data_provider=data_provider,
|
data_provider=data_provider,
|
||||||
symbols=config.get('symbols', ['ETH/USDT']),
|
symbols=config.get('symbols', ['ETH/USDT']),
|
||||||
@ -151,6 +75,7 @@ def run_web_dashboard():
|
|||||||
model_registry=model_registry
|
model_registry=model_registry
|
||||||
)
|
)
|
||||||
logger.info("Enhanced Trading Orchestrator with 2-Action System initialized")
|
logger.info("Enhanced Trading Orchestrator with 2-Action System initialized")
|
||||||
|
logger.info("Always Invested: Learning to spot high risk/reward setups")
|
||||||
|
|
||||||
# Create trading executor for live execution
|
# Create trading executor for live execution
|
||||||
trading_executor = TradingExecutor()
|
trading_executor = TradingExecutor()
|
||||||
@ -174,6 +99,7 @@ def run_web_dashboard():
|
|||||||
logger.info("Real-time Indicators & Pivots: ENABLED")
|
logger.info("Real-time Indicators & Pivots: ENABLED")
|
||||||
logger.info("Live Trading Execution: ENABLED")
|
logger.info("Live Trading Execution: ENABLED")
|
||||||
logger.info("2-Action System: BUY/SELL with position intelligence")
|
logger.info("2-Action System: BUY/SELL with position intelligence")
|
||||||
|
logger.info("Always Invested: Different thresholds for entry/exit")
|
||||||
logger.info("Pipeline: Data -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
|
logger.info("Pipeline: Data -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
|
||||||
|
|
||||||
dashboard.run(host=host, port=port, debug=False)
|
dashboard.run(host=host, port=port, debug=False)
|
||||||
@ -198,12 +124,8 @@ def run_web_dashboard():
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""Main entry point with streamlined mode selection"""
|
"""Main entry point with streamlined web-only operation"""
|
||||||
parser = argparse.ArgumentParser(description='Streamlined Trading System - Integrated Pipeline')
|
parser = argparse.ArgumentParser(description='Streamlined Trading System - 2-Action Web Dashboard')
|
||||||
parser.add_argument('--mode',
|
|
||||||
choices=['test', 'web'],
|
|
||||||
default='web',
|
|
||||||
help='Operation mode: test (system check) or web (live trading)')
|
|
||||||
parser.add_argument('--symbol', type=str, default='ETH/USDT',
|
parser.add_argument('--symbol', type=str, default='ETH/USDT',
|
||||||
help='Primary trading symbol (default: ETH/USDT)')
|
help='Primary trading symbol (default: ETH/USDT)')
|
||||||
parser.add_argument('--port', type=int, default=8050,
|
parser.add_argument('--port', type=int, default=8050,
|
||||||
@ -218,18 +140,15 @@ async def main():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("=" * 70)
|
logger.info("=" * 70)
|
||||||
logger.info("STREAMLINED TRADING SYSTEM - INTEGRATED PIPELINE")
|
logger.info("STREAMLINED TRADING SYSTEM - 2-ACTION WEB DASHBOARD")
|
||||||
logger.info(f"Mode: {args.mode.upper()}")
|
|
||||||
logger.info(f"Primary Symbol: {args.symbol}")
|
logger.info(f"Primary Symbol: {args.symbol}")
|
||||||
if args.mode == 'web':
|
logger.info(f"Web Port: {args.port}")
|
||||||
logger.info("Integrated Flow: Data -> Indicators -> CNN -> RL -> Execution")
|
|
||||||
logger.info("2-Action System: BUY/SELL with intelligent position management")
|
logger.info("2-Action System: BUY/SELL with intelligent position management")
|
||||||
|
logger.info("Always Invested: Learning to spot high risk/reward setups")
|
||||||
|
logger.info("Flow: Data -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
|
||||||
logger.info("=" * 70)
|
logger.info("=" * 70)
|
||||||
|
|
||||||
# Route to appropriate mode
|
# Run the web dashboard
|
||||||
if args.mode == 'test':
|
|
||||||
run_data_test()
|
|
||||||
elif args.mode == 'web':
|
|
||||||
run_web_dashboard()
|
run_web_dashboard()
|
||||||
|
|
||||||
logger.info("[SUCCESS] Operation completed successfully!")
|
logger.info("[SUCCESS] Operation completed successfully!")
|
||||||
|
558
model_manager.py
Normal file
558
model_manager.py
Normal file
@ -0,0 +1,558 @@
|
|||||||
|
"""
|
||||||
|
Enhanced Model Management System for Trading Dashboard
|
||||||
|
|
||||||
|
This system provides:
|
||||||
|
- Automatic cleanup of old model checkpoints
|
||||||
|
- Best model tracking with performance metrics
|
||||||
|
- Configurable retention policies
|
||||||
|
- Startup model loading
|
||||||
|
- Performance-based model selection
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import glob
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelMetrics:
|
||||||
|
"""Performance metrics for model evaluation"""
|
||||||
|
accuracy: float = 0.0
|
||||||
|
profit_factor: float = 0.0
|
||||||
|
win_rate: float = 0.0
|
||||||
|
sharpe_ratio: float = 0.0
|
||||||
|
max_drawdown: float = 0.0
|
||||||
|
total_trades: int = 0
|
||||||
|
avg_trade_duration: float = 0.0
|
||||||
|
confidence_score: float = 0.0
|
||||||
|
|
||||||
|
def get_composite_score(self) -> float:
|
||||||
|
"""Calculate composite performance score"""
|
||||||
|
# Weighted composite score
|
||||||
|
weights = {
|
||||||
|
'profit_factor': 0.3,
|
||||||
|
'sharpe_ratio': 0.25,
|
||||||
|
'win_rate': 0.2,
|
||||||
|
'accuracy': 0.15,
|
||||||
|
'confidence_score': 0.1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Normalize values to 0-1 range
|
||||||
|
normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0
|
||||||
|
normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1
|
||||||
|
normalized_win_rate = self.win_rate
|
||||||
|
normalized_accuracy = self.accuracy
|
||||||
|
normalized_confidence = self.confidence_score
|
||||||
|
|
||||||
|
# Apply penalties for poor performance
|
||||||
|
drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown
|
||||||
|
|
||||||
|
score = (
|
||||||
|
weights['profit_factor'] * normalized_pf +
|
||||||
|
weights['sharpe_ratio'] * normalized_sharpe +
|
||||||
|
weights['win_rate'] * normalized_win_rate +
|
||||||
|
weights['accuracy'] * normalized_accuracy +
|
||||||
|
weights['confidence_score'] * normalized_confidence
|
||||||
|
) * drawdown_penalty
|
||||||
|
|
||||||
|
return min(max(score, 0), 1)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInfo:
|
||||||
|
"""Complete model information and metadata"""
|
||||||
|
model_type: str # 'cnn', 'rl', 'transformer'
|
||||||
|
model_name: str
|
||||||
|
file_path: str
|
||||||
|
creation_time: datetime
|
||||||
|
last_updated: datetime
|
||||||
|
file_size_mb: float
|
||||||
|
metrics: ModelMetrics
|
||||||
|
training_episodes: int = 0
|
||||||
|
model_version: str = "1.0"
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON serialization"""
|
||||||
|
data = asdict(self)
|
||||||
|
data['creation_time'] = self.creation_time.isoformat()
|
||||||
|
data['last_updated'] = self.last_updated.isoformat()
|
||||||
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
|
||||||
|
"""Create from dictionary"""
|
||||||
|
data['creation_time'] = datetime.fromisoformat(data['creation_time'])
|
||||||
|
data['last_updated'] = datetime.fromisoformat(data['last_updated'])
|
||||||
|
data['metrics'] = ModelMetrics(**data['metrics'])
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
"""Enhanced model management system"""
|
||||||
|
|
||||||
|
def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None):
|
||||||
|
self.base_dir = Path(base_dir)
|
||||||
|
self.config = config or self._get_default_config()
|
||||||
|
|
||||||
|
# Model directories
|
||||||
|
self.models_dir = self.base_dir / "models"
|
||||||
|
self.nn_models_dir = self.base_dir / "NN" / "models"
|
||||||
|
self.registry_file = self.models_dir / "model_registry.json"
|
||||||
|
self.best_models_dir = self.models_dir / "best_models"
|
||||||
|
|
||||||
|
# Create directories
|
||||||
|
self.best_models_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Model registry
|
||||||
|
self.model_registry: Dict[str, ModelInfo] = {}
|
||||||
|
self._load_registry()
|
||||||
|
|
||||||
|
logger.info(f"Model Manager initialized - Base: {self.base_dir}")
|
||||||
|
logger.info(f"Retention policy: Keep {self.config['max_models_per_type']} best models per type")
|
||||||
|
|
||||||
|
def _get_default_config(self) -> Dict[str, Any]:
|
||||||
|
"""Get default configuration"""
|
||||||
|
return {
|
||||||
|
'max_models_per_type': 3, # Keep top 3 models per type
|
||||||
|
'max_total_models': 10, # Maximum total models to keep
|
||||||
|
'cleanup_frequency_hours': 24, # Cleanup every 24 hours
|
||||||
|
'min_performance_threshold': 0.3, # Minimum composite score
|
||||||
|
'max_checkpoint_age_days': 7, # Delete checkpoints older than 7 days
|
||||||
|
'auto_cleanup_enabled': True,
|
||||||
|
'backup_before_cleanup': True,
|
||||||
|
'model_size_limit_mb': 100, # Individual model size limit
|
||||||
|
'total_storage_limit_gb': 5.0 # Total storage limit
|
||||||
|
}
|
||||||
|
|
||||||
|
def _load_registry(self):
|
||||||
|
"""Load model registry from file"""
|
||||||
|
try:
|
||||||
|
if self.registry_file.exists():
|
||||||
|
with open(self.registry_file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
self.model_registry = {
|
||||||
|
k: ModelInfo.from_dict(v) for k, v in data.items()
|
||||||
|
}
|
||||||
|
logger.info(f"Loaded {len(self.model_registry)} models from registry")
|
||||||
|
else:
|
||||||
|
logger.info("No existing model registry found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading model registry: {e}")
|
||||||
|
self.model_registry = {}
|
||||||
|
|
||||||
|
def _save_registry(self):
|
||||||
|
"""Save model registry to file"""
|
||||||
|
try:
|
||||||
|
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(self.registry_file, 'w') as f:
|
||||||
|
data = {k: v.to_dict() for k, v in self.model_registry.items()}
|
||||||
|
json.dump(data, f, indent=2, default=str)
|
||||||
|
logger.info(f"Saved registry with {len(self.model_registry)} models")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving model registry: {e}")
|
||||||
|
|
||||||
|
def cleanup_all_existing_models(self, confirm: bool = False) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Clean up all existing model files and prepare for 2-action system training
|
||||||
|
|
||||||
|
Args:
|
||||||
|
confirm: If True, perform the cleanup. If False, return what would be cleaned
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with cleanup statistics
|
||||||
|
"""
|
||||||
|
cleanup_stats = {
|
||||||
|
'files_found': 0,
|
||||||
|
'files_deleted': 0,
|
||||||
|
'directories_cleaned': 0,
|
||||||
|
'space_freed_mb': 0.0,
|
||||||
|
'errors': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Model file patterns for both 2-action and legacy 3-action systems
|
||||||
|
model_patterns = [
|
||||||
|
"**/*.pt", "**/*.pth", "**/*.h5", "**/*.pkl", "**/*.joblib", "**/*.model",
|
||||||
|
"**/checkpoint_*", "**/model_*", "**/cnn_*", "**/dqn_*", "**/rl_*"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Directories to clean
|
||||||
|
model_directories = [
|
||||||
|
"models/saved",
|
||||||
|
"NN/models/saved",
|
||||||
|
"NN/models/saved/checkpoints",
|
||||||
|
"NN/models/saved/realtime_checkpoints",
|
||||||
|
"NN/models/saved/realtime_ticks_checkpoints",
|
||||||
|
"model_backups"
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Scan for files to be cleaned
|
||||||
|
for directory in model_directories:
|
||||||
|
dir_path = Path(self.base_dir) / directory
|
||||||
|
if dir_path.exists():
|
||||||
|
for pattern in model_patterns:
|
||||||
|
for file_path in dir_path.glob(pattern):
|
||||||
|
if file_path.is_file():
|
||||||
|
cleanup_stats['files_found'] += 1
|
||||||
|
file_size = file_path.stat().st_size / (1024 * 1024) # MB
|
||||||
|
cleanup_stats['space_freed_mb'] += file_size
|
||||||
|
|
||||||
|
if confirm:
|
||||||
|
try:
|
||||||
|
file_path.unlink()
|
||||||
|
cleanup_stats['files_deleted'] += 1
|
||||||
|
logger.info(f"Deleted model file: {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
cleanup_stats['errors'].append(f"Failed to delete {file_path}: {e}")
|
||||||
|
|
||||||
|
# Clean up empty checkpoint directories
|
||||||
|
for directory in model_directories:
|
||||||
|
dir_path = Path(self.base_dir) / directory
|
||||||
|
if dir_path.exists():
|
||||||
|
for subdir in dir_path.rglob("*"):
|
||||||
|
if subdir.is_dir() and not any(subdir.iterdir()):
|
||||||
|
if confirm:
|
||||||
|
try:
|
||||||
|
subdir.rmdir()
|
||||||
|
cleanup_stats['directories_cleaned'] += 1
|
||||||
|
logger.info(f"Removed empty directory: {subdir}")
|
||||||
|
except Exception as e:
|
||||||
|
cleanup_stats['errors'].append(f"Failed to remove directory {subdir}: {e}")
|
||||||
|
|
||||||
|
if confirm:
|
||||||
|
# Clear the registry for fresh start with 2-action system
|
||||||
|
self.model_registry = {
|
||||||
|
'models': {},
|
||||||
|
'metadata': {
|
||||||
|
'last_updated': datetime.now().isoformat(),
|
||||||
|
'total_models': 0,
|
||||||
|
'system_type': '2_action', # Mark as 2-action system
|
||||||
|
'action_space': ['SELL', 'BUY'],
|
||||||
|
'version': '2.0'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self._save_registry()
|
||||||
|
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("MODEL CLEANUP COMPLETED - 2-ACTION SYSTEM READY")
|
||||||
|
logger.info(f"Files deleted: {cleanup_stats['files_deleted']}")
|
||||||
|
logger.info(f"Space freed: {cleanup_stats['space_freed_mb']:.2f} MB")
|
||||||
|
logger.info(f"Directories cleaned: {cleanup_stats['directories_cleaned']}")
|
||||||
|
logger.info("Registry reset for 2-action system (BUY/SELL)")
|
||||||
|
logger.info("Ready for fresh training with intelligent position management")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
else:
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("MODEL CLEANUP PREVIEW - 2-ACTION SYSTEM MIGRATION")
|
||||||
|
logger.info(f"Files to delete: {cleanup_stats['files_found']}")
|
||||||
|
logger.info(f"Space to free: {cleanup_stats['space_freed_mb']:.2f} MB")
|
||||||
|
logger.info("Run with confirm=True to perform cleanup")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
cleanup_stats['errors'].append(f"Cleanup error: {e}")
|
||||||
|
logger.error(f"Error during model cleanup: {e}")
|
||||||
|
|
||||||
|
return cleanup_stats
|
||||||
|
|
||||||
|
def register_model(self, model_path: str, model_type: str, metrics: Optional[ModelMetrics] = None) -> str:
|
||||||
|
"""
|
||||||
|
Register a new model in the 2-action system
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model file
|
||||||
|
model_type: Type of model ('cnn', 'rl', 'transformer')
|
||||||
|
metrics: Performance metrics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Unique model name/ID
|
||||||
|
"""
|
||||||
|
if not Path(model_path).exists():
|
||||||
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||||
|
|
||||||
|
# Generate unique model name
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
model_name = f"{model_type}_2action_{timestamp}"
|
||||||
|
|
||||||
|
# Get file info
|
||||||
|
file_path = Path(model_path)
|
||||||
|
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||||
|
|
||||||
|
# Default metrics for 2-action system
|
||||||
|
if metrics is None:
|
||||||
|
metrics = ModelMetrics(
|
||||||
|
accuracy=0.0,
|
||||||
|
profit_factor=1.0,
|
||||||
|
win_rate=0.5,
|
||||||
|
sharpe_ratio=0.0,
|
||||||
|
max_drawdown=0.0,
|
||||||
|
confidence_score=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create model info
|
||||||
|
model_info = ModelInfo(
|
||||||
|
model_type=model_type,
|
||||||
|
model_name=model_name,
|
||||||
|
file_path=str(file_path.absolute()),
|
||||||
|
creation_time=datetime.now(),
|
||||||
|
last_updated=datetime.now(),
|
||||||
|
file_size_mb=file_size_mb,
|
||||||
|
metrics=metrics,
|
||||||
|
model_version="2.0" # 2-action system version
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to registry
|
||||||
|
self.model_registry['models'][model_name] = model_info.to_dict()
|
||||||
|
self.model_registry['metadata']['total_models'] = len(self.model_registry['models'])
|
||||||
|
self.model_registry['metadata']['last_updated'] = datetime.now().isoformat()
|
||||||
|
self.model_registry['metadata']['system_type'] = '2_action'
|
||||||
|
self.model_registry['metadata']['action_space'] = ['SELL', 'BUY']
|
||||||
|
|
||||||
|
self._save_registry()
|
||||||
|
|
||||||
|
# Cleanup old models if necessary
|
||||||
|
self._cleanup_models_by_type(model_type)
|
||||||
|
|
||||||
|
logger.info(f"Registered 2-action model: {model_name}")
|
||||||
|
logger.info(f"Model type: {model_type}, Size: {file_size_mb:.2f} MB")
|
||||||
|
logger.info(f"Performance score: {metrics.get_composite_score():.4f}")
|
||||||
|
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def _should_keep_model(self, model_info: ModelInfo) -> bool:
|
||||||
|
"""Determine if model should be kept based on performance"""
|
||||||
|
score = model_info.metrics.get_composite_score()
|
||||||
|
|
||||||
|
# Check minimum threshold
|
||||||
|
if score < self.config['min_performance_threshold']:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check size limit
|
||||||
|
if model_info.file_size_mb > self.config['model_size_limit_mb']:
|
||||||
|
logger.warning(f"Model too large: {model_info.file_size_mb:.1f}MB > {self.config['model_size_limit_mb']}MB")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if better than existing models of same type
|
||||||
|
existing_models = self.get_models_by_type(model_info.model_type)
|
||||||
|
if len(existing_models) >= self.config['max_models_per_type']:
|
||||||
|
# Find worst performing model
|
||||||
|
worst_model = min(existing_models.values(), key=lambda m: m.metrics.get_composite_score())
|
||||||
|
if score <= worst_model.metrics.get_composite_score():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _cleanup_models_by_type(self, model_type: str):
|
||||||
|
"""Cleanup old models of specific type, keeping only the best ones"""
|
||||||
|
models_of_type = self.get_models_by_type(model_type)
|
||||||
|
max_keep = self.config['max_models_per_type']
|
||||||
|
|
||||||
|
if len(models_of_type) <= max_keep:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Sort by performance score
|
||||||
|
sorted_models = sorted(
|
||||||
|
models_of_type.items(),
|
||||||
|
key=lambda x: x[1].metrics.get_composite_score(),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep only the best models
|
||||||
|
models_to_keep = sorted_models[:max_keep]
|
||||||
|
models_to_remove = sorted_models[max_keep:]
|
||||||
|
|
||||||
|
for model_name, model_info in models_to_remove:
|
||||||
|
try:
|
||||||
|
# Remove file
|
||||||
|
model_path = Path(model_info.file_path)
|
||||||
|
if model_path.exists():
|
||||||
|
model_path.unlink()
|
||||||
|
|
||||||
|
# Remove from registry
|
||||||
|
del self.model_registry[model_name]
|
||||||
|
|
||||||
|
logger.info(f"Removed old model: {model_name} (Score: {model_info.metrics.get_composite_score():.3f})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error removing model {model_name}: {e}")
|
||||||
|
|
||||||
|
def get_models_by_type(self, model_type: str) -> Dict[str, ModelInfo]:
|
||||||
|
"""Get all models of a specific type"""
|
||||||
|
return {
|
||||||
|
name: info for name, info in self.model_registry.items()
|
||||||
|
if info.model_type == model_type
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_best_model(self, model_type: str) -> Optional[ModelInfo]:
|
||||||
|
"""Get the best performing model of a specific type"""
|
||||||
|
models_of_type = self.get_models_by_type(model_type)
|
||||||
|
|
||||||
|
if not models_of_type:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return max(models_of_type.values(), key=lambda m: m.metrics.get_composite_score())
|
||||||
|
|
||||||
|
def load_best_models(self) -> Dict[str, Any]:
|
||||||
|
"""Load the best models for each type"""
|
||||||
|
loaded_models = {}
|
||||||
|
|
||||||
|
for model_type in ['cnn', 'rl', 'transformer']:
|
||||||
|
best_model = self.get_best_model(model_type)
|
||||||
|
|
||||||
|
if best_model:
|
||||||
|
try:
|
||||||
|
model_path = Path(best_model.file_path)
|
||||||
|
if model_path.exists():
|
||||||
|
# Load the model
|
||||||
|
model_data = torch.load(model_path, map_location='cpu')
|
||||||
|
loaded_models[model_type] = {
|
||||||
|
'model': model_data,
|
||||||
|
'info': best_model,
|
||||||
|
'path': str(model_path)
|
||||||
|
}
|
||||||
|
logger.info(f"Loaded best {model_type} model: {best_model.model_name} "
|
||||||
|
f"(Score: {best_model.metrics.get_composite_score():.3f})")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Best {model_type} model file not found: {model_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading {model_type} model: {e}")
|
||||||
|
else:
|
||||||
|
logger.info(f"No {model_type} model available")
|
||||||
|
|
||||||
|
return loaded_models
|
||||||
|
|
||||||
|
def update_model_performance(self, model_name: str, metrics: ModelMetrics):
|
||||||
|
"""Update performance metrics for a model"""
|
||||||
|
if model_name in self.model_registry:
|
||||||
|
self.model_registry[model_name].metrics = metrics
|
||||||
|
self.model_registry[model_name].last_updated = datetime.now()
|
||||||
|
self._save_registry()
|
||||||
|
|
||||||
|
logger.info(f"Updated metrics for {model_name}: Score {metrics.get_composite_score():.3f}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Model {model_name} not found in registry")
|
||||||
|
|
||||||
|
def get_storage_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get storage usage statistics"""
|
||||||
|
total_size_mb = 0
|
||||||
|
model_count = 0
|
||||||
|
|
||||||
|
for model_info in self.model_registry.values():
|
||||||
|
total_size_mb += model_info.file_size_mb
|
||||||
|
model_count += 1
|
||||||
|
|
||||||
|
# Check actual storage usage
|
||||||
|
actual_size_mb = 0
|
||||||
|
if self.best_models_dir.exists():
|
||||||
|
actual_size_mb = sum(
|
||||||
|
f.stat().st_size for f in self.best_models_dir.rglob('*') if f.is_file()
|
||||||
|
) / 1024 / 1024
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_models': model_count,
|
||||||
|
'registered_size_mb': total_size_mb,
|
||||||
|
'actual_size_mb': actual_size_mb,
|
||||||
|
'storage_limit_gb': self.config['total_storage_limit_gb'],
|
||||||
|
'utilization_percent': (actual_size_mb / 1024) / self.config['total_storage_limit_gb'] * 100,
|
||||||
|
'models_by_type': {
|
||||||
|
model_type: len(self.get_models_by_type(model_type))
|
||||||
|
for model_type in ['cnn', 'rl', 'transformer']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Get model performance leaderboard"""
|
||||||
|
leaderboard = []
|
||||||
|
|
||||||
|
for model_name, model_info in self.model_registry.items():
|
||||||
|
leaderboard.append({
|
||||||
|
'name': model_name,
|
||||||
|
'type': model_info.model_type,
|
||||||
|
'score': model_info.metrics.get_composite_score(),
|
||||||
|
'profit_factor': model_info.metrics.profit_factor,
|
||||||
|
'win_rate': model_info.metrics.win_rate,
|
||||||
|
'sharpe_ratio': model_info.metrics.sharpe_ratio,
|
||||||
|
'size_mb': model_info.file_size_mb,
|
||||||
|
'age_days': (datetime.now() - model_info.creation_time).days,
|
||||||
|
'last_updated': model_info.last_updated.strftime('%Y-%m-%d %H:%M')
|
||||||
|
})
|
||||||
|
|
||||||
|
# Sort by score
|
||||||
|
leaderboard.sort(key=lambda x: x['score'], reverse=True)
|
||||||
|
|
||||||
|
return leaderboard
|
||||||
|
|
||||||
|
def cleanup_checkpoints(self) -> Dict[str, Any]:
|
||||||
|
"""Clean up old checkpoint files"""
|
||||||
|
cleanup_summary = {
|
||||||
|
'deleted_files': 0,
|
||||||
|
'freed_space_mb': 0,
|
||||||
|
'errors': []
|
||||||
|
}
|
||||||
|
|
||||||
|
cutoff_date = datetime.now() - timedelta(days=self.config['max_checkpoint_age_days'])
|
||||||
|
|
||||||
|
# Search for checkpoint files
|
||||||
|
checkpoint_patterns = [
|
||||||
|
"**/checkpoint_*.pt",
|
||||||
|
"**/model_*.pt",
|
||||||
|
"**/*checkpoint*",
|
||||||
|
"**/epoch_*.pt"
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in checkpoint_patterns:
|
||||||
|
for file_path in self.base_dir.rglob(pattern):
|
||||||
|
if "best_models" not in str(file_path) and file_path.is_file():
|
||||||
|
try:
|
||||||
|
file_time = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||||
|
if file_time < cutoff_date:
|
||||||
|
size_mb = file_path.stat().st_size / 1024 / 1024
|
||||||
|
file_path.unlink()
|
||||||
|
cleanup_summary['deleted_files'] += 1
|
||||||
|
cleanup_summary['freed_space_mb'] += size_mb
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error deleting checkpoint {file_path}: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
cleanup_summary['errors'].append(error_msg)
|
||||||
|
|
||||||
|
if cleanup_summary['deleted_files'] > 0:
|
||||||
|
logger.info(f"Checkpoint cleanup: Deleted {cleanup_summary['deleted_files']} files, "
|
||||||
|
f"freed {cleanup_summary['freed_space_mb']:.1f}MB")
|
||||||
|
|
||||||
|
return cleanup_summary
|
||||||
|
|
||||||
|
def create_model_manager() -> ModelManager:
|
||||||
|
"""Create and initialize the global model manager"""
|
||||||
|
return ModelManager()
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
# Create model manager
|
||||||
|
manager = ModelManager()
|
||||||
|
|
||||||
|
# Clean up all existing models (with confirmation)
|
||||||
|
print("WARNING: This will delete ALL existing models!")
|
||||||
|
print("Type 'CONFIRM' to proceed:")
|
||||||
|
user_input = input().strip()
|
||||||
|
|
||||||
|
if user_input == "CONFIRM":
|
||||||
|
cleanup_result = manager.cleanup_all_existing_models(confirm=True)
|
||||||
|
print(f"\nCleanup complete:")
|
||||||
|
print(f"- Deleted {cleanup_result['files_deleted']} files")
|
||||||
|
print(f"- Freed {cleanup_result['space_freed_mb']:.1f}MB of space")
|
||||||
|
print(f"- Cleaned {cleanup_result['directories_cleaned']} directories")
|
||||||
|
|
||||||
|
if cleanup_result['errors']:
|
||||||
|
print(f"- {len(cleanup_result['errors'])} errors occurred")
|
||||||
|
else:
|
||||||
|
print("Cleanup cancelled")
|
@ -1,477 +1,477 @@
|
|||||||
#!/usr/bin/env python3
|
# #!/usr/bin/env python3
|
||||||
"""
|
# """
|
||||||
Enhanced RL Training Launcher with Real Data Integration
|
# Enhanced RL Training Launcher with Real Data Integration
|
||||||
|
|
||||||
This script launches the comprehensive RL training system that uses:
|
# This script launches the comprehensive RL training system that uses:
|
||||||
- Real-time tick data (300s window for momentum detection)
|
# - Real-time tick data (300s window for momentum detection)
|
||||||
- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d)
|
# - Multi-timeframe OHLCV data (1s, 1m, 1h, 1d)
|
||||||
- BTC reference data for correlation
|
# - BTC reference data for correlation
|
||||||
- CNN hidden features and predictions
|
# - CNN hidden features and predictions
|
||||||
- Williams Market Structure pivot points
|
# - Williams Market Structure pivot points
|
||||||
- Market microstructure analysis
|
# - Market microstructure analysis
|
||||||
|
|
||||||
The RL model will receive ~13,400 features instead of the previous ~100 basic features.
|
# The RL model will receive ~13,400 features instead of the previous ~100 basic features.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
import asyncio
|
# import asyncio
|
||||||
import logging
|
# import logging
|
||||||
import time
|
# import time
|
||||||
import signal
|
# import signal
|
||||||
import sys
|
# import sys
|
||||||
from datetime import datetime, timedelta
|
# from datetime import datetime, timedelta
|
||||||
from pathlib import Path
|
# from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
# from typing import Dict, List, Optional
|
||||||
|
|
||||||
# Configure logging
|
# # Configure logging
|
||||||
logging.basicConfig(
|
# logging.basicConfig(
|
||||||
level=logging.INFO,
|
# level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
handlers=[
|
# handlers=[
|
||||||
logging.FileHandler('enhanced_rl_training.log'),
|
# logging.FileHandler('enhanced_rl_training.log'),
|
||||||
logging.StreamHandler(sys.stdout)
|
# logging.StreamHandler(sys.stdout)
|
||||||
]
|
# ]
|
||||||
)
|
# )
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
# logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Import our enhanced components
|
# # Import our enhanced components
|
||||||
from core.config import get_config
|
# from core.config import get_config
|
||||||
from core.data_provider import DataProvider
|
# from core.data_provider import DataProvider
|
||||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
# from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||||
from training.enhanced_rl_trainer import EnhancedRLTrainer
|
# from training.enhanced_rl_trainer import EnhancedRLTrainer
|
||||||
from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
|
# from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
|
||||||
from training.williams_market_structure import WilliamsMarketStructure
|
# from training.williams_market_structure import WilliamsMarketStructure
|
||||||
from training.cnn_rl_bridge import CNNRLBridge
|
# from training.cnn_rl_bridge import CNNRLBridge
|
||||||
|
|
||||||
class EnhancedRLTrainingSystem:
|
# class EnhancedRLTrainingSystem:
|
||||||
"""Comprehensive RL training system with real data integration"""
|
# """Comprehensive RL training system with real data integration"""
|
||||||
|
|
||||||
def __init__(self):
|
# def __init__(self):
|
||||||
"""Initialize the enhanced RL training system"""
|
# """Initialize the enhanced RL training system"""
|
||||||
self.config = get_config()
|
# self.config = get_config()
|
||||||
self.running = False
|
# self.running = False
|
||||||
self.data_provider = None
|
# self.data_provider = None
|
||||||
self.orchestrator = None
|
# self.orchestrator = None
|
||||||
self.rl_trainer = None
|
# self.rl_trainer = None
|
||||||
|
|
||||||
# Performance tracking
|
# # Performance tracking
|
||||||
self.training_stats = {
|
# self.training_stats = {
|
||||||
'training_sessions': 0,
|
# 'training_sessions': 0,
|
||||||
'total_experiences': 0,
|
# 'total_experiences': 0,
|
||||||
'avg_state_size': 0,
|
# 'avg_state_size': 0,
|
||||||
'data_quality_score': 0.0,
|
# 'data_quality_score': 0.0,
|
||||||
'last_training_time': None
|
# 'last_training_time': None
|
||||||
}
|
# }
|
||||||
|
|
||||||
logger.info("Enhanced RL Training System initialized")
|
# logger.info("Enhanced RL Training System initialized")
|
||||||
logger.info("Features:")
|
# logger.info("Features:")
|
||||||
logger.info("- Real-time tick data processing (300s window)")
|
# logger.info("- Real-time tick data processing (300s window)")
|
||||||
logger.info("- Multi-timeframe OHLCV analysis (1s, 1m, 1h, 1d)")
|
# logger.info("- Multi-timeframe OHLCV analysis (1s, 1m, 1h, 1d)")
|
||||||
logger.info("- BTC correlation analysis")
|
# logger.info("- BTC correlation analysis")
|
||||||
logger.info("- CNN feature integration")
|
# logger.info("- CNN feature integration")
|
||||||
logger.info("- Williams Market Structure pivot points")
|
# logger.info("- Williams Market Structure pivot points")
|
||||||
logger.info("- ~13,400 feature state vector (vs previous ~100)")
|
# logger.info("- ~13,400 feature state vector (vs previous ~100)")
|
||||||
|
|
||||||
async def initialize(self):
|
# async def initialize(self):
|
||||||
"""Initialize all components"""
|
# """Initialize all components"""
|
||||||
try:
|
# try:
|
||||||
logger.info("Initializing enhanced RL training components...")
|
# logger.info("Initializing enhanced RL training components...")
|
||||||
|
|
||||||
# Initialize data provider with real-time streaming
|
# # Initialize data provider with real-time streaming
|
||||||
logger.info("Setting up data provider with real-time streaming...")
|
# logger.info("Setting up data provider with real-time streaming...")
|
||||||
self.data_provider = DataProvider(
|
# self.data_provider = DataProvider(
|
||||||
symbols=self.config.symbols,
|
# symbols=self.config.symbols,
|
||||||
timeframes=self.config.timeframes
|
# timeframes=self.config.timeframes
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Start real-time data streaming
|
# # Start real-time data streaming
|
||||||
await self.data_provider.start_real_time_streaming()
|
# await self.data_provider.start_real_time_streaming()
|
||||||
logger.info("Real-time data streaming started")
|
# logger.info("Real-time data streaming started")
|
||||||
|
|
||||||
# Wait for initial data collection
|
# # Wait for initial data collection
|
||||||
logger.info("Collecting initial market data...")
|
# logger.info("Collecting initial market data...")
|
||||||
await asyncio.sleep(30) # Allow 30 seconds for data collection
|
# await asyncio.sleep(30) # Allow 30 seconds for data collection
|
||||||
|
|
||||||
# Initialize enhanced orchestrator
|
# # Initialize enhanced orchestrator
|
||||||
logger.info("Initializing enhanced orchestrator...")
|
# logger.info("Initializing enhanced orchestrator...")
|
||||||
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
# self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||||
|
|
||||||
# Initialize enhanced RL trainer with comprehensive state building
|
# # Initialize enhanced RL trainer with comprehensive state building
|
||||||
logger.info("Initializing enhanced RL trainer...")
|
# logger.info("Initializing enhanced RL trainer...")
|
||||||
self.rl_trainer = EnhancedRLTrainer(
|
# self.rl_trainer = EnhancedRLTrainer(
|
||||||
config=self.config,
|
# config=self.config,
|
||||||
orchestrator=self.orchestrator
|
# orchestrator=self.orchestrator
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Verify data availability
|
# # Verify data availability
|
||||||
data_status = await self._verify_data_availability()
|
# data_status = await self._verify_data_availability()
|
||||||
if not data_status['has_sufficient_data']:
|
# if not data_status['has_sufficient_data']:
|
||||||
logger.warning("Insufficient data detected. Continuing with limited training.")
|
# logger.warning("Insufficient data detected. Continuing with limited training.")
|
||||||
logger.warning(f"Data status: {data_status}")
|
# logger.warning(f"Data status: {data_status}")
|
||||||
else:
|
# else:
|
||||||
logger.info("Sufficient data available for comprehensive RL training")
|
# logger.info("Sufficient data available for comprehensive RL training")
|
||||||
logger.info(f"Tick data: {data_status['tick_count']} ticks")
|
# logger.info(f"Tick data: {data_status['tick_count']} ticks")
|
||||||
logger.info(f"OHLCV data: {data_status['ohlcv_bars']} bars")
|
# logger.info(f"OHLCV data: {data_status['ohlcv_bars']} bars")
|
||||||
|
|
||||||
self.running = True
|
# self.running = True
|
||||||
logger.info("Enhanced RL training system initialized successfully")
|
# logger.info("Enhanced RL training system initialized successfully")
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error during initialization: {e}")
|
# logger.error(f"Error during initialization: {e}")
|
||||||
raise
|
# raise
|
||||||
|
|
||||||
async def _verify_data_availability(self) -> Dict[str, any]:
|
# async def _verify_data_availability(self) -> Dict[str, any]:
|
||||||
"""Verify that we have sufficient data for training"""
|
# """Verify that we have sufficient data for training"""
|
||||||
try:
|
# try:
|
||||||
data_status = {
|
# data_status = {
|
||||||
'has_sufficient_data': False,
|
# 'has_sufficient_data': False,
|
||||||
'tick_count': 0,
|
# 'tick_count': 0,
|
||||||
'ohlcv_bars': 0,
|
# 'ohlcv_bars': 0,
|
||||||
'symbols_with_data': [],
|
# 'symbols_with_data': [],
|
||||||
'missing_data': []
|
# 'missing_data': []
|
||||||
}
|
# }
|
||||||
|
|
||||||
for symbol in self.config.symbols:
|
# for symbol in self.config.symbols:
|
||||||
# Check tick data
|
# # Check tick data
|
||||||
recent_ticks = self.data_provider.get_recent_ticks(symbol, count=100)
|
# recent_ticks = self.data_provider.get_recent_ticks(symbol, count=100)
|
||||||
tick_count = len(recent_ticks)
|
# tick_count = len(recent_ticks)
|
||||||
|
|
||||||
# Check OHLCV data
|
# # Check OHLCV data
|
||||||
ohlcv_bars = 0
|
# ohlcv_bars = 0
|
||||||
for timeframe in ['1s', '1m', '1h', '1d']:
|
# for timeframe in ['1s', '1m', '1h', '1d']:
|
||||||
try:
|
# try:
|
||||||
df = self.data_provider.get_historical_data(
|
# df = self.data_provider.get_historical_data(
|
||||||
symbol=symbol,
|
# symbol=symbol,
|
||||||
timeframe=timeframe,
|
# timeframe=timeframe,
|
||||||
limit=50,
|
# limit=50,
|
||||||
refresh=True
|
# refresh=True
|
||||||
)
|
# )
|
||||||
if df is not None and not df.empty:
|
# if df is not None and not df.empty:
|
||||||
ohlcv_bars += len(df)
|
# ohlcv_bars += len(df)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.warning(f"Error checking {timeframe} data for {symbol}: {e}")
|
# logger.warning(f"Error checking {timeframe} data for {symbol}: {e}")
|
||||||
|
|
||||||
data_status['tick_count'] += tick_count
|
# data_status['tick_count'] += tick_count
|
||||||
data_status['ohlcv_bars'] += ohlcv_bars
|
# data_status['ohlcv_bars'] += ohlcv_bars
|
||||||
|
|
||||||
if tick_count >= 50 and ohlcv_bars >= 100:
|
# if tick_count >= 50 and ohlcv_bars >= 100:
|
||||||
data_status['symbols_with_data'].append(symbol)
|
# data_status['symbols_with_data'].append(symbol)
|
||||||
else:
|
# else:
|
||||||
data_status['missing_data'].append(f"{symbol}: {tick_count} ticks, {ohlcv_bars} bars")
|
# data_status['missing_data'].append(f"{symbol}: {tick_count} ticks, {ohlcv_bars} bars")
|
||||||
|
|
||||||
# Consider data sufficient if we have at least one symbol with good data
|
# # Consider data sufficient if we have at least one symbol with good data
|
||||||
data_status['has_sufficient_data'] = len(data_status['symbols_with_data']) > 0
|
# data_status['has_sufficient_data'] = len(data_status['symbols_with_data']) > 0
|
||||||
|
|
||||||
return data_status
|
# return data_status
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error verifying data availability: {e}")
|
# logger.error(f"Error verifying data availability: {e}")
|
||||||
return {'has_sufficient_data': False, 'error': str(e)}
|
# return {'has_sufficient_data': False, 'error': str(e)}
|
||||||
|
|
||||||
async def run_training_loop(self):
|
# async def run_training_loop(self):
|
||||||
"""Run the main training loop with real data"""
|
# """Run the main training loop with real data"""
|
||||||
logger.info("Starting enhanced RL training loop...")
|
# logger.info("Starting enhanced RL training loop...")
|
||||||
|
|
||||||
training_cycle = 0
|
# training_cycle = 0
|
||||||
last_state_size_log = time.time()
|
# last_state_size_log = time.time()
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
while self.running:
|
# while self.running:
|
||||||
training_cycle += 1
|
# training_cycle += 1
|
||||||
cycle_start_time = time.time()
|
# cycle_start_time = time.time()
|
||||||
|
|
||||||
logger.info(f"Training cycle {training_cycle} started")
|
# logger.info(f"Training cycle {training_cycle} started")
|
||||||
|
|
||||||
# Get comprehensive market states with real data
|
# # Get comprehensive market states with real data
|
||||||
market_states = await self._get_comprehensive_market_states()
|
# market_states = await self._get_comprehensive_market_states()
|
||||||
|
|
||||||
if not market_states:
|
# if not market_states:
|
||||||
logger.warning("No market states available. Waiting for data...")
|
# logger.warning("No market states available. Waiting for data...")
|
||||||
await asyncio.sleep(60)
|
# await asyncio.sleep(60)
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
# Train RL agents with comprehensive states
|
# # Train RL agents with comprehensive states
|
||||||
training_results = await self._train_rl_agents(market_states)
|
# training_results = await self._train_rl_agents(market_states)
|
||||||
|
|
||||||
# Update performance tracking
|
# # Update performance tracking
|
||||||
self._update_training_stats(training_results, market_states)
|
# self._update_training_stats(training_results, market_states)
|
||||||
|
|
||||||
# Log training progress
|
# # Log training progress
|
||||||
cycle_duration = time.time() - cycle_start_time
|
# cycle_duration = time.time() - cycle_start_time
|
||||||
logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
|
# logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
|
||||||
|
|
||||||
# Log state size periodically
|
# # Log state size periodically
|
||||||
if time.time() - last_state_size_log > 300: # Every 5 minutes
|
# if time.time() - last_state_size_log > 300: # Every 5 minutes
|
||||||
self._log_state_size_info(market_states)
|
# self._log_state_size_info(market_states)
|
||||||
last_state_size_log = time.time()
|
# last_state_size_log = time.time()
|
||||||
|
|
||||||
# Save models periodically
|
# # Save models periodically
|
||||||
if training_cycle % 10 == 0:
|
# if training_cycle % 10 == 0:
|
||||||
await self._save_training_progress()
|
# await self._save_training_progress()
|
||||||
|
|
||||||
# Wait before next training cycle
|
# # Wait before next training cycle
|
||||||
await asyncio.sleep(300) # Train every 5 minutes
|
# await asyncio.sleep(300) # Train every 5 minutes
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error in training loop: {e}")
|
# logger.error(f"Error in training loop: {e}")
|
||||||
raise
|
# raise
|
||||||
|
|
||||||
async def _get_comprehensive_market_states(self) -> Dict[str, any]:
|
# async def _get_comprehensive_market_states(self) -> Dict[str, any]:
|
||||||
"""Get comprehensive market states with all required data"""
|
# """Get comprehensive market states with all required data"""
|
||||||
try:
|
# try:
|
||||||
# Get market states from orchestrator
|
# # Get market states from orchestrator
|
||||||
universal_stream = self.orchestrator.universal_adapter.get_universal_stream()
|
# universal_stream = self.orchestrator.universal_adapter.get_universal_stream()
|
||||||
market_states = await self.orchestrator._get_all_market_states_universal(universal_stream)
|
# market_states = await self.orchestrator._get_all_market_states_universal(universal_stream)
|
||||||
|
|
||||||
# Verify data quality
|
# # Verify data quality
|
||||||
quality_score = self._calculate_data_quality(market_states)
|
# quality_score = self._calculate_data_quality(market_states)
|
||||||
self.training_stats['data_quality_score'] = quality_score
|
# self.training_stats['data_quality_score'] = quality_score
|
||||||
|
|
||||||
if quality_score < 0.5:
|
# if quality_score < 0.5:
|
||||||
logger.warning(f"Low data quality detected: {quality_score:.2f}")
|
# logger.warning(f"Low data quality detected: {quality_score:.2f}")
|
||||||
|
|
||||||
return market_states
|
# return market_states
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error getting comprehensive market states: {e}")
|
# logger.error(f"Error getting comprehensive market states: {e}")
|
||||||
return {}
|
# return {}
|
||||||
|
|
||||||
def _calculate_data_quality(self, market_states: Dict[str, any]) -> float:
|
# def _calculate_data_quality(self, market_states: Dict[str, any]) -> float:
|
||||||
"""Calculate data quality score based on available data"""
|
# """Calculate data quality score based on available data"""
|
||||||
try:
|
# try:
|
||||||
if not market_states:
|
# if not market_states:
|
||||||
return 0.0
|
# return 0.0
|
||||||
|
|
||||||
total_score = 0.0
|
# total_score = 0.0
|
||||||
total_symbols = len(market_states)
|
# total_symbols = len(market_states)
|
||||||
|
|
||||||
for symbol, state in market_states.items():
|
# for symbol, state in market_states.items():
|
||||||
symbol_score = 0.0
|
# symbol_score = 0.0
|
||||||
|
|
||||||
# Score based on tick data availability
|
# # Score based on tick data availability
|
||||||
if hasattr(state, 'raw_ticks') and state.raw_ticks:
|
# if hasattr(state, 'raw_ticks') and state.raw_ticks:
|
||||||
tick_score = min(len(state.raw_ticks) / 100, 1.0) # Max score for 100+ ticks
|
# tick_score = min(len(state.raw_ticks) / 100, 1.0) # Max score for 100+ ticks
|
||||||
symbol_score += tick_score * 0.3
|
# symbol_score += tick_score * 0.3
|
||||||
|
|
||||||
# Score based on OHLCV data availability
|
# # Score based on OHLCV data availability
|
||||||
if hasattr(state, 'ohlcv_data') and state.ohlcv_data:
|
# if hasattr(state, 'ohlcv_data') and state.ohlcv_data:
|
||||||
ohlcv_score = len(state.ohlcv_data) / 4.0 # Max score for all 4 timeframes
|
# ohlcv_score = len(state.ohlcv_data) / 4.0 # Max score for all 4 timeframes
|
||||||
symbol_score += min(ohlcv_score, 1.0) * 0.4
|
# symbol_score += min(ohlcv_score, 1.0) * 0.4
|
||||||
|
|
||||||
# Score based on CNN features
|
# # Score based on CNN features
|
||||||
if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
|
# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
|
||||||
symbol_score += 0.15
|
# symbol_score += 0.15
|
||||||
|
|
||||||
# Score based on pivot points
|
# # Score based on pivot points
|
||||||
if hasattr(state, 'pivot_points') and state.pivot_points:
|
# if hasattr(state, 'pivot_points') and state.pivot_points:
|
||||||
symbol_score += 0.15
|
# symbol_score += 0.15
|
||||||
|
|
||||||
total_score += symbol_score
|
# total_score += symbol_score
|
||||||
|
|
||||||
return total_score / total_symbols if total_symbols > 0 else 0.0
|
# return total_score / total_symbols if total_symbols > 0 else 0.0
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.warning(f"Error calculating data quality: {e}")
|
# logger.warning(f"Error calculating data quality: {e}")
|
||||||
return 0.5 # Default to medium quality
|
# return 0.5 # Default to medium quality
|
||||||
|
|
||||||
async def _train_rl_agents(self, market_states: Dict[str, any]) -> Dict[str, any]:
|
# async def _train_rl_agents(self, market_states: Dict[str, any]) -> Dict[str, any]:
|
||||||
"""Train RL agents with comprehensive market states"""
|
# """Train RL agents with comprehensive market states"""
|
||||||
try:
|
# try:
|
||||||
training_results = {
|
# training_results = {
|
||||||
'symbols_trained': [],
|
# 'symbols_trained': [],
|
||||||
'total_experiences': 0,
|
# 'total_experiences': 0,
|
||||||
'avg_state_size': 0,
|
# 'avg_state_size': 0,
|
||||||
'training_errors': []
|
# 'training_errors': []
|
||||||
}
|
# }
|
||||||
|
|
||||||
for symbol, market_state in market_states.items():
|
# for symbol, market_state in market_states.items():
|
||||||
try:
|
# try:
|
||||||
# Convert market state to comprehensive RL state
|
# # Convert market state to comprehensive RL state
|
||||||
rl_state = self.rl_trainer._market_state_to_rl_state(market_state)
|
# rl_state = self.rl_trainer._market_state_to_rl_state(market_state)
|
||||||
|
|
||||||
if rl_state is not None and len(rl_state) > 0:
|
# if rl_state is not None and len(rl_state) > 0:
|
||||||
# Record state size
|
# # Record state size
|
||||||
training_results['avg_state_size'] += len(rl_state)
|
# training_results['avg_state_size'] += len(rl_state)
|
||||||
|
|
||||||
# Simulate trading action for experience generation
|
# # Simulate trading action for experience generation
|
||||||
# In real implementation, this would be actual trading decisions
|
# # In real implementation, this would be actual trading decisions
|
||||||
action = self._simulate_trading_action(symbol, rl_state)
|
# action = self._simulate_trading_action(symbol, rl_state)
|
||||||
|
|
||||||
# Generate reward based on market outcome
|
# # Generate reward based on market outcome
|
||||||
reward = self._calculate_training_reward(symbol, market_state, action)
|
# reward = self._calculate_training_reward(symbol, market_state, action)
|
||||||
|
|
||||||
# Add experience to RL agent
|
# # Add experience to RL agent
|
||||||
agent = self.rl_trainer.agents.get(symbol)
|
# agent = self.rl_trainer.agents.get(symbol)
|
||||||
if agent:
|
# if agent:
|
||||||
# Create next state (would be actual next market state in real scenario)
|
# # Create next state (would be actual next market state in real scenario)
|
||||||
next_state = rl_state # Simplified for now
|
# next_state = rl_state # Simplified for now
|
||||||
|
|
||||||
agent.remember(
|
# agent.remember(
|
||||||
state=rl_state,
|
# state=rl_state,
|
||||||
action=action,
|
# action=action,
|
||||||
reward=reward,
|
# reward=reward,
|
||||||
next_state=next_state,
|
# next_state=next_state,
|
||||||
done=False
|
# done=False
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Train agent if enough experiences
|
# # Train agent if enough experiences
|
||||||
if len(agent.replay_buffer) >= agent.batch_size:
|
# if len(agent.replay_buffer) >= agent.batch_size:
|
||||||
loss = agent.replay()
|
# loss = agent.replay()
|
||||||
if loss is not None:
|
# if loss is not None:
|
||||||
logger.debug(f"Agent {symbol} training loss: {loss:.4f}")
|
# logger.debug(f"Agent {symbol} training loss: {loss:.4f}")
|
||||||
|
|
||||||
training_results['symbols_trained'].append(symbol)
|
# training_results['symbols_trained'].append(symbol)
|
||||||
training_results['total_experiences'] += 1
|
# training_results['total_experiences'] += 1
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
error_msg = f"Error training {symbol}: {e}"
|
# error_msg = f"Error training {symbol}: {e}"
|
||||||
logger.warning(error_msg)
|
# logger.warning(error_msg)
|
||||||
training_results['training_errors'].append(error_msg)
|
# training_results['training_errors'].append(error_msg)
|
||||||
|
|
||||||
# Calculate average state size
|
# # Calculate average state size
|
||||||
if len(training_results['symbols_trained']) > 0:
|
# if len(training_results['symbols_trained']) > 0:
|
||||||
training_results['avg_state_size'] /= len(training_results['symbols_trained'])
|
# training_results['avg_state_size'] /= len(training_results['symbols_trained'])
|
||||||
|
|
||||||
return training_results
|
# return training_results
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error training RL agents: {e}")
|
# logger.error(f"Error training RL agents: {e}")
|
||||||
return {'error': str(e)}
|
# return {'error': str(e)}
|
||||||
|
|
||||||
def _simulate_trading_action(self, symbol: str, rl_state) -> int:
|
# def _simulate_trading_action(self, symbol: str, rl_state) -> int:
|
||||||
"""Simulate trading action for training (would be real decision in production)"""
|
# """Simulate trading action for training (would be real decision in production)"""
|
||||||
# Simple simulation based on state features
|
# # Simple simulation based on state features
|
||||||
if len(rl_state) > 100:
|
# if len(rl_state) > 100:
|
||||||
# Use momentum features to decide action
|
# # Use momentum features to decide action
|
||||||
momentum_features = rl_state[:100] # First 100 features assumed to be momentum
|
# momentum_features = rl_state[:100] # First 100 features assumed to be momentum
|
||||||
avg_momentum = sum(momentum_features) / len(momentum_features)
|
# avg_momentum = sum(momentum_features) / len(momentum_features)
|
||||||
|
|
||||||
if avg_momentum > 0.6:
|
# if avg_momentum > 0.6:
|
||||||
return 1 # BUY
|
# return 1 # BUY
|
||||||
elif avg_momentum < 0.4:
|
# elif avg_momentum < 0.4:
|
||||||
return 2 # SELL
|
# return 2 # SELL
|
||||||
else:
|
# else:
|
||||||
return 0 # HOLD
|
# return 0 # HOLD
|
||||||
else:
|
# else:
|
||||||
return 0 # HOLD as default
|
# return 0 # HOLD as default
|
||||||
|
|
||||||
def _calculate_training_reward(self, symbol: str, market_state, action: int) -> float:
|
# def _calculate_training_reward(self, symbol: str, market_state, action: int) -> float:
|
||||||
"""Calculate training reward based on market state and action"""
|
# """Calculate training reward based on market state and action"""
|
||||||
try:
|
# try:
|
||||||
# Simple reward calculation based on market conditions
|
# # Simple reward calculation based on market conditions
|
||||||
base_reward = 0.0
|
# base_reward = 0.0
|
||||||
|
|
||||||
# Reward based on volatility alignment
|
# # Reward based on volatility alignment
|
||||||
if hasattr(market_state, 'volatility'):
|
# if hasattr(market_state, 'volatility'):
|
||||||
if action == 0 and market_state.volatility > 0.02: # HOLD in high volatility
|
# if action == 0 and market_state.volatility > 0.02: # HOLD in high volatility
|
||||||
base_reward += 0.1
|
# base_reward += 0.1
|
||||||
elif action != 0 and market_state.volatility < 0.01: # Trade in low volatility
|
# elif action != 0 and market_state.volatility < 0.01: # Trade in low volatility
|
||||||
base_reward += 0.1
|
# base_reward += 0.1
|
||||||
|
|
||||||
# Reward based on trend alignment
|
# # Reward based on trend alignment
|
||||||
if hasattr(market_state, 'trend_strength'):
|
# if hasattr(market_state, 'trend_strength'):
|
||||||
if action == 1 and market_state.trend_strength > 0.6: # BUY in uptrend
|
# if action == 1 and market_state.trend_strength > 0.6: # BUY in uptrend
|
||||||
base_reward += 0.2
|
# base_reward += 0.2
|
||||||
elif action == 2 and market_state.trend_strength < 0.4: # SELL in downtrend
|
# elif action == 2 and market_state.trend_strength < 0.4: # SELL in downtrend
|
||||||
base_reward += 0.2
|
# base_reward += 0.2
|
||||||
|
|
||||||
return base_reward
|
# return base_reward
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.warning(f"Error calculating reward for {symbol}: {e}")
|
# logger.warning(f"Error calculating reward for {symbol}: {e}")
|
||||||
return 0.0
|
# return 0.0
|
||||||
|
|
||||||
def _update_training_stats(self, training_results: Dict[str, any], market_states: Dict[str, any]):
|
# def _update_training_stats(self, training_results: Dict[str, any], market_states: Dict[str, any]):
|
||||||
"""Update training statistics"""
|
# """Update training statistics"""
|
||||||
self.training_stats['training_sessions'] += 1
|
# self.training_stats['training_sessions'] += 1
|
||||||
self.training_stats['total_experiences'] += training_results.get('total_experiences', 0)
|
# self.training_stats['total_experiences'] += training_results.get('total_experiences', 0)
|
||||||
self.training_stats['avg_state_size'] = training_results.get('avg_state_size', 0)
|
# self.training_stats['avg_state_size'] = training_results.get('avg_state_size', 0)
|
||||||
self.training_stats['last_training_time'] = datetime.now()
|
# self.training_stats['last_training_time'] = datetime.now()
|
||||||
|
|
||||||
# Log statistics periodically
|
# # Log statistics periodically
|
||||||
if self.training_stats['training_sessions'] % 10 == 0:
|
# if self.training_stats['training_sessions'] % 10 == 0:
|
||||||
logger.info("Training Statistics:")
|
# logger.info("Training Statistics:")
|
||||||
logger.info(f" Sessions: {self.training_stats['training_sessions']}")
|
# logger.info(f" Sessions: {self.training_stats['training_sessions']}")
|
||||||
logger.info(f" Total Experiences: {self.training_stats['total_experiences']}")
|
# logger.info(f" Total Experiences: {self.training_stats['total_experiences']}")
|
||||||
logger.info(f" Avg State Size: {self.training_stats['avg_state_size']:.0f}")
|
# logger.info(f" Avg State Size: {self.training_stats['avg_state_size']:.0f}")
|
||||||
logger.info(f" Data Quality: {self.training_stats['data_quality_score']:.2f}")
|
# logger.info(f" Data Quality: {self.training_stats['data_quality_score']:.2f}")
|
||||||
|
|
||||||
def _log_state_size_info(self, market_states: Dict[str, any]):
|
# def _log_state_size_info(self, market_states: Dict[str, any]):
|
||||||
"""Log information about state sizes for debugging"""
|
# """Log information about state sizes for debugging"""
|
||||||
for symbol, state in market_states.items():
|
# for symbol, state in market_states.items():
|
||||||
info = []
|
# info = []
|
||||||
|
|
||||||
if hasattr(state, 'raw_ticks'):
|
# if hasattr(state, 'raw_ticks'):
|
||||||
info.append(f"ticks: {len(state.raw_ticks)}")
|
# info.append(f"ticks: {len(state.raw_ticks)}")
|
||||||
|
|
||||||
if hasattr(state, 'ohlcv_data'):
|
# if hasattr(state, 'ohlcv_data'):
|
||||||
total_bars = sum(len(bars) for bars in state.ohlcv_data.values())
|
# total_bars = sum(len(bars) for bars in state.ohlcv_data.values())
|
||||||
info.append(f"OHLCV bars: {total_bars}")
|
# info.append(f"OHLCV bars: {total_bars}")
|
||||||
|
|
||||||
if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
|
# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
|
||||||
info.append("CNN features: available")
|
# info.append("CNN features: available")
|
||||||
|
|
||||||
if hasattr(state, 'pivot_points') and state.pivot_points:
|
# if hasattr(state, 'pivot_points') and state.pivot_points:
|
||||||
info.append("pivot points: available")
|
# info.append("pivot points: available")
|
||||||
|
|
||||||
logger.info(f"{symbol} state data: {', '.join(info)}")
|
# logger.info(f"{symbol} state data: {', '.join(info)}")
|
||||||
|
|
||||||
async def _save_training_progress(self):
|
# async def _save_training_progress(self):
|
||||||
"""Save training progress and models"""
|
# """Save training progress and models"""
|
||||||
try:
|
# try:
|
||||||
if self.rl_trainer:
|
# if self.rl_trainer:
|
||||||
self.rl_trainer._save_all_models()
|
# self.rl_trainer._save_all_models()
|
||||||
logger.info("Training progress saved")
|
# logger.info("Training progress saved")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error saving training progress: {e}")
|
# logger.error(f"Error saving training progress: {e}")
|
||||||
|
|
||||||
async def shutdown(self):
|
# async def shutdown(self):
|
||||||
"""Graceful shutdown"""
|
# """Graceful shutdown"""
|
||||||
logger.info("Shutting down enhanced RL training system...")
|
# logger.info("Shutting down enhanced RL training system...")
|
||||||
self.running = False
|
# self.running = False
|
||||||
|
|
||||||
# Save final state
|
# # Save final state
|
||||||
await self._save_training_progress()
|
# await self._save_training_progress()
|
||||||
|
|
||||||
# Stop data provider
|
# # Stop data provider
|
||||||
if self.data_provider:
|
# if self.data_provider:
|
||||||
await self.data_provider.stop_real_time_streaming()
|
# await self.data_provider.stop_real_time_streaming()
|
||||||
|
|
||||||
logger.info("Enhanced RL training system shutdown complete")
|
# logger.info("Enhanced RL training system shutdown complete")
|
||||||
|
|
||||||
async def main():
|
# async def main():
|
||||||
"""Main function to run enhanced RL training"""
|
# """Main function to run enhanced RL training"""
|
||||||
system = None
|
# system = None
|
||||||
|
|
||||||
def signal_handler(signum, frame):
|
# def signal_handler(signum, frame):
|
||||||
logger.info("Received shutdown signal")
|
# logger.info("Received shutdown signal")
|
||||||
if system:
|
# if system:
|
||||||
asyncio.create_task(system.shutdown())
|
# asyncio.create_task(system.shutdown())
|
||||||
|
|
||||||
# Set up signal handlers
|
# # Set up signal handlers
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
# signal.signal(signal.SIGINT, signal_handler)
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
# signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
# Create and initialize the training system
|
# # Create and initialize the training system
|
||||||
system = EnhancedRLTrainingSystem()
|
# system = EnhancedRLTrainingSystem()
|
||||||
await system.initialize()
|
# await system.initialize()
|
||||||
|
|
||||||
logger.info("Enhanced RL Training System is now running...")
|
# logger.info("Enhanced RL Training System is now running...")
|
||||||
logger.info("The RL model now receives ~13,400 features instead of ~100!")
|
# logger.info("The RL model now receives ~13,400 features instead of ~100!")
|
||||||
logger.info("Press Ctrl+C to stop")
|
# logger.info("Press Ctrl+C to stop")
|
||||||
|
|
||||||
# Run the training loop
|
# # Run the training loop
|
||||||
await system.run_training_loop()
|
# await system.run_training_loop()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
# except KeyboardInterrupt:
|
||||||
logger.info("Training interrupted by user")
|
# logger.info("Training interrupted by user")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error in main training loop: {e}")
|
# logger.error(f"Error in main training loop: {e}")
|
||||||
raise
|
# raise
|
||||||
finally:
|
# finally:
|
||||||
if system:
|
# if system:
|
||||||
await system.shutdown()
|
# await system.shutdown()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
# asyncio.run(main())
|
@ -1,112 +1,112 @@
|
|||||||
#!/usr/bin/env python3
|
# #!/usr/bin/env python3
|
||||||
"""
|
# """
|
||||||
Enhanced Scalping Dashboard Launcher
|
# Enhanced Scalping Dashboard Launcher
|
||||||
|
|
||||||
Features:
|
# Features:
|
||||||
- 1-second OHLCV bar charts instead of tick points
|
# - 1-second OHLCV bar charts instead of tick points
|
||||||
- 15-minute server-side tick cache for model training
|
# - 15-minute server-side tick cache for model training
|
||||||
- Enhanced volume visualization with buy/sell separation
|
# - Enhanced volume visualization with buy/sell separation
|
||||||
- Ultra-low latency WebSocket streaming
|
# - Ultra-low latency WebSocket streaming
|
||||||
- Real-time candle aggregation from tick data
|
# - Real-time candle aggregation from tick data
|
||||||
"""
|
# """
|
||||||
|
|
||||||
import sys
|
# import sys
|
||||||
import logging
|
# import logging
|
||||||
import argparse
|
# import argparse
|
||||||
from pathlib import Path
|
# from pathlib import Path
|
||||||
|
|
||||||
# Add project root to path
|
# # Add project root to path
|
||||||
project_root = Path(__file__).parent
|
# project_root = Path(__file__).parent
|
||||||
sys.path.insert(0, str(project_root))
|
# sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
from web.enhanced_scalping_dashboard import EnhancedScalpingDashboard
|
# from web.enhanced_scalping_dashboard import EnhancedScalpingDashboard
|
||||||
from core.data_provider import DataProvider
|
# from core.data_provider import DataProvider
|
||||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
# from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||||
|
|
||||||
def setup_logging(level: str = "INFO"):
|
# def setup_logging(level: str = "INFO"):
|
||||||
"""Setup logging configuration"""
|
# """Setup logging configuration"""
|
||||||
log_level = getattr(logging, level.upper(), logging.INFO)
|
# log_level = getattr(logging, level.upper(), logging.INFO)
|
||||||
|
|
||||||
logging.basicConfig(
|
# logging.basicConfig(
|
||||||
level=log_level,
|
# level=log_level,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
handlers=[
|
# handlers=[
|
||||||
logging.StreamHandler(sys.stdout),
|
# logging.StreamHandler(sys.stdout),
|
||||||
logging.FileHandler('logs/enhanced_dashboard.log', mode='a')
|
# logging.FileHandler('logs/enhanced_dashboard.log', mode='a')
|
||||||
]
|
# ]
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Reduce noise from external libraries
|
# # Reduce noise from external libraries
|
||||||
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
# logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||||
logging.getLogger('requests').setLevel(logging.WARNING)
|
# logging.getLogger('requests').setLevel(logging.WARNING)
|
||||||
logging.getLogger('websockets').setLevel(logging.WARNING)
|
# logging.getLogger('websockets').setLevel(logging.WARNING)
|
||||||
|
|
||||||
def main():
|
# def main():
|
||||||
"""Main function to launch enhanced scalping dashboard"""
|
# """Main function to launch enhanced scalping dashboard"""
|
||||||
parser = argparse.ArgumentParser(description='Enhanced Scalping Dashboard with 1s Bars and 15min Cache')
|
# parser = argparse.ArgumentParser(description='Enhanced Scalping Dashboard with 1s Bars and 15min Cache')
|
||||||
parser.add_argument('--host', default='127.0.0.1', help='Host to bind to (default: 127.0.0.1)')
|
# parser.add_argument('--host', default='127.0.0.1', help='Host to bind to (default: 127.0.0.1)')
|
||||||
parser.add_argument('--port', type=int, default=8051, help='Port to bind to (default: 8051)')
|
# parser.add_argument('--port', type=int, default=8051, help='Port to bind to (default: 8051)')
|
||||||
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
|
# parser.add_argument('--debug', action='store_true', help='Enable debug mode')
|
||||||
parser.add_argument('--log-level', default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
|
# parser.add_argument('--log-level', default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
|
||||||
help='Logging level (default: INFO)')
|
# help='Logging level (default: INFO)')
|
||||||
|
|
||||||
args = parser.parse_args()
|
# args = parser.parse_args()
|
||||||
|
|
||||||
# Setup logging
|
# # Setup logging
|
||||||
setup_logging(args.log_level)
|
# setup_logging(args.log_level)
|
||||||
logger = logging.getLogger(__name__)
|
# logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
logger.info("=" * 80)
|
# logger.info("=" * 80)
|
||||||
logger.info("ENHANCED SCALPING DASHBOARD STARTUP")
|
# logger.info("ENHANCED SCALPING DASHBOARD STARTUP")
|
||||||
logger.info("=" * 80)
|
# logger.info("=" * 80)
|
||||||
logger.info("Features:")
|
# logger.info("Features:")
|
||||||
logger.info(" - 1-second OHLCV bar charts (instead of tick points)")
|
# logger.info(" - 1-second OHLCV bar charts (instead of tick points)")
|
||||||
logger.info(" - 15-minute server-side tick cache for model training")
|
# logger.info(" - 15-minute server-side tick cache for model training")
|
||||||
logger.info(" - Enhanced volume visualization with buy/sell separation")
|
# logger.info(" - Enhanced volume visualization with buy/sell separation")
|
||||||
logger.info(" - Ultra-low latency WebSocket streaming")
|
# logger.info(" - Ultra-low latency WebSocket streaming")
|
||||||
logger.info(" - Real-time candle aggregation from tick data")
|
# logger.info(" - Real-time candle aggregation from tick data")
|
||||||
logger.info("=" * 80)
|
# logger.info("=" * 80)
|
||||||
|
|
||||||
# Initialize core components
|
# # Initialize core components
|
||||||
logger.info("Initializing data provider...")
|
# logger.info("Initializing data provider...")
|
||||||
data_provider = DataProvider()
|
# data_provider = DataProvider()
|
||||||
|
|
||||||
logger.info("Initializing enhanced trading orchestrator...")
|
# logger.info("Initializing enhanced trading orchestrator...")
|
||||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
# orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||||
|
|
||||||
# Create enhanced dashboard
|
# # Create enhanced dashboard
|
||||||
logger.info("Creating enhanced scalping dashboard...")
|
# logger.info("Creating enhanced scalping dashboard...")
|
||||||
dashboard = EnhancedScalpingDashboard(
|
# dashboard = EnhancedScalpingDashboard(
|
||||||
data_provider=data_provider,
|
# data_provider=data_provider,
|
||||||
orchestrator=orchestrator
|
# orchestrator=orchestrator
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Launch dashboard
|
# # Launch dashboard
|
||||||
logger.info(f"Launching dashboard at http://{args.host}:{args.port}")
|
# logger.info(f"Launching dashboard at http://{args.host}:{args.port}")
|
||||||
logger.info("Dashboard Features:")
|
# logger.info("Dashboard Features:")
|
||||||
logger.info(" - Main chart: ETH/USDT 1s OHLCV bars with volume subplot")
|
# logger.info(" - Main chart: ETH/USDT 1s OHLCV bars with volume subplot")
|
||||||
logger.info(" - Secondary chart: BTC/USDT 1s bars")
|
# logger.info(" - Secondary chart: BTC/USDT 1s bars")
|
||||||
logger.info(" - Volume analysis: Real-time volume comparison")
|
# logger.info(" - Volume analysis: Real-time volume comparison")
|
||||||
logger.info(" - Tick cache: 15-minute rolling window for model training")
|
# logger.info(" - Tick cache: 15-minute rolling window for model training")
|
||||||
logger.info(" - Trading session: $100 starting balance with P&L tracking")
|
# logger.info(" - Trading session: $100 starting balance with P&L tracking")
|
||||||
logger.info(" - System performance: Real-time callback monitoring")
|
# logger.info(" - System performance: Real-time callback monitoring")
|
||||||
logger.info("=" * 80)
|
# logger.info("=" * 80)
|
||||||
|
|
||||||
dashboard.run(
|
# dashboard.run(
|
||||||
host=args.host,
|
# host=args.host,
|
||||||
port=args.port,
|
# port=args.port,
|
||||||
debug=args.debug
|
# debug=args.debug
|
||||||
)
|
# )
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
# except KeyboardInterrupt:
|
||||||
logger.info("Dashboard stopped by user (Ctrl+C)")
|
# logger.info("Dashboard stopped by user (Ctrl+C)")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error running enhanced dashboard: {e}")
|
# logger.error(f"Error running enhanced dashboard: {e}")
|
||||||
logger.exception("Full traceback:")
|
# logger.exception("Full traceback:")
|
||||||
sys.exit(1)
|
# sys.exit(1)
|
||||||
finally:
|
# finally:
|
||||||
logger.info("Enhanced Scalping Dashboard shutdown complete")
|
# logger.info("Enhanced Scalping Dashboard shutdown complete")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
main()
|
# main()
|
||||||
|
@ -1,35 +1,35 @@
|
|||||||
#!/usr/bin/env python3
|
# #!/usr/bin/env python3
|
||||||
"""
|
# """
|
||||||
Enhanced Trading System Launcher
|
# Enhanced Trading System Launcher
|
||||||
Quick launcher for the enhanced multi-modal trading system
|
# Quick launcher for the enhanced multi-modal trading system
|
||||||
"""
|
# """
|
||||||
|
|
||||||
import asyncio
|
# import asyncio
|
||||||
import sys
|
# import sys
|
||||||
from pathlib import Path
|
# from pathlib import Path
|
||||||
|
|
||||||
# Add project root to path
|
# # Add project root to path
|
||||||
project_root = Path(__file__).parent
|
# project_root = Path(__file__).parent
|
||||||
sys.path.insert(0, str(project_root))
|
# sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
from enhanced_trading_main import main
|
# from enhanced_trading_main import main
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
print("🚀 Launching Enhanced Multi-Modal Trading System...")
|
# print("🚀 Launching Enhanced Multi-Modal Trading System...")
|
||||||
print("📊 Features Active:")
|
# print("📊 Features Active:")
|
||||||
print(" - RL agents learning from every trading decision")
|
# print(" - RL agents learning from every trading decision")
|
||||||
print(" - CNN training on perfect moves with known outcomes")
|
# print(" - CNN training on perfect moves with known outcomes")
|
||||||
print(" - Multi-timeframe pattern recognition")
|
# print(" - Multi-timeframe pattern recognition")
|
||||||
print(" - Real-time market adaptation")
|
# print(" - Real-time market adaptation")
|
||||||
print(" - Performance monitoring and tracking")
|
# print(" - Performance monitoring and tracking")
|
||||||
print()
|
# print()
|
||||||
print("Press Ctrl+C to stop the system gracefully")
|
# print("Press Ctrl+C to stop the system gracefully")
|
||||||
print("=" * 60)
|
# print("=" * 60)
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
asyncio.run(main())
|
# asyncio.run(main())
|
||||||
except KeyboardInterrupt:
|
# except KeyboardInterrupt:
|
||||||
print("\n🛑 System stopped by user")
|
# print("\n🛑 System stopped by user")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
print(f"\n❌ System error: {e}")
|
# print(f"\n❌ System error: {e}")
|
||||||
sys.exit(1)
|
# sys.exit(1)
|
@ -1,37 +1,37 @@
|
|||||||
#!/usr/bin/env python3
|
# #!/usr/bin/env python3
|
||||||
"""
|
# """
|
||||||
Run Fixed Scalping Dashboard
|
# Run Fixed Scalping Dashboard
|
||||||
"""
|
# """
|
||||||
|
|
||||||
import logging
|
# import logging
|
||||||
import sys
|
# import sys
|
||||||
import os
|
# import os
|
||||||
|
|
||||||
# Add project root to path
|
# # Add project root to path
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
# sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
# Setup logging
|
# # Setup logging
|
||||||
logging.basicConfig(
|
# logging.basicConfig(
|
||||||
level=logging.INFO,
|
# level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
)
|
# )
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
# logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def main():
|
# def main():
|
||||||
"""Run the enhanced scalping dashboard"""
|
# """Run the enhanced scalping dashboard"""
|
||||||
try:
|
# try:
|
||||||
logger.info("Starting Enhanced Scalping Dashboard...")
|
# logger.info("Starting Enhanced Scalping Dashboard...")
|
||||||
|
|
||||||
from web.old_archived.scalping_dashboard import create_scalping_dashboard
|
# from web.old_archived.scalping_dashboard import create_scalping_dashboard
|
||||||
|
|
||||||
dashboard = create_scalping_dashboard()
|
# dashboard = create_scalping_dashboard()
|
||||||
dashboard.run(host='127.0.0.1', port=8051, debug=True)
|
# dashboard.run(host='127.0.0.1', port=8051, debug=True)
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error starting dashboard: {e}")
|
# logger.error(f"Error starting dashboard: {e}")
|
||||||
import traceback
|
# import traceback
|
||||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
# logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
main()
|
# main()
|
@ -1,75 +1,75 @@
|
|||||||
#!/usr/bin/env python3
|
# #!/usr/bin/env python3
|
||||||
"""
|
# """
|
||||||
Run Ultra-Fast Scalping Dashboard (500x Leverage)
|
# Run Ultra-Fast Scalping Dashboard (500x Leverage)
|
||||||
|
|
||||||
This script starts the custom scalping dashboard with:
|
# This script starts the custom scalping dashboard with:
|
||||||
- Full-width 1s ETH/USDT candlestick chart
|
# - Full-width 1s ETH/USDT candlestick chart
|
||||||
- 3 small ETH charts: 1m, 1h, 1d
|
# - 3 small ETH charts: 1m, 1h, 1d
|
||||||
- 1 small BTC 1s chart
|
# - 1 small BTC 1s chart
|
||||||
- Ultra-fast 100ms updates for scalping
|
# - Ultra-fast 100ms updates for scalping
|
||||||
- Real-time PnL tracking and logging
|
# - Real-time PnL tracking and logging
|
||||||
- Enhanced orchestrator with real AI model decisions
|
# - Enhanced orchestrator with real AI model decisions
|
||||||
"""
|
# """
|
||||||
|
|
||||||
import argparse
|
# import argparse
|
||||||
import logging
|
# import logging
|
||||||
import sys
|
# import sys
|
||||||
from pathlib import Path
|
# from pathlib import Path
|
||||||
|
|
||||||
# Add project root to path
|
# # Add project root to path
|
||||||
project_root = Path(__file__).parent
|
# project_root = Path(__file__).parent
|
||||||
sys.path.insert(0, str(project_root))
|
# sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
from core.config import setup_logging
|
# from core.config import setup_logging
|
||||||
from core.data_provider import DataProvider
|
# from core.data_provider import DataProvider
|
||||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
# from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||||
from web.old_archived.scalping_dashboard import create_scalping_dashboard
|
# from web.old_archived.scalping_dashboard import create_scalping_dashboard
|
||||||
|
|
||||||
# Setup logging
|
# # Setup logging
|
||||||
setup_logging()
|
# setup_logging()
|
||||||
logger = logging.getLogger(__name__)
|
# logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def main():
|
# def main():
|
||||||
"""Main function for scalping dashboard"""
|
# """Main function for scalping dashboard"""
|
||||||
# Parse command line arguments
|
# # Parse command line arguments
|
||||||
parser = argparse.ArgumentParser(description='Ultra-Fast Scalping Dashboard (500x Leverage)')
|
# parser = argparse.ArgumentParser(description='Ultra-Fast Scalping Dashboard (500x Leverage)')
|
||||||
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes (for compatibility)')
|
# parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes (for compatibility)')
|
||||||
parser.add_argument('--max-position', type=float, default=0.1, help='Maximum position size')
|
# parser.add_argument('--max-position', type=float, default=0.1, help='Maximum position size')
|
||||||
parser.add_argument('--leverage', type=int, default=500, help='Leverage multiplier')
|
# parser.add_argument('--leverage', type=int, default=500, help='Leverage multiplier')
|
||||||
parser.add_argument('--port', type=int, default=8051, help='Dashboard port')
|
# parser.add_argument('--port', type=int, default=8051, help='Dashboard port')
|
||||||
parser.add_argument('--host', type=str, default='127.0.0.1', help='Dashboard host')
|
# parser.add_argument('--host', type=str, default='127.0.0.1', help='Dashboard host')
|
||||||
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
|
# parser.add_argument('--debug', action='store_true', help='Enable debug mode')
|
||||||
|
|
||||||
args = parser.parse_args()
|
# args = parser.parse_args()
|
||||||
|
|
||||||
logger.info("STARTING SCALPING DASHBOARD")
|
# logger.info("STARTING SCALPING DASHBOARD")
|
||||||
logger.info("Session-based trading with $100 starting balance")
|
# logger.info("Session-based trading with $100 starting balance")
|
||||||
logger.info(f"Configuration: Leverage={args.leverage}x, Max Position={args.max_position}, Port={args.port}")
|
# logger.info(f"Configuration: Leverage={args.leverage}x, Max Position={args.max_position}, Port={args.port}")
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
# Initialize components
|
# # Initialize components
|
||||||
logger.info("Initializing data provider...")
|
# logger.info("Initializing data provider...")
|
||||||
data_provider = DataProvider()
|
# data_provider = DataProvider()
|
||||||
|
|
||||||
logger.info("Initializing trading orchestrator...")
|
# logger.info("Initializing trading orchestrator...")
|
||||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
# orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||||
|
|
||||||
logger.info("LAUNCHING DASHBOARD")
|
# logger.info("LAUNCHING DASHBOARD")
|
||||||
logger.info(f"Dashboard will be available at http://{args.host}:{args.port}")
|
# logger.info(f"Dashboard will be available at http://{args.host}:{args.port}")
|
||||||
|
|
||||||
# Start the dashboard
|
# # Start the dashboard
|
||||||
dashboard = create_scalping_dashboard(data_provider, orchestrator)
|
# dashboard = create_scalping_dashboard(data_provider, orchestrator)
|
||||||
dashboard.run(host=args.host, port=args.port, debug=args.debug)
|
# dashboard.run(host=args.host, port=args.port, debug=args.debug)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
# except KeyboardInterrupt:
|
||||||
logger.info("Dashboard stopped by user")
|
# logger.info("Dashboard stopped by user")
|
||||||
return 0
|
# return 0
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"ERROR: {e}")
|
# logger.error(f"ERROR: {e}")
|
||||||
import traceback
|
# import traceback
|
||||||
traceback.print_exc()
|
# traceback.print_exc()
|
||||||
return 1
|
# return 1
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
exit_code = main()
|
# exit_code = main()
|
||||||
sys.exit(exit_code if exit_code else 0)
|
# sys.exit(exit_code if exit_code else 0)
|
176
test_leverage_slider.py
Normal file
176
test_leverage_slider.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test Leverage Slider Functionality
|
||||||
|
|
||||||
|
This script tests the leverage slider integration in the dashboard:
|
||||||
|
- Verifies slider range (1x to 100x)
|
||||||
|
- Tests risk level calculation
|
||||||
|
- Checks leverage multiplier updates
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from core.config import setup_logging
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||||
|
from web.dashboard import TradingDashboard
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def test_leverage_calculations():
|
||||||
|
"""Test leverage risk calculations"""
|
||||||
|
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info("TESTING LEVERAGE CALCULATIONS")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
{'leverage': 1, 'expected_risk': 'Low Risk'},
|
||||||
|
{'leverage': 5, 'expected_risk': 'Low Risk'},
|
||||||
|
{'leverage': 10, 'expected_risk': 'Medium Risk'},
|
||||||
|
{'leverage': 25, 'expected_risk': 'Medium Risk'},
|
||||||
|
{'leverage': 30, 'expected_risk': 'High Risk'},
|
||||||
|
{'leverage': 50, 'expected_risk': 'High Risk'},
|
||||||
|
{'leverage': 75, 'expected_risk': 'Extreme Risk'},
|
||||||
|
{'leverage': 100, 'expected_risk': 'Extreme Risk'},
|
||||||
|
]
|
||||||
|
|
||||||
|
for test_case in test_cases:
|
||||||
|
leverage = test_case['leverage']
|
||||||
|
expected_risk = test_case['expected_risk']
|
||||||
|
|
||||||
|
# Calculate risk level using same logic as dashboard
|
||||||
|
if leverage <= 5:
|
||||||
|
actual_risk = "Low Risk"
|
||||||
|
elif leverage <= 25:
|
||||||
|
actual_risk = "Medium Risk"
|
||||||
|
elif leverage <= 50:
|
||||||
|
actual_risk = "High Risk"
|
||||||
|
else:
|
||||||
|
actual_risk = "Extreme Risk"
|
||||||
|
|
||||||
|
status = "PASS" if actual_risk == expected_risk else "FAIL"
|
||||||
|
logger.info(f" {leverage:3d}x leverage -> {actual_risk:13s} (expected: {expected_risk:13s}) [{status}]")
|
||||||
|
|
||||||
|
if status == "FAIL":
|
||||||
|
logger.error(f"Test failed for {leverage}x leverage!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info("All leverage calculation tests PASSED!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def test_leverage_reward_amplification():
|
||||||
|
"""Test how different leverage levels amplify rewards"""
|
||||||
|
|
||||||
|
logger.info("\n" + "=" * 50)
|
||||||
|
logger.info("TESTING LEVERAGE REWARD AMPLIFICATION")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
base_price = 3000.0
|
||||||
|
price_changes = [0.001, 0.002, 0.005, 0.01] # 0.1%, 0.2%, 0.5%, 1.0%
|
||||||
|
leverage_levels = [1, 5, 10, 25, 50, 100]
|
||||||
|
|
||||||
|
logger.info("Price Change | " + " | ".join([f"{lev:3d}x" for lev in leverage_levels]))
|
||||||
|
logger.info("-" * 70)
|
||||||
|
|
||||||
|
for price_change_pct in price_changes:
|
||||||
|
results = []
|
||||||
|
for leverage in leverage_levels:
|
||||||
|
# Calculate amplified return
|
||||||
|
amplified_return = price_change_pct * leverage * 100 # Convert to percentage
|
||||||
|
results.append(f"{amplified_return:6.1f}%")
|
||||||
|
|
||||||
|
logger.info(f" {price_change_pct*100:4.1f}% | " + " | ".join(results))
|
||||||
|
|
||||||
|
logger.info("\nKey insights:")
|
||||||
|
logger.info("- 1x leverage: Traditional trading returns")
|
||||||
|
logger.info("- 50x leverage: Our current default for enhanced learning")
|
||||||
|
logger.info("- 100x leverage: Maximum risk/reward amplification")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def test_dashboard_integration():
|
||||||
|
"""Test dashboard integration"""
|
||||||
|
|
||||||
|
logger.info("\n" + "=" * 50)
|
||||||
|
logger.info("TESTING DASHBOARD INTEGRATION")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize components
|
||||||
|
logger.info("Creating data provider...")
|
||||||
|
data_provider = DataProvider()
|
||||||
|
|
||||||
|
logger.info("Creating enhanced orchestrator...")
|
||||||
|
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||||
|
|
||||||
|
logger.info("Creating trading dashboard...")
|
||||||
|
dashboard = TradingDashboard(data_provider, orchestrator)
|
||||||
|
|
||||||
|
# Test leverage settings
|
||||||
|
logger.info(f"Initial leverage: {dashboard.leverage_multiplier}x")
|
||||||
|
logger.info(f"Leverage range: {dashboard.min_leverage}x to {dashboard.max_leverage}x")
|
||||||
|
logger.info(f"Leverage step: {dashboard.leverage_step}x")
|
||||||
|
|
||||||
|
# Test leverage updates
|
||||||
|
test_leverages = [10, 25, 50, 75]
|
||||||
|
for test_leverage in test_leverages:
|
||||||
|
dashboard.leverage_multiplier = test_leverage
|
||||||
|
logger.info(f"Set leverage to {dashboard.leverage_multiplier}x")
|
||||||
|
|
||||||
|
logger.info("Dashboard integration test PASSED!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Dashboard integration test FAILED: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all leverage tests"""
|
||||||
|
|
||||||
|
logger.info("LEVERAGE SLIDER FUNCTIONALITY TEST")
|
||||||
|
logger.info("Testing the 50x leverage system with adjustable slider")
|
||||||
|
|
||||||
|
all_passed = True
|
||||||
|
|
||||||
|
# Test 1: Leverage calculations
|
||||||
|
if not test_leverage_calculations():
|
||||||
|
all_passed = False
|
||||||
|
|
||||||
|
# Test 2: Reward amplification
|
||||||
|
if not test_leverage_reward_amplification():
|
||||||
|
all_passed = False
|
||||||
|
|
||||||
|
# Test 3: Dashboard integration
|
||||||
|
if not test_dashboard_integration():
|
||||||
|
all_passed = False
|
||||||
|
|
||||||
|
# Final result
|
||||||
|
logger.info("\n" + "=" * 50)
|
||||||
|
if all_passed:
|
||||||
|
logger.info("ALL TESTS PASSED!")
|
||||||
|
logger.info("Leverage slider functionality is working correctly.")
|
||||||
|
logger.info("\nTo use:")
|
||||||
|
logger.info("1. Run: python run_scalping_dashboard.py")
|
||||||
|
logger.info("2. Open: http://127.0.0.1:8050")
|
||||||
|
logger.info("3. Find the leverage slider in the System & Leverage panel")
|
||||||
|
logger.info("4. Adjust leverage from 1x to 100x")
|
||||||
|
logger.info("5. Watch risk levels update automatically")
|
||||||
|
else:
|
||||||
|
logger.error("SOME TESTS FAILED!")
|
||||||
|
logger.error("Check the error messages above.")
|
||||||
|
|
||||||
|
return 0 if all_passed else 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit_code = main()
|
||||||
|
sys.exit(exit_code)
|
554
web/dashboard.py
554
web/dashboard.py
@ -11,7 +11,7 @@ This module provides a modern, responsive web dashboard for the trading system:
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import dash
|
import dash
|
||||||
from dash import dcc, html, Input, Output
|
from dash import Dash, dcc, html, Input, Output
|
||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
from plotly.subplots import make_subplots
|
from plotly.subplots import make_subplots
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
@ -28,6 +28,8 @@ from collections import deque
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional, Any, Union, Tuple
|
from typing import Dict, List, Optional, Any, Union, Tuple
|
||||||
import websocket
|
import websocket
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
# Setup logger immediately after logging import
|
# Setup logger immediately after logging import
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -175,9 +177,49 @@ class TradingDashboard:
|
|||||||
"""Enhanced Trading Dashboard with Williams pivot points and unified timezone handling"""
|
"""Enhanced Trading Dashboard with Williams pivot points and unified timezone handling"""
|
||||||
|
|
||||||
def __init__(self, data_provider: DataProvider = None, orchestrator: TradingOrchestrator = None, trading_executor: TradingExecutor = None):
|
def __init__(self, data_provider: DataProvider = None, orchestrator: TradingOrchestrator = None, trading_executor: TradingExecutor = None):
|
||||||
"""Initialize the dashboard with unified data stream and enhanced RL training"""
|
self.app = Dash(__name__)
|
||||||
|
|
||||||
|
# Initialize config first
|
||||||
|
from core.config import get_config
|
||||||
self.config = get_config()
|
self.config = get_config()
|
||||||
|
|
||||||
|
self.data_provider = data_provider or DataProvider()
|
||||||
|
self.orchestrator = orchestrator
|
||||||
|
self.trading_executor = trading_executor
|
||||||
|
|
||||||
|
# Enhanced trading state with leverage support
|
||||||
|
self.leverage_enabled = True
|
||||||
|
self.leverage_multiplier = 50.0 # 50x leverage (adjustable via slider)
|
||||||
|
self.base_capital = 10000.0
|
||||||
|
self.current_position = 0.0 # -1 to 1 (short to long)
|
||||||
|
self.position_size = 0.0
|
||||||
|
self.entry_price = 0.0
|
||||||
|
self.unrealized_pnl = 0.0
|
||||||
|
self.realized_pnl = 0.0
|
||||||
|
|
||||||
|
# Leverage settings for slider
|
||||||
|
self.min_leverage = 1.0
|
||||||
|
self.max_leverage = 100.0
|
||||||
|
self.leverage_step = 1.0
|
||||||
|
|
||||||
|
# Connect to trading server for leverage functionality
|
||||||
|
self.trading_server_url = "http://127.0.0.1:8052"
|
||||||
|
self.training_server_url = "http://127.0.0.1:8053"
|
||||||
|
self.stream_server_url = "http://127.0.0.1:8054"
|
||||||
|
|
||||||
|
# Enhanced performance tracking
|
||||||
|
self.leverage_metrics = {
|
||||||
|
'leverage_efficiency': 0.0,
|
||||||
|
'margin_used': 0.0,
|
||||||
|
'margin_available': 10000.0,
|
||||||
|
'effective_exposure': 0.0,
|
||||||
|
'risk_reward_ratio': 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Enhanced models will be loaded through model registry later
|
||||||
|
|
||||||
|
# Rest of initialization...
|
||||||
|
|
||||||
# Initialize timezone from config
|
# Initialize timezone from config
|
||||||
timezone_name = self.config.get('system', {}).get('timezone', 'Europe/Sofia')
|
timezone_name = self.config.get('system', {}).get('timezone', 'Europe/Sofia')
|
||||||
self.timezone = pytz.timezone(timezone_name)
|
self.timezone = pytz.timezone(timezone_name)
|
||||||
@ -874,13 +916,15 @@ class TradingDashboard:
|
|||||||
], className="card-body p-2")
|
], className="card-body p-2")
|
||||||
], className="card", style={"width": "32%", "marginLeft": "2%"}),
|
], className="card", style={"width": "32%", "marginLeft": "2%"}),
|
||||||
|
|
||||||
# System status - 1/3 width with icon tooltip
|
# System status and leverage controls - 1/3 width with icon tooltip
|
||||||
html.Div([
|
html.Div([
|
||||||
html.Div([
|
html.Div([
|
||||||
html.H6([
|
html.H6([
|
||||||
html.I(className="fas fa-server me-2"),
|
html.I(className="fas fa-server me-2"),
|
||||||
"System"
|
"System & Leverage"
|
||||||
], className="card-title mb-2"),
|
], className="card-title mb-2"),
|
||||||
|
|
||||||
|
# System status
|
||||||
html.Div([
|
html.Div([
|
||||||
html.I(
|
html.I(
|
||||||
id="system-status-icon",
|
id="system-status-icon",
|
||||||
@ -889,7 +933,44 @@ class TradingDashboard:
|
|||||||
style={"cursor": "pointer"}
|
style={"cursor": "pointer"}
|
||||||
),
|
),
|
||||||
html.Div(id="system-status-details", className="small mt-2")
|
html.Div(id="system-status-details", className="small mt-2")
|
||||||
], className="text-center")
|
], className="text-center mb-3"),
|
||||||
|
|
||||||
|
# Leverage Controls
|
||||||
|
html.Div([
|
||||||
|
html.Label([
|
||||||
|
html.I(className="fas fa-chart-line me-1"),
|
||||||
|
"Leverage Multiplier"
|
||||||
|
], className="form-label small fw-bold"),
|
||||||
|
html.Div([
|
||||||
|
dcc.Slider(
|
||||||
|
id='leverage-slider',
|
||||||
|
min=self.min_leverage,
|
||||||
|
max=self.max_leverage,
|
||||||
|
step=self.leverage_step,
|
||||||
|
value=self.leverage_multiplier,
|
||||||
|
marks={
|
||||||
|
1: '1x',
|
||||||
|
10: '10x',
|
||||||
|
25: '25x',
|
||||||
|
50: '50x',
|
||||||
|
75: '75x',
|
||||||
|
100: '100x'
|
||||||
|
},
|
||||||
|
tooltip={
|
||||||
|
"placement": "bottom",
|
||||||
|
"always_visible": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
], className="mb-2"),
|
||||||
|
html.Div([
|
||||||
|
html.Span(id="current-leverage", className="badge bg-warning text-dark"),
|
||||||
|
html.Span(" • ", className="mx-1"),
|
||||||
|
html.Span(id="leverage-risk", className="badge bg-info")
|
||||||
|
], className="text-center"),
|
||||||
|
html.Div([
|
||||||
|
html.Small("Higher leverage = Higher rewards & risks", className="text-muted")
|
||||||
|
], className="text-center mt-1")
|
||||||
|
])
|
||||||
], className="card-body p-2")
|
], className="card-body p-2")
|
||||||
], className="card", style={"width": "32%", "marginLeft": "2%"})
|
], className="card", style={"width": "32%", "marginLeft": "2%"})
|
||||||
], className="d-flex")
|
], className="d-flex")
|
||||||
@ -918,6 +999,8 @@ class TradingDashboard:
|
|||||||
Output('system-status-icon', 'className'),
|
Output('system-status-icon', 'className'),
|
||||||
Output('system-status-icon', 'title'),
|
Output('system-status-icon', 'title'),
|
||||||
Output('system-status-details', 'children'),
|
Output('system-status-details', 'children'),
|
||||||
|
Output('current-leverage', 'children'),
|
||||||
|
Output('leverage-risk', 'children'),
|
||||||
# Model data feed charts
|
# Model data feed charts
|
||||||
# Output('model-data-1m', 'figure'),
|
# Output('model-data-1m', 'figure'),
|
||||||
# Output('model-data-1h', 'figure'),
|
# Output('model-data-1h', 'figure'),
|
||||||
@ -1168,10 +1251,26 @@ class TradingDashboard:
|
|||||||
logger.warning(f"Closed trades table error: {e}")
|
logger.warning(f"Closed trades table error: {e}")
|
||||||
closed_trades_table = [html.P("Closed trades data unavailable", className="text-muted")]
|
closed_trades_table = [html.P("Closed trades data unavailable", className="text-muted")]
|
||||||
|
|
||||||
|
# Calculate leverage display values
|
||||||
|
leverage_text = f"{self.leverage_multiplier:.0f}x"
|
||||||
|
if self.leverage_multiplier <= 5:
|
||||||
|
risk_level = "Low Risk"
|
||||||
|
risk_class = "bg-success"
|
||||||
|
elif self.leverage_multiplier <= 25:
|
||||||
|
risk_level = "Medium Risk"
|
||||||
|
risk_class = "bg-warning text-dark"
|
||||||
|
elif self.leverage_multiplier <= 50:
|
||||||
|
risk_level = "High Risk"
|
||||||
|
risk_class = "bg-danger"
|
||||||
|
else:
|
||||||
|
risk_level = "Extreme Risk"
|
||||||
|
risk_class = "bg-dark"
|
||||||
|
|
||||||
return (
|
return (
|
||||||
price_text, pnl_text, pnl_class, fees_text, position_text, position_class, trade_count_text, portfolio_text, mexc_status,
|
price_text, pnl_text, pnl_class, fees_text, position_text, position_class, trade_count_text, portfolio_text, mexc_status,
|
||||||
price_chart, training_metrics, decisions_list, session_perf, closed_trades_table,
|
price_chart, training_metrics, decisions_list, session_perf, closed_trades_table,
|
||||||
system_status['icon_class'], system_status['title'], system_status['details'],
|
system_status['icon_class'], system_status['title'], system_status['details'],
|
||||||
|
leverage_text, f"{risk_level}",
|
||||||
# # Model data feed charts
|
# # Model data feed charts
|
||||||
# self._create_model_data_chart('ETH/USDT', '1m'),
|
# self._create_model_data_chart('ETH/USDT', '1m'),
|
||||||
# self._create_model_data_chart('ETH/USDT', '1h'),
|
# self._create_model_data_chart('ETH/USDT', '1h'),
|
||||||
@ -1194,11 +1293,12 @@ class TradingDashboard:
|
|||||||
"fas fa-circle text-danger fa-2x",
|
"fas fa-circle text-danger fa-2x",
|
||||||
"Error: Dashboard error - check logs",
|
"Error: Dashboard error - check logs",
|
||||||
[html.P(f"Error: {str(e)}", className="text-danger")],
|
[html.P(f"Error: {str(e)}", className="text-danger")],
|
||||||
|
f"{self.leverage_multiplier:.0f}x", "Error",
|
||||||
# Model data feed charts
|
# Model data feed charts
|
||||||
self._create_model_data_chart('ETH/USDT', '1m'),
|
# self._create_model_data_chart('ETH/USDT', '1m'),
|
||||||
self._create_model_data_chart('ETH/USDT', '1h'),
|
# self._create_model_data_chart('ETH/USDT', '1h'),
|
||||||
self._create_model_data_chart('ETH/USDT', '1d'),
|
# self._create_model_data_chart('ETH/USDT', '1d'),
|
||||||
self._create_model_data_chart('BTC/USDT', '1s')
|
# self._create_model_data_chart('BTC/USDT', '1s')
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clear history callback
|
# Clear history callback
|
||||||
@ -1220,6 +1320,60 @@ class TradingDashboard:
|
|||||||
return [html.P(f"Error clearing history: {str(e)}", className="text-danger text-center")]
|
return [html.P(f"Error clearing history: {str(e)}", className="text-danger text-center")]
|
||||||
return dash.no_update
|
return dash.no_update
|
||||||
|
|
||||||
|
# Leverage slider callback
|
||||||
|
@self.app.callback(
|
||||||
|
[Output('current-leverage', 'children', allow_duplicate=True),
|
||||||
|
Output('leverage-risk', 'children', allow_duplicate=True),
|
||||||
|
Output('leverage-risk', 'className', allow_duplicate=True)],
|
||||||
|
[Input('leverage-slider', 'value')],
|
||||||
|
prevent_initial_call=True
|
||||||
|
)
|
||||||
|
def update_leverage(leverage_value):
|
||||||
|
"""Update leverage multiplier and risk assessment"""
|
||||||
|
try:
|
||||||
|
if leverage_value is None:
|
||||||
|
return dash.no_update
|
||||||
|
|
||||||
|
# Update internal leverage value
|
||||||
|
self.leverage_multiplier = float(leverage_value)
|
||||||
|
|
||||||
|
# Calculate risk level and styling
|
||||||
|
leverage_text = f"{self.leverage_multiplier:.0f}x"
|
||||||
|
|
||||||
|
if self.leverage_multiplier <= 5:
|
||||||
|
risk_level = "Low Risk"
|
||||||
|
risk_class = "badge bg-success"
|
||||||
|
elif self.leverage_multiplier <= 25:
|
||||||
|
risk_level = "Medium Risk"
|
||||||
|
risk_class = "badge bg-warning text-dark"
|
||||||
|
elif self.leverage_multiplier <= 50:
|
||||||
|
risk_level = "High Risk"
|
||||||
|
risk_class = "badge bg-danger"
|
||||||
|
else:
|
||||||
|
risk_level = "Extreme Risk"
|
||||||
|
risk_class = "badge bg-dark"
|
||||||
|
|
||||||
|
# Update trading server if connected
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
response = requests.post(f"{self.trading_server_url}/update_leverage",
|
||||||
|
json={"leverage": self.leverage_multiplier},
|
||||||
|
timeout=2)
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.info(f"[LEVERAGE] Updated trading server leverage to {self.leverage_multiplier}x")
|
||||||
|
else:
|
||||||
|
logger.warning(f"[LEVERAGE] Failed to update trading server: {response.status_code}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"[LEVERAGE] Trading server not available: {e}")
|
||||||
|
|
||||||
|
logger.info(f"[LEVERAGE] Leverage updated to {self.leverage_multiplier}x ({risk_level})")
|
||||||
|
|
||||||
|
return leverage_text, risk_level, risk_class
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating leverage: {e}")
|
||||||
|
return f"{self.leverage_multiplier:.0f}x", "Error", "badge bg-secondary"
|
||||||
|
|
||||||
def _simulate_price_update(self, symbol: str, base_price: float) -> float:
|
def _simulate_price_update(self, symbol: str, base_price: float) -> float:
|
||||||
"""
|
"""
|
||||||
Create realistic price movement for demo purposes
|
Create realistic price movement for demo purposes
|
||||||
@ -2218,10 +2372,11 @@ class TradingDashboard:
|
|||||||
size = self.current_position['size']
|
size = self.current_position['size']
|
||||||
entry_time = self.current_position['timestamp']
|
entry_time = self.current_position['timestamp']
|
||||||
|
|
||||||
# Calculate PnL for closing short
|
# Calculate PnL for closing short with leverage
|
||||||
gross_pnl = (entry_price - exit_price) * size # Short PnL calculation
|
leveraged_pnl, leveraged_fee = self._calculate_leveraged_pnl_and_fees(
|
||||||
fee = exit_price * size * fee_rate
|
entry_price, exit_price, size, 'SHORT', fee_rate
|
||||||
net_pnl = gross_pnl - fee - self.current_position['fees']
|
)
|
||||||
|
net_pnl = leveraged_pnl - leveraged_fee - self.current_position['fees']
|
||||||
|
|
||||||
self.total_realized_pnl += net_pnl
|
self.total_realized_pnl += net_pnl
|
||||||
self.total_fees += fee
|
self.total_fees += fee
|
||||||
@ -2246,8 +2401,8 @@ class TradingDashboard:
|
|||||||
'entry_price': entry_price,
|
'entry_price': entry_price,
|
||||||
'exit_price': exit_price,
|
'exit_price': exit_price,
|
||||||
'size': size,
|
'size': size,
|
||||||
'gross_pnl': gross_pnl,
|
'gross_pnl': leveraged_pnl,
|
||||||
'fees': fee + self.current_position['fees'],
|
'fees': leveraged_fee + self.current_position['fees'],
|
||||||
'fee_type': fee_type,
|
'fee_type': fee_type,
|
||||||
'fee_rate': fee_rate,
|
'fee_rate': fee_rate,
|
||||||
'net_pnl': net_pnl,
|
'net_pnl': net_pnl,
|
||||||
@ -2280,7 +2435,7 @@ class TradingDashboard:
|
|||||||
# Now open long position (regardless of previous position)
|
# Now open long position (regardless of previous position)
|
||||||
if self.current_position is None:
|
if self.current_position is None:
|
||||||
# Open long position with confidence-based size
|
# Open long position with confidence-based size
|
||||||
fee = decision['price'] * decision['size'] * fee_rate
|
fee = decision['price'] * decision['size'] * fee_rate * self.leverage_multiplier # Leverage affects fees
|
||||||
self.current_position = {
|
self.current_position = {
|
||||||
'side': 'LONG',
|
'side': 'LONG',
|
||||||
'price': decision['price'],
|
'price': decision['price'],
|
||||||
@ -2310,10 +2465,11 @@ class TradingDashboard:
|
|||||||
size = self.current_position['size']
|
size = self.current_position['size']
|
||||||
entry_time = self.current_position['timestamp']
|
entry_time = self.current_position['timestamp']
|
||||||
|
|
||||||
# Calculate PnL for closing short
|
# Calculate PnL for closing short with leverage
|
||||||
gross_pnl = (entry_price - exit_price) * size # Short PnL calculation
|
leveraged_pnl, leveraged_fee = self._calculate_leveraged_pnl_and_fees(
|
||||||
fee = exit_price * size * fee_rate
|
entry_price, exit_price, size, 'SHORT', fee_rate
|
||||||
net_pnl = gross_pnl - fee - self.current_position['fees']
|
)
|
||||||
|
net_pnl = leveraged_pnl - leveraged_fee - self.current_position['fees']
|
||||||
|
|
||||||
self.total_realized_pnl += net_pnl
|
self.total_realized_pnl += net_pnl
|
||||||
self.total_fees += fee
|
self.total_fees += fee
|
||||||
@ -2337,8 +2493,8 @@ class TradingDashboard:
|
|||||||
'entry_price': entry_price,
|
'entry_price': entry_price,
|
||||||
'exit_price': exit_price,
|
'exit_price': exit_price,
|
||||||
'size': size,
|
'size': size,
|
||||||
'gross_pnl': gross_pnl,
|
'gross_pnl': leveraged_pnl,
|
||||||
'fees': fee + self.current_position['fees'],
|
'fees': leveraged_fee + self.current_position['fees'],
|
||||||
'fee_type': fee_type,
|
'fee_type': fee_type,
|
||||||
'fee_rate': fee_rate,
|
'fee_rate': fee_rate,
|
||||||
'net_pnl': net_pnl,
|
'net_pnl': net_pnl,
|
||||||
@ -2377,10 +2533,11 @@ class TradingDashboard:
|
|||||||
size = self.current_position['size']
|
size = self.current_position['size']
|
||||||
entry_time = self.current_position['timestamp']
|
entry_time = self.current_position['timestamp']
|
||||||
|
|
||||||
# Calculate PnL for closing long
|
# Calculate PnL for closing long with leverage
|
||||||
gross_pnl = (exit_price - entry_price) * size # Long PnL calculation
|
leveraged_pnl, leveraged_fee = self._calculate_leveraged_pnl_and_fees(
|
||||||
fee = exit_price * size * fee_rate
|
entry_price, exit_price, size, 'LONG', fee_rate
|
||||||
net_pnl = gross_pnl - fee - self.current_position['fees']
|
)
|
||||||
|
net_pnl = leveraged_pnl - leveraged_fee - self.current_position['fees']
|
||||||
|
|
||||||
self.total_realized_pnl += net_pnl
|
self.total_realized_pnl += net_pnl
|
||||||
self.total_fees += fee
|
self.total_fees += fee
|
||||||
@ -2405,8 +2562,8 @@ class TradingDashboard:
|
|||||||
'entry_price': entry_price,
|
'entry_price': entry_price,
|
||||||
'exit_price': exit_price,
|
'exit_price': exit_price,
|
||||||
'size': size,
|
'size': size,
|
||||||
'gross_pnl': gross_pnl,
|
'gross_pnl': leveraged_pnl,
|
||||||
'fees': fee + self.current_position['fees'],
|
'fees': leveraged_fee + self.current_position['fees'],
|
||||||
'fee_type': fee_type,
|
'fee_type': fee_type,
|
||||||
'fee_rate': fee_rate,
|
'fee_rate': fee_rate,
|
||||||
'net_pnl': net_pnl,
|
'net_pnl': net_pnl,
|
||||||
@ -2427,7 +2584,7 @@ class TradingDashboard:
|
|||||||
# Now open short position (regardless of previous position)
|
# Now open short position (regardless of previous position)
|
||||||
if self.current_position is None:
|
if self.current_position is None:
|
||||||
# Open short position with confidence-based size
|
# Open short position with confidence-based size
|
||||||
fee = decision['price'] * decision['size'] * fee_rate
|
fee = decision['price'] * decision['size'] * fee_rate * self.leverage_multiplier # Leverage affects fees
|
||||||
self.current_position = {
|
self.current_position = {
|
||||||
'side': 'SHORT',
|
'side': 'SHORT',
|
||||||
'price': decision['price'],
|
'price': decision['price'],
|
||||||
@ -2458,8 +2615,34 @@ class TradingDashboard:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing trading decision: {e}")
|
logger.error(f"Error processing trading decision: {e}")
|
||||||
|
|
||||||
|
def _calculate_leveraged_pnl_and_fees(self, entry_price: float, exit_price: float, size: float, side: str, fee_rate: float):
|
||||||
|
"""Calculate leveraged PnL and fees for closed positions"""
|
||||||
|
try:
|
||||||
|
# Calculate base PnL
|
||||||
|
if side == 'LONG':
|
||||||
|
base_pnl = (exit_price - entry_price) * size
|
||||||
|
elif side == 'SHORT':
|
||||||
|
base_pnl = (entry_price - exit_price) * size
|
||||||
|
else:
|
||||||
|
return 0.0, 0.0
|
||||||
|
|
||||||
|
# Apply leverage amplification
|
||||||
|
leveraged_pnl = base_pnl * self.leverage_multiplier
|
||||||
|
|
||||||
|
# Calculate fees with leverage (higher position value = higher fees)
|
||||||
|
position_value = exit_price * size * self.leverage_multiplier
|
||||||
|
leveraged_fee = position_value * fee_rate
|
||||||
|
|
||||||
|
logger.info(f"[LEVERAGE] {side} PnL: Base=${base_pnl:.2f} x {self.leverage_multiplier}x = ${leveraged_pnl:.2f}, Fee=${leveraged_fee:.4f}")
|
||||||
|
|
||||||
|
return leveraged_pnl, leveraged_fee
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error calculating leveraged PnL and fees: {e}")
|
||||||
|
return 0.0, 0.0
|
||||||
|
|
||||||
def _calculate_unrealized_pnl(self, current_price: float) -> float:
|
def _calculate_unrealized_pnl(self, current_price: float) -> float:
|
||||||
"""Calculate unrealized PnL for open position"""
|
"""Calculate unrealized PnL for open position with leverage amplification"""
|
||||||
try:
|
try:
|
||||||
if not self.current_position:
|
if not self.current_position:
|
||||||
return 0.0
|
return 0.0
|
||||||
@ -2467,13 +2650,21 @@ class TradingDashboard:
|
|||||||
entry_price = self.current_position['price']
|
entry_price = self.current_position['price']
|
||||||
size = self.current_position['size']
|
size = self.current_position['size']
|
||||||
|
|
||||||
|
# Calculate base PnL
|
||||||
if self.current_position['side'] == 'LONG':
|
if self.current_position['side'] == 'LONG':
|
||||||
return (current_price - entry_price) * size
|
base_pnl = (current_price - entry_price) * size
|
||||||
elif self.current_position['side'] == 'SHORT':
|
elif self.current_position['side'] == 'SHORT':
|
||||||
return (entry_price - current_price) * size
|
base_pnl = (entry_price - current_price) * size
|
||||||
|
else:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
# Apply leverage amplification
|
||||||
|
leveraged_pnl = base_pnl * self.leverage_multiplier
|
||||||
|
|
||||||
|
logger.debug(f"[LEVERAGE PnL] Base: ${base_pnl:.2f} x {self.leverage_multiplier}x = ${leveraged_pnl:.2f}")
|
||||||
|
|
||||||
|
return leveraged_pnl
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error calculating unrealized PnL: {e}")
|
logger.warning(f"Error calculating unrealized PnL: {e}")
|
||||||
return 0.0
|
return 0.0
|
||||||
@ -2804,208 +2995,189 @@ class TradingDashboard:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _load_available_models(self):
|
def _load_available_models(self):
|
||||||
"""Load available CNN and RL models for real trading"""
|
"""Load available models with enhanced model management"""
|
||||||
try:
|
try:
|
||||||
from pathlib import Path
|
from model_manager import ModelManager, ModelMetrics
|
||||||
import torch
|
|
||||||
|
|
||||||
models_loaded = 0
|
# Initialize model manager
|
||||||
|
self.model_manager = ModelManager()
|
||||||
|
|
||||||
# Try to load real CNN models - handle different architectures
|
# Load best models
|
||||||
cnn_paths = [
|
loaded_models = self.model_manager.load_best_models()
|
||||||
'models/cnn/scalping_cnn_trained_best.pt',
|
|
||||||
'models/cnn/scalping_cnn_trained.pt',
|
|
||||||
'models/saved/cnn_model_best.pt'
|
|
||||||
]
|
|
||||||
|
|
||||||
for cnn_path in cnn_paths:
|
if loaded_models:
|
||||||
if Path(cnn_path).exists():
|
logger.info(f"Loaded {len(loaded_models)} best models via ModelManager")
|
||||||
try:
|
|
||||||
# Load with weights_only=False for older models
|
|
||||||
checkpoint = torch.load(cnn_path, map_location='cpu', weights_only=False)
|
|
||||||
|
|
||||||
# Try different CNN model classes to find the right architecture
|
# Update internal model storage
|
||||||
cnn_model = None
|
for model_type, model_data in loaded_models.items():
|
||||||
model_classes = []
|
model_info = model_data['info']
|
||||||
|
logger.info(f"Using best {model_type} model: {model_info.model_name} "
|
||||||
|
f"(Score: {model_info.metrics.get_composite_score():.3f})")
|
||||||
|
|
||||||
# Try importing different CNN classes
|
|
||||||
try:
|
|
||||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch
|
|
||||||
model_classes.append(CNNModelPyTorch)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
from models.cnn.enhanced_cnn import EnhancedCNN
|
|
||||||
model_classes.append(EnhancedCNN)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try to load with each model class
|
|
||||||
for model_class in model_classes:
|
|
||||||
try:
|
|
||||||
# Try different parameter combinations
|
|
||||||
param_combinations = [
|
|
||||||
{'window_size': 20, 'timeframes': ['1m', '5m', '1h'], 'output_size': 3},
|
|
||||||
{'window_size': 20, 'output_size': 3},
|
|
||||||
{'input_channels': 5, 'num_classes': 3}
|
|
||||||
]
|
|
||||||
|
|
||||||
for params in param_combinations:
|
|
||||||
try:
|
|
||||||
cnn_model = model_class(**params)
|
|
||||||
|
|
||||||
# Try to load state dict with different keys
|
|
||||||
if hasattr(checkpoint, 'keys'):
|
|
||||||
state_dict_keys = ['model_state_dict', 'state_dict', 'model']
|
|
||||||
for key in state_dict_keys:
|
|
||||||
if key in checkpoint:
|
|
||||||
cnn_model.model.load_state_dict(checkpoint[key], strict=False)
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
# Try loading checkpoint directly as state dict
|
logger.info("No managed models available, falling back to legacy loading")
|
||||||
cnn_model.model.load_state_dict(checkpoint, strict=False)
|
# Fallback to original model loading logic
|
||||||
|
self._load_legacy_models()
|
||||||
|
|
||||||
cnn_model.model.eval()
|
except ImportError:
|
||||||
logger.info(f"[MODEL] Successfully loaded CNN model: {model_class.__name__}")
|
logger.warning("ModelManager not available, using legacy model loading")
|
||||||
break
|
self._load_legacy_models()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Failed to load with {model_class.__name__} and params {params}: {e}")
|
logger.error(f"Error loading models via ModelManager: {e}")
|
||||||
continue
|
self._load_legacy_models()
|
||||||
|
|
||||||
if cnn_model is not None:
|
def _load_legacy_models(self):
|
||||||
break
|
"""Legacy model loading method (original implementation)"""
|
||||||
|
self.available_models = {
|
||||||
|
'cnn': [],
|
||||||
|
'rl': [],
|
||||||
|
'hybrid': []
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
try:
|
||||||
logger.debug(f"Failed to initialize {model_class.__name__}: {e}")
|
# Check for CNN models
|
||||||
continue
|
cnn_models_dir = "models/cnn"
|
||||||
|
if os.path.exists(cnn_models_dir):
|
||||||
|
for model_file in os.listdir(cnn_models_dir):
|
||||||
|
if model_file.endswith('.pt'):
|
||||||
|
model_path = os.path.join(cnn_models_dir, model_file)
|
||||||
|
try:
|
||||||
|
# Try to load model to verify it's valid
|
||||||
|
model = torch.load(model_path, map_location='cpu')
|
||||||
|
|
||||||
if cnn_model is not None:
|
|
||||||
# Create a simple wrapper for the orchestrator
|
|
||||||
class CNNWrapper:
|
class CNNWrapper:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = f"CNN_{Path(cnn_path).stem}"
|
self.model.eval()
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
|
|
||||||
def predict(self, feature_matrix):
|
def predict(self, feature_matrix):
|
||||||
"""Simple prediction interface"""
|
with torch.no_grad():
|
||||||
try:
|
if hasattr(feature_matrix, 'shape') and len(feature_matrix.shape) == 2:
|
||||||
# Simplified prediction - return reasonable defaults
|
feature_tensor = torch.FloatTensor(feature_matrix).unsqueeze(0)
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Use basic trend analysis for more realistic predictions
|
|
||||||
if feature_matrix is not None:
|
|
||||||
trend = random.choice([-1, 0, 1])
|
|
||||||
if trend == 1:
|
|
||||||
action_probs = [0.2, 0.3, 0.5] # Bullish
|
|
||||||
elif trend == -1:
|
|
||||||
action_probs = [0.5, 0.3, 0.2] # Bearish
|
|
||||||
else:
|
else:
|
||||||
action_probs = [0.25, 0.5, 0.25] # Neutral
|
feature_tensor = torch.FloatTensor(feature_matrix)
|
||||||
else:
|
|
||||||
action_probs = [0.33, 0.34, 0.33]
|
|
||||||
|
|
||||||
confidence = max(action_probs)
|
prediction = self.model(feature_tensor)
|
||||||
return np.array(action_probs), confidence
|
|
||||||
except Exception as e:
|
if hasattr(prediction, 'cpu'):
|
||||||
logger.warning(f"CNN prediction error: {e}")
|
prediction = prediction.cpu().numpy()
|
||||||
return np.array([0.33, 0.34, 0.33]), 0.5
|
elif isinstance(prediction, torch.Tensor):
|
||||||
|
prediction = prediction.detach().numpy()
|
||||||
|
|
||||||
|
# Ensure we return probabilities
|
||||||
|
if len(prediction.shape) > 1:
|
||||||
|
prediction = prediction[0]
|
||||||
|
|
||||||
|
# Apply softmax if needed
|
||||||
|
if len(prediction) == 3:
|
||||||
|
exp_pred = np.exp(prediction - np.max(prediction))
|
||||||
|
prediction = exp_pred / np.sum(exp_pred)
|
||||||
|
|
||||||
|
return prediction
|
||||||
|
|
||||||
def get_memory_usage(self):
|
def get_memory_usage(self):
|
||||||
return 100 # MB estimate
|
return 50 # MB estimate
|
||||||
|
|
||||||
def to_device(self, device):
|
def to_device(self, device):
|
||||||
self.device = device
|
self.model = self.model.to(device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
wrapped_model = CNNWrapper(cnn_model)
|
wrapper = CNNWrapper(model)
|
||||||
|
self.available_models['cnn'].append({
|
||||||
|
'name': model_file,
|
||||||
|
'path': model_path,
|
||||||
|
'model': wrapper,
|
||||||
|
'type': 'cnn'
|
||||||
|
})
|
||||||
|
logger.info(f"Loaded CNN model: {model_file}")
|
||||||
|
|
||||||
# Register with orchestrator using the wrapper
|
|
||||||
if self.orchestrator.register_model(wrapped_model, weight=0.7):
|
|
||||||
logger.info(f"[MODEL] Loaded REAL CNN model from: {cnn_path}")
|
|
||||||
models_loaded += 1
|
|
||||||
break
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load real CNN from {cnn_path}: {e}")
|
logger.warning(f"Failed to load CNN model {model_file}: {e}")
|
||||||
|
|
||||||
# Try to load real RL models with enhanced training capability
|
# Check for RL models
|
||||||
rl_paths = [
|
rl_models_dir = "models/rl"
|
||||||
'models/rl/scalping_agent_trained_best.pt',
|
if os.path.exists(rl_models_dir):
|
||||||
'models/trading_agent_best_pnl.pt',
|
for model_file in os.listdir(rl_models_dir):
|
||||||
'models/trading_agent_best_reward.pt'
|
if model_file.endswith('.pt'):
|
||||||
]
|
|
||||||
|
|
||||||
for rl_path in rl_paths:
|
|
||||||
if Path(rl_path).exists():
|
|
||||||
try:
|
try:
|
||||||
# Load checkpoint with weights_only=False
|
checkpoint_path = os.path.join(rl_models_dir, model_file)
|
||||||
checkpoint = torch.load(rl_path, map_location='cpu', weights_only=False)
|
|
||||||
|
|
||||||
# Create RL agent wrapper for basic functionality
|
|
||||||
class RLWrapper:
|
class RLWrapper:
|
||||||
def __init__(self, checkpoint_path):
|
def __init__(self, checkpoint_path):
|
||||||
self.name = f"RL_{Path(checkpoint_path).stem}"
|
self.checkpoint_path = checkpoint_path
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
|
|
||||||
def predict(self, feature_matrix):
|
def predict(self, feature_matrix):
|
||||||
"""Simple prediction interface"""
|
# Mock RL prediction
|
||||||
try:
|
if hasattr(feature_matrix, 'shape'):
|
||||||
import random
|
state_sum = np.sum(feature_matrix) % 100
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# RL agent behavior - more conservative
|
|
||||||
if feature_matrix is not None:
|
|
||||||
confidence_level = random.uniform(0.4, 0.8)
|
|
||||||
|
|
||||||
if confidence_level > 0.7:
|
|
||||||
action_choice = random.choice(['BUY', 'SELL'])
|
|
||||||
if action_choice == 'BUY':
|
|
||||||
action_probs = [0.15, 0.25, 0.6]
|
|
||||||
else:
|
else:
|
||||||
action_probs = [0.6, 0.25, 0.15]
|
state_sum = np.sum(np.array(feature_matrix)) % 100
|
||||||
else:
|
|
||||||
action_probs = [0.2, 0.6, 0.2] # Prefer HOLD
|
|
||||||
else:
|
|
||||||
action_probs = [0.33, 0.34, 0.33]
|
|
||||||
|
|
||||||
confidence = max(action_probs)
|
if state_sum > 70:
|
||||||
return np.array(action_probs), confidence
|
action_probs = [0.1, 0.1, 0.8] # BUY
|
||||||
except Exception as e:
|
elif state_sum < 30:
|
||||||
logger.warning(f"RL prediction error: {e}")
|
action_probs = [0.8, 0.1, 0.1] # SELL
|
||||||
return np.array([0.33, 0.34, 0.33]), 0.5
|
else:
|
||||||
|
action_probs = [0.2, 0.6, 0.2] # HOLD
|
||||||
|
|
||||||
|
return np.array(action_probs)
|
||||||
|
|
||||||
def get_memory_usage(self):
|
def get_memory_usage(self):
|
||||||
return 80 # MB estimate
|
return 75 # MB estimate
|
||||||
|
|
||||||
def to_device(self, device):
|
def to_device(self, device):
|
||||||
self.device = device
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
rl_wrapper = RLWrapper(rl_path)
|
wrapper = RLWrapper(checkpoint_path)
|
||||||
|
self.available_models['rl'].append({
|
||||||
# Register with orchestrator
|
'name': model_file,
|
||||||
if self.orchestrator.register_model(rl_wrapper, weight=0.3):
|
'path': checkpoint_path,
|
||||||
logger.info(f"[MODEL] Loaded REAL RL agent from: {rl_path}")
|
'model': wrapper,
|
||||||
models_loaded += 1
|
'type': 'rl'
|
||||||
break
|
})
|
||||||
except Exception as e:
|
logger.info(f"Loaded RL model: {model_file}")
|
||||||
logger.warning(f"Failed to load real RL agent from {rl_path}: {e}")
|
|
||||||
|
|
||||||
# Set up continuous learning from trading outcomes
|
|
||||||
if models_loaded > 0:
|
|
||||||
logger.info(f"[SUCCESS] Loaded {models_loaded} REAL models for trading")
|
|
||||||
# Get model registry stats
|
|
||||||
memory_stats = self.model_registry.get_memory_stats()
|
|
||||||
logger.info(f"[MEMORY] Model registry: {len(memory_stats.get('models', {}))} models loaded")
|
|
||||||
else:
|
|
||||||
logger.warning("[WARNING] No real models loaded - orchestrator will not make predictions")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading real models: {e}")
|
logger.warning(f"Failed to load RL model {model_file}: {e}")
|
||||||
logger.warning("Continuing without pre-trained models")
|
|
||||||
|
total_models = sum(len(models) for models in self.available_models.values())
|
||||||
|
logger.info(f"Legacy model loading complete. Total models: {total_models}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in legacy model loading: {e}")
|
||||||
|
# Initialize empty model structure
|
||||||
|
self.available_models = {'cnn': [], 'rl': [], 'hybrid': []}
|
||||||
|
|
||||||
|
def register_model_performance(self, model_type: str, profit_factor: float,
|
||||||
|
win_rate: float, sharpe_ratio: float = 0.0,
|
||||||
|
accuracy: float = 0.0):
|
||||||
|
"""Register model performance with the model manager"""
|
||||||
|
try:
|
||||||
|
if hasattr(self, 'model_manager'):
|
||||||
|
# Find the current best model of this type
|
||||||
|
best_model = self.model_manager.get_best_model(model_type)
|
||||||
|
|
||||||
|
if best_model:
|
||||||
|
# Create metrics from performance data
|
||||||
|
from model_manager import ModelMetrics
|
||||||
|
|
||||||
|
metrics = ModelMetrics(
|
||||||
|
accuracy=accuracy,
|
||||||
|
profit_factor=profit_factor,
|
||||||
|
win_rate=win_rate,
|
||||||
|
sharpe_ratio=sharpe_ratio,
|
||||||
|
max_drawdown=0.0, # Will be calculated from trade history
|
||||||
|
total_trades=len(self.closed_trades),
|
||||||
|
confidence_score=0.7 # Default confidence
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update model performance
|
||||||
|
self.model_manager.update_model_performance(best_model.model_name, metrics)
|
||||||
|
logger.info(f"Updated {model_type} model performance: PF={profit_factor:.2f}, WR={win_rate:.2f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error registering model performance: {e}")
|
||||||
|
|
||||||
def _create_system_status_compact(self, memory_stats: Dict) -> Dict:
|
def _create_system_status_compact(self, memory_stats: Dict) -> Dict:
|
||||||
"""Create system status display in compact format"""
|
"""Create system status display in compact format"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user