train wip

This commit is contained in:
Dobromir Popov
2025-10-25 00:47:56 +03:00
parent e816cb9795
commit bd213c44e0

View File

@@ -325,8 +325,8 @@ class RealTrainingAdapter:
logger.debug(f" Test case {i+1}: has_market_state={bool(market_state)}, has_expected_outcome={bool(expected_outcome)}")
# Create POSITIVE sample (where model SHOULD trade)
positive_sample = {
# Create ENTRY sample (where model SHOULD enter trade)
entry_sample = {
'market_state': market_state,
'action': test_case.get('action'),
'direction': expected_outcome.get('direction'),
@@ -334,12 +334,40 @@ class RealTrainingAdapter:
'entry_price': expected_outcome.get('entry_price'),
'exit_price': expected_outcome.get('exit_price'),
'timestamp': test_case.get('timestamp'),
'label': 'TRADE', # Positive label
'label': 'ENTRY', # Entry signal
'repetitions': training_repetitions
}
training_data.append(positive_sample)
logger.debug(f"Positive sample: {positive_sample['direction']} @ {positive_sample['entry_price']} -> {positive_sample['exit_price']} ({positive_sample['profit_loss_pct']:.2f}%)")
training_data.append(entry_sample)
logger.debug(f"Entry sample: {entry_sample['direction']} @ {entry_sample['entry_price']}")
# Create HOLD samples (every candle while position is open)
# This teaches the model to maintain the position until exit
hold_samples = self._create_hold_samples(
test_case=test_case,
market_state=market_state,
repetitions=training_repetitions // 4 # Quarter reps for hold samples
)
training_data.extend(hold_samples)
logger.debug(f" 📊 Added {len(hold_samples)} HOLD samples (during position)")
# Create EXIT sample (where model SHOULD exit trade)
exit_timestamp = test_case.get('annotation_metadata', {}).get('exit_timestamp')
if exit_timestamp:
exit_sample = {
'market_state': market_state, # TODO: Get market state at exit time
'action': 'CLOSE',
'direction': expected_outcome.get('direction'),
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
'entry_price': expected_outcome.get('entry_price'),
'exit_price': expected_outcome.get('exit_price'),
'timestamp': exit_timestamp,
'label': 'EXIT', # Exit signal
'repetitions': training_repetitions
}
training_data.append(exit_sample)
logger.debug(f" ✅ Exit sample @ {exit_sample['exit_price']} ({exit_sample['profit_loss_pct']:.2f}%)")
# Create NEGATIVE samples (where model should NOT trade)
# These are candles before and after the signal
@@ -356,19 +384,110 @@ class RealTrainingAdapter:
except Exception as e:
logger.error(f"❌ Error preparing test case {i+1}: {e}")
total_positive = sum(1 for s in training_data if s.get('label') == 'TRADE')
total_negative = sum(1 for s in training_data if s.get('label') == 'NO_TRADE')
total_entry = sum(1 for s in training_data if s.get('label') == 'ENTRY')
total_hold = sum(1 for s in training_data if s.get('label') == 'HOLD')
total_exit = sum(1 for s in training_data if s.get('label') == 'EXIT')
total_no_trade = sum(1 for s in training_data if s.get('label') == 'NO_TRADE')
logger.info(f"✅ Prepared {len(training_data)} training samples from {len(test_cases)} test cases")
logger.info(f" Positive samples (TRADE): {total_positive}")
logger.info(f" Negative samples (NO_TRADE): {total_negative}")
logger.info(f" Ratio: 1:{total_negative/total_positive:.1f} (positive:negative)")
logger.info(f" ENTRY samples: {total_entry}")
logger.info(f" HOLD samples: {total_hold}")
logger.info(f" EXIT samples: {total_exit}")
logger.info(f" NO_TRADE samples: {total_no_trade}")
if total_entry > 0:
logger.info(f" Ratio: 1:{total_no_trade/total_entry:.1f} (entry:no_trade)")
if len(training_data) < len(test_cases):
logger.warning(f"⚠️ Skipped {len(test_cases) - len(training_data)} test cases due to missing data")
return training_data
def _create_hold_samples(self, test_case: Dict, market_state: Dict, repetitions: int) -> List[Dict]:
"""
Create HOLD training samples for every candle while position is open
This teaches the model to:
1. Maintain the position (not exit early)
2. Recognize the trade is still valid
3. Wait for the optimal exit point
Args:
test_case: Test case with entry/exit info
market_state: Market state data
repetitions: Number of times to repeat each hold sample
Returns:
List of HOLD training samples
"""
hold_samples = []
try:
from datetime import datetime, timedelta
# Get entry and exit timestamps
entry_timestamp = test_case.get('timestamp')
expected_outcome = test_case.get('expected_outcome', {})
# Calculate exit timestamp from holding period
holding_period_seconds = expected_outcome.get('holding_period_seconds', 0)
if holding_period_seconds == 0:
logger.debug(" No holding period, skipping HOLD samples")
return hold_samples
# Parse entry timestamp
try:
if 'T' in entry_timestamp:
entry_time = datetime.fromisoformat(entry_timestamp.replace('Z', '+00:00'))
else:
entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M:%S')
entry_time = entry_time.replace(tzinfo=pytz.UTC)
except Exception as e:
logger.warning(f"Could not parse entry timestamp '{entry_timestamp}': {e}")
return hold_samples
exit_time = entry_time + timedelta(seconds=holding_period_seconds)
# Get 1m timeframe timestamps
timeframes = market_state.get('timeframes', {})
if '1m' not in timeframes:
return hold_samples
timestamps = timeframes['1m'].get('timestamps', [])
# Find all candles between entry and exit
for idx, ts_str in enumerate(timestamps):
ts = datetime.fromisoformat(ts_str.replace(' ', 'T'))
# If this candle is between entry and exit (exclusive)
if entry_time < ts < exit_time:
# Create market state snapshot at this candle
hold_market_state = self._create_market_state_snapshot(market_state, idx)
hold_sample = {
'market_state': hold_market_state,
'action': 'HOLD',
'direction': expected_outcome.get('direction'),
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
'entry_price': expected_outcome.get('entry_price'),
'exit_price': expected_outcome.get('exit_price'),
'timestamp': ts_str,
'label': 'HOLD', # Hold position
'repetitions': repetitions,
'in_position': True # Flag indicating we're in a position
}
hold_samples.append(hold_sample)
logger.debug(f" Created {len(hold_samples)} HOLD samples between entry and exit")
except Exception as e:
logger.error(f"Error creating HOLD samples: {e}")
import traceback
logger.error(traceback.format_exc())
return hold_samples
def _create_negative_samples(self, market_state: Dict, signal_timestamp: str,
window_size: int, repetitions: int) -> List[Dict]:
"""
@@ -400,17 +519,39 @@ class RealTrainingAdapter:
# Find the index of the signal timestamp
from datetime import datetime
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
# Parse signal timestamp - handle different formats
try:
if 'T' in signal_timestamp:
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
else:
signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M:%S')
signal_time = signal_time.replace(tzinfo=pytz.UTC)
except Exception as e:
logger.warning(f"Could not parse signal timestamp '{signal_timestamp}': {e}")
return negative_samples
signal_index = None
for idx, ts_str in enumerate(timestamps):
ts = datetime.fromisoformat(ts_str.replace(' ', 'T'))
if abs((ts - signal_time).total_seconds()) < 60: # Within 1 minute
signal_index = idx
break
try:
# Parse timestamp from market data
if 'T' in ts_str:
ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00'))
else:
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
ts = ts.replace(tzinfo=pytz.UTC)
# Match within 1 minute
if abs((ts - signal_time).total_seconds()) < 60:
signal_index = idx
logger.debug(f" Found signal at index {idx}: {ts_str}")
break
except Exception as e:
continue
if signal_index is None:
logger.warning(f"Could not find signal timestamp in market data")
logger.warning(f"Could not find signal timestamp {signal_timestamp} in market data")
logger.warning(f" Market data has {len(timestamps)} timestamps from {timestamps[0] if timestamps else 'N/A'} to {timestamps[-1] if timestamps else 'N/A'}")
return negative_samples
# Create negative samples from candles before and after the signal
@@ -596,6 +737,224 @@ class RealTrainingAdapter:
state_size = agent.state_size if hasattr(agent, 'state_size') else 100
return [0.0] * state_size
def _convert_annotation_to_transformer_batch(self, training_sample: Dict) -> Dict[str, 'torch.Tensor']:
"""
Convert annotation training sample to transformer model input format
The transformer expects:
- price_data: [batch, seq_len, features] - OHLCV sequences
- cob_data: [batch, seq_len, cob_features] - Change of Bid data
- tech_data: [batch, features] - Technical indicators
- market_data: [batch, features] - Market context
- actions: [batch] - Target actions (0=HOLD, 1=BUY, 2=SELL)
- future_prices: [batch] - Future price targets
- trade_success: [batch] - Whether trade was successful
"""
import torch
import numpy as np
try:
market_state = training_sample.get('market_state', {})
# Extract OHLCV data from ALL timeframes
timeframes = market_state.get('timeframes', {})
# Collect data from all available timeframes
all_price_data = []
timeframe_order = ['1s', '1m', '1h', '1d'] # Process in order
for tf in timeframe_order:
if tf not in timeframes:
continue
tf_data = timeframes[tf]
# Convert to numpy arrays
opens = np.array(tf_data.get('open', []), dtype=np.float32)
highs = np.array(tf_data.get('high', []), dtype=np.float32)
lows = np.array(tf_data.get('low', []), dtype=np.float32)
closes = np.array(tf_data.get('close', []), dtype=np.float32)
volumes = np.array(tf_data.get('volume', []), dtype=np.float32)
if len(closes) > 0:
# Stack OHLCV for this timeframe [seq_len, 5]
tf_price_data = np.stack([opens, highs, lows, closes, volumes], axis=-1)
all_price_data.append(tf_price_data)
if not all_price_data:
logger.warning("No price data in any timeframe")
return None
# Concatenate all timeframes along sequence dimension
# This gives the model multi-timeframe context
price_data = np.concatenate(all_price_data, axis=0)
# Add batch dimension [1, total_seq_len, 5]
price_data = torch.tensor(price_data, dtype=torch.float32).unsqueeze(0)
# Get primary timeframe for reference
primary_tf = '1m' if '1m' in timeframes else timeframe_order[0]
closes = np.array(timeframes[primary_tf].get('close', []), dtype=np.float32)
# Create placeholder COB data (zeros if not available)
# COB data shape: [1, seq_len, cob_features]
# Transformer expects 100 COB features (as defined in TransformerConfig)
cob_data = torch.zeros(1, len(closes), 100, dtype=torch.float32) # 100 COB features
# Create technical indicators (simple ones for now)
# tech_data shape: [1, features]
tech_features = []
# Add simple technical indicators
if len(closes) >= 20:
sma_20 = np.mean(closes[-20:])
tech_features.append(closes[-1] / sma_20 - 1.0) # Price vs SMA
else:
tech_features.append(0.0)
if len(closes) >= 2:
returns = (closes[-1] - closes[-2]) / closes[-2]
tech_features.append(returns) # Recent return
else:
tech_features.append(0.0)
# Add volatility
if len(closes) >= 20:
volatility = np.std(closes[-20:]) / np.mean(closes[-20:])
tech_features.append(volatility)
else:
tech_features.append(0.0)
# Pad tech_features to match transformer's expected size (40 features)
while len(tech_features) < 40:
tech_features.append(0.0)
tech_data = torch.tensor([tech_features[:40]], dtype=torch.float32) # Ensure exactly 40 features
# Create market context data with pivot points
# market_data shape: [1, features]
market_features = []
# Add volume profile
primary_volumes = np.array(timeframes[primary_tf].get('volume', []), dtype=np.float32)
if len(primary_volumes) >= 20:
vol_ratio = primary_volumes[-1] / np.mean(primary_volumes[-20:])
market_features.append(vol_ratio)
else:
market_features.append(1.0)
# Add price range
primary_highs = np.array(timeframes[primary_tf].get('high', []), dtype=np.float32)
primary_lows = np.array(timeframes[primary_tf].get('low', []), dtype=np.float32)
if len(primary_highs) >= 20 and len(primary_lows) >= 20:
price_range = (np.max(primary_highs[-20:]) - np.min(primary_lows[-20:])) / closes[-1]
market_features.append(price_range)
else:
market_features.append(0.0)
# Add pivot point features
# Calculate simple pivot points from recent price action
if len(primary_highs) >= 5 and len(primary_lows) >= 5:
# Pivot Point = (High + Low + Close) / 3
pivot = (primary_highs[-1] + primary_lows[-1] + closes[-1]) / 3.0
# Support and Resistance levels
r1 = 2 * pivot - primary_lows[-1] # Resistance 1
s1 = 2 * pivot - primary_highs[-1] # Support 1
# Normalize relative to current price
pivot_distance = (closes[-1] - pivot) / closes[-1]
r1_distance = (closes[-1] - r1) / closes[-1]
s1_distance = (closes[-1] - s1) / closes[-1]
market_features.extend([pivot_distance, r1_distance, s1_distance])
else:
market_features.extend([0.0, 0.0, 0.0])
# Add Williams pivot levels if available in market state
pivot_markers = market_state.get('pivot_markers', {})
if pivot_markers:
# Count nearby pivot levels
num_support = len([p for p in pivot_markers.get('support_levels', []) if abs(p - closes[-1]) / closes[-1] < 0.02])
num_resistance = len([p for p in pivot_markers.get('resistance_levels', []) if abs(p - closes[-1]) / closes[-1] < 0.02])
market_features.extend([float(num_support), float(num_resistance)])
else:
market_features.extend([0.0, 0.0])
# Pad market_features to match transformer's expected size (30 features)
while len(market_features) < 30:
market_features.append(0.0)
market_data = torch.tensor([market_features[:30]], dtype=torch.float32) # Ensure exactly 30 features
# Convert action to tensor
# 0 = HOLD/NO_TRADE, 1 = BUY (LONG), 2 = SELL (SHORT)
action_label = training_sample.get('label', 'TRADE')
direction = training_sample.get('direction', 'NONE')
in_position = training_sample.get('in_position', False)
if action_label == 'NO_TRADE':
action = 0 # HOLD - no position
elif action_label == 'HOLD':
action = 0 # HOLD - maintain position
elif action_label == 'ENTRY':
if direction == 'LONG':
action = 1 # BUY
elif direction == 'SHORT':
action = 2 # SELL
else:
action = 0
elif action_label == 'EXIT':
# Exit is opposite of entry
if direction == 'LONG':
action = 2 # SELL to close long
elif direction == 'SHORT':
action = 1 # BUY to close short
else:
action = 0
elif direction == 'LONG':
action = 1 # BUY
elif direction == 'SHORT':
action = 2 # SELL
else:
action = 0 # HOLD
actions = torch.tensor([action], dtype=torch.long)
# Future price target
entry_price = training_sample.get('entry_price')
exit_price = training_sample.get('exit_price')
if exit_price and entry_price:
future_price = exit_price
else:
future_price = closes[-1] # Current price for HOLD
future_prices = torch.tensor([future_price], dtype=torch.float32)
# Trade success (1.0 if profitable, 0.0 otherwise)
profit_loss_pct = training_sample.get('profit_loss_pct', 0.0)
trade_success = torch.tensor([1.0 if profit_loss_pct > 0 else 0.0], dtype=torch.float32)
# Return batch dictionary
batch = {
'price_data': price_data,
'cob_data': cob_data,
'tech_data': tech_data,
'market_data': market_data,
'actions': actions,
'future_prices': future_prices,
'trade_success': trade_success
}
return batch
except Exception as e:
logger.error(f"Error converting annotation to transformer batch: {e}")
import traceback
logger.error(traceback.format_exc())
return None
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
"""
Train Transformer model using orchestrator's existing training infrastructure
@@ -618,40 +977,63 @@ class RealTrainingAdapter:
# Use the trainer's train_step method for individual samples
if hasattr(trainer, 'train_step'):
logger.info(" Using trainer.train_step() method")
logger.info(" Converting annotation data to transformer format...")
import torch
# Train using train_step for each sample
# Convert all training samples to transformer format
converted_batches = []
for i, data in enumerate(training_data):
batch = self._convert_annotation_to_transformer_batch(data)
if batch is not None:
# Repeat based on repetitions parameter
repetitions = data.get('repetitions', 1)
for _ in range(repetitions):
converted_batches.append(batch)
else:
logger.warning(f" Failed to convert sample {i+1}")
if not converted_batches:
raise Exception("No valid training batches after conversion")
logger.info(f" ✅ Converted {len(training_data)} samples to {len(converted_batches)} training batches")
# Train using train_step for each batch
for epoch in range(session.total_epochs):
epoch_loss = 0.0
num_samples = 0
epoch_accuracy = 0.0
num_batches = 0
for i, data in enumerate(training_data):
for i, batch in enumerate(converted_batches):
try:
# Call the trainer's train_step method
loss = trainer.train_step(data)
# Call the trainer's train_step method with proper batch format
result = trainer.train_step(batch)
if loss is not None:
epoch_loss += float(loss)
num_samples += 1
if result is not None:
epoch_loss += result.get('total_loss', 0.0)
epoch_accuracy += result.get('accuracy', 0.0)
num_batches += 1
if (i + 1) % 10 == 0:
logger.debug(f" Sample {i + 1}/{len(training_data)}, Loss: {loss:.6f}")
if (i + 1) % 100 == 0:
logger.debug(f" Batch {i + 1}/{len(converted_batches)}, Loss: {result.get('total_loss', 0.0):.6f}")
except Exception as e:
logger.error(f" Error in sample {i + 1}: {e}")
logger.error(f" Error in batch {i + 1}: {e}")
import traceback
logger.error(traceback.format_exc())
continue
avg_loss = epoch_loss / num_samples if num_samples > 0 else 0.0
avg_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
avg_accuracy = epoch_accuracy / num_batches if num_batches > 0 else 0.0
session.current_epoch = epoch + 1
session.current_loss = avg_loss
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Avg Loss: {avg_loss:.6f} ({num_samples} samples)")
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
session.final_loss = session.current_loss
session.accuracy = 0.85 # TODO: Calculate actual accuracy
session.accuracy = avg_accuracy
logger.info(f" Training complete: Loss = {session.final_loss:.6f}")
logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}")
else:
raise Exception(f"Transformer trainer does not have train_on_batch() or train() methods. Available methods: {[m for m in dir(trainer) if not m.startswith('_')]}")