15 KiB
15 KiB
Live Inference & Training Mode Guide
Overview
The system has an EnhancedRealtimeTrainingSystem that can perform:
- Live Inference: Predict next candle every second
- Retrospective Training: Train on previous candle once result is known
- Multi-Timeframe: Process 1s, 1m, 1h, 1d candles independently
Current Status
✅ Available
EnhancedRealtimeTrainingSystemclass exists inNN/training/enhanced_realtime_training.py- Comprehensive feature engineering
- Multi-model support (DQN, CNN, COB RL)
- Prediction tracking database
- Experience replay buffers
❌ Not Enabled
- Not instantiated in orchestrator
- No integration with main trading loop
- No UI controls to start/stop
Architecture
Live Inference Flow
Every 1 second:
┌─────────────────────────────────────────┐
│ 1. Fetch Latest Data │
│ - 1s candle (just closed) │
│ - 1m candle (if minute boundary) │
│ - 1h candle (if hour boundary) │
│ - 1d candle (if day boundary) │
└──────────────┬──────────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ 2. Make Predictions │
│ - Next 1s candle OHLCV │
│ - Next 1m candle OHLCV (if needed) │
│ - Trading action (BUY/SELL/HOLD) │
│ - Confidence score │
└──────────────┬──────────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ 3. Store Predictions │
│ - Save to prediction_database │
│ - Track prediction_id │
│ - Wait for resolution │
└─────────────────────────────────────────┘
Retrospective Training Flow
Every 1 second (after candle closes):
┌─────────────────────────────────────────┐
│ 1. Get Previous Candle Result │
│ - Actual OHLCV values │
│ - Price change │
│ - Volume │
└──────────────┬──────────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ 2. Resolve Predictions │
│ - Compare predicted vs actual │
│ - Calculate reward/loss │
│ - Update prediction accuracy │
└──────────────┬──────────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ 3. Create Training Experience │
│ - State: market data before candle │
│ - Action: predicted action │
│ - Reward: based on accuracy │
│ - Next State: market data after │
└──────────────┬──────────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ 4. Train Models (if enough samples) │
│ - Batch training (32-64 samples) │
│ - Update model weights │
│ - Save checkpoint │
└─────────────────────────────────────────┘
Implementation Plan
Phase 1: Enable Realtime Training System
1.1 Initialize in Orchestrator
# In core/orchestrator.py __init__()
if ENHANCED_TRAINING_AVAILABLE:
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
orchestrator=self,
data_provider=self.data_provider,
dashboard=None # Optional dashboard integration
)
logger.info("EnhancedRealtimeTrainingSystem initialized")
else:
self.enhanced_training_system = None
logger.warning("EnhancedRealtimeTrainingSystem not available")
1.2 Add Start/Stop Methods
# In core/orchestrator.py
def start_live_training(self):
"""Start live inference and training mode"""
if self.enhanced_training_system:
self.enhanced_training_system.start_training()
logger.info("Live training mode started")
return True
else:
logger.error("Enhanced training system not available")
return False
def stop_live_training(self):
"""Stop live inference and training mode"""
if self.enhanced_training_system:
self.enhanced_training_system.stop_training()
logger.info("Live training mode stopped")
return True
return False
def is_live_training_active(self) -> bool:
"""Check if live training is active"""
if self.enhanced_training_system:
return self.enhanced_training_system.is_training
return False
Phase 2: Implement Prediction & Training Loop
2.1 Main Loop (runs every 1 second)
# In EnhancedRealtimeTrainingSystem
def _live_inference_loop(self):
"""Main loop for live inference and training"""
while self.is_training:
try:
current_time = time.time()
# 1. Check which timeframes need processing
timeframes_to_process = self._get_active_timeframes(current_time)
for timeframe in timeframes_to_process:
# 2. Make prediction for next candle
prediction = self._make_next_candle_prediction(timeframe)
# 3. Store prediction
if prediction:
self._store_prediction(prediction)
# 4. Resolve previous predictions
self._resolve_timeframe_predictions(timeframe)
# 5. Train on resolved predictions
if self._should_train(timeframe):
self._train_on_timeframe(timeframe)
# Sleep until next second
elapsed = time.time() - current_time
sleep_time = max(0, 1.0 - elapsed)
time.sleep(sleep_time)
except Exception as e:
logger.error(f"Error in live inference loop: {e}")
time.sleep(1)
2.2 Prediction Method
def _make_next_candle_prediction(self, timeframe: str) -> Dict:
"""
Predict next candle OHLCV values
Returns:
{
'timeframe': '1s',
'timestamp': datetime,
'predicted_open': float,
'predicted_high': float,
'predicted_low': float,
'predicted_close': float,
'predicted_volume': float,
'action': 'BUY'|'SELL'|'HOLD',
'confidence': float
}
"""
# Get current market state (600 candles)
market_state = self._get_market_state(timeframe)
# Get model prediction
if self.orchestrator.primary_transformer:
output = self.orchestrator.primary_transformer(market_state)
# Extract next candle prediction
next_candle = output['next_candles'][timeframe]
action_probs = output['action_probs']
return {
'timeframe': timeframe,
'timestamp': datetime.now(),
'predicted_open': next_candle[0].item(),
'predicted_high': next_candle[1].item(),
'predicted_low': next_candle[2].item(),
'predicted_close': next_candle[3].item(),
'predicted_volume': next_candle[4].item(),
'action': ['HOLD', 'BUY', 'SELL'][torch.argmax(action_probs).item()],
'confidence': torch.max(action_probs).item()
}
return None
2.3 Training Method
def _train_on_timeframe(self, timeframe: str):
"""
Train model on resolved predictions for this timeframe
Process:
1. Get resolved predictions (predicted vs actual)
2. Create training batches
3. Calculate loss
4. Update model weights
5. Save checkpoint (if needed)
"""
# Get resolved predictions
resolved = self._get_resolved_predictions(timeframe, limit=100)
if len(resolved) < 32: # Need minimum batch size
return
# Create training batches
batches = self._create_training_batches(resolved)
# Train model
if self.orchestrator.primary_transformer_trainer:
trainer = self.orchestrator.primary_transformer_trainer
for batch in batches:
result = trainer.train_step(batch)
# Log progress
if result:
logger.debug(f"Trained on {timeframe}: loss={result['total_loss']:.4f}")
# Save checkpoint every N batches
if self.training_iteration % 100 == 0:
self._save_checkpoint(timeframe)
Configuration
Training Intervals
training_config = {
# Inference intervals (how often to predict)
'inference_1s': 1, # Every 1 second
'inference_1m': 60, # Every 1 minute
'inference_1h': 3600, # Every 1 hour
'inference_1d': 86400, # Every 1 day
# Training intervals (how often to train)
'training_1s': 10, # Train every 10 seconds (10 samples)
'training_1m': 300, # Train every 5 minutes (5 samples)
'training_1h': 3600, # Train every 1 hour (1 sample)
'training_1d': 86400, # Train every 1 day (1 sample)
# Batch sizes
'batch_size_1s': 32,
'batch_size_1m': 16,
'batch_size_1h': 8,
'batch_size_1d': 4,
# Buffer sizes
'buffer_size_1s': 1000,
'buffer_size_1m': 500,
'buffer_size_1h': 200,
'buffer_size_1d': 100
}
Performance Targets
| Timeframe | Predictions/Hour | Training/Hour | GPU Load | Memory |
|---|---|---|---|---|
| 1s | 3,600 | 360 (every 10s) | 30-50% | 2GB |
| 1m | 60 | 12 (every 5m) | 10-20% | 1GB |
| 1h | 1 | 1 (every 1h) | 5-10% | 500MB |
| 1d | 0.04 | 0.04 (every 1d) | <5% | 200MB |
Database Schema
Predictions Table
CREATE TABLE predictions (
prediction_id INTEGER PRIMARY KEY,
model_name VARCHAR,
symbol VARCHAR,
timeframe VARCHAR,
timestamp BIGINT,
-- Predicted values
predicted_open DOUBLE,
predicted_high DOUBLE,
predicted_low DOUBLE,
predicted_close DOUBLE,
predicted_volume DOUBLE,
predicted_action VARCHAR,
confidence DOUBLE,
-- Actual values (filled when resolved)
actual_open DOUBLE,
actual_high DOUBLE,
actual_low DOUBLE,
actual_close DOUBLE,
actual_volume DOUBLE,
-- Accuracy metrics
price_error DOUBLE,
volume_error DOUBLE,
action_correct BOOLEAN,
reward DOUBLE,
-- Status
status VARCHAR, -- 'pending', 'resolved', 'expired'
resolved_at BIGINT
);
UI Integration
Dashboard Controls
<!-- Live Training Panel -->
<div class="live-training-panel">
<h3>Live Inference & Training</h3>
<div class="status">
<span id="live-status">Inactive</span>
<button id="start-live-btn">Start Live Mode</button>
<button id="stop-live-btn" disabled>Stop Live Mode</button>
</div>
<div class="metrics">
<div class="metric">
<label>Predictions/sec:</label>
<span id="predictions-per-sec">0</span>
</div>
<div class="metric">
<label>Training batches/min:</label>
<span id="training-per-min">0</span>
</div>
<div class="metric">
<label>Accuracy (1m):</label>
<span id="accuracy-1m">0%</span>
</div>
<div class="metric">
<label>GPU Load:</label>
<span id="gpu-load">0%</span>
</div>
</div>
<div class="recent-predictions">
<h4>Recent Predictions</h4>
<table id="predictions-table">
<thead>
<tr>
<th>Time</th>
<th>TF</th>
<th>Predicted</th>
<th>Actual</th>
<th>Error</th>
<th>Action</th>
</tr>
</thead>
<tbody></tbody>
</table>
</div>
</div>
API Endpoints
# In ANNOTATE/web/app.py
@app.route('/api/live-training/start', methods=['POST'])
def start_live_training():
if orchestrator.start_live_training():
return jsonify({'status': 'started'})
return jsonify({'error': 'Failed to start'}), 500
@app.route('/api/live-training/stop', methods=['POST'])
def stop_live_training():
if orchestrator.stop_live_training():
return jsonify({'status': 'stopped'})
return jsonify({'error': 'Failed to stop'}), 500
@app.route('/api/live-training/status', methods=['GET'])
def get_live_training_status():
if orchestrator.enhanced_training_system:
return jsonify({
'active': orchestrator.is_live_training_active(),
'predictions_per_sec': orchestrator.enhanced_training_system.get_prediction_rate(),
'training_per_min': orchestrator.enhanced_training_system.get_training_rate(),
'accuracy': orchestrator.enhanced_training_system.get_accuracy_stats()
})
return jsonify({'active': False})
@app.route('/api/live-training/predictions', methods=['GET'])
def get_recent_predictions():
limit = request.args.get('limit', 50, type=int)
if orchestrator.enhanced_training_system:
predictions = orchestrator.enhanced_training_system.get_recent_predictions(limit)
return jsonify({'predictions': predictions})
return jsonify({'predictions': []})
Summary
To Enable Live Mode:
- Initialize
EnhancedRealtimeTrainingSystemin orchestrator - Add start/stop methods to orchestrator
- Implement prediction and training loops
- Create UI controls and API endpoints
- Test with 1s timeframe first
- Scale to other timeframes
Expected Behavior:
- ✅ Predict next candle every second
- ✅ Train on previous candle once result known
- ✅ 1 second delay for training (retrospective)
- ✅ Continuous learning from live data
- ✅ Real-time accuracy tracking
- ✅ Automatic checkpoint saving
Performance:
- 1s timeframe: 3,600 predictions/hour, 360 training batches/hour
- GPU load: 30-50% during active training
- Memory: ~2GB for 1s, less for longer timeframes
- Latency: <100ms per prediction
The system is designed and ready - it just needs to be enabled and integrated!