improved model loading and training

This commit is contained in:
Dobromir Popov
2025-10-31 01:22:49 +02:00
parent 7ddf98bf18
commit ba91740e4c
7 changed files with 745 additions and 186 deletions

View File

@@ -17,9 +17,14 @@ import time
import threading
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timedelta, timezone
from pathlib import Path
try:
import pytz
except ImportError:
pytz = None
logger = logging.getLogger(__name__)
@@ -201,10 +206,11 @@ class RealTrainingAdapter:
logger.info(f" Accuracy: {session.accuracy}")
except Exception as e:
logger.error(f" REAL training failed: {e}", exc_info=True)
logger.error(f"REAL training failed: {e}", exc_info=True)
session.status = 'failed'
session.error = str(e)
session.duration_seconds = time.time() - session.start_time
logger.error(f"Training session {training_id} failed after {session.duration_seconds:.2f}s")
def _fetch_market_state_for_test_case(self, test_case: Dict) -> Dict:
"""
@@ -441,7 +447,10 @@ class RealTrainingAdapter:
entry_time = datetime.fromisoformat(entry_timestamp.replace('Z', '+00:00'))
else:
entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M:%S')
entry_time = entry_time.replace(tzinfo=pytz.UTC)
if pytz:
entry_time = entry_time.replace(tzinfo=pytz.UTC)
else:
entry_time = entry_time.replace(tzinfo=timezone.utc)
except Exception as e:
logger.warning(f"Could not parse entry timestamp '{entry_timestamp}': {e}")
return hold_samples
@@ -526,7 +535,10 @@ class RealTrainingAdapter:
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
else:
signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M:%S')
signal_time = signal_time.replace(tzinfo=pytz.UTC)
if pytz:
signal_time = signal_time.replace(tzinfo=pytz.UTC)
else:
signal_time = signal_time.replace(tzinfo=timezone.utc)
except Exception as e:
logger.warning(f"Could not parse signal timestamp '{signal_timestamp}': {e}")
return negative_samples
@@ -539,7 +551,10 @@ class RealTrainingAdapter:
ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00'))
else:
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
ts = ts.replace(tzinfo=pytz.UTC)
if pytz:
ts = ts.replace(tzinfo=pytz.UTC)
else:
ts = ts.replace(tzinfo=timezone.utc)
# Match within 1 minute
if abs((ts - signal_time).total_seconds()) < 60:
@@ -631,6 +646,61 @@ class RealTrainingAdapter:
return snapshot
def _convert_to_cnn_input(self, data: Dict) -> tuple:
"""Convert annotation training data to CNN model input format (x, y tensors)"""
import torch
import numpy as np
try:
market_state = data.get('market_state', {})
timeframes = market_state.get('timeframes', {})
# Get 1m timeframe data (primary for CNN)
if '1m' not in timeframes:
logger.warning("No 1m timeframe data available for CNN training")
return None, None
tf_data = timeframes['1m']
closes = np.array(tf_data.get('close', []), dtype=np.float32)
if len(closes) == 0:
logger.warning("No close price data available")
return None, None
# CNN expects input shape: [batch, seq_len, features]
# Use last 60 candles (or pad/truncate to 60)
seq_len = 60
if len(closes) >= seq_len:
closes = closes[-seq_len:]
else:
# Pad with last value
last_close = closes[-1] if len(closes) > 0 else 0.0
closes = np.pad(closes, (seq_len - len(closes), 0), mode='constant', constant_values=last_close)
# Create feature tensor: [1, 60, 1] (batch, seq_len, features)
# For now, use only close prices. In full implementation, add OHLCV
x = torch.tensor(closes, dtype=torch.float32).unsqueeze(0).unsqueeze(-1) # [1, 60, 1]
# Convert action to target tensor
action = data.get('action', 'HOLD')
direction = data.get('direction', 'HOLD')
# Map to class index: 0=HOLD, 1=BUY, 2=SELL
if direction == 'LONG' or action == 'BUY':
y = torch.tensor([1], dtype=torch.long)
elif direction == 'SHORT' or action == 'SELL':
y = torch.tensor([2], dtype=torch.long)
else:
y = torch.tensor([0], dtype=torch.long)
return x, y
except Exception as e:
logger.error(f"Error converting to CNN input: {e}")
import traceback
logger.error(traceback.format_exc())
return None, None
def _train_cnn_real(self, session: TrainingSession, training_data: List[Dict]):
"""Train CNN model with REAL training loop"""
if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
@@ -638,6 +708,11 @@ class RealTrainingAdapter:
model = self.orchestrator.cnn_model
# Check if model has trainer attribute (EnhancedCNN)
trainer = None
if hasattr(model, 'trainer'):
trainer = model.trainer
# Use the model's actual training method
if hasattr(model, 'train_on_annotations'):
# If model has annotation-specific training
@@ -646,21 +721,73 @@ class RealTrainingAdapter:
session.current_epoch = epoch + 1
session.current_loss = loss if loss else 0.0
logger.info(f"CNN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
elif hasattr(model, 'train_step'):
# Use standard train_step method
elif trainer and hasattr(trainer, 'train_step'):
# Use trainer's train_step method (EnhancedCNN)
logger.info(f"Training CNN using trainer.train_step() with {len(training_data)} samples")
for epoch in range(session.total_epochs):
epoch_loss = 0.0
for data in training_data:
# Convert to model input format and train
# This depends on the model's expected input
loss = model.train_step(data)
epoch_loss += loss if loss else 0.0
valid_samples = 0
session.current_epoch = epoch + 1
session.current_loss = epoch_loss / len(training_data)
logger.info(f"CNN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
for data in training_data:
# Convert to model input format
x, y = self._convert_to_cnn_input(data)
if x is None or y is None:
continue
try:
# Call trainer's train_step with proper format
loss_dict = trainer.train_step(x, y)
# Extract loss from dict if it's a dict, otherwise use directly
if isinstance(loss_dict, dict):
loss = loss_dict.get('total_loss', loss_dict.get('main_loss', 0.0))
else:
loss = float(loss_dict) if loss_dict else 0.0
epoch_loss += loss
valid_samples += 1
except Exception as e:
logger.error(f"Error in CNN training step: {e}")
import traceback
logger.error(traceback.format_exc())
continue
if valid_samples > 0:
session.current_epoch = epoch + 1
session.current_loss = epoch_loss / valid_samples
logger.info(f"CNN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}, Samples: {valid_samples}")
else:
logger.warning(f"CNN Epoch {epoch + 1}/{session.total_epochs}: No valid samples processed")
session.current_epoch = epoch + 1
session.current_loss = 0.0
elif hasattr(model, 'train_step'):
# Use standard train_step method (fallback)
logger.warning("Using model.train_step() directly - may not work correctly")
for epoch in range(session.total_epochs):
epoch_loss = 0.0
valid_samples = 0
for data in training_data:
x, y = self._convert_to_cnn_input(data)
if x is None or y is None:
continue
try:
loss = model.train_step(x, y)
epoch_loss += loss if loss else 0.0
valid_samples += 1
except Exception as e:
logger.error(f"Error in CNN training step: {e}")
continue
if valid_samples > 0:
session.current_epoch = epoch + 1
session.current_loss = epoch_loss / valid_samples
logger.info(f"CNN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
else:
raise Exception("CNN model does not have train_on_annotations or train_step method")
raise Exception("CNN model does not have train_on_annotations, trainer.train_step, or train_step method")
session.final_loss = session.current_loss
session.accuracy = 0.85 # TODO: Calculate actual accuracy