IMPLEMENTED: WIP; realtime candle predictions training
This commit is contained in:
@@ -1095,7 +1095,8 @@ class RealTrainingAdapter:
|
||||
raise Exception("CNN model does not have train_on_annotations, trainer.train_step, or train_step method")
|
||||
|
||||
session.final_loss = session.current_loss
|
||||
session.accuracy = 0.85 # TODO: Calculate actual accuracy
|
||||
# Accuracy calculated from actual training metrics, not synthetic
|
||||
session.accuracy = None # Will be set by training loop if available
|
||||
|
||||
def _train_dqn_real(self, session: TrainingSession, training_data: List[Dict]):
|
||||
"""Train DQN model with REAL training loop"""
|
||||
@@ -1133,7 +1134,8 @@ class RealTrainingAdapter:
|
||||
raise Exception("DQN agent does not have replay method")
|
||||
|
||||
session.final_loss = session.current_loss
|
||||
session.accuracy = 0.85 # TODO: Calculate actual accuracy
|
||||
# Accuracy calculated from actual training metrics, not synthetic
|
||||
session.accuracy = None # Will be set by training loop if available
|
||||
|
||||
def _build_state_from_data(self, data: Dict, agent: Any) -> List[float]:
|
||||
"""Build proper state representation from training data"""
|
||||
@@ -2781,6 +2783,68 @@ class RealTrainingAdapter:
|
||||
logger.warning(f"Error fetching market state for candle: {e}")
|
||||
return {}
|
||||
|
||||
def _convert_prediction_to_batch(self, prediction_sample: Dict, timeframe: str):
|
||||
"""
|
||||
Convert a validated prediction to a training batch
|
||||
|
||||
Args:
|
||||
prediction_sample: Dict with predicted_candle, actual_candle, market_state, etc.
|
||||
timeframe: Target timeframe for prediction
|
||||
|
||||
Returns:
|
||||
Batch dict ready for trainer.train_step()
|
||||
"""
|
||||
try:
|
||||
market_state = prediction_sample.get('market_state', {})
|
||||
if not market_state or 'timeframes' not in market_state:
|
||||
logger.warning("No market state in prediction sample")
|
||||
return None
|
||||
|
||||
# Use existing conversion method but with actual target
|
||||
annotation = {
|
||||
'symbol': prediction_sample.get('symbol', 'ETH/USDT'),
|
||||
'timestamp': prediction_sample.get('timestamp'),
|
||||
'action': 'BUY', # Placeholder, not used for candle prediction training
|
||||
'entry_price': float(prediction_sample['predicted_candle'][0]), # Open
|
||||
'market_state': market_state
|
||||
}
|
||||
|
||||
# Convert using existing method
|
||||
batch = self._convert_annotation_to_transformer_batch(annotation)
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
# Override the future candle target with actual candle data
|
||||
actual = prediction_sample['actual_candle'] # [O, H, L, C]
|
||||
|
||||
# Create target tensor for the specific timeframe
|
||||
import torch
|
||||
device = batch['prices_1m'].device if 'prices_1m' in batch else torch.device('cpu')
|
||||
|
||||
# Target candle: [O, H, L, C, V] - we don't have actual volume, use predicted
|
||||
target_candle = [
|
||||
actual[0], # Open
|
||||
actual[1], # High
|
||||
actual[2], # Low
|
||||
actual[3], # Close
|
||||
prediction_sample['predicted_candle'][4] # Volume (from prediction)
|
||||
]
|
||||
|
||||
# Add to batch based on timeframe
|
||||
if timeframe == '1s':
|
||||
batch['future_candle_1s'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
|
||||
elif timeframe == '1m':
|
||||
batch['future_candle_1m'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
|
||||
elif timeframe == '1h':
|
||||
batch['future_candle_1h'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
|
||||
|
||||
logger.debug(f"Converted prediction to batch for {timeframe} timeframe")
|
||||
return batch
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting prediction to batch: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _train_transformer_on_sample(self, training_sample: Dict):
|
||||
"""Train transformer on a single sample with checkpoint saving"""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user