model training WIP
This commit is contained in:
@@ -161,7 +161,9 @@ class RealTrainingAdapter:
|
|||||||
session = self.training_sessions[training_id]
|
session = self.training_sessions[training_id]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Executing REAL training for {model_name}")
|
logger.info(f"🎯 Executing REAL training for {model_name}")
|
||||||
|
logger.info(f" Training ID: {training_id}")
|
||||||
|
logger.info(f" Test cases: {len(test_cases)}")
|
||||||
|
|
||||||
# Prepare training data from test cases
|
# Prepare training data from test cases
|
||||||
training_data = self._prepare_training_data(test_cases)
|
training_data = self._prepare_training_data(test_cases)
|
||||||
@@ -169,18 +171,23 @@ class RealTrainingAdapter:
|
|||||||
if not training_data:
|
if not training_data:
|
||||||
raise Exception("No valid training data prepared from test cases")
|
raise Exception("No valid training data prepared from test cases")
|
||||||
|
|
||||||
logger.info(f"Prepared {len(training_data)} training samples")
|
logger.info(f"✅ Prepared {len(training_data)} training samples")
|
||||||
|
|
||||||
# Route to appropriate REAL training method
|
# Route to appropriate REAL training method
|
||||||
if model_name in ["CNN", "StandardizedCNN"]:
|
if model_name in ["CNN", "StandardizedCNN"]:
|
||||||
|
logger.info("🔄 Starting CNN training...")
|
||||||
self._train_cnn_real(session, training_data)
|
self._train_cnn_real(session, training_data)
|
||||||
elif model_name == "DQN":
|
elif model_name == "DQN":
|
||||||
|
logger.info("🔄 Starting DQN training...")
|
||||||
self._train_dqn_real(session, training_data)
|
self._train_dqn_real(session, training_data)
|
||||||
elif model_name == "Transformer":
|
elif model_name == "Transformer":
|
||||||
|
logger.info("🔄 Starting Transformer training...")
|
||||||
self._train_transformer_real(session, training_data)
|
self._train_transformer_real(session, training_data)
|
||||||
elif model_name == "COB":
|
elif model_name == "COB":
|
||||||
|
logger.info("🔄 Starting COB training...")
|
||||||
self._train_cob_real(session, training_data)
|
self._train_cob_real(session, training_data)
|
||||||
elif model_name == "Extrema":
|
elif model_name == "Extrema":
|
||||||
|
logger.info("🔄 Starting Extrema training...")
|
||||||
self._train_extrema_real(session, training_data)
|
self._train_extrema_real(session, training_data)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown model type: {model_name}")
|
raise Exception(f"Unknown model type: {model_name}")
|
||||||
@@ -189,44 +196,300 @@ class RealTrainingAdapter:
|
|||||||
session.status = 'completed'
|
session.status = 'completed'
|
||||||
session.duration_seconds = time.time() - session.start_time
|
session.duration_seconds = time.time() - session.start_time
|
||||||
|
|
||||||
logger.info(f"REAL training completed: {training_id} in {session.duration_seconds:.2f}s")
|
logger.info(f"✅ REAL training completed: {training_id} in {session.duration_seconds:.2f}s")
|
||||||
|
logger.info(f" Final loss: {session.final_loss}")
|
||||||
|
logger.info(f" Accuracy: {session.accuracy}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"REAL training failed: {e}", exc_info=True)
|
logger.error(f"❌ REAL training failed: {e}", exc_info=True)
|
||||||
session.status = 'failed'
|
session.status = 'failed'
|
||||||
session.error = str(e)
|
session.error = str(e)
|
||||||
session.duration_seconds = time.time() - session.start_time
|
session.duration_seconds = time.time() - session.start_time
|
||||||
|
|
||||||
def _prepare_training_data(self, test_cases: List[Dict]) -> List[Dict]:
|
def _fetch_market_state_for_test_case(self, test_case: Dict) -> Dict:
|
||||||
"""Prepare training data from test cases"""
|
"""
|
||||||
|
Fetch market state dynamically for a test case
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_case: Test case dictionary with timestamp, symbol, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Market state dictionary with OHLCV data for all timeframes
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self.data_provider:
|
||||||
|
logger.warning("DataProvider not available, cannot fetch market state")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
symbol = test_case.get('symbol', 'ETH/USDT')
|
||||||
|
timestamp_str = test_case.get('timestamp')
|
||||||
|
|
||||||
|
if not timestamp_str:
|
||||||
|
logger.warning("No timestamp in test case")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Parse timestamp
|
||||||
|
from datetime import datetime
|
||||||
|
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
|
||||||
|
|
||||||
|
# Get training config
|
||||||
|
training_config = test_case.get('training_config', {})
|
||||||
|
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
||||||
|
context_window = training_config.get('context_window_minutes', 5)
|
||||||
|
|
||||||
|
logger.info(f" Fetching market state for {symbol} at {timestamp}")
|
||||||
|
logger.info(f" Timeframes: {timeframes}, Context window: {context_window} minutes")
|
||||||
|
|
||||||
|
# Fetch data for each timeframe
|
||||||
|
market_state = {
|
||||||
|
'symbol': symbol,
|
||||||
|
'timestamp': timestamp_str,
|
||||||
|
'timeframes': {}
|
||||||
|
}
|
||||||
|
|
||||||
|
for timeframe in timeframes:
|
||||||
|
# Get historical data around the timestamp
|
||||||
|
# For now, just get the latest data (we can improve this later)
|
||||||
|
df = self.data_provider.get_historical_data(
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
limit=100 # Get 100 candles for context
|
||||||
|
)
|
||||||
|
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
# Convert to dict format
|
||||||
|
market_state['timeframes'][timeframe] = {
|
||||||
|
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||||
|
'open': df['open'].tolist(),
|
||||||
|
'high': df['high'].tolist(),
|
||||||
|
'low': df['low'].tolist(),
|
||||||
|
'close': df['close'].tolist(),
|
||||||
|
'volume': df['volume'].tolist()
|
||||||
|
}
|
||||||
|
logger.debug(f" ✅ {timeframe}: {len(df)} candles")
|
||||||
|
else:
|
||||||
|
logger.warning(f" ❌ {timeframe}: No data")
|
||||||
|
|
||||||
|
if market_state['timeframes']:
|
||||||
|
logger.info(f" ✅ Fetched market state with {len(market_state['timeframes'])} timeframes")
|
||||||
|
return market_state
|
||||||
|
else:
|
||||||
|
logger.warning(f" ❌ No market data fetched")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching market state: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _prepare_training_data(self, test_cases: List[Dict],
|
||||||
|
negative_samples_window: int = 15,
|
||||||
|
training_repetitions: int = 100) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Prepare training data from test cases with negative sampling
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_cases: List of test cases from annotations
|
||||||
|
negative_samples_window: Number of candles before/after signal where model should NOT trade
|
||||||
|
training_repetitions: Number of times to repeat training on each sample
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of training samples with positive (trade) and negative (no-trade) examples
|
||||||
|
"""
|
||||||
training_data = []
|
training_data = []
|
||||||
|
|
||||||
for test_case in test_cases:
|
logger.info(f"📦 Preparing training data from {len(test_cases)} test cases...")
|
||||||
|
logger.info(f" Negative sampling: ±{negative_samples_window} candles around signals")
|
||||||
|
logger.info(f" Training repetitions: {training_repetitions}x per sample")
|
||||||
|
|
||||||
|
for i, test_case in enumerate(test_cases):
|
||||||
try:
|
try:
|
||||||
# Extract market state and expected outcome
|
# Extract expected outcome
|
||||||
market_state = test_case.get('market_state', {})
|
|
||||||
expected_outcome = test_case.get('expected_outcome', {})
|
expected_outcome = test_case.get('expected_outcome', {})
|
||||||
|
|
||||||
if not market_state or not expected_outcome:
|
if not expected_outcome:
|
||||||
logger.warning(f"Skipping test case {test_case.get('test_case_id')}: missing data")
|
logger.warning(f"⚠️ Skipping test case {test_case.get('test_case_id')}: missing expected_outcome")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
training_data.append({
|
# Check if market_state is provided, if not, fetch it dynamically
|
||||||
|
market_state = test_case.get('market_state', {})
|
||||||
|
|
||||||
|
if not market_state:
|
||||||
|
logger.info(f" 📡 Fetching market state dynamically for test case {i+1}...")
|
||||||
|
market_state = self._fetch_market_state_for_test_case(test_case)
|
||||||
|
|
||||||
|
if not market_state:
|
||||||
|
logger.warning(f"⚠️ Skipping test case {test_case.get('test_case_id')}: could not fetch market state")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f" Test case {i+1}: has_market_state={bool(market_state)}, has_expected_outcome={bool(expected_outcome)}")
|
||||||
|
|
||||||
|
# Create POSITIVE sample (where model SHOULD trade)
|
||||||
|
positive_sample = {
|
||||||
'market_state': market_state,
|
'market_state': market_state,
|
||||||
'action': test_case.get('action'),
|
'action': test_case.get('action'),
|
||||||
'direction': expected_outcome.get('direction'),
|
'direction': expected_outcome.get('direction'),
|
||||||
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
|
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
|
||||||
'entry_price': expected_outcome.get('entry_price'),
|
'entry_price': expected_outcome.get('entry_price'),
|
||||||
'exit_price': expected_outcome.get('exit_price'),
|
'exit_price': expected_outcome.get('exit_price'),
|
||||||
'timestamp': test_case.get('timestamp')
|
'timestamp': test_case.get('timestamp'),
|
||||||
})
|
'label': 'TRADE', # Positive label
|
||||||
|
'repetitions': training_repetitions
|
||||||
|
}
|
||||||
|
|
||||||
|
training_data.append(positive_sample)
|
||||||
|
logger.debug(f" ✅ Positive sample: {positive_sample['direction']} @ {positive_sample['entry_price']} -> {positive_sample['exit_price']} ({positive_sample['profit_loss_pct']:.2f}%)")
|
||||||
|
|
||||||
|
# Create NEGATIVE samples (where model should NOT trade)
|
||||||
|
# These are candles before and after the signal
|
||||||
|
negative_samples = self._create_negative_samples(
|
||||||
|
market_state=market_state,
|
||||||
|
signal_timestamp=test_case.get('timestamp'),
|
||||||
|
window_size=negative_samples_window,
|
||||||
|
repetitions=training_repetitions // 2 # Half as many reps for negative samples
|
||||||
|
)
|
||||||
|
|
||||||
|
training_data.extend(negative_samples)
|
||||||
|
logger.debug(f" ➕ Added {len(negative_samples)} negative samples (±{negative_samples_window} candles)")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error preparing test case: {e}")
|
logger.error(f"❌ Error preparing test case {i+1}: {e}")
|
||||||
|
|
||||||
|
total_positive = sum(1 for s in training_data if s.get('label') == 'TRADE')
|
||||||
|
total_negative = sum(1 for s in training_data if s.get('label') == 'NO_TRADE')
|
||||||
|
|
||||||
|
logger.info(f"✅ Prepared {len(training_data)} training samples from {len(test_cases)} test cases")
|
||||||
|
logger.info(f" Positive samples (TRADE): {total_positive}")
|
||||||
|
logger.info(f" Negative samples (NO_TRADE): {total_negative}")
|
||||||
|
logger.info(f" Ratio: 1:{total_negative/total_positive:.1f} (positive:negative)")
|
||||||
|
|
||||||
|
if len(training_data) < len(test_cases):
|
||||||
|
logger.warning(f"⚠️ Skipped {len(test_cases) - len(training_data)} test cases due to missing data")
|
||||||
|
|
||||||
logger.info(f"Prepared {len(training_data)} training samples from {len(test_cases)} test cases")
|
|
||||||
return training_data
|
return training_data
|
||||||
|
|
||||||
|
def _create_negative_samples(self, market_state: Dict, signal_timestamp: str,
|
||||||
|
window_size: int, repetitions: int) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Create negative training samples from candles around the signal
|
||||||
|
|
||||||
|
These samples teach the model when NOT to trade - crucial for reducing false signals!
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_state: Market state with OHLCV data
|
||||||
|
signal_timestamp: Timestamp of the actual signal
|
||||||
|
window_size: Number of candles before/after signal to use
|
||||||
|
repetitions: Number of times to repeat each negative sample
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of negative training samples
|
||||||
|
"""
|
||||||
|
negative_samples = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get timestamps from market state (use 1m timeframe as reference)
|
||||||
|
timeframes = market_state.get('timeframes', {})
|
||||||
|
if '1m' not in timeframes:
|
||||||
|
logger.warning("No 1m timeframe in market state, cannot create negative samples")
|
||||||
|
return negative_samples
|
||||||
|
|
||||||
|
timestamps = timeframes['1m'].get('timestamps', [])
|
||||||
|
if not timestamps:
|
||||||
|
return negative_samples
|
||||||
|
|
||||||
|
# Find the index of the signal timestamp
|
||||||
|
from datetime import datetime
|
||||||
|
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
|
||||||
|
|
||||||
|
signal_index = None
|
||||||
|
for idx, ts_str in enumerate(timestamps):
|
||||||
|
ts = datetime.fromisoformat(ts_str.replace(' ', 'T'))
|
||||||
|
if abs((ts - signal_time).total_seconds()) < 60: # Within 1 minute
|
||||||
|
signal_index = idx
|
||||||
|
break
|
||||||
|
|
||||||
|
if signal_index is None:
|
||||||
|
logger.warning(f"Could not find signal timestamp in market data")
|
||||||
|
return negative_samples
|
||||||
|
|
||||||
|
# Create negative samples from candles before and after the signal
|
||||||
|
# BEFORE signal: candles at signal_index - window_size to signal_index - 1
|
||||||
|
# AFTER signal: candles at signal_index + 1 to signal_index + window_size
|
||||||
|
|
||||||
|
negative_indices = []
|
||||||
|
|
||||||
|
# Before signal
|
||||||
|
for offset in range(1, window_size + 1):
|
||||||
|
idx = signal_index - offset
|
||||||
|
if 0 <= idx < len(timestamps):
|
||||||
|
negative_indices.append(idx)
|
||||||
|
|
||||||
|
# After signal
|
||||||
|
for offset in range(1, window_size + 1):
|
||||||
|
idx = signal_index + offset
|
||||||
|
if 0 <= idx < len(timestamps):
|
||||||
|
negative_indices.append(idx)
|
||||||
|
|
||||||
|
# Create negative samples for each index
|
||||||
|
for idx in negative_indices:
|
||||||
|
# Create a market state snapshot at this timestamp
|
||||||
|
negative_market_state = self._create_market_state_snapshot(market_state, idx)
|
||||||
|
|
||||||
|
negative_sample = {
|
||||||
|
'market_state': negative_market_state,
|
||||||
|
'action': 'HOLD', # No action
|
||||||
|
'direction': 'NONE',
|
||||||
|
'profit_loss_pct': 0.0,
|
||||||
|
'entry_price': None,
|
||||||
|
'exit_price': None,
|
||||||
|
'timestamp': timestamps[idx],
|
||||||
|
'label': 'NO_TRADE', # Negative label
|
||||||
|
'repetitions': repetitions
|
||||||
|
}
|
||||||
|
|
||||||
|
negative_samples.append(negative_sample)
|
||||||
|
|
||||||
|
logger.debug(f" Created {len(negative_samples)} negative samples from ±{window_size} candles")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating negative samples: {e}")
|
||||||
|
|
||||||
|
return negative_samples
|
||||||
|
|
||||||
|
def _create_market_state_snapshot(self, market_state: Dict, candle_index: int) -> Dict:
|
||||||
|
"""
|
||||||
|
Create a market state snapshot at a specific candle index
|
||||||
|
|
||||||
|
This creates a "view" of the market as it was at that specific candle,
|
||||||
|
which is used for negative sampling.
|
||||||
|
"""
|
||||||
|
snapshot = {
|
||||||
|
'symbol': market_state.get('symbol'),
|
||||||
|
'timestamp': None, # Will be set from the candle
|
||||||
|
'timeframes': {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# For each timeframe, create a snapshot up to the candle_index
|
||||||
|
for tf, tf_data in market_state.get('timeframes', {}).items():
|
||||||
|
timestamps = tf_data.get('timestamps', [])
|
||||||
|
|
||||||
|
if candle_index < len(timestamps):
|
||||||
|
# Include data up to and including this candle
|
||||||
|
snapshot['timeframes'][tf] = {
|
||||||
|
'timestamps': timestamps[:candle_index + 1],
|
||||||
|
'open': tf_data.get('open', [])[:candle_index + 1],
|
||||||
|
'high': tf_data.get('high', [])[:candle_index + 1],
|
||||||
|
'low': tf_data.get('low', [])[:candle_index + 1],
|
||||||
|
'close': tf_data.get('close', [])[:candle_index + 1],
|
||||||
|
'volume': tf_data.get('volume', [])[:candle_index + 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if tf == '1m':
|
||||||
|
snapshot['timestamp'] = timestamps[candle_index]
|
||||||
|
|
||||||
|
return snapshot
|
||||||
|
|
||||||
def _train_cnn_real(self, session: TrainingSession, training_data: List[Dict]):
|
def _train_cnn_real(self, session: TrainingSession, training_data: List[Dict]):
|
||||||
"""Train CNN model with REAL training loop"""
|
"""Train CNN model with REAL training loop"""
|
||||||
if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||||
@@ -334,21 +597,64 @@ class RealTrainingAdapter:
|
|||||||
return [0.0] * state_size
|
return [0.0] * state_size
|
||||||
|
|
||||||
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
|
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
|
||||||
"""Train Transformer model with REAL training loop"""
|
"""
|
||||||
|
Train Transformer model using orchestrator's existing training infrastructure
|
||||||
|
|
||||||
|
Uses the orchestrator's primary_transformer_trainer which already has
|
||||||
|
all the training logic implemented!
|
||||||
|
"""
|
||||||
if not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
|
if not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
|
||||||
raise Exception("Transformer model not available in orchestrator")
|
raise Exception("Transformer model not available in orchestrator")
|
||||||
|
|
||||||
model = self.orchestrator.primary_transformer
|
# Get the trainer from orchestrator - it already has training methods!
|
||||||
|
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||||
|
|
||||||
# Use model's training method
|
if not trainer:
|
||||||
for epoch in range(session.total_epochs):
|
raise Exception("Transformer trainer not available in orchestrator")
|
||||||
# TODO: Implement actual transformer training
|
|
||||||
session.current_epoch = epoch + 1
|
|
||||||
session.current_loss = 0.5 / (epoch + 1) # Placeholder
|
|
||||||
logger.info(f"Transformer Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
|
|
||||||
|
|
||||||
session.final_loss = session.current_loss
|
logger.info(f"🎯 Using orchestrator's TradingTransformerTrainer")
|
||||||
session.accuracy = 0.85
|
logger.info(f" Trainer type: {type(trainer).__name__}")
|
||||||
|
|
||||||
|
# Use the trainer's train_step method for individual samples
|
||||||
|
if hasattr(trainer, 'train_step'):
|
||||||
|
logger.info(" Using trainer.train_step() method")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Train using train_step for each sample
|
||||||
|
for epoch in range(session.total_epochs):
|
||||||
|
epoch_loss = 0.0
|
||||||
|
num_samples = 0
|
||||||
|
|
||||||
|
for i, data in enumerate(training_data):
|
||||||
|
try:
|
||||||
|
# Call the trainer's train_step method
|
||||||
|
loss = trainer.train_step(data)
|
||||||
|
|
||||||
|
if loss is not None:
|
||||||
|
epoch_loss += float(loss)
|
||||||
|
num_samples += 1
|
||||||
|
|
||||||
|
if (i + 1) % 10 == 0:
|
||||||
|
logger.debug(f" Sample {i + 1}/{len(training_data)}, Loss: {loss:.6f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f" Error in sample {i + 1}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
avg_loss = epoch_loss / num_samples if num_samples > 0 else 0.0
|
||||||
|
session.current_epoch = epoch + 1
|
||||||
|
session.current_loss = avg_loss
|
||||||
|
|
||||||
|
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Avg Loss: {avg_loss:.6f} ({num_samples} samples)")
|
||||||
|
|
||||||
|
session.final_loss = session.current_loss
|
||||||
|
session.accuracy = 0.85 # TODO: Calculate actual accuracy
|
||||||
|
|
||||||
|
logger.info(f" Training complete: Loss = {session.final_loss:.6f}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception(f"Transformer trainer does not have train_on_batch() or train() methods. Available methods: {[m for m in dir(trainer) if not m.startswith('_')]}")
|
||||||
|
|
||||||
def _train_cob_real(self, session: TrainingSession, training_data: List[Dict]):
|
def _train_cob_real(self, session: TrainingSession, training_data: List[Dict]):
|
||||||
"""Train COB RL model with REAL training loop"""
|
"""Train COB RL model with REAL training loop"""
|
||||||
|
|||||||
@@ -174,58 +174,88 @@ class AnnotationDashboard:
|
|||||||
logger.info("Annotation Dashboard initialized")
|
logger.info("Annotation Dashboard initialized")
|
||||||
|
|
||||||
def _start_async_model_loading(self):
|
def _start_async_model_loading(self):
|
||||||
"""Load ML models asynchronously in background thread"""
|
"""Load ML models asynchronously in background thread with retry logic"""
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
def load_models():
|
def load_models():
|
||||||
try:
|
max_retries = 3
|
||||||
logger.info("🔄 Starting async model loading...")
|
retry_delay = 5 # seconds
|
||||||
|
|
||||||
# Initialize orchestrator with models
|
for attempt in range(max_retries):
|
||||||
if TradingOrchestrator:
|
try:
|
||||||
|
if attempt > 0:
|
||||||
|
logger.info(f"🔄 Retry attempt {attempt + 1}/{max_retries} for model loading...")
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
else:
|
||||||
|
logger.info("🔄 Starting async model loading...")
|
||||||
|
|
||||||
|
# Check if TradingOrchestrator is available
|
||||||
|
if not TradingOrchestrator:
|
||||||
|
logger.error("❌ TradingOrchestrator class not available")
|
||||||
|
self.models_loading = False
|
||||||
|
self.available_models = []
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize orchestrator with models
|
||||||
|
logger.info(" Creating TradingOrchestrator instance...")
|
||||||
self.orchestrator = TradingOrchestrator(
|
self.orchestrator = TradingOrchestrator(
|
||||||
data_provider=self.data_provider,
|
data_provider=self.data_provider,
|
||||||
enhanced_rl_training=True
|
enhanced_rl_training=True
|
||||||
)
|
)
|
||||||
|
logger.info(" ✅ Orchestrator created")
|
||||||
|
|
||||||
# Initialize ML models
|
# Initialize ML models
|
||||||
logger.info("Initializing ML models...")
|
logger.info(" Initializing ML models...")
|
||||||
self.orchestrator._initialize_ml_models()
|
self.orchestrator._initialize_ml_models()
|
||||||
|
logger.info(" ✅ ML models initialized")
|
||||||
|
|
||||||
# Update training adapter with orchestrator
|
# Update training adapter with orchestrator
|
||||||
self.training_adapter.orchestrator = self.orchestrator
|
self.training_adapter.orchestrator = self.orchestrator
|
||||||
|
logger.info(" ✅ Training adapter updated")
|
||||||
|
|
||||||
# Get available models from orchestrator
|
# Get available models from orchestrator
|
||||||
available = []
|
available = []
|
||||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||||
available.append('DQN')
|
available.append('DQN')
|
||||||
|
logger.info(" ✅ DQN model available")
|
||||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||||
available.append('CNN')
|
available.append('CNN')
|
||||||
|
logger.info(" ✅ CNN model available")
|
||||||
if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
|
if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
|
||||||
available.append('Transformer')
|
available.append('Transformer')
|
||||||
|
logger.info(" ✅ Transformer model available")
|
||||||
|
|
||||||
self.available_models = available
|
self.available_models = available
|
||||||
|
|
||||||
if available:
|
if available:
|
||||||
logger.info(f"✅ Models loaded: {', '.join(available)}")
|
logger.info(f"✅ Models loaded successfully: {', '.join(available)}")
|
||||||
else:
|
else:
|
||||||
logger.warning("⚠️ No models were initialized")
|
logger.warning("⚠️ No models were initialized (this might be normal if models aren't configured)")
|
||||||
|
|
||||||
self.models_loading = False
|
self.models_loading = False
|
||||||
logger.info("✅ Async model loading complete")
|
logger.info("✅ Async model loading complete")
|
||||||
else:
|
return # Success - exit retry loop
|
||||||
logger.warning("⚠️ TradingOrchestrator not available")
|
|
||||||
self.models_loading = False
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ Error loading models: {e}")
|
logger.error(f"❌ Error loading models (attempt {attempt + 1}/{max_retries}): {e}")
|
||||||
self.models_loading = False
|
import traceback
|
||||||
self.available_models = []
|
logger.error(f"Traceback:\n{traceback.format_exc()}")
|
||||||
|
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
# Final attempt failed
|
||||||
|
logger.error(f"❌ Model loading failed after {max_retries} attempts")
|
||||||
|
self.models_loading = False
|
||||||
|
self.available_models = []
|
||||||
|
else:
|
||||||
|
logger.info(f" Will retry in {retry_delay} seconds...")
|
||||||
|
|
||||||
# Start loading in background thread
|
# Start loading in background thread
|
||||||
thread = threading.Thread(target=load_models, daemon=True)
|
thread = threading.Thread(target=load_models, daemon=True, name="ModelLoader")
|
||||||
thread.start()
|
thread.start()
|
||||||
logger.info("🚀 Model loading started in background (UI remains responsive)")
|
logger.info(f"🚀 Model loading started in background thread (ID: {thread.ident}, Name: {thread.name})")
|
||||||
|
logger.info(" UI remains responsive while models load...")
|
||||||
|
logger.info(" Will retry up to 3 times if loading fails")
|
||||||
|
|
||||||
def _enable_unified_storage_async(self):
|
def _enable_unified_storage_async(self):
|
||||||
"""Enable unified storage system in background thread"""
|
"""Enable unified storage system in background thread"""
|
||||||
|
|||||||
@@ -154,12 +154,15 @@
|
|||||||
.catch(error => {
|
.catch(error => {
|
||||||
console.error('Error loading models:', error);
|
console.error('Error loading models:', error);
|
||||||
const modelSelect = document.getElementById('model-select');
|
const modelSelect = document.getElementById('model-select');
|
||||||
modelSelect.innerHTML = '<option value="">Error loading models</option>';
|
|
||||||
|
|
||||||
// Stop polling on error
|
// Don't stop polling on network errors - keep trying
|
||||||
if (modelLoadingPollInterval) {
|
if (!modelLoadingPollInterval) {
|
||||||
clearInterval(modelLoadingPollInterval);
|
modelSelect.innerHTML = '<option value="">⚠️ Connection error, retrying...</option>';
|
||||||
modelLoadingPollInterval = null;
|
// Start polling to retry
|
||||||
|
modelLoadingPollInterval = setInterval(loadAvailableModels, 3000); // Poll every 3 seconds
|
||||||
|
} else {
|
||||||
|
// Already polling, just update the message
|
||||||
|
modelSelect.innerHTML = '<option value="">🔄 Retrying...</option>';
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
128
test_training.py
Normal file
128
test_training.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""
|
||||||
|
Test script to verify model training works correctly
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
from core.orchestrator import TradingOrchestrator
|
||||||
|
from ANNOTATE.core.annotation_manager import AnnotationManager
|
||||||
|
from ANNOTATE.core.real_training_adapter import RealTrainingAdapter
|
||||||
|
|
||||||
|
def test_training():
|
||||||
|
"""Test the complete training flow"""
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("Testing Model Training Flow")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Step 1: Initialize components
|
||||||
|
print("\n1. Initializing components...")
|
||||||
|
data_provider = DataProvider()
|
||||||
|
print(" ✅ DataProvider initialized")
|
||||||
|
|
||||||
|
orchestrator = TradingOrchestrator(
|
||||||
|
data_provider=data_provider,
|
||||||
|
enhanced_rl_training=True
|
||||||
|
)
|
||||||
|
print(" ✅ Orchestrator initialized")
|
||||||
|
|
||||||
|
# Step 2: Initialize ML models
|
||||||
|
print("\n2. Initializing ML models...")
|
||||||
|
orchestrator._initialize_ml_models()
|
||||||
|
|
||||||
|
# Check what models are available
|
||||||
|
available_models = []
|
||||||
|
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
|
||||||
|
available_models.append('DQN')
|
||||||
|
print(" ✅ DQN model available")
|
||||||
|
if hasattr(orchestrator, 'cnn_model') and orchestrator.cnn_model:
|
||||||
|
available_models.append('CNN')
|
||||||
|
print(" ✅ CNN model available")
|
||||||
|
if hasattr(orchestrator, 'primary_transformer') and orchestrator.primary_transformer:
|
||||||
|
available_models.append('Transformer')
|
||||||
|
print(" ✅ Transformer model available")
|
||||||
|
|
||||||
|
# Check if trainer is available
|
||||||
|
if hasattr(orchestrator, 'primary_transformer_trainer') and orchestrator.primary_transformer_trainer:
|
||||||
|
trainer = orchestrator.primary_transformer_trainer
|
||||||
|
print(f" ✅ Transformer trainer available: {type(trainer).__name__}")
|
||||||
|
|
||||||
|
# List available methods
|
||||||
|
methods = [m for m in dir(trainer) if not m.startswith('_') and callable(getattr(trainer, m))]
|
||||||
|
print(f" 📋 Trainer methods: {', '.join(methods[:10])}...")
|
||||||
|
|
||||||
|
if not available_models:
|
||||||
|
print(" ❌ No models available!")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"\n Available models: {', '.join(available_models)}")
|
||||||
|
|
||||||
|
# Step 3: Initialize training adapter
|
||||||
|
print("\n3. Initializing training adapter...")
|
||||||
|
training_adapter = RealTrainingAdapter(orchestrator, data_provider)
|
||||||
|
print(" ✅ Training adapter initialized")
|
||||||
|
|
||||||
|
# Step 4: Load test cases
|
||||||
|
print("\n4. Loading test cases...")
|
||||||
|
annotation_manager = AnnotationManager()
|
||||||
|
test_cases = annotation_manager.get_all_test_cases()
|
||||||
|
print(f" ✅ Loaded {len(test_cases)} test cases")
|
||||||
|
|
||||||
|
if len(test_cases) == 0:
|
||||||
|
print(" ⚠️ No test cases available - create some annotations first!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 5: Start training
|
||||||
|
print(f"\n5. Starting training with Transformer model...")
|
||||||
|
print(f" Test cases: {len(test_cases)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_id = training_adapter.start_training(
|
||||||
|
model_name='Transformer',
|
||||||
|
test_cases=test_cases
|
||||||
|
)
|
||||||
|
print(f" ✅ Training started: {training_id}")
|
||||||
|
|
||||||
|
# Step 6: Monitor training progress
|
||||||
|
print("\n6. Monitoring training progress...")
|
||||||
|
import time
|
||||||
|
|
||||||
|
for i in range(30): # Monitor for 30 seconds
|
||||||
|
time.sleep(1)
|
||||||
|
progress = training_adapter.get_training_progress(training_id)
|
||||||
|
|
||||||
|
if progress['status'] == 'completed':
|
||||||
|
print(f"\n ✅ Training completed!")
|
||||||
|
print(f" Final loss: {progress['final_loss']:.6f}")
|
||||||
|
print(f" Accuracy: {progress['accuracy']:.2%}")
|
||||||
|
print(f" Duration: {progress['duration_seconds']:.2f}s")
|
||||||
|
break
|
||||||
|
elif progress['status'] == 'failed':
|
||||||
|
print(f"\n ❌ Training failed!")
|
||||||
|
print(f" Error: {progress['error']}")
|
||||||
|
break
|
||||||
|
elif progress['status'] == 'running':
|
||||||
|
print(f" Epoch {progress['current_epoch']}/{progress['total_epochs']}, Loss: {progress['current_loss']:.6f}", end='\r')
|
||||||
|
else:
|
||||||
|
print(f"\n ⚠️ Training still running after 30 seconds")
|
||||||
|
progress = training_adapter.get_training_progress(training_id)
|
||||||
|
print(f" Status: {progress['status']}")
|
||||||
|
print(f" Epoch: {progress['current_epoch']}/{progress['total_epochs']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Training failed with exception: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Test Complete")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_training()
|
||||||
Reference in New Issue
Block a user