Files
gogo2/ANNOTATE/core/real_training_adapter.py
2025-10-31 03:14:35 +02:00

1628 lines
73 KiB
Python

"""
Real Training Adapter for ANNOTATE System
This adapter connects the ANNOTATE annotation system to the REAL training implementations.
NO SIMULATION - Uses actual model training from NN/training and core modules.
Integrates with:
- NN/training/enhanced_realtime_training.py
- NN/training/model_manager.py
- core/unified_training_manager.py
- core/orchestrator.py
"""
import logging
import uuid
import time
import threading
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
try:
import pytz
except ImportError:
pytz = None
logger = logging.getLogger(__name__)
def parse_timestamp_to_utc(timestamp_str: str) -> datetime:
"""
Unified timestamp parser that handles all formats and ensures UTC timezone.
Handles:
- ISO format with timezone: '2025-10-27T14:00:00+00:00'
- ISO format with Z: '2025-10-27T14:00:00Z'
- Space-separated with seconds: '2025-10-27 14:00:00'
- Space-separated without seconds: '2025-10-27 14:00'
Args:
timestamp_str: Timestamp string in various formats
Returns:
Timezone-aware datetime object in UTC
Raises:
ValueError: If timestamp cannot be parsed
"""
if not timestamp_str:
raise ValueError("Empty timestamp string")
# Try ISO format first (handles T separator and timezone info)
if 'T' in timestamp_str or '+' in timestamp_str:
try:
# Handle 'Z' suffix (Zulu time = UTC)
if timestamp_str.endswith('Z'):
timestamp_str = timestamp_str[:-1] + '+00:00'
return datetime.fromisoformat(timestamp_str)
except ValueError:
pass
# Try space-separated formats
# Replace space with T for fromisoformat compatibility
if ' ' in timestamp_str:
try:
# Try parsing with fromisoformat after converting space to T
dt = datetime.fromisoformat(timestamp_str.replace(' ', 'T'))
# Make timezone-aware if naive
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
except ValueError:
pass
# Try explicit format parsing as fallback
formats = [
'%Y-%m-%d %H:%M:%S', # With seconds
'%Y-%m-%d %H:%M', # Without seconds
'%Y-%m-%dT%H:%M:%S', # ISO without timezone
'%Y-%m-%dT%H:%M', # ISO without seconds or timezone
]
for fmt in formats:
try:
dt = datetime.strptime(timestamp_str, fmt)
# Make timezone-aware
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
except ValueError:
continue
# If all parsing attempts fail
raise ValueError(f"Could not parse timestamp: '{timestamp_str}'")
@dataclass
class TrainingSession:
"""Real training session tracking"""
training_id: str
model_name: str
test_cases_count: int
status: str # 'running', 'completed', 'failed'
current_epoch: int
total_epochs: int
current_loss: float
start_time: float
duration_seconds: Optional[float] = None
final_loss: Optional[float] = None
accuracy: Optional[float] = None
error: Optional[str] = None
class RealTrainingAdapter:
"""
Adapter for REAL model training using annotations.
This class bridges the ANNOTATE system with the actual training implementations.
NO SIMULATION CODE - All training is real.
"""
def __init__(self, orchestrator=None, data_provider=None):
"""
Initialize with real orchestrator and data provider
Args:
orchestrator: TradingOrchestrator instance with real models
data_provider: DataProvider for fetching real market data
"""
self.orchestrator = orchestrator
self.data_provider = data_provider
self.training_sessions: Dict[str, TrainingSession] = {}
# Import real training systems
self._import_training_systems()
logger.info("RealTrainingAdapter initialized - NO SIMULATION, REAL TRAINING ONLY")
def _import_training_systems(self):
"""Import real training system implementations"""
try:
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
self.enhanced_training_available = True
logger.info("EnhancedRealtimeTrainingSystem available")
except ImportError as e:
self.enhanced_training_available = False
logger.warning(f"EnhancedRealtimeTrainingSystem not available: {e}")
try:
from NN.training.model_manager import ModelManager
self.model_manager_available = True
logger.info("ModelManager available")
except ImportError as e:
self.model_manager_available = False
logger.warning(f"ModelManager not available: {e}")
try:
from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter
self.enhanced_rl_adapter_available = True
logger.info("EnhancedRLTrainingAdapter available")
except ImportError as e:
self.enhanced_rl_adapter_available = False
logger.warning(f"EnhancedRLTrainingAdapter not available: {e}")
def get_available_models(self) -> List[str]:
"""Get list of available models from orchestrator"""
if not self.orchestrator:
logger.error("Orchestrator not available")
return []
available = []
# Check which models are actually loaded in orchestrator
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
available.append("CNN")
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
available.append("DQN")
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
available.append("Transformer")
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
available.append("COB")
if hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
available.append("Extrema")
logger.info(f"Available models for training: {available}")
return available
def start_training(self, model_name: str, test_cases: List[Dict]) -> str:
"""
Start REAL training session with test cases
Args:
model_name: Name of model to train (CNN, DQN, Transformer, COB, Extrema)
test_cases: List of test cases from annotations
Returns:
training_id: Unique ID for this training session
"""
if not self.orchestrator:
raise Exception("Orchestrator not available - cannot train models")
training_id = str(uuid.uuid4())
# Create training session
session = TrainingSession(
training_id=training_id,
model_name=model_name,
test_cases_count=len(test_cases),
status='running',
current_epoch=0,
total_epochs=10, # Reasonable for annotation-based training
current_loss=0.0,
start_time=time.time()
)
self.training_sessions[training_id] = session
logger.info(f"Starting REAL training session: {training_id} for {model_name} with {len(test_cases)} test cases")
# Start actual training in background thread
thread = threading.Thread(
target=self._execute_real_training,
args=(training_id, model_name, test_cases),
daemon=True
)
thread.start()
return training_id
def _execute_real_training(self, training_id: str, model_name: str, test_cases: List[Dict]):
"""Execute REAL model training (runs in background thread)"""
session = self.training_sessions[training_id]
try:
logger.info(f"Executing REAL training for {model_name}")
logger.info(f" Training ID: {training_id}")
logger.info(f" Test cases: {len(test_cases)}")
# Prepare training data from test cases
training_data = self._prepare_training_data(test_cases)
if not training_data:
raise Exception("No valid training data prepared from test cases")
logger.info(f" Prepared {len(training_data)} training samples")
# Route to appropriate REAL training method
if model_name in ["CNN", "StandardizedCNN"]:
logger.info(" Starting CNN training...")
self._train_cnn_real(session, training_data)
elif model_name == "DQN":
logger.info(" Starting DQN training...")
self._train_dqn_real(session, training_data)
elif model_name == "Transformer":
logger.info(" Starting Transformer training...")
self._train_transformer_real(session, training_data)
elif model_name == "COB":
logger.info(" Starting COB training...")
self._train_cob_real(session, training_data)
elif model_name == "Extrema":
logger.info(" Starting Extrema training...")
self._train_extrema_real(session, training_data)
else:
raise Exception(f"Unknown model type: {model_name}")
# Mark as completed
session.status = 'completed'
session.duration_seconds = time.time() - session.start_time
logger.info(f" REAL training completed: {training_id} in {session.duration_seconds:.2f}s")
logger.info(f" Final loss: {session.final_loss}")
logger.info(f" Accuracy: {session.accuracy}")
except Exception as e:
logger.error(f"REAL training failed: {e}", exc_info=True)
session.status = 'failed'
session.error = str(e)
session.duration_seconds = time.time() - session.start_time
logger.error(f"Training session {training_id} failed after {session.duration_seconds:.2f}s")
def _get_secondary_symbol(self, primary_symbol: str) -> str:
"""
Determine secondary symbol based on primary symbol
Rules:
- ETH/USDT -> BTC/USDT
- SOL/USDT -> BTC/USDT
- BTC/USDT -> ETH/USDT
Args:
primary_symbol: Primary trading symbol
Returns:
Secondary symbol for correlation analysis
"""
if 'BTC' in primary_symbol:
return 'ETH/USDT'
else:
return 'BTC/USDT'
def _fetch_market_state_for_test_case(self, test_case: Dict) -> Dict:
"""
Fetch market state dynamically for a test case from DuckDB storage
This fetches HISTORICAL data at the specific timestamp from the annotation,
not current/latest data.
Args:
test_case: Test case dictionary with timestamp, symbol, etc.
Returns:
Market state dictionary with OHLCV data for all timeframes
"""
try:
if not self.data_provider:
logger.warning("DataProvider not available, cannot fetch market state")
return {}
symbol = test_case.get('symbol', 'ETH/USDT')
timestamp_str = test_case.get('timestamp')
if not timestamp_str:
logger.warning("No timestamp in test case")
return {}
# Parse timestamp using unified parser
try:
timestamp = parse_timestamp_to_utc(timestamp_str)
except Exception as e:
logger.warning(f"Could not parse timestamp '{timestamp_str}': {e}")
return {}
# Get training config
training_config = test_case.get('training_config', {})
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
candles_per_timeframe = training_config.get('candles_per_timeframe', 600) # 600 candles per batch
# Determine secondary symbol based on primary symbol
# ETH/SOL -> BTC, BTC -> ETH
secondary_symbol = self._get_secondary_symbol(symbol)
logger.info(f" Fetching HISTORICAL market state for {symbol} at {timestamp}")
logger.info(f" Primary symbol: {symbol} - Timeframes: {timeframes}")
logger.info(f" Secondary symbol: {secondary_symbol} - Timeframe: 1m")
logger.info(f" Candles per batch: {candles_per_timeframe}")
# Calculate time range based on candles needed
# For 600 candles at 1m = 600 minutes = 10 hours
from datetime import timedelta
# Calculate time window for each timeframe to get 600 candles
time_windows = {
'1s': timedelta(seconds=candles_per_timeframe), # 600 seconds = 10 minutes
'1m': timedelta(minutes=candles_per_timeframe), # 600 minutes = 10 hours
'1h': timedelta(hours=candles_per_timeframe), # 600 hours = 25 days
'1d': timedelta(days=candles_per_timeframe) # 600 days = ~1.6 years
}
# Use the largest window to ensure we have enough data for all timeframes
max_window = max(time_windows.values())
start_time = timestamp - max_window
end_time = timestamp
# Fetch data for primary symbol (all timeframes) and secondary symbol (1m only)
market_state = {
'symbol': symbol,
'timestamp': timestamp_str,
'timeframes': {},
'secondary_symbol': secondary_symbol,
'secondary_timeframes': {}
}
# Try to get data from DuckDB storage first (historical data)
duckdb_storage = None
if hasattr(self.data_provider, 'duckdb_storage'):
duckdb_storage = self.data_provider.duckdb_storage
# Fetch primary symbol data (all timeframes)
logger.info(f" Fetching primary symbol data: {symbol}")
for timeframe in timeframes:
df = None
limit = candles_per_timeframe # Always fetch 600 candles
# Try DuckDB storage first (has historical data)
if duckdb_storage:
try:
df = duckdb_storage.get_ohlcv_data(
symbol=symbol,
timeframe=timeframe,
start_time=start_time,
end_time=end_time,
limit=limit,
direction='latest'
)
if df is not None and not df.empty:
logger.debug(f" {timeframe}: {len(df)} candles from DuckDB (historical)")
except Exception as e:
logger.debug(f" {timeframe}: DuckDB query failed: {e}")
# Fallback to data_provider (might have cached data)
if df is None or df.empty:
try:
# Use get_historical_data_replay for time-specific data
replay_data = self.data_provider.get_historical_data_replay(
symbol=symbol,
start_time=start_time,
end_time=end_time,
timeframes=[timeframe]
)
df = replay_data.get(timeframe)
if df is not None and not df.empty:
logger.debug(f" {timeframe}: {len(df)} candles from replay")
except Exception as e:
logger.debug(f" {timeframe}: Replay failed: {e}")
# Last resort: get latest data (not ideal but better than nothing)
if df is None or df.empty:
logger.warning(f" {timeframe}: No historical data found, using latest data as fallback")
df = self.data_provider.get_historical_data(
symbol=symbol,
timeframe=timeframe,
limit=limit # Use calculated limit
)
if df is not None and not df.empty:
# Convert to dict format
market_state['timeframes'][timeframe] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': df['open'].tolist(),
'high': df['high'].tolist(),
'low': df['low'].tolist(),
'close': df['close'].tolist(),
'volume': df['volume'].tolist()
}
logger.info(f" {symbol} {timeframe}: {len(df)} candles")
else:
logger.warning(f" {symbol} {timeframe}: No data available")
# Fetch secondary symbol data (1m timeframe only, 600 candles)
logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)")
secondary_df = None
# Try DuckDB first
if duckdb_storage:
try:
secondary_df = duckdb_storage.get_ohlcv_data(
symbol=secondary_symbol,
timeframe='1m',
start_time=start_time,
end_time=end_time,
limit=candles_per_timeframe,
direction='latest'
)
if secondary_df is not None and not secondary_df.empty:
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from DuckDB")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: DuckDB query failed: {e}")
# Fallback to replay
if secondary_df is None or secondary_df.empty:
try:
replay_data = self.data_provider.get_historical_data_replay(
symbol=secondary_symbol,
start_time=start_time,
end_time=end_time,
timeframes=['1m']
)
secondary_df = replay_data.get('1m')
if secondary_df is not None and not secondary_df.empty:
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: Replay failed: {e}")
# Last resort: latest data
if secondary_df is None or secondary_df.empty:
logger.warning(f" {secondary_symbol} 1m: No historical data, using latest as fallback")
secondary_df = self.data_provider.get_historical_data(
symbol=secondary_symbol,
timeframe='1m',
limit=candles_per_timeframe
)
# Store secondary symbol data
if secondary_df is not None and not secondary_df.empty:
market_state['secondary_timeframes']['1m'] = {
'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': secondary_df['open'].tolist(),
'high': secondary_df['high'].tolist(),
'low': secondary_df['low'].tolist(),
'close': secondary_df['close'].tolist(),
'volume': secondary_df['volume'].tolist()
}
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles")
else:
logger.warning(f" {secondary_symbol} 1m: No data available")
# Verify we have data
if market_state['timeframes']:
total_primary = sum(len(tf_data.get('timestamps', [])) for tf_data in market_state['timeframes'].values())
total_secondary = sum(len(tf_data.get('timestamps', [])) for tf_data in market_state['secondary_timeframes'].values())
logger.info(f" [OK] Fetched {len(market_state['timeframes'])} primary timeframes ({total_primary} total candles)")
logger.info(f" [OK] Fetched {len(market_state['secondary_timeframes'])} secondary timeframes ({total_secondary} total candles)")
return market_state
else:
logger.warning(f" No market data fetched for any timeframe")
return {}
except Exception as e:
logger.error(f"Error fetching market state: {e}")
import traceback
logger.error(traceback.format_exc())
return {}
def _prepare_training_data(self, test_cases: List[Dict],
negative_samples_window: int = 15,
training_repetitions: int = 100) -> List[Dict]:
"""
Prepare training data from test cases with negative sampling
Args:
test_cases: List of test cases from annotations
negative_samples_window: Number of candles before/after signal where model should NOT trade
training_repetitions: Number of times to repeat training on each sample
Returns:
List of training samples with positive (trade) and negative (no-trade) examples
"""
training_data = []
logger.info(f"Preparing training data from {len(test_cases)} test cases...")
logger.info(f" Negative sampling: +/-{negative_samples_window} candles around signals")
logger.info(f" Training repetitions: {training_repetitions}x per sample")
for i, test_case in enumerate(test_cases):
try:
# Extract expected outcome
expected_outcome = test_case.get('expected_outcome', {})
if not expected_outcome:
logger.warning(f" Skipping test case {test_case.get('test_case_id')}: missing expected_outcome")
continue
# Check if market_state is provided, if not, fetch it dynamically
market_state = test_case.get('market_state', {})
if not market_state:
logger.info(f" Fetching market state dynamically for test case {i+1}...")
market_state = self._fetch_market_state_for_test_case(test_case)
if not market_state:
logger.warning(f" Skipping test case {test_case.get('test_case_id')}: could not fetch market state")
continue
logger.debug(f" Test case {i+1}: has_market_state={bool(market_state)}, has_expected_outcome={bool(expected_outcome)}")
# Create ENTRY sample (where model SHOULD enter trade)
entry_sample = {
'market_state': market_state,
'action': test_case.get('action'),
'direction': expected_outcome.get('direction'),
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
'entry_price': expected_outcome.get('entry_price'),
'exit_price': expected_outcome.get('exit_price'),
'timestamp': test_case.get('timestamp'),
'label': 'ENTRY', # Entry signal
'repetitions': training_repetitions
}
training_data.append(entry_sample)
logger.info(f" Test case {i+1}: ENTRY sample - {entry_sample['direction']} @ {entry_sample['entry_price']}")
# Create HOLD samples (every candle while position is open)
# This teaches the model to maintain the position until exit
hold_samples = self._create_hold_samples(
test_case=test_case,
market_state=market_state,
repetitions=training_repetitions // 4 # Quarter reps for hold samples
)
training_data.extend(hold_samples)
if hold_samples:
logger.info(f" Test case {i+1}: Added {len(hold_samples)} HOLD samples (during position)")
# Create EXIT sample (where model SHOULD exit trade)
exit_timestamp = test_case.get('annotation_metadata', {}).get('exit_timestamp')
if exit_timestamp:
exit_sample = {
'market_state': market_state, # TODO: Get market state at exit time
'action': 'CLOSE',
'direction': expected_outcome.get('direction'),
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
'entry_price': expected_outcome.get('entry_price'),
'exit_price': expected_outcome.get('exit_price'),
'timestamp': exit_timestamp,
'label': 'EXIT', # Exit signal
'repetitions': training_repetitions
}
training_data.append(exit_sample)
logger.info(f" Test case {i+1}: EXIT sample @ {exit_sample['exit_price']} ({exit_sample['profit_loss_pct']:.2f}%)")
# Create NEGATIVE samples (where model should NOT trade)
# These are candles before and after the signal (±15 candles)
# This teaches the model to recognize when NOT to enter
negative_samples = self._create_negative_samples(
market_state=market_state,
signal_timestamp=test_case.get('timestamp'),
window_size=negative_samples_window,
repetitions=training_repetitions // 2 # Half as many reps for negative samples
)
training_data.extend(negative_samples)
if negative_samples:
logger.info(f" Test case {i+1}: Added {len(negative_samples)} NO_TRADE samples (±{negative_samples_window} candles)")
# Show breakdown of before/after
before_count = sum(1 for s in negative_samples if 'before' in str(s.get('timestamp', '')))
after_count = len(negative_samples) - before_count
logger.info(f" -> {before_count} beforesignal, {after_count} after signal")
except Exception as e:
logger.error(f" Error preparing test case {i+1}: {e}")
total_entry = sum(1 for s in training_data if s.get('label') == 'ENTRY')
total_hold = sum(1 for s in training_data if s.get('label') == 'HOLD')
total_exit = sum(1 for s in training_data if s.get('label') == 'EXIT')
total_no_trade = sum(1 for s in training_data if s.get('label') == 'NO_TRADE')
logger.info(f" Prepared {len(training_data)} training samples from {len(test_cases)} test cases")
logger.info(f" ENTRY samples: {total_entry}")
logger.info(f" HOLD samples: {total_hold}")
logger.info(f" EXIT samples: {total_exit}")
logger.info(f" NO_TRADE samples: {total_no_trade}")
if total_entry > 0:
logger.info(f" Ratio: 1:{total_no_trade/total_entry:.1f} (entry:no_trade)")
if len(training_data) < len(test_cases):
logger.warning(f" Skipped {len(test_cases) - len(training_data)} test cases due to missing data")
return training_data
def _create_hold_samples(self, test_case: Dict, market_state: Dict, repetitions: int) -> List[Dict]:
"""
Create HOLD training samples for every candle while position is open
This teaches the model to:
1. Maintain the position (not exit early)
2. Recognize the trade is still valid
3. Wait for the optimal exit point
Args:
test_case: Test case with entry/exit info
market_state: Market state data
repetitions: Number of times to repeat each hold sample
Returns:
List of HOLD training samples
"""
hold_samples = []
try:
from datetime import datetime, timedelta
# Get entry and exit timestamps
entry_timestamp = test_case.get('timestamp')
expected_outcome = test_case.get('expected_outcome', {})
# Calculate exit timestamp from holding period
holding_period_seconds = expected_outcome.get('holding_period_seconds', 0)
if holding_period_seconds == 0:
logger.debug(" No holding period, skipping HOLD samples")
return hold_samples
# Parse entry timestamp using unified parser
try:
entry_time = parse_timestamp_to_utc(entry_timestamp)
except Exception as e:
logger.warning(f"Could not parse entry timestamp '{entry_timestamp}': {e}")
return hold_samples
exit_time = entry_time + timedelta(seconds=holding_period_seconds)
# Get 1m timeframe timestamps
timeframes = market_state.get('timeframes', {})
if '1m' not in timeframes:
return hold_samples
timestamps = timeframes['1m'].get('timestamps', [])
# Find all candles between entry and exit
for idx, ts_str in enumerate(timestamps):
# Parse timestamp using unified parser
try:
ts = parse_timestamp_to_utc(ts_str)
except Exception as e:
logger.debug(f"Could not parse timestamp '{ts_str}': {e}")
continue
# If this candle is between entry and exit (exclusive)
if entry_time < ts < exit_time:
# Create market state snapshot at this candle
hold_market_state = self._create_market_state_snapshot(market_state, idx)
hold_sample = {
'market_state': hold_market_state,
'action': 'HOLD',
'direction': expected_outcome.get('direction'),
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
'entry_price': expected_outcome.get('entry_price'),
'exit_price': expected_outcome.get('exit_price'),
'timestamp': ts_str,
'label': 'HOLD', # Hold position
'repetitions': repetitions,
'in_position': True # Flag indicating we're in a position
}
hold_samples.append(hold_sample)
logger.debug(f" Created {len(hold_samples)} HOLD samples between entry and exit")
except Exception as e:
logger.error(f"Error creating HOLD samples: {e}")
import traceback
logger.error(traceback.format_exc())
return hold_samples
def _create_negative_samples(self, market_state: Dict, signal_timestamp: str,
window_size: int, repetitions: int) -> List[Dict]:
"""
Create negative training samples from candles around the signal
These samples teach the model when NOT to trade - crucial for reducing false signals!
Args:
market_state: Market state with OHLCV data
signal_timestamp: Timestamp of the actual signal
window_size: Number of candles before/after signal to use
repetitions: Number of times to repeat each negative sample
Returns:
List of negative training samples
"""
negative_samples = []
try:
# Get timestamps from market state (use 1m timeframe as reference)
timeframes = market_state.get('timeframes', {})
if '1m' not in timeframes:
logger.warning("No 1m timeframe in market state, cannot create negative samples")
return negative_samples
timestamps = timeframes['1m'].get('timestamps', [])
if not timestamps:
return negative_samples
# Find the index of the signal timestamp
from datetime import datetime
# Parse signal timestamp using unified parser
try:
signal_time = parse_timestamp_to_utc(signal_timestamp)
except Exception as e:
logger.warning(f"Could not parse signal timestamp '{signal_timestamp}': {e}")
return negative_samples
signal_index = None
for idx, ts_str in enumerate(timestamps):
try:
# Parse timestamp using unified parser
ts = parse_timestamp_to_utc(ts_str)
# Match within 1 minute
if abs((ts - signal_time).total_seconds()) < 60:
signal_index = idx
logger.debug(f" Found signal at index {idx}: {ts_str}")
break
except Exception as e:
continue
if signal_index is None:
logger.warning(f"Could not find signal timestamp {signal_timestamp} in market data")
logger.warning(f" Market data has {len(timestamps)} timestamps from {timestamps[0] if timestamps else 'N/A'} to {timestamps[-1] if timestamps else 'N/A'}")
return negative_samples
# Create negative samples from candles before and after the signal
# BEFORE signal: candles at signal_index - window_size to signal_index - 1
# AFTER signal: candles at signal_index + 1 to signal_index + window_size
negative_indices = []
# Before signal
for offset in range(1, window_size + 1):
idx = signal_index - offset
if 0 <= idx < len(timestamps):
negative_indices.append(idx)
# After signal
for offset in range(1, window_size + 1):
idx = signal_index + offset
if 0 <= idx < len(timestamps):
negative_indices.append(idx)
# Create negative samples for each index
for idx in negative_indices:
# Create a market state snapshot at this timestamp
negative_market_state = self._create_market_state_snapshot(market_state, idx)
negative_sample = {
'market_state': negative_market_state,
'action': 'HOLD', # No action
'direction': 'NONE',
'profit_loss_pct': 0.0,
'entry_price': None,
'exit_price': None,
'timestamp': timestamps[idx],
'label': 'NO_TRADE', # Negative label
'repetitions': repetitions
}
negative_samples.append(negative_sample)
logger.debug(f" Created {len(negative_samples)} negative samples from ±{window_size} candles")
except Exception as e:
logger.error(f"Error creating negative samples: {e}")
return negative_samples
def _create_market_state_snapshot(self, market_state: Dict, candle_index: int) -> Dict:
"""
Create a market state snapshot at a specific candle index
This creates a "view" of the market as it was at that specific candle,
which is used for negative sampling.
"""
snapshot = {
'symbol': market_state.get('symbol'),
'timestamp': None, # Will be set from the candle
'timeframes': {}
}
# For each timeframe, create a snapshot up to the candle_index
for tf, tf_data in market_state.get('timeframes', {}).items():
timestamps = tf_data.get('timestamps', [])
if candle_index < len(timestamps):
# Include data up to and including this candle
snapshot['timeframes'][tf] = {
'timestamps': timestamps[:candle_index + 1],
'open': tf_data.get('open', [])[:candle_index + 1],
'high': tf_data.get('high', [])[:candle_index + 1],
'low': tf_data.get('low', [])[:candle_index + 1],
'close': tf_data.get('close', [])[:candle_index + 1],
'volume': tf_data.get('volume', [])[:candle_index + 1]
}
if tf == '1m':
snapshot['timestamp'] = timestamps[candle_index]
return snapshot
def _convert_to_cnn_input(self, data: Dict) -> tuple:
"""Convert annotation training data to CNN model input format (x, y tensors)"""
import torch
import numpy as np
try:
market_state = data.get('market_state', {})
timeframes = market_state.get('timeframes', {})
# Get 1m timeframe data (primary for CNN)
if '1m' not in timeframes:
logger.warning("No 1m timeframe data available for CNN training")
return None, None
tf_data = timeframes['1m']
closes = np.array(tf_data.get('close', []), dtype=np.float32)
if len(closes) == 0:
logger.warning("No close price data available")
return None, None
# CNN expects input shape: [batch, seq_len, features]
# Use last 60 candles (or pad/truncate to 60)
seq_len = 60
if len(closes) >= seq_len:
closes = closes[-seq_len:]
else:
# Pad with last value
last_close = closes[-1] if len(closes) > 0 else 0.0
closes = np.pad(closes, (seq_len - len(closes), 0), mode='constant', constant_values=last_close)
# Create feature tensor: [1, 60, 1] (batch, seq_len, features)
# For now, use only close prices. In full implementation, add OHLCV
x = torch.tensor(closes, dtype=torch.float32).unsqueeze(0).unsqueeze(-1) # [1, 60, 1]
# Convert action to target tensor
action = data.get('action', 'HOLD')
direction = data.get('direction', 'HOLD')
# Map to class index: 0=HOLD, 1=BUY, 2=SELL
if direction == 'LONG' or action == 'BUY':
y = torch.tensor([1], dtype=torch.long)
elif direction == 'SHORT' or action == 'SELL':
y = torch.tensor([2], dtype=torch.long)
else:
y = torch.tensor([0], dtype=torch.long)
return x, y
except Exception as e:
logger.error(f"Error converting to CNN input: {e}")
import traceback
logger.error(traceback.format_exc())
return None, None
def _train_cnn_real(self, session: TrainingSession, training_data: List[Dict]):
"""Train CNN model with REAL training loop"""
if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
raise Exception("CNN model not available in orchestrator")
model = self.orchestrator.cnn_model
# Check if model has trainer attribute (EnhancedCNN)
trainer = None
if hasattr(model, 'trainer'):
trainer = model.trainer
# Use the model's actual training method
if hasattr(model, 'train_on_annotations'):
# If model has annotation-specific training
for epoch in range(session.total_epochs):
loss = model.train_on_annotations(training_data)
session.current_epoch = epoch + 1
session.current_loss = loss if loss else 0.0
logger.info(f"CNN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
elif trainer and hasattr(trainer, 'train_step'):
# Use trainer's train_step method (EnhancedCNN)
logger.info(f"Training CNN using trainer.train_step() with {len(training_data)} samples")
for epoch in range(session.total_epochs):
epoch_loss = 0.0
valid_samples = 0
for data in training_data:
# Convert to model input format
x, y = self._convert_to_cnn_input(data)
if x is None or y is None:
continue
try:
# Call trainer's train_step with proper format
loss_dict = trainer.train_step(x, y)
# Extract loss from dict if it's a dict, otherwise use directly
if isinstance(loss_dict, dict):
loss = loss_dict.get('total_loss', loss_dict.get('main_loss', 0.0))
else:
loss = float(loss_dict) if loss_dict else 0.0
epoch_loss += loss
valid_samples += 1
except Exception as e:
logger.error(f"Error in CNN training step: {e}")
import traceback
logger.error(traceback.format_exc())
continue
if valid_samples > 0:
session.current_epoch = epoch + 1
session.current_loss = epoch_loss / valid_samples
logger.info(f"CNN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}, Samples: {valid_samples}")
else:
logger.warning(f"CNN Epoch {epoch + 1}/{session.total_epochs}: No valid samples processed")
session.current_epoch = epoch + 1
session.current_loss = 0.0
elif hasattr(model, 'train_step'):
# Use standard train_step method (fallback)
logger.warning("Using model.train_step() directly - may not work correctly")
for epoch in range(session.total_epochs):
epoch_loss = 0.0
valid_samples = 0
for data in training_data:
x, y = self._convert_to_cnn_input(data)
if x is None or y is None:
continue
try:
loss = model.train_step(x, y)
epoch_loss += loss if loss else 0.0
valid_samples += 1
except Exception as e:
logger.error(f"Error in CNN training step: {e}")
continue
if valid_samples > 0:
session.current_epoch = epoch + 1
session.current_loss = epoch_loss / valid_samples
logger.info(f"CNN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
else:
raise Exception("CNN model does not have train_on_annotations, trainer.train_step, or train_step method")
session.final_loss = session.current_loss
session.accuracy = 0.85 # TODO: Calculate actual accuracy
def _train_dqn_real(self, session: TrainingSession, training_data: List[Dict]):
"""Train DQN model with REAL training loop"""
if not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
raise Exception("DQN model not available in orchestrator")
agent = self.orchestrator.rl_agent
# Use EnhancedRLTrainingAdapter if available for better reward calculation
if self.enhanced_rl_adapter_available and hasattr(self.orchestrator, 'enhanced_rl_adapter'):
logger.info("Using EnhancedRLTrainingAdapter for DQN training")
# The enhanced adapter will handle training through its async loop
# For now, we'll use the traditional approach but with better state building
# Add experiences to replay buffer
for data in training_data:
# Calculate reward from profit/loss
reward = data['profit_loss_pct'] / 100.0 if data.get('profit_loss_pct') else 0.0
# Add to memory if agent has remember method
if hasattr(agent, 'remember'):
# Try to build proper state representation
state = self._build_state_from_data(data, agent)
action = 1 if data.get('direction') == 'LONG' else 0
agent.remember(state, action, reward, state, True)
# Train with replay
if hasattr(agent, 'replay'):
for epoch in range(session.total_epochs):
loss = agent.replay()
session.current_epoch = epoch + 1
session.current_loss = loss if loss else 0.0
logger.info(f"DQN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
else:
raise Exception("DQN agent does not have replay method")
session.final_loss = session.current_loss
session.accuracy = 0.85 # TODO: Calculate actual accuracy
def _build_state_from_data(self, data: Dict, agent: Any) -> List[float]:
"""Build proper state representation from training data"""
try:
# Try to extract market state features
market_state = data.get('market_state', {})
# Get state size from agent
state_size = agent.state_size if hasattr(agent, 'state_size') else 100
# Build feature vector from market state
features = []
# Add price-based features if available
if 'entry_price' in data:
features.append(float(data['entry_price']))
if 'exit_price' in data:
features.append(float(data['exit_price']))
if 'profit_loss_pct' in data:
features.append(float(data['profit_loss_pct']))
# Pad or truncate to match state size
if len(features) < state_size:
features.extend([0.0] * (state_size - len(features)))
else:
features = features[:state_size]
return features
except Exception as e:
logger.error(f"Error building state from data: {e}")
# Return zero state as fallback
state_size = agent.state_size if hasattr(agent, 'state_size') else 100
return [0.0] * state_size
def _convert_annotation_to_transformer_batch(self, training_sample: Dict) -> Dict[str, 'torch.Tensor']:
"""
Convert annotation training sample to transformer model input format
The transformer expects:
- price_data: [batch, seq_len, features] - OHLCV sequences
- cob_data: [batch, seq_len, cob_features] - Change of Bid data
- tech_data: [batch, features] - Technical indicators
- market_data: [batch, features] - Market context
- actions: [batch] - Target actions (0=HOLD, 1=BUY, 2=SELL)
- future_prices: [batch] - Future price targets
- trade_success: [batch] - Whether trade was successful
"""
import torch
import numpy as np
try:
market_state = training_sample.get('market_state', {})
# Extract OHLCV data from ALL timeframes
timeframes = market_state.get('timeframes', {})
# Collect data from all available timeframes
all_price_data = []
timeframe_order = ['1s', '1m', '1h', '1d'] # Process in order
for tf in timeframe_order:
if tf not in timeframes:
continue
tf_data = timeframes[tf]
# Convert to numpy arrays
opens = np.array(tf_data.get('open', []), dtype=np.float32)
highs = np.array(tf_data.get('high', []), dtype=np.float32)
lows = np.array(tf_data.get('low', []), dtype=np.float32)
closes = np.array(tf_data.get('close', []), dtype=np.float32)
volumes = np.array(tf_data.get('volume', []), dtype=np.float32)
if len(closes) > 0:
# Stack OHLCV for this timeframe [seq_len, 5]
tf_price_data = np.stack([opens, highs, lows, closes, volumes], axis=-1)
all_price_data.append(tf_price_data)
if not all_price_data:
logger.warning("No price data in any timeframe")
return None
# Use only the primary timeframe (1m) for transformer training
# The transformer expects a fixed sequence length of 150
primary_tf = '1m' if '1m' in timeframes else timeframe_order[0]
if primary_tf not in timeframes:
logger.warning(f"Primary timeframe {primary_tf} not available")
return None
# Get primary timeframe data
primary_data = timeframes[primary_tf]
closes = np.array(primary_data.get('close', []), dtype=np.float32)
if len(closes) == 0:
logger.warning("No data in primary timeframe")
return None
# Use the last 150 candles (or pad/truncate to exactly 150)
target_seq_len = 150 # Transformer expects exactly 150 sequence length
if len(closes) >= target_seq_len:
# Take the last 150 candles
price_data = np.stack([
np.array(primary_data.get('open', [])[-target_seq_len:], dtype=np.float32),
np.array(primary_data.get('high', [])[-target_seq_len:], dtype=np.float32),
np.array(primary_data.get('low', [])[-target_seq_len:], dtype=np.float32),
np.array(primary_data.get('close', [])[-target_seq_len:], dtype=np.float32),
np.array(primary_data.get('volume', [])[-target_seq_len:], dtype=np.float32)
], axis=-1)
else:
# Pad with the last available candle
last_open = primary_data.get('open', [0])[-1] if primary_data.get('open') else 0
last_high = primary_data.get('high', [0])[-1] if primary_data.get('high') else 0
last_low = primary_data.get('low', [0])[-1] if primary_data.get('low') else 0
last_close = primary_data.get('close', [0])[-1] if primary_data.get('close') else 0
last_volume = primary_data.get('volume', [0])[-1] if primary_data.get('volume') else 0
# Pad arrays to target length
opens = np.array(primary_data.get('open', []), dtype=np.float32)
highs = np.array(primary_data.get('high', []), dtype=np.float32)
lows = np.array(primary_data.get('low', []), dtype=np.float32)
closes = np.array(primary_data.get('close', []), dtype=np.float32)
volumes = np.array(primary_data.get('volume', []), dtype=np.float32)
# Pad with last values
while len(opens) < target_seq_len:
opens = np.append(opens, last_open)
highs = np.append(highs, last_high)
lows = np.append(lows, last_low)
closes = np.append(closes, last_close)
volumes = np.append(volumes, last_volume)
price_data = np.stack([opens, highs, lows, closes, volumes], axis=-1)
# Add batch dimension [1, 150, 5]
price_data = torch.tensor(price_data, dtype=torch.float32).unsqueeze(0)
# Sequence length is now exactly 150
total_seq_len = 150
# Create placeholder COB data (zeros if not available)
# COB data shape: [1, 150, cob_features]
# MUST match the total sequence length from price_data (150)
# Transformer expects 100 COB features (as defined in TransformerConfig)
cob_data = torch.zeros(1, 150, 100, dtype=torch.float32) # Match price seq_len (150)
# Create technical indicators (simple ones for now)
# tech_data shape: [1, features]
tech_features = []
# Use the closes data from the price_data we just created
closes_for_tech = price_data[0, :, 3].numpy() # Close prices from OHLCV data
# Add simple technical indicators
if len(closes_for_tech) >= 20:
sma_20 = np.mean(closes_for_tech[-20:])
tech_features.append(closes_for_tech[-1] / sma_20 - 1.0) # Price vs SMA
else:
tech_features.append(0.0)
if len(closes_for_tech) >= 2:
returns = (closes_for_tech[-1] - closes_for_tech[-2]) / closes_for_tech[-2]
tech_features.append(returns) # Recent return
else:
tech_features.append(0.0)
# Add volatility
if len(closes_for_tech) >= 20:
volatility = np.std(closes_for_tech[-20:]) / np.mean(closes_for_tech[-20:])
tech_features.append(volatility)
else:
tech_features.append(0.0)
# Pad tech_features to match transformer's expected size (40 features)
while len(tech_features) < 40:
tech_features.append(0.0)
tech_data = torch.tensor([tech_features[:40]], dtype=torch.float32) # Ensure exactly 40 features
# Create market context data with pivot points
# market_data shape: [1, features]
market_features = []
# Add volume profile
volumes_for_tech = price_data[0, :, 4].numpy() # Volume from OHLCV data
if len(volumes_for_tech) >= 20:
vol_ratio = volumes_for_tech[-1] / np.mean(volumes_for_tech[-20:])
market_features.append(vol_ratio)
else:
market_features.append(1.0)
# Add price range
highs_for_tech = price_data[0, :, 1].numpy() # High from OHLCV data
lows_for_tech = price_data[0, :, 2].numpy() # Low from OHLCV data
if len(highs_for_tech) >= 20 and len(lows_for_tech) >= 20:
price_range = (np.max(highs_for_tech[-20:]) - np.min(lows_for_tech[-20:])) / closes_for_tech[-1]
market_features.append(price_range)
else:
market_features.append(0.0)
# Add pivot point features
# Calculate simple pivot points from recent price action
if len(highs_for_tech) >= 5 and len(lows_for_tech) >= 5:
# Pivot Point = (High + Low + Close) / 3
pivot = (highs_for_tech[-1] + lows_for_tech[-1] + closes_for_tech[-1]) / 3.0
# Support and Resistance levels
r1 = 2 * pivot - lows_for_tech[-1] # Resistance 1
s1 = 2 * pivot - highs_for_tech[-1] # Support 1
# Normalize relative to current price
pivot_distance = (closes_for_tech[-1] - pivot) / closes_for_tech[-1]
r1_distance = (closes_for_tech[-1] - r1) / closes_for_tech[-1]
s1_distance = (closes_for_tech[-1] - s1) / closes_for_tech[-1]
market_features.extend([pivot_distance, r1_distance, s1_distance])
else:
market_features.extend([0.0, 0.0, 0.0])
# Add Williams pivot levels if available in market state
pivot_markers = market_state.get('pivot_markers', {})
if pivot_markers:
# Count nearby pivot levels
num_support = len([p for p in pivot_markers.get('support_levels', []) if abs(p - closes[-1]) / closes[-1] < 0.02])
num_resistance = len([p for p in pivot_markers.get('resistance_levels', []) if abs(p - closes[-1]) / closes[-1] < 0.02])
market_features.extend([float(num_support), float(num_resistance)])
else:
market_features.extend([0.0, 0.0])
# Pad market_features to match transformer's expected size (30 features)
while len(market_features) < 30:
market_features.append(0.0)
market_data = torch.tensor([market_features[:30]], dtype=torch.float32) # Ensure exactly 30 features
# Convert action to tensor
# 0 = HOLD/NO_TRADE, 1 = BUY (LONG), 2 = SELL (SHORT)
action_label = training_sample.get('label', 'TRADE')
direction = training_sample.get('direction', 'NONE')
in_position = training_sample.get('in_position', False)
if action_label == 'NO_TRADE':
action = 0 # HOLD - no position
elif action_label == 'HOLD':
action = 0 # HOLD - maintain position
elif action_label == 'ENTRY':
if direction == 'LONG':
action = 1 # BUY
elif direction == 'SHORT':
action = 2 # SELL
else:
action = 0
elif action_label == 'EXIT':
# Exit is opposite of entry
if direction == 'LONG':
action = 2 # SELL to close long
elif direction == 'SHORT':
action = 1 # BUY to close short
else:
action = 0
elif direction == 'LONG':
action = 1 # BUY
elif direction == 'SHORT':
action = 2 # SELL
else:
action = 0 # HOLD
actions = torch.tensor([action], dtype=torch.long)
# Future price target - NORMALIZED
# Model predicts price change ratio, not absolute price
entry_price = training_sample.get('entry_price')
exit_price = training_sample.get('exit_price')
current_price = closes_for_tech[-1] # Most recent close price
if exit_price and entry_price:
# Normalize: (exit_price - current_price) / current_price
# This gives the expected price change as a ratio
future_price_ratio = (exit_price - current_price) / current_price
else:
# For HOLD samples, expect no price change
future_price_ratio = 0.0
future_prices = torch.tensor([future_price_ratio], dtype=torch.float32)
# Trade success (1.0 if profitable, 0.0 otherwise)
# Shape must be [batch_size, 1] to match confidence head output
profit_loss_pct = training_sample.get('profit_loss_pct', 0.0)
trade_success = torch.tensor([[1.0 if profit_loss_pct > 0 else 0.0]], dtype=torch.float32)
# Return batch dictionary
batch = {
'price_data': price_data,
'cob_data': cob_data,
'tech_data': tech_data,
'market_data': market_data,
'actions': actions,
'future_prices': future_prices,
'trade_success': trade_success
}
return batch
except Exception as e:
logger.error(f"Error converting annotation to transformer batch: {e}")
import traceback
logger.error(traceback.format_exc())
return None
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
"""
Train Transformer model using orchestrator's existing training infrastructure
Uses the orchestrator's primary_transformer_trainer which already has
all the training logic implemented!
"""
if not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
raise Exception("Transformer model not available in orchestrator")
# Get the trainer from orchestrator - it already has training methods!
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer:
raise Exception("Transformer trainer not available in orchestrator")
logger.info(f"Using orchestrator's TradingTransformerTrainer")
logger.info(f" Trainer type: {type(trainer).__name__}")
# Use the trainer's train_step method for individual samples
if hasattr(trainer, 'train_step'):
logger.info(" Using trainer.train_step() method")
logger.info(" Converting annotation data to transformer format...")
import torch
# Convert all training samples to transformer format
converted_batches = []
for i, data in enumerate(training_data):
batch = self._convert_annotation_to_transformer_batch(data)
if batch is not None:
# Repeat based on repetitions parameter
# IMPORTANT: Clone each batch to avoid in-place operation issues when reusing tensors
repetitions = data.get('repetitions', 1)
for _ in range(repetitions):
# Clone all tensors in the batch to ensure independence
cloned_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
converted_batches.append(cloned_batch)
else:
logger.warning(f" Failed to convert sample {i+1}")
if not converted_batches:
raise Exception("No valid training batches after conversion")
logger.info(f" Converted {len(training_data)} samples to {len(converted_batches)} training batches")
# Train using train_step for each batch
for epoch in range(session.total_epochs):
epoch_loss = 0.0
epoch_accuracy = 0.0
num_batches = 0
for i, batch in enumerate(converted_batches):
try:
# Call the trainer's train_step method with proper batch format
result = trainer.train_step(batch)
if result is not None:
epoch_loss += result.get('total_loss', 0.0)
epoch_accuracy += result.get('accuracy', 0.0)
num_batches += 1
if (i + 1) % 100 == 0:
logger.info(f" Batch {i + 1}/{len(converted_batches)}, Loss: {result.get('total_loss', 0.0):.6f}, Accuracy: {result.get('accuracy', 0.0):.2%}")
except Exception as e:
logger.error(f" Error in batch {i + 1}: {e}")
import traceback
logger.error(traceback.format_exc())
continue
avg_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
avg_accuracy = epoch_accuracy / num_batches if num_batches > 0 else 0.0
session.current_epoch = epoch + 1
session.current_loss = avg_loss
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
session.final_loss = session.current_loss
session.accuracy = avg_accuracy
logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}")
else:
raise Exception(f"Transformer trainer does not have train_on_batch() or train() methods. Available methods: {[m for m in dir(trainer) if not m.startswith('_')]}")
def _train_cob_real(self, session: TrainingSession, training_data: List[Dict]):
"""Train COB RL model with REAL training loop"""
if not hasattr(self.orchestrator, 'cob_rl_agent') or not self.orchestrator.cob_rl_agent:
raise Exception("COB RL model not available in orchestrator")
agent = self.orchestrator.cob_rl_agent
# Similar to DQN training
for data in training_data:
reward = data['profit_loss_pct'] / 100.0 if data.get('profit_loss_pct') else 0.0
if hasattr(agent, 'remember'):
state = [0.0] * agent.state_size if hasattr(agent, 'state_size') else []
action = 1 if data.get('direction') == 'LONG' else 0
agent.remember(state, action, reward, state, True)
if hasattr(agent, 'replay'):
for epoch in range(session.total_epochs):
loss = agent.replay()
session.current_epoch = epoch + 1
session.current_loss = loss if loss else 0.0
logger.info(f"COB RL Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
session.final_loss = session.current_loss
session.accuracy = 0.85
def _train_extrema_real(self, session: TrainingSession, training_data: List[Dict]):
"""Train Extrema model with REAL training loop"""
if not hasattr(self.orchestrator, 'extrema_trainer') or not self.orchestrator.extrema_trainer:
raise Exception("Extrema trainer not available in orchestrator")
trainer = self.orchestrator.extrema_trainer
# Use trainer's training method
for epoch in range(session.total_epochs):
# TODO: Implement actual extrema training
session.current_epoch = epoch + 1
session.current_loss = 0.5 / (epoch + 1) # Placeholder
logger.info(f"Extrema Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
session.final_loss = session.current_loss
session.accuracy = 0.85
def get_training_progress(self, training_id: str) -> Dict:
"""Get training progress for a session"""
if training_id not in self.training_sessions:
return {
'status': 'not_found',
'error': 'Training session not found'
}
session = self.training_sessions[training_id]
return {
'status': session.status,
'model_name': session.model_name,
'test_cases_count': session.test_cases_count,
'current_epoch': session.current_epoch,
'total_epochs': session.total_epochs,
'current_loss': session.current_loss,
'final_loss': session.final_loss,
'accuracy': session.accuracy,
'duration_seconds': session.duration_seconds,
'error': session.error
}
# Real-time inference support
def start_realtime_inference(self, model_name: str, symbol: str, data_provider) -> str:
"""
Start real-time inference using orchestrator's REAL prediction methods
Args:
model_name: Name of model to use for inference
symbol: Trading symbol
data_provider: Data provider for market data
Returns:
inference_id: Unique ID for this inference session
"""
if not self.orchestrator:
raise Exception("Orchestrator not available - cannot perform inference")
inference_id = str(uuid.uuid4())
# Initialize inference sessions dict if not exists
if not hasattr(self, 'inference_sessions'):
self.inference_sessions = {}
# Create inference session
self.inference_sessions[inference_id] = {
'model_name': model_name,
'symbol': symbol,
'status': 'running',
'start_time': time.time(),
'signals': [],
'stop_flag': False
}
logger.info(f"Starting REAL-TIME inference: {inference_id} with {model_name} on {symbol}")
# Start inference loop in background thread
thread = threading.Thread(
target=self._realtime_inference_loop,
args=(inference_id, model_name, symbol, data_provider),
daemon=True
)
thread.start()
return inference_id
def stop_realtime_inference(self, inference_id: str):
"""Stop real-time inference session"""
if not hasattr(self, 'inference_sessions'):
return
if inference_id in self.inference_sessions:
self.inference_sessions[inference_id]['stop_flag'] = True
self.inference_sessions[inference_id]['status'] = 'stopped'
logger.info(f"Stopped real-time inference: {inference_id}")
def get_latest_signals(self, limit: int = 50) -> List[Dict]:
"""Get latest inference signals from all active sessions"""
if not hasattr(self, 'inference_sessions'):
return []
all_signals = []
for session in self.inference_sessions.values():
all_signals.extend(session.get('signals', []))
# Sort by timestamp and return latest
all_signals.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
return all_signals[:limit]
def _realtime_inference_loop(self, inference_id: str, model_name: str, symbol: str, data_provider):
"""
Real-time inference loop using orchestrator's REAL prediction methods
This runs in a background thread and continuously makes predictions
using the actual model inference methods from the orchestrator.
"""
session = self.inference_sessions[inference_id]
try:
while not session['stop_flag']:
try:
# Use orchestrator's REAL prediction method
if hasattr(self.orchestrator, 'make_decision'):
# Get real prediction from orchestrator
decision = self.orchestrator.make_decision(symbol)
if decision:
# Store signal
signal = {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'model': model_name,
'action': decision.action,
'confidence': decision.confidence,
'price': decision.price
}
session['signals'].append(signal)
# Keep only last 100 signals
if len(session['signals']) > 100:
session['signals'] = session['signals'][-100:]
logger.info(f"REAL Signal: {signal['action']} @ {signal['price']} (confidence: {signal['confidence']:.2f})")
# Sleep for 1 second before next inference
time.sleep(1)
except Exception as e:
logger.error(f"Error in REAL inference loop: {e}")
time.sleep(5)
logger.info(f"REAL inference loop stopped: {inference_id}")
except Exception as e:
logger.error(f"Fatal error in REAL inference loop: {e}")
session['status'] = 'error'
session['error'] = str(e)