model training WIP
This commit is contained in:
@@ -161,7 +161,9 @@ class RealTrainingAdapter:
|
||||
session = self.training_sessions[training_id]
|
||||
|
||||
try:
|
||||
logger.info(f"Executing REAL training for {model_name}")
|
||||
logger.info(f"🎯 Executing REAL training for {model_name}")
|
||||
logger.info(f" Training ID: {training_id}")
|
||||
logger.info(f" Test cases: {len(test_cases)}")
|
||||
|
||||
# Prepare training data from test cases
|
||||
training_data = self._prepare_training_data(test_cases)
|
||||
@@ -169,18 +171,23 @@ class RealTrainingAdapter:
|
||||
if not training_data:
|
||||
raise Exception("No valid training data prepared from test cases")
|
||||
|
||||
logger.info(f"Prepared {len(training_data)} training samples")
|
||||
logger.info(f"✅ Prepared {len(training_data)} training samples")
|
||||
|
||||
# Route to appropriate REAL training method
|
||||
if model_name in ["CNN", "StandardizedCNN"]:
|
||||
logger.info("🔄 Starting CNN training...")
|
||||
self._train_cnn_real(session, training_data)
|
||||
elif model_name == "DQN":
|
||||
logger.info("🔄 Starting DQN training...")
|
||||
self._train_dqn_real(session, training_data)
|
||||
elif model_name == "Transformer":
|
||||
logger.info("🔄 Starting Transformer training...")
|
||||
self._train_transformer_real(session, training_data)
|
||||
elif model_name == "COB":
|
||||
logger.info("🔄 Starting COB training...")
|
||||
self._train_cob_real(session, training_data)
|
||||
elif model_name == "Extrema":
|
||||
logger.info("🔄 Starting Extrema training...")
|
||||
self._train_extrema_real(session, training_data)
|
||||
else:
|
||||
raise Exception(f"Unknown model type: {model_name}")
|
||||
@@ -189,44 +196,300 @@ class RealTrainingAdapter:
|
||||
session.status = 'completed'
|
||||
session.duration_seconds = time.time() - session.start_time
|
||||
|
||||
logger.info(f"REAL training completed: {training_id} in {session.duration_seconds:.2f}s")
|
||||
logger.info(f"✅ REAL training completed: {training_id} in {session.duration_seconds:.2f}s")
|
||||
logger.info(f" Final loss: {session.final_loss}")
|
||||
logger.info(f" Accuracy: {session.accuracy}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"REAL training failed: {e}", exc_info=True)
|
||||
logger.error(f"❌ REAL training failed: {e}", exc_info=True)
|
||||
session.status = 'failed'
|
||||
session.error = str(e)
|
||||
session.duration_seconds = time.time() - session.start_time
|
||||
|
||||
def _prepare_training_data(self, test_cases: List[Dict]) -> List[Dict]:
|
||||
"""Prepare training data from test cases"""
|
||||
def _fetch_market_state_for_test_case(self, test_case: Dict) -> Dict:
|
||||
"""
|
||||
Fetch market state dynamically for a test case
|
||||
|
||||
Args:
|
||||
test_case: Test case dictionary with timestamp, symbol, etc.
|
||||
|
||||
Returns:
|
||||
Market state dictionary with OHLCV data for all timeframes
|
||||
"""
|
||||
try:
|
||||
if not self.data_provider:
|
||||
logger.warning("DataProvider not available, cannot fetch market state")
|
||||
return {}
|
||||
|
||||
symbol = test_case.get('symbol', 'ETH/USDT')
|
||||
timestamp_str = test_case.get('timestamp')
|
||||
|
||||
if not timestamp_str:
|
||||
logger.warning("No timestamp in test case")
|
||||
return {}
|
||||
|
||||
# Parse timestamp
|
||||
from datetime import datetime
|
||||
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
|
||||
|
||||
# Get training config
|
||||
training_config = test_case.get('training_config', {})
|
||||
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
||||
context_window = training_config.get('context_window_minutes', 5)
|
||||
|
||||
logger.info(f" Fetching market state for {symbol} at {timestamp}")
|
||||
logger.info(f" Timeframes: {timeframes}, Context window: {context_window} minutes")
|
||||
|
||||
# Fetch data for each timeframe
|
||||
market_state = {
|
||||
'symbol': symbol,
|
||||
'timestamp': timestamp_str,
|
||||
'timeframes': {}
|
||||
}
|
||||
|
||||
for timeframe in timeframes:
|
||||
# Get historical data around the timestamp
|
||||
# For now, just get the latest data (we can improve this later)
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=100 # Get 100 candles for context
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Convert to dict format
|
||||
market_state['timeframes'][timeframe] = {
|
||||
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': df['open'].tolist(),
|
||||
'high': df['high'].tolist(),
|
||||
'low': df['low'].tolist(),
|
||||
'close': df['close'].tolist(),
|
||||
'volume': df['volume'].tolist()
|
||||
}
|
||||
logger.debug(f" ✅ {timeframe}: {len(df)} candles")
|
||||
else:
|
||||
logger.warning(f" ❌ {timeframe}: No data")
|
||||
|
||||
if market_state['timeframes']:
|
||||
logger.info(f" ✅ Fetched market state with {len(market_state['timeframes'])} timeframes")
|
||||
return market_state
|
||||
else:
|
||||
logger.warning(f" ❌ No market data fetched")
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching market state: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return {}
|
||||
|
||||
def _prepare_training_data(self, test_cases: List[Dict],
|
||||
negative_samples_window: int = 15,
|
||||
training_repetitions: int = 100) -> List[Dict]:
|
||||
"""
|
||||
Prepare training data from test cases with negative sampling
|
||||
|
||||
Args:
|
||||
test_cases: List of test cases from annotations
|
||||
negative_samples_window: Number of candles before/after signal where model should NOT trade
|
||||
training_repetitions: Number of times to repeat training on each sample
|
||||
|
||||
Returns:
|
||||
List of training samples with positive (trade) and negative (no-trade) examples
|
||||
"""
|
||||
training_data = []
|
||||
|
||||
for test_case in test_cases:
|
||||
logger.info(f"📦 Preparing training data from {len(test_cases)} test cases...")
|
||||
logger.info(f" Negative sampling: ±{negative_samples_window} candles around signals")
|
||||
logger.info(f" Training repetitions: {training_repetitions}x per sample")
|
||||
|
||||
for i, test_case in enumerate(test_cases):
|
||||
try:
|
||||
# Extract market state and expected outcome
|
||||
market_state = test_case.get('market_state', {})
|
||||
# Extract expected outcome
|
||||
expected_outcome = test_case.get('expected_outcome', {})
|
||||
|
||||
if not market_state or not expected_outcome:
|
||||
logger.warning(f"Skipping test case {test_case.get('test_case_id')}: missing data")
|
||||
if not expected_outcome:
|
||||
logger.warning(f"⚠️ Skipping test case {test_case.get('test_case_id')}: missing expected_outcome")
|
||||
continue
|
||||
|
||||
training_data.append({
|
||||
# Check if market_state is provided, if not, fetch it dynamically
|
||||
market_state = test_case.get('market_state', {})
|
||||
|
||||
if not market_state:
|
||||
logger.info(f" 📡 Fetching market state dynamically for test case {i+1}...")
|
||||
market_state = self._fetch_market_state_for_test_case(test_case)
|
||||
|
||||
if not market_state:
|
||||
logger.warning(f"⚠️ Skipping test case {test_case.get('test_case_id')}: could not fetch market state")
|
||||
continue
|
||||
|
||||
logger.debug(f" Test case {i+1}: has_market_state={bool(market_state)}, has_expected_outcome={bool(expected_outcome)}")
|
||||
|
||||
# Create POSITIVE sample (where model SHOULD trade)
|
||||
positive_sample = {
|
||||
'market_state': market_state,
|
||||
'action': test_case.get('action'),
|
||||
'direction': expected_outcome.get('direction'),
|
||||
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
|
||||
'entry_price': expected_outcome.get('entry_price'),
|
||||
'exit_price': expected_outcome.get('exit_price'),
|
||||
'timestamp': test_case.get('timestamp')
|
||||
})
|
||||
'timestamp': test_case.get('timestamp'),
|
||||
'label': 'TRADE', # Positive label
|
||||
'repetitions': training_repetitions
|
||||
}
|
||||
|
||||
training_data.append(positive_sample)
|
||||
logger.debug(f" ✅ Positive sample: {positive_sample['direction']} @ {positive_sample['entry_price']} -> {positive_sample['exit_price']} ({positive_sample['profit_loss_pct']:.2f}%)")
|
||||
|
||||
# Create NEGATIVE samples (where model should NOT trade)
|
||||
# These are candles before and after the signal
|
||||
negative_samples = self._create_negative_samples(
|
||||
market_state=market_state,
|
||||
signal_timestamp=test_case.get('timestamp'),
|
||||
window_size=negative_samples_window,
|
||||
repetitions=training_repetitions // 2 # Half as many reps for negative samples
|
||||
)
|
||||
|
||||
training_data.extend(negative_samples)
|
||||
logger.debug(f" ➕ Added {len(negative_samples)} negative samples (±{negative_samples_window} candles)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing test case: {e}")
|
||||
logger.error(f"❌ Error preparing test case {i+1}: {e}")
|
||||
|
||||
total_positive = sum(1 for s in training_data if s.get('label') == 'TRADE')
|
||||
total_negative = sum(1 for s in training_data if s.get('label') == 'NO_TRADE')
|
||||
|
||||
logger.info(f"✅ Prepared {len(training_data)} training samples from {len(test_cases)} test cases")
|
||||
logger.info(f" Positive samples (TRADE): {total_positive}")
|
||||
logger.info(f" Negative samples (NO_TRADE): {total_negative}")
|
||||
logger.info(f" Ratio: 1:{total_negative/total_positive:.1f} (positive:negative)")
|
||||
|
||||
if len(training_data) < len(test_cases):
|
||||
logger.warning(f"⚠️ Skipped {len(test_cases) - len(training_data)} test cases due to missing data")
|
||||
|
||||
logger.info(f"Prepared {len(training_data)} training samples from {len(test_cases)} test cases")
|
||||
return training_data
|
||||
|
||||
def _create_negative_samples(self, market_state: Dict, signal_timestamp: str,
|
||||
window_size: int, repetitions: int) -> List[Dict]:
|
||||
"""
|
||||
Create negative training samples from candles around the signal
|
||||
|
||||
These samples teach the model when NOT to trade - crucial for reducing false signals!
|
||||
|
||||
Args:
|
||||
market_state: Market state with OHLCV data
|
||||
signal_timestamp: Timestamp of the actual signal
|
||||
window_size: Number of candles before/after signal to use
|
||||
repetitions: Number of times to repeat each negative sample
|
||||
|
||||
Returns:
|
||||
List of negative training samples
|
||||
"""
|
||||
negative_samples = []
|
||||
|
||||
try:
|
||||
# Get timestamps from market state (use 1m timeframe as reference)
|
||||
timeframes = market_state.get('timeframes', {})
|
||||
if '1m' not in timeframes:
|
||||
logger.warning("No 1m timeframe in market state, cannot create negative samples")
|
||||
return negative_samples
|
||||
|
||||
timestamps = timeframes['1m'].get('timestamps', [])
|
||||
if not timestamps:
|
||||
return negative_samples
|
||||
|
||||
# Find the index of the signal timestamp
|
||||
from datetime import datetime
|
||||
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
|
||||
|
||||
signal_index = None
|
||||
for idx, ts_str in enumerate(timestamps):
|
||||
ts = datetime.fromisoformat(ts_str.replace(' ', 'T'))
|
||||
if abs((ts - signal_time).total_seconds()) < 60: # Within 1 minute
|
||||
signal_index = idx
|
||||
break
|
||||
|
||||
if signal_index is None:
|
||||
logger.warning(f"Could not find signal timestamp in market data")
|
||||
return negative_samples
|
||||
|
||||
# Create negative samples from candles before and after the signal
|
||||
# BEFORE signal: candles at signal_index - window_size to signal_index - 1
|
||||
# AFTER signal: candles at signal_index + 1 to signal_index + window_size
|
||||
|
||||
negative_indices = []
|
||||
|
||||
# Before signal
|
||||
for offset in range(1, window_size + 1):
|
||||
idx = signal_index - offset
|
||||
if 0 <= idx < len(timestamps):
|
||||
negative_indices.append(idx)
|
||||
|
||||
# After signal
|
||||
for offset in range(1, window_size + 1):
|
||||
idx = signal_index + offset
|
||||
if 0 <= idx < len(timestamps):
|
||||
negative_indices.append(idx)
|
||||
|
||||
# Create negative samples for each index
|
||||
for idx in negative_indices:
|
||||
# Create a market state snapshot at this timestamp
|
||||
negative_market_state = self._create_market_state_snapshot(market_state, idx)
|
||||
|
||||
negative_sample = {
|
||||
'market_state': negative_market_state,
|
||||
'action': 'HOLD', # No action
|
||||
'direction': 'NONE',
|
||||
'profit_loss_pct': 0.0,
|
||||
'entry_price': None,
|
||||
'exit_price': None,
|
||||
'timestamp': timestamps[idx],
|
||||
'label': 'NO_TRADE', # Negative label
|
||||
'repetitions': repetitions
|
||||
}
|
||||
|
||||
negative_samples.append(negative_sample)
|
||||
|
||||
logger.debug(f" Created {len(negative_samples)} negative samples from ±{window_size} candles")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating negative samples: {e}")
|
||||
|
||||
return negative_samples
|
||||
|
||||
def _create_market_state_snapshot(self, market_state: Dict, candle_index: int) -> Dict:
|
||||
"""
|
||||
Create a market state snapshot at a specific candle index
|
||||
|
||||
This creates a "view" of the market as it was at that specific candle,
|
||||
which is used for negative sampling.
|
||||
"""
|
||||
snapshot = {
|
||||
'symbol': market_state.get('symbol'),
|
||||
'timestamp': None, # Will be set from the candle
|
||||
'timeframes': {}
|
||||
}
|
||||
|
||||
# For each timeframe, create a snapshot up to the candle_index
|
||||
for tf, tf_data in market_state.get('timeframes', {}).items():
|
||||
timestamps = tf_data.get('timestamps', [])
|
||||
|
||||
if candle_index < len(timestamps):
|
||||
# Include data up to and including this candle
|
||||
snapshot['timeframes'][tf] = {
|
||||
'timestamps': timestamps[:candle_index + 1],
|
||||
'open': tf_data.get('open', [])[:candle_index + 1],
|
||||
'high': tf_data.get('high', [])[:candle_index + 1],
|
||||
'low': tf_data.get('low', [])[:candle_index + 1],
|
||||
'close': tf_data.get('close', [])[:candle_index + 1],
|
||||
'volume': tf_data.get('volume', [])[:candle_index + 1]
|
||||
}
|
||||
|
||||
if tf == '1m':
|
||||
snapshot['timestamp'] = timestamps[candle_index]
|
||||
|
||||
return snapshot
|
||||
|
||||
def _train_cnn_real(self, session: TrainingSession, training_data: List[Dict]):
|
||||
"""Train CNN model with REAL training loop"""
|
||||
if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||
@@ -334,21 +597,64 @@ class RealTrainingAdapter:
|
||||
return [0.0] * state_size
|
||||
|
||||
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
|
||||
"""Train Transformer model with REAL training loop"""
|
||||
"""
|
||||
Train Transformer model using orchestrator's existing training infrastructure
|
||||
|
||||
Uses the orchestrator's primary_transformer_trainer which already has
|
||||
all the training logic implemented!
|
||||
"""
|
||||
if not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
|
||||
raise Exception("Transformer model not available in orchestrator")
|
||||
|
||||
model = self.orchestrator.primary_transformer
|
||||
# Get the trainer from orchestrator - it already has training methods!
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
|
||||
# Use model's training method
|
||||
for epoch in range(session.total_epochs):
|
||||
# TODO: Implement actual transformer training
|
||||
session.current_epoch = epoch + 1
|
||||
session.current_loss = 0.5 / (epoch + 1) # Placeholder
|
||||
logger.info(f"Transformer Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
|
||||
if not trainer:
|
||||
raise Exception("Transformer trainer not available in orchestrator")
|
||||
|
||||
session.final_loss = session.current_loss
|
||||
session.accuracy = 0.85
|
||||
logger.info(f"🎯 Using orchestrator's TradingTransformerTrainer")
|
||||
logger.info(f" Trainer type: {type(trainer).__name__}")
|
||||
|
||||
# Use the trainer's train_step method for individual samples
|
||||
if hasattr(trainer, 'train_step'):
|
||||
logger.info(" Using trainer.train_step() method")
|
||||
|
||||
import torch
|
||||
|
||||
# Train using train_step for each sample
|
||||
for epoch in range(session.total_epochs):
|
||||
epoch_loss = 0.0
|
||||
num_samples = 0
|
||||
|
||||
for i, data in enumerate(training_data):
|
||||
try:
|
||||
# Call the trainer's train_step method
|
||||
loss = trainer.train_step(data)
|
||||
|
||||
if loss is not None:
|
||||
epoch_loss += float(loss)
|
||||
num_samples += 1
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
logger.debug(f" Sample {i + 1}/{len(training_data)}, Loss: {loss:.6f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" Error in sample {i + 1}: {e}")
|
||||
continue
|
||||
|
||||
avg_loss = epoch_loss / num_samples if num_samples > 0 else 0.0
|
||||
session.current_epoch = epoch + 1
|
||||
session.current_loss = avg_loss
|
||||
|
||||
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Avg Loss: {avg_loss:.6f} ({num_samples} samples)")
|
||||
|
||||
session.final_loss = session.current_loss
|
||||
session.accuracy = 0.85 # TODO: Calculate actual accuracy
|
||||
|
||||
logger.info(f" Training complete: Loss = {session.final_loss:.6f}")
|
||||
|
||||
else:
|
||||
raise Exception(f"Transformer trainer does not have train_on_batch() or train() methods. Available methods: {[m for m in dir(trainer) if not m.startswith('_')]}")
|
||||
|
||||
def _train_cob_real(self, session: TrainingSession, training_data: List[Dict]):
|
||||
"""Train COB RL model with REAL training loop"""
|
||||
|
||||
Reference in New Issue
Block a user