T model trend prediction added
This commit is contained in:
@@ -574,16 +574,17 @@ class RealTrainingAdapter:
|
||||
training_data.append(entry_sample)
|
||||
logger.info(f" Test case {i+1}: ENTRY sample - {entry_sample['direction']} @ {entry_sample['entry_price']}")
|
||||
|
||||
# Create HOLD samples (every candle while position is open)
|
||||
# Create HOLD samples (every 13 candles while position is open)
|
||||
# This teaches the model to maintain the position until exit
|
||||
hold_samples = self._create_hold_samples(
|
||||
test_case=test_case,
|
||||
market_state=market_state
|
||||
market_state=market_state,
|
||||
sample_interval=13 # One sample every 13 candles
|
||||
)
|
||||
|
||||
training_data.extend(hold_samples)
|
||||
if hold_samples:
|
||||
logger.info(f" Test case {i+1}: Added {len(hold_samples)} HOLD samples (during position)")
|
||||
logger.info(f" Test case {i+1}: Added {len(hold_samples)} HOLD samples (every 13 candles during position)")
|
||||
|
||||
# Create EXIT sample (where model SHOULD exit trade)
|
||||
# Exit info is in expected_outcome, not annotation_metadata
|
||||
@@ -606,21 +607,20 @@ class RealTrainingAdapter:
|
||||
logger.info(f" Test case {i+1}: EXIT sample @ {exit_price} ({expected_outcome.get('profit_loss_pct', 0):.2f}%)")
|
||||
|
||||
# Create NEGATIVE samples (where model should NOT trade)
|
||||
# These are candles before and after the signal (±15 candles)
|
||||
# 5 candles before entry + 5 candles after exit = 10 NO_TRADE samples per annotation
|
||||
# This teaches the model to recognize when NOT to enter
|
||||
negative_samples = self._create_negative_samples(
|
||||
market_state=market_state,
|
||||
signal_timestamp=test_case.get('timestamp'),
|
||||
window_size=negative_samples_window
|
||||
entry_timestamp=test_case.get('timestamp'),
|
||||
exit_timestamp=None, # Will be calculated from holding period
|
||||
holding_period_seconds=expected_outcome.get('holding_period_seconds', 0),
|
||||
samples_before=5, # 5 candles before entry
|
||||
samples_after=5 # 5 candles after exit
|
||||
)
|
||||
|
||||
training_data.extend(negative_samples)
|
||||
if negative_samples:
|
||||
logger.info(f" Test case {i+1}: Added {len(negative_samples)} NO_TRADE samples (±{negative_samples_window} candles)")
|
||||
# Show breakdown of before/after
|
||||
before_count = sum(1 for s in negative_samples if 'before' in str(s.get('timestamp', '')))
|
||||
after_count = len(negative_samples) - before_count
|
||||
logger.info(f" -> {before_count} before signal, {after_count} after signal")
|
||||
logger.info(f" Test case {i+1}: Added {len(negative_samples)} NO_TRADE samples (5 before entry + 5 after exit)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" Error preparing test case {i+1}: {e}")
|
||||
@@ -644,9 +644,9 @@ class RealTrainingAdapter:
|
||||
|
||||
return training_data
|
||||
|
||||
def _create_hold_samples(self, test_case: Dict, market_state: Dict) -> List[Dict]:
|
||||
def _create_hold_samples(self, test_case: Dict, market_state: Dict, sample_interval: int = 13) -> List[Dict]:
|
||||
"""
|
||||
Create HOLD training samples for every candle while position is open
|
||||
Create HOLD training samples at intervals while position is open
|
||||
|
||||
This teaches the model to:
|
||||
1. Maintain the position (not exit early)
|
||||
@@ -656,6 +656,7 @@ class RealTrainingAdapter:
|
||||
Args:
|
||||
test_case: Test case with entry/exit info
|
||||
market_state: Market state data
|
||||
sample_interval: Create one sample every N candles (default: 13)
|
||||
|
||||
Returns:
|
||||
List of HOLD training samples
|
||||
@@ -691,7 +692,8 @@ class RealTrainingAdapter:
|
||||
|
||||
timestamps = timeframes['1m'].get('timestamps', [])
|
||||
|
||||
# Find all candles between entry and exit
|
||||
# Find all candles between entry and exit, sample every N candles
|
||||
candles_in_position = []
|
||||
for idx, ts_str in enumerate(timestamps):
|
||||
# Parse timestamp using unified parser
|
||||
try:
|
||||
@@ -702,24 +704,43 @@ class RealTrainingAdapter:
|
||||
|
||||
# If this candle is between entry and exit (exclusive)
|
||||
if entry_time < ts < exit_time:
|
||||
# Create market state snapshot at this candle
|
||||
hold_market_state = self._create_market_state_snapshot(market_state, idx)
|
||||
|
||||
hold_sample = {
|
||||
'market_state': hold_market_state,
|
||||
'action': 'HOLD',
|
||||
'direction': expected_outcome.get('direction'),
|
||||
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
|
||||
'entry_price': expected_outcome.get('entry_price'),
|
||||
'exit_price': expected_outcome.get('exit_price'),
|
||||
'timestamp': ts_str,
|
||||
'label': 'HOLD', # Hold position
|
||||
'in_position': True # Flag indicating we're in a position
|
||||
}
|
||||
|
||||
hold_samples.append(hold_sample)
|
||||
candles_in_position.append((idx, ts_str, ts))
|
||||
|
||||
logger.debug(f" Created {len(hold_samples)} HOLD samples between entry and exit")
|
||||
# Sample every Nth candle (e.g., every 13 candles)
|
||||
for i in range(0, len(candles_in_position), sample_interval):
|
||||
idx, ts_str, ts = candles_in_position[i]
|
||||
|
||||
# Create market state snapshot at this candle
|
||||
hold_market_state = self._create_market_state_snapshot(market_state, idx)
|
||||
|
||||
# Calculate current unrealized PnL at this point
|
||||
entry_price = expected_outcome.get('entry_price', 0)
|
||||
current_price = timeframes['1m']['close'][idx] if idx < len(timeframes['1m']['close']) else entry_price
|
||||
direction = expected_outcome.get('direction')
|
||||
|
||||
if entry_price > 0 and current_price > 0:
|
||||
if direction == 'LONG':
|
||||
current_pnl = (current_price - entry_price) / entry_price * 100
|
||||
else: # SHORT
|
||||
current_pnl = (entry_price - current_price) / entry_price * 100
|
||||
else:
|
||||
current_pnl = 0.0
|
||||
|
||||
hold_sample = {
|
||||
'market_state': hold_market_state,
|
||||
'action': 'HOLD',
|
||||
'direction': direction,
|
||||
'profit_loss_pct': current_pnl, # Current unrealized PnL
|
||||
'entry_price': entry_price,
|
||||
'exit_price': expected_outcome.get('exit_price'),
|
||||
'timestamp': ts_str,
|
||||
'label': 'HOLD', # Hold position
|
||||
'in_position': True # Flag indicating we're in a position
|
||||
}
|
||||
|
||||
hold_samples.append(hold_sample)
|
||||
|
||||
logger.debug(f" Created {len(hold_samples)} HOLD samples (every {sample_interval} candles, {len(candles_in_position)} total candles in position)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating HOLD samples: {e}")
|
||||
@@ -728,24 +749,30 @@ class RealTrainingAdapter:
|
||||
|
||||
return hold_samples
|
||||
|
||||
def _create_negative_samples(self, market_state: Dict, signal_timestamp: str,
|
||||
window_size: int) -> List[Dict]:
|
||||
def _create_negative_samples(self, market_state: Dict, entry_timestamp: str,
|
||||
exit_timestamp: Optional[str], holding_period_seconds: int,
|
||||
samples_before: int = 5, samples_after: int = 5) -> List[Dict]:
|
||||
"""
|
||||
Create negative training samples from candles around the signal
|
||||
Create negative training samples from candles before entry and after exit
|
||||
|
||||
These samples teach the model when NOT to trade - crucial for reducing false signals!
|
||||
|
||||
Args:
|
||||
market_state: Market state with OHLCV data
|
||||
signal_timestamp: Timestamp of the actual signal
|
||||
window_size: Number of candles before/after signal to use
|
||||
entry_timestamp: Timestamp of entry signal
|
||||
exit_timestamp: Timestamp of exit signal (optional, calculated from holding period)
|
||||
holding_period_seconds: Duration of the trade in seconds
|
||||
samples_before: Number of candles before entry (default: 5)
|
||||
samples_after: Number of candles after exit (default: 5)
|
||||
|
||||
Returns:
|
||||
List of negative training samples
|
||||
List of negative training samples (NO_TRADE)
|
||||
"""
|
||||
negative_samples = []
|
||||
|
||||
try:
|
||||
from datetime import timedelta
|
||||
|
||||
# Get timestamps from market state (use 1m timeframe as reference)
|
||||
timeframes = market_state.get('timeframes', {})
|
||||
if '1m' not in timeframes:
|
||||
@@ -756,55 +783,65 @@ class RealTrainingAdapter:
|
||||
if not timestamps:
|
||||
return negative_samples
|
||||
|
||||
# Find the index of the signal timestamp
|
||||
from datetime import datetime
|
||||
|
||||
# Parse signal timestamp using unified parser
|
||||
# Parse entry timestamp
|
||||
try:
|
||||
signal_time = parse_timestamp_to_utc(signal_timestamp)
|
||||
entry_time = parse_timestamp_to_utc(entry_timestamp)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse signal timestamp '{signal_timestamp}': {e}")
|
||||
logger.warning(f"Could not parse entry timestamp '{entry_timestamp}': {e}")
|
||||
return negative_samples
|
||||
|
||||
signal_index = None
|
||||
# Calculate exit time
|
||||
exit_time = entry_time + timedelta(seconds=holding_period_seconds)
|
||||
|
||||
# Find entry and exit indices
|
||||
entry_index = None
|
||||
exit_index = None
|
||||
|
||||
for idx, ts_str in enumerate(timestamps):
|
||||
try:
|
||||
# Parse timestamp using unified parser
|
||||
ts = parse_timestamp_to_utc(ts_str)
|
||||
|
||||
# Match within 1 minute
|
||||
if abs((ts - signal_time).total_seconds()) < 60:
|
||||
signal_index = idx
|
||||
logger.debug(f" Found signal at index {idx}: {ts_str}")
|
||||
# Match entry within 1 minute
|
||||
if entry_index is None and abs((ts - entry_time).total_seconds()) < 60:
|
||||
entry_index = idx
|
||||
|
||||
# Match exit within 1 minute
|
||||
if exit_index is None and abs((ts - exit_time).total_seconds()) < 60:
|
||||
exit_index = idx
|
||||
|
||||
if entry_index is not None and exit_index is not None:
|
||||
break
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if signal_index is None:
|
||||
logger.warning(f"Could not find signal timestamp {signal_timestamp} in market data")
|
||||
logger.warning(f" Market data has {len(timestamps)} timestamps from {timestamps[0] if timestamps else 'N/A'} to {timestamps[-1] if timestamps else 'N/A'}")
|
||||
if entry_index is None:
|
||||
logger.warning(f"Could not find entry timestamp in market data")
|
||||
return negative_samples
|
||||
|
||||
# Create negative samples from candles before and after the signal
|
||||
# BEFORE signal: candles at signal_index - window_size to signal_index - 1
|
||||
# AFTER signal: candles at signal_index + 1 to signal_index + window_size
|
||||
# If exit not found, estimate it
|
||||
if exit_index is None:
|
||||
# Estimate: 1 minute per candle
|
||||
candles_in_trade = holding_period_seconds // 60
|
||||
exit_index = min(entry_index + candles_in_trade, len(timestamps) - 1)
|
||||
logger.debug(f" Estimated exit index: {exit_index} ({candles_in_trade} candles)")
|
||||
|
||||
# Create NO_TRADE samples: 5 before entry + 5 after exit
|
||||
negative_indices = []
|
||||
|
||||
# Before signal
|
||||
for offset in range(1, window_size + 1):
|
||||
idx = signal_index - offset
|
||||
# 5 candles BEFORE entry
|
||||
for offset in range(1, samples_before + 1):
|
||||
idx = entry_index - offset
|
||||
if 0 <= idx < len(timestamps):
|
||||
negative_indices.append(idx)
|
||||
negative_indices.append(('before_entry', idx))
|
||||
|
||||
# After signal
|
||||
for offset in range(1, window_size + 1):
|
||||
idx = signal_index + offset
|
||||
# 5 candles AFTER exit
|
||||
for offset in range(1, samples_after + 1):
|
||||
idx = exit_index + offset
|
||||
if 0 <= idx < len(timestamps):
|
||||
negative_indices.append(idx)
|
||||
negative_indices.append(('after_exit', idx))
|
||||
|
||||
# Create negative samples for each index
|
||||
for idx in negative_indices:
|
||||
# Create negative samples
|
||||
for location, idx in negative_indices:
|
||||
# Create a market state snapshot at this timestamp
|
||||
negative_market_state = self._create_market_state_snapshot(market_state, idx)
|
||||
|
||||
@@ -816,12 +853,13 @@ class RealTrainingAdapter:
|
||||
'entry_price': None,
|
||||
'exit_price': None,
|
||||
'timestamp': timestamps[idx],
|
||||
'label': 'NO_TRADE' # Negative label
|
||||
'label': 'NO_TRADE', # Negative label
|
||||
'in_position': False # Not in position
|
||||
}
|
||||
|
||||
negative_samples.append(negative_sample)
|
||||
|
||||
logger.debug(f" Created {len(negative_samples)} negative samples from ±{window_size} candles")
|
||||
logger.debug(f" Created {len(negative_samples)} NO_TRADE samples ({samples_before} before entry + {samples_after} after exit)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating negative samples: {e}")
|
||||
@@ -1423,6 +1461,33 @@ class RealTrainingAdapter:
|
||||
# FIXED: Ensure shape is [1, 1] not [1] to match BCELoss requirements
|
||||
trade_success = torch.tensor([[1.0 if profit_loss_pct > 0 else 0.0]], dtype=torch.float32) # [1, 1]
|
||||
|
||||
# NEW: Trend vector target for trend analysis optimization
|
||||
# Calculate expected trend from entry to exit
|
||||
direction = training_sample.get('direction', 'NONE')
|
||||
|
||||
if direction == 'LONG':
|
||||
# Upward trend: positive angle, positive direction
|
||||
trend_angle = 0.785 # ~45 degrees in radians (pi/4)
|
||||
trend_direction = 1.0 # Upward
|
||||
elif direction == 'SHORT':
|
||||
# Downward trend: negative angle, negative direction
|
||||
trend_angle = -0.785 # ~-45 degrees
|
||||
trend_direction = -1.0 # Downward
|
||||
else:
|
||||
# No trend
|
||||
trend_angle = 0.0
|
||||
trend_direction = 0.0
|
||||
|
||||
# Steepness based on profit potential
|
||||
if exit_price and entry_price and entry_price > 0:
|
||||
price_change_pct = abs((exit_price - entry_price) / entry_price)
|
||||
trend_steepness = min(price_change_pct * 10, 1.0) # Normalize to [0, 1]
|
||||
else:
|
||||
trend_steepness = 0.0
|
||||
|
||||
# Create trend target tensor [batch, 3]: [angle, steepness, direction]
|
||||
trend_target = torch.tensor([[trend_angle, trend_steepness, trend_direction]], dtype=torch.float32) # [1, 3]
|
||||
|
||||
# Return batch dictionary with ALL timeframes
|
||||
batch = {
|
||||
# Multi-timeframe price data
|
||||
@@ -1440,8 +1505,9 @@ class RealTrainingAdapter:
|
||||
|
||||
# Training targets
|
||||
'actions': actions, # [1]
|
||||
'future_prices': future_prices, # [1]
|
||||
'future_prices': future_prices, # [1, 1]
|
||||
'trade_success': trade_success, # [1, 1]
|
||||
'trend_target': trend_target, # [1, 3] - NEW: [angle, steepness, direction]
|
||||
|
||||
# Legacy support (use 1m as default)
|
||||
'price_data': price_data_1m if price_data_1m is not None else ref_data
|
||||
@@ -1646,13 +1712,14 @@ class RealTrainingAdapter:
|
||||
batch_loss = result.get('total_loss', 0.0)
|
||||
batch_accuracy = result.get('accuracy', 0.0)
|
||||
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
|
||||
batch_trend_loss = result.get('trend_loss', 0.0)
|
||||
epoch_loss += batch_loss
|
||||
epoch_accuracy += batch_accuracy
|
||||
num_batches += 1
|
||||
|
||||
# Log first batch and every 5th batch for debugging
|
||||
if (i + 1) == 1 or (i + 1) % 5 == 0:
|
||||
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}")
|
||||
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}")
|
||||
else:
|
||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user