train wip
This commit is contained in:
@@ -325,8 +325,8 @@ class RealTrainingAdapter:
|
||||
|
||||
logger.debug(f" Test case {i+1}: has_market_state={bool(market_state)}, has_expected_outcome={bool(expected_outcome)}")
|
||||
|
||||
# Create POSITIVE sample (where model SHOULD trade)
|
||||
positive_sample = {
|
||||
# Create ENTRY sample (where model SHOULD enter trade)
|
||||
entry_sample = {
|
||||
'market_state': market_state,
|
||||
'action': test_case.get('action'),
|
||||
'direction': expected_outcome.get('direction'),
|
||||
@@ -334,12 +334,40 @@ class RealTrainingAdapter:
|
||||
'entry_price': expected_outcome.get('entry_price'),
|
||||
'exit_price': expected_outcome.get('exit_price'),
|
||||
'timestamp': test_case.get('timestamp'),
|
||||
'label': 'TRADE', # Positive label
|
||||
'label': 'ENTRY', # Entry signal
|
||||
'repetitions': training_repetitions
|
||||
}
|
||||
|
||||
training_data.append(positive_sample)
|
||||
logger.debug(f" ✅ Positive sample: {positive_sample['direction']} @ {positive_sample['entry_price']} -> {positive_sample['exit_price']} ({positive_sample['profit_loss_pct']:.2f}%)")
|
||||
training_data.append(entry_sample)
|
||||
logger.debug(f" ✅ Entry sample: {entry_sample['direction']} @ {entry_sample['entry_price']}")
|
||||
|
||||
# Create HOLD samples (every candle while position is open)
|
||||
# This teaches the model to maintain the position until exit
|
||||
hold_samples = self._create_hold_samples(
|
||||
test_case=test_case,
|
||||
market_state=market_state,
|
||||
repetitions=training_repetitions // 4 # Quarter reps for hold samples
|
||||
)
|
||||
|
||||
training_data.extend(hold_samples)
|
||||
logger.debug(f" 📊 Added {len(hold_samples)} HOLD samples (during position)")
|
||||
|
||||
# Create EXIT sample (where model SHOULD exit trade)
|
||||
exit_timestamp = test_case.get('annotation_metadata', {}).get('exit_timestamp')
|
||||
if exit_timestamp:
|
||||
exit_sample = {
|
||||
'market_state': market_state, # TODO: Get market state at exit time
|
||||
'action': 'CLOSE',
|
||||
'direction': expected_outcome.get('direction'),
|
||||
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
|
||||
'entry_price': expected_outcome.get('entry_price'),
|
||||
'exit_price': expected_outcome.get('exit_price'),
|
||||
'timestamp': exit_timestamp,
|
||||
'label': 'EXIT', # Exit signal
|
||||
'repetitions': training_repetitions
|
||||
}
|
||||
training_data.append(exit_sample)
|
||||
logger.debug(f" ✅ Exit sample @ {exit_sample['exit_price']} ({exit_sample['profit_loss_pct']:.2f}%)")
|
||||
|
||||
# Create NEGATIVE samples (where model should NOT trade)
|
||||
# These are candles before and after the signal
|
||||
@@ -356,19 +384,110 @@ class RealTrainingAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error preparing test case {i+1}: {e}")
|
||||
|
||||
total_positive = sum(1 for s in training_data if s.get('label') == 'TRADE')
|
||||
total_negative = sum(1 for s in training_data if s.get('label') == 'NO_TRADE')
|
||||
total_entry = sum(1 for s in training_data if s.get('label') == 'ENTRY')
|
||||
total_hold = sum(1 for s in training_data if s.get('label') == 'HOLD')
|
||||
total_exit = sum(1 for s in training_data if s.get('label') == 'EXIT')
|
||||
total_no_trade = sum(1 for s in training_data if s.get('label') == 'NO_TRADE')
|
||||
|
||||
logger.info(f"✅ Prepared {len(training_data)} training samples from {len(test_cases)} test cases")
|
||||
logger.info(f" Positive samples (TRADE): {total_positive}")
|
||||
logger.info(f" Negative samples (NO_TRADE): {total_negative}")
|
||||
logger.info(f" Ratio: 1:{total_negative/total_positive:.1f} (positive:negative)")
|
||||
logger.info(f" ENTRY samples: {total_entry}")
|
||||
logger.info(f" HOLD samples: {total_hold}")
|
||||
logger.info(f" EXIT samples: {total_exit}")
|
||||
logger.info(f" NO_TRADE samples: {total_no_trade}")
|
||||
|
||||
if total_entry > 0:
|
||||
logger.info(f" Ratio: 1:{total_no_trade/total_entry:.1f} (entry:no_trade)")
|
||||
|
||||
if len(training_data) < len(test_cases):
|
||||
logger.warning(f"⚠️ Skipped {len(test_cases) - len(training_data)} test cases due to missing data")
|
||||
|
||||
return training_data
|
||||
|
||||
def _create_hold_samples(self, test_case: Dict, market_state: Dict, repetitions: int) -> List[Dict]:
|
||||
"""
|
||||
Create HOLD training samples for every candle while position is open
|
||||
|
||||
This teaches the model to:
|
||||
1. Maintain the position (not exit early)
|
||||
2. Recognize the trade is still valid
|
||||
3. Wait for the optimal exit point
|
||||
|
||||
Args:
|
||||
test_case: Test case with entry/exit info
|
||||
market_state: Market state data
|
||||
repetitions: Number of times to repeat each hold sample
|
||||
|
||||
Returns:
|
||||
List of HOLD training samples
|
||||
"""
|
||||
hold_samples = []
|
||||
|
||||
try:
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Get entry and exit timestamps
|
||||
entry_timestamp = test_case.get('timestamp')
|
||||
expected_outcome = test_case.get('expected_outcome', {})
|
||||
|
||||
# Calculate exit timestamp from holding period
|
||||
holding_period_seconds = expected_outcome.get('holding_period_seconds', 0)
|
||||
if holding_period_seconds == 0:
|
||||
logger.debug(" No holding period, skipping HOLD samples")
|
||||
return hold_samples
|
||||
|
||||
# Parse entry timestamp
|
||||
try:
|
||||
if 'T' in entry_timestamp:
|
||||
entry_time = datetime.fromisoformat(entry_timestamp.replace('Z', '+00:00'))
|
||||
else:
|
||||
entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M:%S')
|
||||
entry_time = entry_time.replace(tzinfo=pytz.UTC)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse entry timestamp '{entry_timestamp}': {e}")
|
||||
return hold_samples
|
||||
|
||||
exit_time = entry_time + timedelta(seconds=holding_period_seconds)
|
||||
|
||||
# Get 1m timeframe timestamps
|
||||
timeframes = market_state.get('timeframes', {})
|
||||
if '1m' not in timeframes:
|
||||
return hold_samples
|
||||
|
||||
timestamps = timeframes['1m'].get('timestamps', [])
|
||||
|
||||
# Find all candles between entry and exit
|
||||
for idx, ts_str in enumerate(timestamps):
|
||||
ts = datetime.fromisoformat(ts_str.replace(' ', 'T'))
|
||||
|
||||
# If this candle is between entry and exit (exclusive)
|
||||
if entry_time < ts < exit_time:
|
||||
# Create market state snapshot at this candle
|
||||
hold_market_state = self._create_market_state_snapshot(market_state, idx)
|
||||
|
||||
hold_sample = {
|
||||
'market_state': hold_market_state,
|
||||
'action': 'HOLD',
|
||||
'direction': expected_outcome.get('direction'),
|
||||
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
|
||||
'entry_price': expected_outcome.get('entry_price'),
|
||||
'exit_price': expected_outcome.get('exit_price'),
|
||||
'timestamp': ts_str,
|
||||
'label': 'HOLD', # Hold position
|
||||
'repetitions': repetitions,
|
||||
'in_position': True # Flag indicating we're in a position
|
||||
}
|
||||
|
||||
hold_samples.append(hold_sample)
|
||||
|
||||
logger.debug(f" Created {len(hold_samples)} HOLD samples between entry and exit")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating HOLD samples: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
return hold_samples
|
||||
|
||||
def _create_negative_samples(self, market_state: Dict, signal_timestamp: str,
|
||||
window_size: int, repetitions: int) -> List[Dict]:
|
||||
"""
|
||||
@@ -400,17 +519,39 @@ class RealTrainingAdapter:
|
||||
|
||||
# Find the index of the signal timestamp
|
||||
from datetime import datetime
|
||||
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
|
||||
|
||||
# Parse signal timestamp - handle different formats
|
||||
try:
|
||||
if 'T' in signal_timestamp:
|
||||
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
|
||||
else:
|
||||
signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M:%S')
|
||||
signal_time = signal_time.replace(tzinfo=pytz.UTC)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse signal timestamp '{signal_timestamp}': {e}")
|
||||
return negative_samples
|
||||
|
||||
signal_index = None
|
||||
for idx, ts_str in enumerate(timestamps):
|
||||
ts = datetime.fromisoformat(ts_str.replace(' ', 'T'))
|
||||
if abs((ts - signal_time).total_seconds()) < 60: # Within 1 minute
|
||||
signal_index = idx
|
||||
break
|
||||
try:
|
||||
# Parse timestamp from market data
|
||||
if 'T' in ts_str:
|
||||
ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00'))
|
||||
else:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||
ts = ts.replace(tzinfo=pytz.UTC)
|
||||
|
||||
# Match within 1 minute
|
||||
if abs((ts - signal_time).total_seconds()) < 60:
|
||||
signal_index = idx
|
||||
logger.debug(f" Found signal at index {idx}: {ts_str}")
|
||||
break
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if signal_index is None:
|
||||
logger.warning(f"Could not find signal timestamp in market data")
|
||||
logger.warning(f"Could not find signal timestamp {signal_timestamp} in market data")
|
||||
logger.warning(f" Market data has {len(timestamps)} timestamps from {timestamps[0] if timestamps else 'N/A'} to {timestamps[-1] if timestamps else 'N/A'}")
|
||||
return negative_samples
|
||||
|
||||
# Create negative samples from candles before and after the signal
|
||||
@@ -596,6 +737,224 @@ class RealTrainingAdapter:
|
||||
state_size = agent.state_size if hasattr(agent, 'state_size') else 100
|
||||
return [0.0] * state_size
|
||||
|
||||
def _convert_annotation_to_transformer_batch(self, training_sample: Dict) -> Dict[str, 'torch.Tensor']:
|
||||
"""
|
||||
Convert annotation training sample to transformer model input format
|
||||
|
||||
The transformer expects:
|
||||
- price_data: [batch, seq_len, features] - OHLCV sequences
|
||||
- cob_data: [batch, seq_len, cob_features] - Change of Bid data
|
||||
- tech_data: [batch, features] - Technical indicators
|
||||
- market_data: [batch, features] - Market context
|
||||
- actions: [batch] - Target actions (0=HOLD, 1=BUY, 2=SELL)
|
||||
- future_prices: [batch] - Future price targets
|
||||
- trade_success: [batch] - Whether trade was successful
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
market_state = training_sample.get('market_state', {})
|
||||
|
||||
# Extract OHLCV data from ALL timeframes
|
||||
timeframes = market_state.get('timeframes', {})
|
||||
|
||||
# Collect data from all available timeframes
|
||||
all_price_data = []
|
||||
timeframe_order = ['1s', '1m', '1h', '1d'] # Process in order
|
||||
|
||||
for tf in timeframe_order:
|
||||
if tf not in timeframes:
|
||||
continue
|
||||
|
||||
tf_data = timeframes[tf]
|
||||
|
||||
# Convert to numpy arrays
|
||||
opens = np.array(tf_data.get('open', []), dtype=np.float32)
|
||||
highs = np.array(tf_data.get('high', []), dtype=np.float32)
|
||||
lows = np.array(tf_data.get('low', []), dtype=np.float32)
|
||||
closes = np.array(tf_data.get('close', []), dtype=np.float32)
|
||||
volumes = np.array(tf_data.get('volume', []), dtype=np.float32)
|
||||
|
||||
if len(closes) > 0:
|
||||
# Stack OHLCV for this timeframe [seq_len, 5]
|
||||
tf_price_data = np.stack([opens, highs, lows, closes, volumes], axis=-1)
|
||||
all_price_data.append(tf_price_data)
|
||||
|
||||
if not all_price_data:
|
||||
logger.warning("No price data in any timeframe")
|
||||
return None
|
||||
|
||||
# Concatenate all timeframes along sequence dimension
|
||||
# This gives the model multi-timeframe context
|
||||
price_data = np.concatenate(all_price_data, axis=0)
|
||||
|
||||
# Add batch dimension [1, total_seq_len, 5]
|
||||
price_data = torch.tensor(price_data, dtype=torch.float32).unsqueeze(0)
|
||||
|
||||
# Get primary timeframe for reference
|
||||
primary_tf = '1m' if '1m' in timeframes else timeframe_order[0]
|
||||
closes = np.array(timeframes[primary_tf].get('close', []), dtype=np.float32)
|
||||
|
||||
# Create placeholder COB data (zeros if not available)
|
||||
# COB data shape: [1, seq_len, cob_features]
|
||||
# Transformer expects 100 COB features (as defined in TransformerConfig)
|
||||
cob_data = torch.zeros(1, len(closes), 100, dtype=torch.float32) # 100 COB features
|
||||
|
||||
# Create technical indicators (simple ones for now)
|
||||
# tech_data shape: [1, features]
|
||||
tech_features = []
|
||||
|
||||
# Add simple technical indicators
|
||||
if len(closes) >= 20:
|
||||
sma_20 = np.mean(closes[-20:])
|
||||
tech_features.append(closes[-1] / sma_20 - 1.0) # Price vs SMA
|
||||
else:
|
||||
tech_features.append(0.0)
|
||||
|
||||
if len(closes) >= 2:
|
||||
returns = (closes[-1] - closes[-2]) / closes[-2]
|
||||
tech_features.append(returns) # Recent return
|
||||
else:
|
||||
tech_features.append(0.0)
|
||||
|
||||
# Add volatility
|
||||
if len(closes) >= 20:
|
||||
volatility = np.std(closes[-20:]) / np.mean(closes[-20:])
|
||||
tech_features.append(volatility)
|
||||
else:
|
||||
tech_features.append(0.0)
|
||||
|
||||
# Pad tech_features to match transformer's expected size (40 features)
|
||||
while len(tech_features) < 40:
|
||||
tech_features.append(0.0)
|
||||
|
||||
tech_data = torch.tensor([tech_features[:40]], dtype=torch.float32) # Ensure exactly 40 features
|
||||
|
||||
# Create market context data with pivot points
|
||||
# market_data shape: [1, features]
|
||||
market_features = []
|
||||
|
||||
# Add volume profile
|
||||
primary_volumes = np.array(timeframes[primary_tf].get('volume', []), dtype=np.float32)
|
||||
if len(primary_volumes) >= 20:
|
||||
vol_ratio = primary_volumes[-1] / np.mean(primary_volumes[-20:])
|
||||
market_features.append(vol_ratio)
|
||||
else:
|
||||
market_features.append(1.0)
|
||||
|
||||
# Add price range
|
||||
primary_highs = np.array(timeframes[primary_tf].get('high', []), dtype=np.float32)
|
||||
primary_lows = np.array(timeframes[primary_tf].get('low', []), dtype=np.float32)
|
||||
if len(primary_highs) >= 20 and len(primary_lows) >= 20:
|
||||
price_range = (np.max(primary_highs[-20:]) - np.min(primary_lows[-20:])) / closes[-1]
|
||||
market_features.append(price_range)
|
||||
else:
|
||||
market_features.append(0.0)
|
||||
|
||||
# Add pivot point features
|
||||
# Calculate simple pivot points from recent price action
|
||||
if len(primary_highs) >= 5 and len(primary_lows) >= 5:
|
||||
# Pivot Point = (High + Low + Close) / 3
|
||||
pivot = (primary_highs[-1] + primary_lows[-1] + closes[-1]) / 3.0
|
||||
|
||||
# Support and Resistance levels
|
||||
r1 = 2 * pivot - primary_lows[-1] # Resistance 1
|
||||
s1 = 2 * pivot - primary_highs[-1] # Support 1
|
||||
|
||||
# Normalize relative to current price
|
||||
pivot_distance = (closes[-1] - pivot) / closes[-1]
|
||||
r1_distance = (closes[-1] - r1) / closes[-1]
|
||||
s1_distance = (closes[-1] - s1) / closes[-1]
|
||||
|
||||
market_features.extend([pivot_distance, r1_distance, s1_distance])
|
||||
else:
|
||||
market_features.extend([0.0, 0.0, 0.0])
|
||||
|
||||
# Add Williams pivot levels if available in market state
|
||||
pivot_markers = market_state.get('pivot_markers', {})
|
||||
if pivot_markers:
|
||||
# Count nearby pivot levels
|
||||
num_support = len([p for p in pivot_markers.get('support_levels', []) if abs(p - closes[-1]) / closes[-1] < 0.02])
|
||||
num_resistance = len([p for p in pivot_markers.get('resistance_levels', []) if abs(p - closes[-1]) / closes[-1] < 0.02])
|
||||
market_features.extend([float(num_support), float(num_resistance)])
|
||||
else:
|
||||
market_features.extend([0.0, 0.0])
|
||||
|
||||
# Pad market_features to match transformer's expected size (30 features)
|
||||
while len(market_features) < 30:
|
||||
market_features.append(0.0)
|
||||
|
||||
market_data = torch.tensor([market_features[:30]], dtype=torch.float32) # Ensure exactly 30 features
|
||||
|
||||
# Convert action to tensor
|
||||
# 0 = HOLD/NO_TRADE, 1 = BUY (LONG), 2 = SELL (SHORT)
|
||||
action_label = training_sample.get('label', 'TRADE')
|
||||
direction = training_sample.get('direction', 'NONE')
|
||||
in_position = training_sample.get('in_position', False)
|
||||
|
||||
if action_label == 'NO_TRADE':
|
||||
action = 0 # HOLD - no position
|
||||
elif action_label == 'HOLD':
|
||||
action = 0 # HOLD - maintain position
|
||||
elif action_label == 'ENTRY':
|
||||
if direction == 'LONG':
|
||||
action = 1 # BUY
|
||||
elif direction == 'SHORT':
|
||||
action = 2 # SELL
|
||||
else:
|
||||
action = 0
|
||||
elif action_label == 'EXIT':
|
||||
# Exit is opposite of entry
|
||||
if direction == 'LONG':
|
||||
action = 2 # SELL to close long
|
||||
elif direction == 'SHORT':
|
||||
action = 1 # BUY to close short
|
||||
else:
|
||||
action = 0
|
||||
elif direction == 'LONG':
|
||||
action = 1 # BUY
|
||||
elif direction == 'SHORT':
|
||||
action = 2 # SELL
|
||||
else:
|
||||
action = 0 # HOLD
|
||||
|
||||
actions = torch.tensor([action], dtype=torch.long)
|
||||
|
||||
# Future price target
|
||||
entry_price = training_sample.get('entry_price')
|
||||
exit_price = training_sample.get('exit_price')
|
||||
|
||||
if exit_price and entry_price:
|
||||
future_price = exit_price
|
||||
else:
|
||||
future_price = closes[-1] # Current price for HOLD
|
||||
|
||||
future_prices = torch.tensor([future_price], dtype=torch.float32)
|
||||
|
||||
# Trade success (1.0 if profitable, 0.0 otherwise)
|
||||
profit_loss_pct = training_sample.get('profit_loss_pct', 0.0)
|
||||
trade_success = torch.tensor([1.0 if profit_loss_pct > 0 else 0.0], dtype=torch.float32)
|
||||
|
||||
# Return batch dictionary
|
||||
batch = {
|
||||
'price_data': price_data,
|
||||
'cob_data': cob_data,
|
||||
'tech_data': tech_data,
|
||||
'market_data': market_data,
|
||||
'actions': actions,
|
||||
'future_prices': future_prices,
|
||||
'trade_success': trade_success
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting annotation to transformer batch: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
|
||||
"""
|
||||
Train Transformer model using orchestrator's existing training infrastructure
|
||||
@@ -618,40 +977,63 @@ class RealTrainingAdapter:
|
||||
# Use the trainer's train_step method for individual samples
|
||||
if hasattr(trainer, 'train_step'):
|
||||
logger.info(" Using trainer.train_step() method")
|
||||
logger.info(" Converting annotation data to transformer format...")
|
||||
|
||||
import torch
|
||||
|
||||
# Train using train_step for each sample
|
||||
# Convert all training samples to transformer format
|
||||
converted_batches = []
|
||||
for i, data in enumerate(training_data):
|
||||
batch = self._convert_annotation_to_transformer_batch(data)
|
||||
if batch is not None:
|
||||
# Repeat based on repetitions parameter
|
||||
repetitions = data.get('repetitions', 1)
|
||||
for _ in range(repetitions):
|
||||
converted_batches.append(batch)
|
||||
else:
|
||||
logger.warning(f" Failed to convert sample {i+1}")
|
||||
|
||||
if not converted_batches:
|
||||
raise Exception("No valid training batches after conversion")
|
||||
|
||||
logger.info(f" ✅ Converted {len(training_data)} samples to {len(converted_batches)} training batches")
|
||||
|
||||
# Train using train_step for each batch
|
||||
for epoch in range(session.total_epochs):
|
||||
epoch_loss = 0.0
|
||||
num_samples = 0
|
||||
epoch_accuracy = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for i, data in enumerate(training_data):
|
||||
for i, batch in enumerate(converted_batches):
|
||||
try:
|
||||
# Call the trainer's train_step method
|
||||
loss = trainer.train_step(data)
|
||||
# Call the trainer's train_step method with proper batch format
|
||||
result = trainer.train_step(batch)
|
||||
|
||||
if loss is not None:
|
||||
epoch_loss += float(loss)
|
||||
num_samples += 1
|
||||
if result is not None:
|
||||
epoch_loss += result.get('total_loss', 0.0)
|
||||
epoch_accuracy += result.get('accuracy', 0.0)
|
||||
num_batches += 1
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
logger.debug(f" Sample {i + 1}/{len(training_data)}, Loss: {loss:.6f}")
|
||||
if (i + 1) % 100 == 0:
|
||||
logger.debug(f" Batch {i + 1}/{len(converted_batches)}, Loss: {result.get('total_loss', 0.0):.6f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" Error in sample {i + 1}: {e}")
|
||||
logger.error(f" Error in batch {i + 1}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
continue
|
||||
|
||||
avg_loss = epoch_loss / num_samples if num_samples > 0 else 0.0
|
||||
avg_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
|
||||
avg_accuracy = epoch_accuracy / num_batches if num_batches > 0 else 0.0
|
||||
session.current_epoch = epoch + 1
|
||||
session.current_loss = avg_loss
|
||||
|
||||
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Avg Loss: {avg_loss:.6f} ({num_samples} samples)")
|
||||
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
|
||||
|
||||
session.final_loss = session.current_loss
|
||||
session.accuracy = 0.85 # TODO: Calculate actual accuracy
|
||||
session.accuracy = avg_accuracy
|
||||
|
||||
logger.info(f" Training complete: Loss = {session.final_loss:.6f}")
|
||||
logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}")
|
||||
|
||||
else:
|
||||
raise Exception(f"Transformer trainer does not have train_on_batch() or train() methods. Available methods: {[m for m in dir(trainer) if not m.startswith('_')]}")
|
||||
|
||||
Reference in New Issue
Block a user