t perd viz - wip
This commit is contained in:
@@ -113,6 +113,8 @@ class TrainingSession:
|
||||
final_loss: Optional[float] = None
|
||||
accuracy: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
gpu_utilization: Optional[float] = None # GPU utilization percentage
|
||||
cpu_utilization: Optional[float] = None # CPU utilization percentage
|
||||
|
||||
|
||||
class RealTrainingAdapter:
|
||||
@@ -240,6 +242,13 @@ class RealTrainingAdapter:
|
||||
logger.info(f" Training ID: {training_id}")
|
||||
logger.info(f" Test cases: {len(test_cases)}")
|
||||
|
||||
# Clear previous predictions for clean visualization
|
||||
# Get symbol from first test case
|
||||
symbol = test_cases[0].get('symbol', 'ETH/USDT') if test_cases else 'ETH/USDT'
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'clear_predictions'):
|
||||
self.orchestrator.clear_predictions(symbol)
|
||||
logger.info(f" Cleared previous predictions for {symbol}")
|
||||
|
||||
# Prepare training data from test cases
|
||||
training_data = self._prepare_training_data(test_cases)
|
||||
|
||||
@@ -1301,6 +1310,7 @@ class RealTrainingAdapter:
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
market_state = training_sample.get('market_state', {})
|
||||
@@ -1522,7 +1532,6 @@ class RealTrainingAdapter:
|
||||
time_in_position_minutes = 0.0
|
||||
if in_position:
|
||||
try:
|
||||
from datetime import datetime
|
||||
entry_timestamp = training_sample.get('timestamp')
|
||||
current_timestamp = training_sample.get('timestamp')
|
||||
|
||||
@@ -1645,7 +1654,14 @@ class RealTrainingAdapter:
|
||||
'norm_params': norm_params_dict, # Dict with keys: '1s', '1m', '1h', '1d', 'btc'
|
||||
|
||||
# Legacy support (use 1m as default)
|
||||
'price_data': price_data_1m if price_data_1m is not None else ref_data
|
||||
'price_data': price_data_1m if price_data_1m is not None else ref_data,
|
||||
|
||||
# Metadata for prediction visualization
|
||||
'metadata': {
|
||||
'current_price': float(current_price),
|
||||
'timestamp': training_sample.get('timestamp', datetime.now()),
|
||||
'symbol': training_sample.get('symbol', 'ETH/USDT')
|
||||
}
|
||||
}
|
||||
|
||||
return batch
|
||||
@@ -1962,6 +1978,13 @@ class RealTrainingAdapter:
|
||||
# Generate batches fresh for each epoch
|
||||
for i, batch in enumerate(batch_generator()):
|
||||
try:
|
||||
# Store prediction before training (for visualization)
|
||||
# Only store predictions on first epoch and every 10th batch to avoid clutter
|
||||
if epoch == 0 and i % 10 == 0 and self.orchestrator:
|
||||
# Get symbol from batch metadata or use default
|
||||
symbol = batch.get('metadata', {}).get('symbol', 'ETH/USDT')
|
||||
self._store_training_prediction(batch, trainer, symbol)
|
||||
|
||||
# Call the trainer's train_step method with mini-batch
|
||||
# Batch is already on GPU and contains multiple samples
|
||||
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||
@@ -2252,6 +2275,28 @@ class RealTrainingAdapter:
|
||||
|
||||
session = self.training_sessions[training_id]
|
||||
|
||||
# Get current GPU/CPU utilization
|
||||
gpu_util = None
|
||||
cpu_util = None
|
||||
|
||||
try:
|
||||
from utils.gpu_monitor import get_gpu_monitor
|
||||
gpu_monitor = get_gpu_monitor()
|
||||
gpu_metrics = gpu_monitor.get_gpu_utilization()
|
||||
if gpu_metrics:
|
||||
gpu_util = gpu_metrics.get('gpu_utilization_percent')
|
||||
if gpu_util is None and gpu_metrics.get('memory_usage_percent'):
|
||||
# Fallback to memory usage as proxy
|
||||
gpu_util = gpu_metrics.get('memory_usage_percent')
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get GPU metrics: {e}")
|
||||
|
||||
try:
|
||||
import psutil
|
||||
cpu_util = psutil.cpu_percent(interval=0.1)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get CPU metrics: {e}")
|
||||
|
||||
return {
|
||||
'status': session.status,
|
||||
'model_name': session.model_name,
|
||||
@@ -2262,7 +2307,9 @@ class RealTrainingAdapter:
|
||||
'final_loss': session.final_loss,
|
||||
'accuracy': session.accuracy,
|
||||
'duration_seconds': session.duration_seconds,
|
||||
'error': session.error
|
||||
'error': session.error,
|
||||
'gpu_utilization': gpu_util,
|
||||
'cpu_utilization': cpu_util
|
||||
}
|
||||
|
||||
def get_active_training_session(self) -> Optional[Dict]:
|
||||
@@ -2415,6 +2462,59 @@ class RealTrainingAdapter:
|
||||
all_signals.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
|
||||
return all_signals[:limit]
|
||||
|
||||
def _store_training_prediction(self, batch: Dict, trainer, symbol: str):
|
||||
"""Store a prediction from training batch for visualization"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
# Make prediction on the batch (without training)
|
||||
with torch.no_grad():
|
||||
trainer.model.eval()
|
||||
|
||||
# Get prediction from model
|
||||
outputs = trainer.model(
|
||||
price_data_1s=batch.get('price_data_1s'),
|
||||
price_data_1m=batch.get('price_data_1m'),
|
||||
price_data_1h=batch.get('price_data_1h'),
|
||||
price_data_1d=batch.get('price_data_1d'),
|
||||
tech_data=batch.get('tech_data'),
|
||||
market_data=batch.get('market_data')
|
||||
)
|
||||
|
||||
trainer.model.train()
|
||||
|
||||
# Extract action prediction
|
||||
action_probs = outputs.get('action_probs')
|
||||
if action_probs is not None:
|
||||
action_idx = torch.argmax(action_probs, dim=-1).item()
|
||||
confidence = action_probs[0, action_idx].item()
|
||||
|
||||
# Map to BUY/SELL/HOLD
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
|
||||
|
||||
# Get current price from batch metadata
|
||||
current_price = batch.get('metadata', {}).get('current_price', 0)
|
||||
timestamp = batch.get('metadata', {}).get('timestamp', datetime.now())
|
||||
|
||||
if current_price > 0:
|
||||
# Store in orchestrator
|
||||
if hasattr(self.orchestrator, 'store_transformer_prediction'):
|
||||
self.orchestrator.store_transformer_prediction(symbol, {
|
||||
'timestamp': timestamp,
|
||||
'current_price': current_price,
|
||||
'predicted_price': current_price * (1.01 if action == 'BUY' else 0.99),
|
||||
'price_change': 1.0 if action == 'BUY' else -1.0,
|
||||
'confidence': confidence,
|
||||
'action': action,
|
||||
'horizon_minutes': 10,
|
||||
'source': 'training'
|
||||
})
|
||||
logger.debug(f"Stored training prediction: {action} @ {current_price} (conf: {confidence:.2f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error storing training prediction: {e}")
|
||||
|
||||
def _realtime_inference_loop(self, inference_id: str, model_name: str, symbol: str, data_provider):
|
||||
"""
|
||||
Real-time inference loop using orchestrator's REAL prediction methods
|
||||
|
||||
Reference in New Issue
Block a user