481 lines
15 KiB
Markdown
481 lines
15 KiB
Markdown
# Live Inference & Training Mode Guide
|
|
|
|
## Overview
|
|
|
|
The system has an `EnhancedRealtimeTrainingSystem` that can perform:
|
|
- **Live Inference**: Predict next candle every second
|
|
- **Retrospective Training**: Train on previous candle once result is known
|
|
- **Multi-Timeframe**: Process 1s, 1m, 1h, 1d candles independently
|
|
|
|
## Current Status
|
|
|
|
### ✅ Available
|
|
- `EnhancedRealtimeTrainingSystem` class exists in `NN/training/enhanced_realtime_training.py`
|
|
- Comprehensive feature engineering
|
|
- Multi-model support (DQN, CNN, COB RL)
|
|
- Prediction tracking database
|
|
- Experience replay buffers
|
|
|
|
### ❌ Not Enabled
|
|
- Not instantiated in orchestrator
|
|
- No integration with main trading loop
|
|
- No UI controls to start/stop
|
|
|
|
---
|
|
|
|
## Architecture
|
|
|
|
### Live Inference Flow
|
|
|
|
```
|
|
Every 1 second:
|
|
┌─────────────────────────────────────────┐
|
|
│ 1. Fetch Latest Data │
|
|
│ - 1s candle (just closed) │
|
|
│ - 1m candle (if minute boundary) │
|
|
│ - 1h candle (if hour boundary) │
|
|
│ - 1d candle (if day boundary) │
|
|
└──────────────┬──────────────────────────┘
|
|
│
|
|
▼
|
|
┌─────────────────────────────────────────┐
|
|
│ 2. Make Predictions │
|
|
│ - Next 1s candle OHLCV │
|
|
│ - Next 1m candle OHLCV (if needed) │
|
|
│ - Trading action (BUY/SELL/HOLD) │
|
|
│ - Confidence score │
|
|
└──────────────┬──────────────────────────┘
|
|
│
|
|
▼
|
|
┌─────────────────────────────────────────┐
|
|
│ 3. Store Predictions │
|
|
│ - Save to prediction_database │
|
|
│ - Track prediction_id │
|
|
│ - Wait for resolution │
|
|
└─────────────────────────────────────────┘
|
|
```
|
|
|
|
### Retrospective Training Flow
|
|
|
|
```
|
|
Every 1 second (after candle closes):
|
|
┌─────────────────────────────────────────┐
|
|
│ 1. Get Previous Candle Result │
|
|
│ - Actual OHLCV values │
|
|
│ - Price change │
|
|
│ - Volume │
|
|
└──────────────┬──────────────────────────┘
|
|
│
|
|
▼
|
|
┌─────────────────────────────────────────┐
|
|
│ 2. Resolve Predictions │
|
|
│ - Compare predicted vs actual │
|
|
│ - Calculate reward/loss │
|
|
│ - Update prediction accuracy │
|
|
└──────────────┬──────────────────────────┘
|
|
│
|
|
▼
|
|
┌─────────────────────────────────────────┐
|
|
│ 3. Create Training Experience │
|
|
│ - State: market data before candle │
|
|
│ - Action: predicted action │
|
|
│ - Reward: based on accuracy │
|
|
│ - Next State: market data after │
|
|
└──────────────┬──────────────────────────┘
|
|
│
|
|
▼
|
|
┌─────────────────────────────────────────┐
|
|
│ 4. Train Models (if enough samples) │
|
|
│ - Batch training (32-64 samples) │
|
|
│ - Update model weights │
|
|
│ - Save checkpoint │
|
|
└─────────────────────────────────────────┘
|
|
```
|
|
|
|
---
|
|
|
|
## Implementation Plan
|
|
|
|
### Phase 1: Enable Realtime Training System
|
|
|
|
#### 1.1 Initialize in Orchestrator
|
|
|
|
```python
|
|
# In core/orchestrator.py __init__()
|
|
|
|
if ENHANCED_TRAINING_AVAILABLE:
|
|
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
|
orchestrator=self,
|
|
data_provider=self.data_provider,
|
|
dashboard=None # Optional dashboard integration
|
|
)
|
|
logger.info("EnhancedRealtimeTrainingSystem initialized")
|
|
else:
|
|
self.enhanced_training_system = None
|
|
logger.warning("EnhancedRealtimeTrainingSystem not available")
|
|
```
|
|
|
|
#### 1.2 Add Start/Stop Methods
|
|
|
|
```python
|
|
# In core/orchestrator.py
|
|
|
|
def start_live_training(self):
|
|
"""Start live inference and training mode"""
|
|
if self.enhanced_training_system:
|
|
self.enhanced_training_system.start_training()
|
|
logger.info("Live training mode started")
|
|
return True
|
|
else:
|
|
logger.error("Enhanced training system not available")
|
|
return False
|
|
|
|
def stop_live_training(self):
|
|
"""Stop live inference and training mode"""
|
|
if self.enhanced_training_system:
|
|
self.enhanced_training_system.stop_training()
|
|
logger.info("Live training mode stopped")
|
|
return True
|
|
return False
|
|
|
|
def is_live_training_active(self) -> bool:
|
|
"""Check if live training is active"""
|
|
if self.enhanced_training_system:
|
|
return self.enhanced_training_system.is_training
|
|
return False
|
|
```
|
|
|
|
### Phase 2: Implement Prediction & Training Loop
|
|
|
|
#### 2.1 Main Loop (runs every 1 second)
|
|
|
|
```python
|
|
# In EnhancedRealtimeTrainingSystem
|
|
|
|
def _live_inference_loop(self):
|
|
"""Main loop for live inference and training"""
|
|
while self.is_training:
|
|
try:
|
|
current_time = time.time()
|
|
|
|
# 1. Check which timeframes need processing
|
|
timeframes_to_process = self._get_active_timeframes(current_time)
|
|
|
|
for timeframe in timeframes_to_process:
|
|
# 2. Make prediction for next candle
|
|
prediction = self._make_next_candle_prediction(timeframe)
|
|
|
|
# 3. Store prediction
|
|
if prediction:
|
|
self._store_prediction(prediction)
|
|
|
|
# 4. Resolve previous predictions
|
|
self._resolve_timeframe_predictions(timeframe)
|
|
|
|
# 5. Train on resolved predictions
|
|
if self._should_train(timeframe):
|
|
self._train_on_timeframe(timeframe)
|
|
|
|
# Sleep until next second
|
|
elapsed = time.time() - current_time
|
|
sleep_time = max(0, 1.0 - elapsed)
|
|
time.sleep(sleep_time)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in live inference loop: {e}")
|
|
time.sleep(1)
|
|
```
|
|
|
|
#### 2.2 Prediction Method
|
|
|
|
```python
|
|
def _make_next_candle_prediction(self, timeframe: str) -> Dict:
|
|
"""
|
|
Predict next candle OHLCV values
|
|
|
|
Returns:
|
|
{
|
|
'timeframe': '1s',
|
|
'timestamp': datetime,
|
|
'predicted_open': float,
|
|
'predicted_high': float,
|
|
'predicted_low': float,
|
|
'predicted_close': float,
|
|
'predicted_volume': float,
|
|
'action': 'BUY'|'SELL'|'HOLD',
|
|
'confidence': float
|
|
}
|
|
"""
|
|
# Get current market state (600 candles)
|
|
market_state = self._get_market_state(timeframe)
|
|
|
|
# Get model prediction
|
|
if self.orchestrator.primary_transformer:
|
|
output = self.orchestrator.primary_transformer(market_state)
|
|
|
|
# Extract next candle prediction
|
|
next_candle = output['next_candles'][timeframe]
|
|
action_probs = output['action_probs']
|
|
|
|
return {
|
|
'timeframe': timeframe,
|
|
'timestamp': datetime.now(),
|
|
'predicted_open': next_candle[0].item(),
|
|
'predicted_high': next_candle[1].item(),
|
|
'predicted_low': next_candle[2].item(),
|
|
'predicted_close': next_candle[3].item(),
|
|
'predicted_volume': next_candle[4].item(),
|
|
'action': ['HOLD', 'BUY', 'SELL'][torch.argmax(action_probs).item()],
|
|
'confidence': torch.max(action_probs).item()
|
|
}
|
|
|
|
return None
|
|
```
|
|
|
|
#### 2.3 Training Method
|
|
|
|
```python
|
|
def _train_on_timeframe(self, timeframe: str):
|
|
"""
|
|
Train model on resolved predictions for this timeframe
|
|
|
|
Process:
|
|
1. Get resolved predictions (predicted vs actual)
|
|
2. Create training batches
|
|
3. Calculate loss
|
|
4. Update model weights
|
|
5. Save checkpoint (if needed)
|
|
"""
|
|
# Get resolved predictions
|
|
resolved = self._get_resolved_predictions(timeframe, limit=100)
|
|
|
|
if len(resolved) < 32: # Need minimum batch size
|
|
return
|
|
|
|
# Create training batches
|
|
batches = self._create_training_batches(resolved)
|
|
|
|
# Train model
|
|
if self.orchestrator.primary_transformer_trainer:
|
|
trainer = self.orchestrator.primary_transformer_trainer
|
|
|
|
for batch in batches:
|
|
result = trainer.train_step(batch)
|
|
|
|
# Log progress
|
|
if result:
|
|
logger.debug(f"Trained on {timeframe}: loss={result['total_loss']:.4f}")
|
|
|
|
# Save checkpoint every N batches
|
|
if self.training_iteration % 100 == 0:
|
|
self._save_checkpoint(timeframe)
|
|
```
|
|
|
|
---
|
|
|
|
## Configuration
|
|
|
|
### Training Intervals
|
|
|
|
```python
|
|
training_config = {
|
|
# Inference intervals (how often to predict)
|
|
'inference_1s': 1, # Every 1 second
|
|
'inference_1m': 60, # Every 1 minute
|
|
'inference_1h': 3600, # Every 1 hour
|
|
'inference_1d': 86400, # Every 1 day
|
|
|
|
# Training intervals (how often to train)
|
|
'training_1s': 10, # Train every 10 seconds (10 samples)
|
|
'training_1m': 300, # Train every 5 minutes (5 samples)
|
|
'training_1h': 3600, # Train every 1 hour (1 sample)
|
|
'training_1d': 86400, # Train every 1 day (1 sample)
|
|
|
|
# Batch sizes
|
|
'batch_size_1s': 32,
|
|
'batch_size_1m': 16,
|
|
'batch_size_1h': 8,
|
|
'batch_size_1d': 4,
|
|
|
|
# Buffer sizes
|
|
'buffer_size_1s': 1000,
|
|
'buffer_size_1m': 500,
|
|
'buffer_size_1h': 200,
|
|
'buffer_size_1d': 100
|
|
}
|
|
```
|
|
|
|
### Performance Targets
|
|
|
|
| Timeframe | Predictions/Hour | Training/Hour | GPU Load | Memory |
|
|
|-----------|------------------|---------------|----------|--------|
|
|
| 1s | 3,600 | 360 (every 10s) | 30-50% | 2GB |
|
|
| 1m | 60 | 12 (every 5m) | 10-20% | 1GB |
|
|
| 1h | 1 | 1 (every 1h) | 5-10% | 500MB |
|
|
| 1d | 0.04 | 0.04 (every 1d) | <5% | 200MB |
|
|
|
|
---
|
|
|
|
## Database Schema
|
|
|
|
### Predictions Table
|
|
|
|
```sql
|
|
CREATE TABLE predictions (
|
|
prediction_id INTEGER PRIMARY KEY,
|
|
model_name VARCHAR,
|
|
symbol VARCHAR,
|
|
timeframe VARCHAR,
|
|
timestamp BIGINT,
|
|
|
|
-- Predicted values
|
|
predicted_open DOUBLE,
|
|
predicted_high DOUBLE,
|
|
predicted_low DOUBLE,
|
|
predicted_close DOUBLE,
|
|
predicted_volume DOUBLE,
|
|
predicted_action VARCHAR,
|
|
confidence DOUBLE,
|
|
|
|
-- Actual values (filled when resolved)
|
|
actual_open DOUBLE,
|
|
actual_high DOUBLE,
|
|
actual_low DOUBLE,
|
|
actual_close DOUBLE,
|
|
actual_volume DOUBLE,
|
|
|
|
-- Accuracy metrics
|
|
price_error DOUBLE,
|
|
volume_error DOUBLE,
|
|
action_correct BOOLEAN,
|
|
reward DOUBLE,
|
|
|
|
-- Status
|
|
status VARCHAR, -- 'pending', 'resolved', 'expired'
|
|
resolved_at BIGINT
|
|
);
|
|
```
|
|
|
|
---
|
|
|
|
## UI Integration
|
|
|
|
### Dashboard Controls
|
|
|
|
```html
|
|
<!-- Live Training Panel -->
|
|
<div class="live-training-panel">
|
|
<h3>Live Inference & Training</h3>
|
|
|
|
<div class="status">
|
|
<span id="live-status">Inactive</span>
|
|
<button id="start-live-btn">Start Live Mode</button>
|
|
<button id="stop-live-btn" disabled>Stop Live Mode</button>
|
|
</div>
|
|
|
|
<div class="metrics">
|
|
<div class="metric">
|
|
<label>Predictions/sec:</label>
|
|
<span id="predictions-per-sec">0</span>
|
|
</div>
|
|
<div class="metric">
|
|
<label>Training batches/min:</label>
|
|
<span id="training-per-min">0</span>
|
|
</div>
|
|
<div class="metric">
|
|
<label>Accuracy (1m):</label>
|
|
<span id="accuracy-1m">0%</span>
|
|
</div>
|
|
<div class="metric">
|
|
<label>GPU Load:</label>
|
|
<span id="gpu-load">0%</span>
|
|
</div>
|
|
</div>
|
|
|
|
<div class="recent-predictions">
|
|
<h4>Recent Predictions</h4>
|
|
<table id="predictions-table">
|
|
<thead>
|
|
<tr>
|
|
<th>Time</th>
|
|
<th>TF</th>
|
|
<th>Predicted</th>
|
|
<th>Actual</th>
|
|
<th>Error</th>
|
|
<th>Action</th>
|
|
</tr>
|
|
</thead>
|
|
<tbody></tbody>
|
|
</table>
|
|
</div>
|
|
</div>
|
|
```
|
|
|
|
### API Endpoints
|
|
|
|
```python
|
|
# In ANNOTATE/web/app.py
|
|
|
|
@app.route('/api/live-training/start', methods=['POST'])
|
|
def start_live_training():
|
|
if orchestrator.start_live_training():
|
|
return jsonify({'status': 'started'})
|
|
return jsonify({'error': 'Failed to start'}), 500
|
|
|
|
@app.route('/api/live-training/stop', methods=['POST'])
|
|
def stop_live_training():
|
|
if orchestrator.stop_live_training():
|
|
return jsonify({'status': 'stopped'})
|
|
return jsonify({'error': 'Failed to stop'}), 500
|
|
|
|
@app.route('/api/live-training/status', methods=['GET'])
|
|
def get_live_training_status():
|
|
if orchestrator.enhanced_training_system:
|
|
return jsonify({
|
|
'active': orchestrator.is_live_training_active(),
|
|
'predictions_per_sec': orchestrator.enhanced_training_system.get_prediction_rate(),
|
|
'training_per_min': orchestrator.enhanced_training_system.get_training_rate(),
|
|
'accuracy': orchestrator.enhanced_training_system.get_accuracy_stats()
|
|
})
|
|
return jsonify({'active': False})
|
|
|
|
@app.route('/api/live-training/predictions', methods=['GET'])
|
|
def get_recent_predictions():
|
|
limit = request.args.get('limit', 50, type=int)
|
|
if orchestrator.enhanced_training_system:
|
|
predictions = orchestrator.enhanced_training_system.get_recent_predictions(limit)
|
|
return jsonify({'predictions': predictions})
|
|
return jsonify({'predictions': []})
|
|
```
|
|
|
|
---
|
|
|
|
## Summary
|
|
|
|
### To Enable Live Mode:
|
|
|
|
1. **Initialize** `EnhancedRealtimeTrainingSystem` in orchestrator
|
|
2. **Add** start/stop methods to orchestrator
|
|
3. **Implement** prediction and training loops
|
|
4. **Create** UI controls and API endpoints
|
|
5. **Test** with 1s timeframe first
|
|
6. **Scale** to other timeframes
|
|
|
|
### Expected Behavior:
|
|
|
|
- ✅ Predict next candle every second
|
|
- ✅ Train on previous candle once result known
|
|
- ✅ 1 second delay for training (retrospective)
|
|
- ✅ Continuous learning from live data
|
|
- ✅ Real-time accuracy tracking
|
|
- ✅ Automatic checkpoint saving
|
|
|
|
### Performance:
|
|
|
|
- **1s timeframe**: 3,600 predictions/hour, 360 training batches/hour
|
|
- **GPU load**: 30-50% during active training
|
|
- **Memory**: ~2GB for 1s, less for longer timeframes
|
|
- **Latency**: <100ms per prediction
|
|
|
|
The system is designed and ready - it just needs to be enabled and integrated!
|