T model trend prediction added

This commit is contained in:
Dobromir Popov
2025-11-10 20:12:22 +02:00
parent 999dea9eb0
commit 27039c70a3
2 changed files with 152 additions and 70 deletions

View File

@@ -574,16 +574,17 @@ class RealTrainingAdapter:
training_data.append(entry_sample)
logger.info(f" Test case {i+1}: ENTRY sample - {entry_sample['direction']} @ {entry_sample['entry_price']}")
# Create HOLD samples (every candle while position is open)
# Create HOLD samples (every 13 candles while position is open)
# This teaches the model to maintain the position until exit
hold_samples = self._create_hold_samples(
test_case=test_case,
market_state=market_state
market_state=market_state,
sample_interval=13 # One sample every 13 candles
)
training_data.extend(hold_samples)
if hold_samples:
logger.info(f" Test case {i+1}: Added {len(hold_samples)} HOLD samples (during position)")
logger.info(f" Test case {i+1}: Added {len(hold_samples)} HOLD samples (every 13 candles during position)")
# Create EXIT sample (where model SHOULD exit trade)
# Exit info is in expected_outcome, not annotation_metadata
@@ -606,21 +607,20 @@ class RealTrainingAdapter:
logger.info(f" Test case {i+1}: EXIT sample @ {exit_price} ({expected_outcome.get('profit_loss_pct', 0):.2f}%)")
# Create NEGATIVE samples (where model should NOT trade)
# These are candles before and after the signal (±15 candles)
# 5 candles before entry + 5 candles after exit = 10 NO_TRADE samples per annotation
# This teaches the model to recognize when NOT to enter
negative_samples = self._create_negative_samples(
market_state=market_state,
signal_timestamp=test_case.get('timestamp'),
window_size=negative_samples_window
entry_timestamp=test_case.get('timestamp'),
exit_timestamp=None, # Will be calculated from holding period
holding_period_seconds=expected_outcome.get('holding_period_seconds', 0),
samples_before=5, # 5 candles before entry
samples_after=5 # 5 candles after exit
)
training_data.extend(negative_samples)
if negative_samples:
logger.info(f" Test case {i+1}: Added {len(negative_samples)} NO_TRADE samples (±{negative_samples_window} candles)")
# Show breakdown of before/after
before_count = sum(1 for s in negative_samples if 'before' in str(s.get('timestamp', '')))
after_count = len(negative_samples) - before_count
logger.info(f" -> {before_count} before signal, {after_count} after signal")
logger.info(f" Test case {i+1}: Added {len(negative_samples)} NO_TRADE samples (5 before entry + 5 after exit)")
except Exception as e:
logger.error(f" Error preparing test case {i+1}: {e}")
@@ -644,9 +644,9 @@ class RealTrainingAdapter:
return training_data
def _create_hold_samples(self, test_case: Dict, market_state: Dict) -> List[Dict]:
def _create_hold_samples(self, test_case: Dict, market_state: Dict, sample_interval: int = 13) -> List[Dict]:
"""
Create HOLD training samples for every candle while position is open
Create HOLD training samples at intervals while position is open
This teaches the model to:
1. Maintain the position (not exit early)
@@ -656,6 +656,7 @@ class RealTrainingAdapter:
Args:
test_case: Test case with entry/exit info
market_state: Market state data
sample_interval: Create one sample every N candles (default: 13)
Returns:
List of HOLD training samples
@@ -691,7 +692,8 @@ class RealTrainingAdapter:
timestamps = timeframes['1m'].get('timestamps', [])
# Find all candles between entry and exit
# Find all candles between entry and exit, sample every N candles
candles_in_position = []
for idx, ts_str in enumerate(timestamps):
# Parse timestamp using unified parser
try:
@@ -702,24 +704,43 @@ class RealTrainingAdapter:
# If this candle is between entry and exit (exclusive)
if entry_time < ts < exit_time:
# Create market state snapshot at this candle
hold_market_state = self._create_market_state_snapshot(market_state, idx)
hold_sample = {
'market_state': hold_market_state,
'action': 'HOLD',
'direction': expected_outcome.get('direction'),
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
'entry_price': expected_outcome.get('entry_price'),
'exit_price': expected_outcome.get('exit_price'),
'timestamp': ts_str,
'label': 'HOLD', # Hold position
'in_position': True # Flag indicating we're in a position
}
hold_samples.append(hold_sample)
candles_in_position.append((idx, ts_str, ts))
logger.debug(f" Created {len(hold_samples)} HOLD samples between entry and exit")
# Sample every Nth candle (e.g., every 13 candles)
for i in range(0, len(candles_in_position), sample_interval):
idx, ts_str, ts = candles_in_position[i]
# Create market state snapshot at this candle
hold_market_state = self._create_market_state_snapshot(market_state, idx)
# Calculate current unrealized PnL at this point
entry_price = expected_outcome.get('entry_price', 0)
current_price = timeframes['1m']['close'][idx] if idx < len(timeframes['1m']['close']) else entry_price
direction = expected_outcome.get('direction')
if entry_price > 0 and current_price > 0:
if direction == 'LONG':
current_pnl = (current_price - entry_price) / entry_price * 100
else: # SHORT
current_pnl = (entry_price - current_price) / entry_price * 100
else:
current_pnl = 0.0
hold_sample = {
'market_state': hold_market_state,
'action': 'HOLD',
'direction': direction,
'profit_loss_pct': current_pnl, # Current unrealized PnL
'entry_price': entry_price,
'exit_price': expected_outcome.get('exit_price'),
'timestamp': ts_str,
'label': 'HOLD', # Hold position
'in_position': True # Flag indicating we're in a position
}
hold_samples.append(hold_sample)
logger.debug(f" Created {len(hold_samples)} HOLD samples (every {sample_interval} candles, {len(candles_in_position)} total candles in position)")
except Exception as e:
logger.error(f"Error creating HOLD samples: {e}")
@@ -728,24 +749,30 @@ class RealTrainingAdapter:
return hold_samples
def _create_negative_samples(self, market_state: Dict, signal_timestamp: str,
window_size: int) -> List[Dict]:
def _create_negative_samples(self, market_state: Dict, entry_timestamp: str,
exit_timestamp: Optional[str], holding_period_seconds: int,
samples_before: int = 5, samples_after: int = 5) -> List[Dict]:
"""
Create negative training samples from candles around the signal
Create negative training samples from candles before entry and after exit
These samples teach the model when NOT to trade - crucial for reducing false signals!
Args:
market_state: Market state with OHLCV data
signal_timestamp: Timestamp of the actual signal
window_size: Number of candles before/after signal to use
entry_timestamp: Timestamp of entry signal
exit_timestamp: Timestamp of exit signal (optional, calculated from holding period)
holding_period_seconds: Duration of the trade in seconds
samples_before: Number of candles before entry (default: 5)
samples_after: Number of candles after exit (default: 5)
Returns:
List of negative training samples
List of negative training samples (NO_TRADE)
"""
negative_samples = []
try:
from datetime import timedelta
# Get timestamps from market state (use 1m timeframe as reference)
timeframes = market_state.get('timeframes', {})
if '1m' not in timeframes:
@@ -756,55 +783,65 @@ class RealTrainingAdapter:
if not timestamps:
return negative_samples
# Find the index of the signal timestamp
from datetime import datetime
# Parse signal timestamp using unified parser
# Parse entry timestamp
try:
signal_time = parse_timestamp_to_utc(signal_timestamp)
entry_time = parse_timestamp_to_utc(entry_timestamp)
except Exception as e:
logger.warning(f"Could not parse signal timestamp '{signal_timestamp}': {e}")
logger.warning(f"Could not parse entry timestamp '{entry_timestamp}': {e}")
return negative_samples
signal_index = None
# Calculate exit time
exit_time = entry_time + timedelta(seconds=holding_period_seconds)
# Find entry and exit indices
entry_index = None
exit_index = None
for idx, ts_str in enumerate(timestamps):
try:
# Parse timestamp using unified parser
ts = parse_timestamp_to_utc(ts_str)
# Match within 1 minute
if abs((ts - signal_time).total_seconds()) < 60:
signal_index = idx
logger.debug(f" Found signal at index {idx}: {ts_str}")
# Match entry within 1 minute
if entry_index is None and abs((ts - entry_time).total_seconds()) < 60:
entry_index = idx
# Match exit within 1 minute
if exit_index is None and abs((ts - exit_time).total_seconds()) < 60:
exit_index = idx
if entry_index is not None and exit_index is not None:
break
except Exception as e:
continue
if signal_index is None:
logger.warning(f"Could not find signal timestamp {signal_timestamp} in market data")
logger.warning(f" Market data has {len(timestamps)} timestamps from {timestamps[0] if timestamps else 'N/A'} to {timestamps[-1] if timestamps else 'N/A'}")
if entry_index is None:
logger.warning(f"Could not find entry timestamp in market data")
return negative_samples
# Create negative samples from candles before and after the signal
# BEFORE signal: candles at signal_index - window_size to signal_index - 1
# AFTER signal: candles at signal_index + 1 to signal_index + window_size
# If exit not found, estimate it
if exit_index is None:
# Estimate: 1 minute per candle
candles_in_trade = holding_period_seconds // 60
exit_index = min(entry_index + candles_in_trade, len(timestamps) - 1)
logger.debug(f" Estimated exit index: {exit_index} ({candles_in_trade} candles)")
# Create NO_TRADE samples: 5 before entry + 5 after exit
negative_indices = []
# Before signal
for offset in range(1, window_size + 1):
idx = signal_index - offset
# 5 candles BEFORE entry
for offset in range(1, samples_before + 1):
idx = entry_index - offset
if 0 <= idx < len(timestamps):
negative_indices.append(idx)
negative_indices.append(('before_entry', idx))
# After signal
for offset in range(1, window_size + 1):
idx = signal_index + offset
# 5 candles AFTER exit
for offset in range(1, samples_after + 1):
idx = exit_index + offset
if 0 <= idx < len(timestamps):
negative_indices.append(idx)
negative_indices.append(('after_exit', idx))
# Create negative samples for each index
for idx in negative_indices:
# Create negative samples
for location, idx in negative_indices:
# Create a market state snapshot at this timestamp
negative_market_state = self._create_market_state_snapshot(market_state, idx)
@@ -816,12 +853,13 @@ class RealTrainingAdapter:
'entry_price': None,
'exit_price': None,
'timestamp': timestamps[idx],
'label': 'NO_TRADE' # Negative label
'label': 'NO_TRADE', # Negative label
'in_position': False # Not in position
}
negative_samples.append(negative_sample)
logger.debug(f" Created {len(negative_samples)} negative samples from ±{window_size} candles")
logger.debug(f" Created {len(negative_samples)} NO_TRADE samples ({samples_before} before entry + {samples_after} after exit)")
except Exception as e:
logger.error(f"Error creating negative samples: {e}")
@@ -1423,6 +1461,33 @@ class RealTrainingAdapter:
# FIXED: Ensure shape is [1, 1] not [1] to match BCELoss requirements
trade_success = torch.tensor([[1.0 if profit_loss_pct > 0 else 0.0]], dtype=torch.float32) # [1, 1]
# NEW: Trend vector target for trend analysis optimization
# Calculate expected trend from entry to exit
direction = training_sample.get('direction', 'NONE')
if direction == 'LONG':
# Upward trend: positive angle, positive direction
trend_angle = 0.785 # ~45 degrees in radians (pi/4)
trend_direction = 1.0 # Upward
elif direction == 'SHORT':
# Downward trend: negative angle, negative direction
trend_angle = -0.785 # ~-45 degrees
trend_direction = -1.0 # Downward
else:
# No trend
trend_angle = 0.0
trend_direction = 0.0
# Steepness based on profit potential
if exit_price and entry_price and entry_price > 0:
price_change_pct = abs((exit_price - entry_price) / entry_price)
trend_steepness = min(price_change_pct * 10, 1.0) # Normalize to [0, 1]
else:
trend_steepness = 0.0
# Create trend target tensor [batch, 3]: [angle, steepness, direction]
trend_target = torch.tensor([[trend_angle, trend_steepness, trend_direction]], dtype=torch.float32) # [1, 3]
# Return batch dictionary with ALL timeframes
batch = {
# Multi-timeframe price data
@@ -1440,8 +1505,9 @@ class RealTrainingAdapter:
# Training targets
'actions': actions, # [1]
'future_prices': future_prices, # [1]
'future_prices': future_prices, # [1, 1]
'trade_success': trade_success, # [1, 1]
'trend_target': trend_target, # [1, 3] - NEW: [angle, steepness, direction]
# Legacy support (use 1m as default)
'price_data': price_data_1m if price_data_1m is not None else ref_data
@@ -1646,13 +1712,14 @@ class RealTrainingAdapter:
batch_loss = result.get('total_loss', 0.0)
batch_accuracy = result.get('accuracy', 0.0)
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
batch_trend_loss = result.get('trend_loss', 0.0)
epoch_loss += batch_loss
epoch_accuracy += batch_accuracy
num_batches += 1
# Log first batch and every 5th batch for debugging
if (i + 1) == 1 or (i + 1) % 5 == 0:
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}")
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}")
else:
logger.warning(f" Batch {i + 1} returned None result - skipping")