Files
gogo2/LIVE_INFERENCE_TRAINING_GUIDE.md
2025-10-31 03:52:41 +02:00

481 lines
15 KiB
Markdown

# Live Inference & Training Mode Guide
## Overview
The system has an `EnhancedRealtimeTrainingSystem` that can perform:
- **Live Inference**: Predict next candle every second
- **Retrospective Training**: Train on previous candle once result is known
- **Multi-Timeframe**: Process 1s, 1m, 1h, 1d candles independently
## Current Status
### ✅ Available
- `EnhancedRealtimeTrainingSystem` class exists in `NN/training/enhanced_realtime_training.py`
- Comprehensive feature engineering
- Multi-model support (DQN, CNN, COB RL)
- Prediction tracking database
- Experience replay buffers
### ❌ Not Enabled
- Not instantiated in orchestrator
- No integration with main trading loop
- No UI controls to start/stop
---
## Architecture
### Live Inference Flow
```
Every 1 second:
┌─────────────────────────────────────────┐
│ 1. Fetch Latest Data │
│ - 1s candle (just closed) │
│ - 1m candle (if minute boundary) │
│ - 1h candle (if hour boundary) │
│ - 1d candle (if day boundary) │
└──────────────┬──────────────────────────┘
┌─────────────────────────────────────────┐
│ 2. Make Predictions │
│ - Next 1s candle OHLCV │
│ - Next 1m candle OHLCV (if needed) │
│ - Trading action (BUY/SELL/HOLD) │
│ - Confidence score │
└──────────────┬──────────────────────────┘
┌─────────────────────────────────────────┐
│ 3. Store Predictions │
│ - Save to prediction_database │
│ - Track prediction_id │
│ - Wait for resolution │
└─────────────────────────────────────────┘
```
### Retrospective Training Flow
```
Every 1 second (after candle closes):
┌─────────────────────────────────────────┐
│ 1. Get Previous Candle Result │
│ - Actual OHLCV values │
│ - Price change │
│ - Volume │
└──────────────┬──────────────────────────┘
┌─────────────────────────────────────────┐
│ 2. Resolve Predictions │
│ - Compare predicted vs actual │
│ - Calculate reward/loss │
│ - Update prediction accuracy │
└──────────────┬──────────────────────────┘
┌─────────────────────────────────────────┐
│ 3. Create Training Experience │
│ - State: market data before candle │
│ - Action: predicted action │
│ - Reward: based on accuracy │
│ - Next State: market data after │
└──────────────┬──────────────────────────┘
┌─────────────────────────────────────────┐
│ 4. Train Models (if enough samples) │
│ - Batch training (32-64 samples) │
│ - Update model weights │
│ - Save checkpoint │
└─────────────────────────────────────────┘
```
---
## Implementation Plan
### Phase 1: Enable Realtime Training System
#### 1.1 Initialize in Orchestrator
```python
# In core/orchestrator.py __init__()
if ENHANCED_TRAINING_AVAILABLE:
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
orchestrator=self,
data_provider=self.data_provider,
dashboard=None # Optional dashboard integration
)
logger.info("EnhancedRealtimeTrainingSystem initialized")
else:
self.enhanced_training_system = None
logger.warning("EnhancedRealtimeTrainingSystem not available")
```
#### 1.2 Add Start/Stop Methods
```python
# In core/orchestrator.py
def start_live_training(self):
"""Start live inference and training mode"""
if self.enhanced_training_system:
self.enhanced_training_system.start_training()
logger.info("Live training mode started")
return True
else:
logger.error("Enhanced training system not available")
return False
def stop_live_training(self):
"""Stop live inference and training mode"""
if self.enhanced_training_system:
self.enhanced_training_system.stop_training()
logger.info("Live training mode stopped")
return True
return False
def is_live_training_active(self) -> bool:
"""Check if live training is active"""
if self.enhanced_training_system:
return self.enhanced_training_system.is_training
return False
```
### Phase 2: Implement Prediction & Training Loop
#### 2.1 Main Loop (runs every 1 second)
```python
# In EnhancedRealtimeTrainingSystem
def _live_inference_loop(self):
"""Main loop for live inference and training"""
while self.is_training:
try:
current_time = time.time()
# 1. Check which timeframes need processing
timeframes_to_process = self._get_active_timeframes(current_time)
for timeframe in timeframes_to_process:
# 2. Make prediction for next candle
prediction = self._make_next_candle_prediction(timeframe)
# 3. Store prediction
if prediction:
self._store_prediction(prediction)
# 4. Resolve previous predictions
self._resolve_timeframe_predictions(timeframe)
# 5. Train on resolved predictions
if self._should_train(timeframe):
self._train_on_timeframe(timeframe)
# Sleep until next second
elapsed = time.time() - current_time
sleep_time = max(0, 1.0 - elapsed)
time.sleep(sleep_time)
except Exception as e:
logger.error(f"Error in live inference loop: {e}")
time.sleep(1)
```
#### 2.2 Prediction Method
```python
def _make_next_candle_prediction(self, timeframe: str) -> Dict:
"""
Predict next candle OHLCV values
Returns:
{
'timeframe': '1s',
'timestamp': datetime,
'predicted_open': float,
'predicted_high': float,
'predicted_low': float,
'predicted_close': float,
'predicted_volume': float,
'action': 'BUY'|'SELL'|'HOLD',
'confidence': float
}
"""
# Get current market state (600 candles)
market_state = self._get_market_state(timeframe)
# Get model prediction
if self.orchestrator.primary_transformer:
output = self.orchestrator.primary_transformer(market_state)
# Extract next candle prediction
next_candle = output['next_candles'][timeframe]
action_probs = output['action_probs']
return {
'timeframe': timeframe,
'timestamp': datetime.now(),
'predicted_open': next_candle[0].item(),
'predicted_high': next_candle[1].item(),
'predicted_low': next_candle[2].item(),
'predicted_close': next_candle[3].item(),
'predicted_volume': next_candle[4].item(),
'action': ['HOLD', 'BUY', 'SELL'][torch.argmax(action_probs).item()],
'confidence': torch.max(action_probs).item()
}
return None
```
#### 2.3 Training Method
```python
def _train_on_timeframe(self, timeframe: str):
"""
Train model on resolved predictions for this timeframe
Process:
1. Get resolved predictions (predicted vs actual)
2. Create training batches
3. Calculate loss
4. Update model weights
5. Save checkpoint (if needed)
"""
# Get resolved predictions
resolved = self._get_resolved_predictions(timeframe, limit=100)
if len(resolved) < 32: # Need minimum batch size
return
# Create training batches
batches = self._create_training_batches(resolved)
# Train model
if self.orchestrator.primary_transformer_trainer:
trainer = self.orchestrator.primary_transformer_trainer
for batch in batches:
result = trainer.train_step(batch)
# Log progress
if result:
logger.debug(f"Trained on {timeframe}: loss={result['total_loss']:.4f}")
# Save checkpoint every N batches
if self.training_iteration % 100 == 0:
self._save_checkpoint(timeframe)
```
---
## Configuration
### Training Intervals
```python
training_config = {
# Inference intervals (how often to predict)
'inference_1s': 1, # Every 1 second
'inference_1m': 60, # Every 1 minute
'inference_1h': 3600, # Every 1 hour
'inference_1d': 86400, # Every 1 day
# Training intervals (how often to train)
'training_1s': 10, # Train every 10 seconds (10 samples)
'training_1m': 300, # Train every 5 minutes (5 samples)
'training_1h': 3600, # Train every 1 hour (1 sample)
'training_1d': 86400, # Train every 1 day (1 sample)
# Batch sizes
'batch_size_1s': 32,
'batch_size_1m': 16,
'batch_size_1h': 8,
'batch_size_1d': 4,
# Buffer sizes
'buffer_size_1s': 1000,
'buffer_size_1m': 500,
'buffer_size_1h': 200,
'buffer_size_1d': 100
}
```
### Performance Targets
| Timeframe | Predictions/Hour | Training/Hour | GPU Load | Memory |
|-----------|------------------|---------------|----------|--------|
| 1s | 3,600 | 360 (every 10s) | 30-50% | 2GB |
| 1m | 60 | 12 (every 5m) | 10-20% | 1GB |
| 1h | 1 | 1 (every 1h) | 5-10% | 500MB |
| 1d | 0.04 | 0.04 (every 1d) | <5% | 200MB |
---
## Database Schema
### Predictions Table
```sql
CREATE TABLE predictions (
prediction_id INTEGER PRIMARY KEY,
model_name VARCHAR,
symbol VARCHAR,
timeframe VARCHAR,
timestamp BIGINT,
-- Predicted values
predicted_open DOUBLE,
predicted_high DOUBLE,
predicted_low DOUBLE,
predicted_close DOUBLE,
predicted_volume DOUBLE,
predicted_action VARCHAR,
confidence DOUBLE,
-- Actual values (filled when resolved)
actual_open DOUBLE,
actual_high DOUBLE,
actual_low DOUBLE,
actual_close DOUBLE,
actual_volume DOUBLE,
-- Accuracy metrics
price_error DOUBLE,
volume_error DOUBLE,
action_correct BOOLEAN,
reward DOUBLE,
-- Status
status VARCHAR, -- 'pending', 'resolved', 'expired'
resolved_at BIGINT
);
```
---
## UI Integration
### Dashboard Controls
```html
<!-- Live Training Panel -->
<div class="live-training-panel">
<h3>Live Inference & Training</h3>
<div class="status">
<span id="live-status">Inactive</span>
<button id="start-live-btn">Start Live Mode</button>
<button id="stop-live-btn" disabled>Stop Live Mode</button>
</div>
<div class="metrics">
<div class="metric">
<label>Predictions/sec:</label>
<span id="predictions-per-sec">0</span>
</div>
<div class="metric">
<label>Training batches/min:</label>
<span id="training-per-min">0</span>
</div>
<div class="metric">
<label>Accuracy (1m):</label>
<span id="accuracy-1m">0%</span>
</div>
<div class="metric">
<label>GPU Load:</label>
<span id="gpu-load">0%</span>
</div>
</div>
<div class="recent-predictions">
<h4>Recent Predictions</h4>
<table id="predictions-table">
<thead>
<tr>
<th>Time</th>
<th>TF</th>
<th>Predicted</th>
<th>Actual</th>
<th>Error</th>
<th>Action</th>
</tr>
</thead>
<tbody></tbody>
</table>
</div>
</div>
```
### API Endpoints
```python
# In ANNOTATE/web/app.py
@app.route('/api/live-training/start', methods=['POST'])
def start_live_training():
if orchestrator.start_live_training():
return jsonify({'status': 'started'})
return jsonify({'error': 'Failed to start'}), 500
@app.route('/api/live-training/stop', methods=['POST'])
def stop_live_training():
if orchestrator.stop_live_training():
return jsonify({'status': 'stopped'})
return jsonify({'error': 'Failed to stop'}), 500
@app.route('/api/live-training/status', methods=['GET'])
def get_live_training_status():
if orchestrator.enhanced_training_system:
return jsonify({
'active': orchestrator.is_live_training_active(),
'predictions_per_sec': orchestrator.enhanced_training_system.get_prediction_rate(),
'training_per_min': orchestrator.enhanced_training_system.get_training_rate(),
'accuracy': orchestrator.enhanced_training_system.get_accuracy_stats()
})
return jsonify({'active': False})
@app.route('/api/live-training/predictions', methods=['GET'])
def get_recent_predictions():
limit = request.args.get('limit', 50, type=int)
if orchestrator.enhanced_training_system:
predictions = orchestrator.enhanced_training_system.get_recent_predictions(limit)
return jsonify({'predictions': predictions})
return jsonify({'predictions': []})
```
---
## Summary
### To Enable Live Mode:
1. **Initialize** `EnhancedRealtimeTrainingSystem` in orchestrator
2. **Add** start/stop methods to orchestrator
3. **Implement** prediction and training loops
4. **Create** UI controls and API endpoints
5. **Test** with 1s timeframe first
6. **Scale** to other timeframes
### Expected Behavior:
- Predict next candle every second
- Train on previous candle once result known
- 1 second delay for training (retrospective)
- Continuous learning from live data
- Real-time accuracy tracking
- Automatic checkpoint saving
### Performance:
- **1s timeframe**: 3,600 predictions/hour, 360 training batches/hour
- **GPU load**: 30-50% during active training
- **Memory**: ~2GB for 1s, less for longer timeframes
- **Latency**: <100ms per prediction
The system is designed and ready - it just needs to be enabled and integrated!