fetching data from the DB to train
This commit is contained in:
@@ -279,6 +279,26 @@ class RealTrainingAdapter:
|
||||
session.duration_seconds = time.time() - session.start_time
|
||||
logger.error(f"Training session {training_id} failed after {session.duration_seconds:.2f}s")
|
||||
|
||||
def _get_secondary_symbol(self, primary_symbol: str) -> str:
|
||||
"""
|
||||
Determine secondary symbol based on primary symbol
|
||||
|
||||
Rules:
|
||||
- ETH/USDT -> BTC/USDT
|
||||
- SOL/USDT -> BTC/USDT
|
||||
- BTC/USDT -> ETH/USDT
|
||||
|
||||
Args:
|
||||
primary_symbol: Primary trading symbol
|
||||
|
||||
Returns:
|
||||
Secondary symbol for correlation analysis
|
||||
"""
|
||||
if 'BTC' in primary_symbol:
|
||||
return 'ETH/USDT'
|
||||
else:
|
||||
return 'BTC/USDT'
|
||||
|
||||
def _fetch_market_state_for_test_case(self, test_case: Dict) -> Dict:
|
||||
"""
|
||||
Fetch market state dynamically for a test case from DuckDB storage
|
||||
@@ -314,28 +334,41 @@ class RealTrainingAdapter:
|
||||
# Get training config
|
||||
training_config = test_case.get('training_config', {})
|
||||
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
||||
context_window = training_config.get('context_window_minutes', 5)
|
||||
negative_samples_window = training_config.get('negative_samples_window', 15) # ±15 candles
|
||||
candles_per_timeframe = training_config.get('candles_per_timeframe', 600) # 600 candles per batch
|
||||
|
||||
# Calculate extended time range to include negative sampling window
|
||||
# For 1m timeframe: ±15 candles = ±15 minutes
|
||||
# Add buffer to ensure we have enough data
|
||||
extended_window_minutes = max(context_window, negative_samples_window + 10)
|
||||
# Determine secondary symbol based on primary symbol
|
||||
# ETH/SOL -> BTC, BTC -> ETH
|
||||
secondary_symbol = self._get_secondary_symbol(symbol)
|
||||
|
||||
logger.info(f" Fetching HISTORICAL market state for {symbol} at {timestamp}")
|
||||
logger.info(f" Timeframes: {timeframes}, Extended window: ±{extended_window_minutes} minutes")
|
||||
logger.info(f" (Includes ±{negative_samples_window} candles for negative sampling)")
|
||||
logger.info(f" Primary symbol: {symbol} - Timeframes: {timeframes}")
|
||||
logger.info(f" Secondary symbol: {secondary_symbol} - Timeframe: 1m")
|
||||
logger.info(f" Candles per batch: {candles_per_timeframe}")
|
||||
|
||||
# Calculate time range for extended context window
|
||||
# Calculate time range based on candles needed
|
||||
# For 600 candles at 1m = 600 minutes = 10 hours
|
||||
from datetime import timedelta
|
||||
start_time = timestamp - timedelta(minutes=extended_window_minutes)
|
||||
end_time = timestamp + timedelta(minutes=extended_window_minutes)
|
||||
|
||||
# Fetch data for each timeframe
|
||||
# Calculate time window for each timeframe to get 600 candles
|
||||
time_windows = {
|
||||
'1s': timedelta(seconds=candles_per_timeframe), # 600 seconds = 10 minutes
|
||||
'1m': timedelta(minutes=candles_per_timeframe), # 600 minutes = 10 hours
|
||||
'1h': timedelta(hours=candles_per_timeframe), # 600 hours = 25 days
|
||||
'1d': timedelta(days=candles_per_timeframe) # 600 days = ~1.6 years
|
||||
}
|
||||
|
||||
# Use the largest window to ensure we have enough data for all timeframes
|
||||
max_window = max(time_windows.values())
|
||||
start_time = timestamp - max_window
|
||||
end_time = timestamp
|
||||
|
||||
# Fetch data for primary symbol (all timeframes) and secondary symbol (1m only)
|
||||
market_state = {
|
||||
'symbol': symbol,
|
||||
'timestamp': timestamp_str,
|
||||
'timeframes': {}
|
||||
'timeframes': {},
|
||||
'secondary_symbol': secondary_symbol,
|
||||
'secondary_timeframes': {}
|
||||
}
|
||||
|
||||
# Try to get data from DuckDB storage first (historical data)
|
||||
@@ -343,21 +376,11 @@ class RealTrainingAdapter:
|
||||
if hasattr(self.data_provider, 'duckdb_storage'):
|
||||
duckdb_storage = self.data_provider.duckdb_storage
|
||||
|
||||
# Fetch primary symbol data (all timeframes)
|
||||
logger.info(f" Fetching primary symbol data: {symbol}")
|
||||
for timeframe in timeframes:
|
||||
df = None
|
||||
|
||||
# Calculate appropriate limit based on timeframe and window
|
||||
# We want enough candles to cover the extended window plus negative samples
|
||||
if timeframe == '1s':
|
||||
limit = extended_window_minutes * 60 * 2 + 100 # 2x for safety + buffer
|
||||
elif timeframe == '1m':
|
||||
limit = extended_window_minutes * 2 + 50 # 2x for safety + buffer
|
||||
elif timeframe == '1h':
|
||||
limit = max(200, extended_window_minutes // 30) # At least 200 candles
|
||||
elif timeframe == '1d':
|
||||
limit = 200 # Fixed for daily
|
||||
else:
|
||||
limit = 300
|
||||
limit = candles_per_timeframe # Always fetch 600 candles
|
||||
|
||||
# Try DuckDB storage first (has historical data)
|
||||
if duckdb_storage:
|
||||
@@ -410,12 +433,74 @@ class RealTrainingAdapter:
|
||||
'close': df['close'].tolist(),
|
||||
'volume': df['volume'].tolist()
|
||||
}
|
||||
logger.debug(f" {timeframe}: {len(df)} candles stored")
|
||||
logger.info(f" {symbol} {timeframe}: {len(df)} candles")
|
||||
else:
|
||||
logger.warning(f" {timeframe}: No data available")
|
||||
logger.warning(f" {symbol} {timeframe}: No data available")
|
||||
|
||||
# Fetch secondary symbol data (1m timeframe only, 600 candles)
|
||||
logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)")
|
||||
secondary_df = None
|
||||
|
||||
# Try DuckDB first
|
||||
if duckdb_storage:
|
||||
try:
|
||||
secondary_df = duckdb_storage.get_ohlcv_data(
|
||||
symbol=secondary_symbol,
|
||||
timeframe='1m',
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=candles_per_timeframe,
|
||||
direction='latest'
|
||||
)
|
||||
if secondary_df is not None and not secondary_df.empty:
|
||||
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from DuckDB")
|
||||
except Exception as e:
|
||||
logger.debug(f" {secondary_symbol} 1m: DuckDB query failed: {e}")
|
||||
|
||||
# Fallback to replay
|
||||
if secondary_df is None or secondary_df.empty:
|
||||
try:
|
||||
replay_data = self.data_provider.get_historical_data_replay(
|
||||
symbol=secondary_symbol,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
timeframes=['1m']
|
||||
)
|
||||
secondary_df = replay_data.get('1m')
|
||||
if secondary_df is not None and not secondary_df.empty:
|
||||
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay")
|
||||
except Exception as e:
|
||||
logger.debug(f" {secondary_symbol} 1m: Replay failed: {e}")
|
||||
|
||||
# Last resort: latest data
|
||||
if secondary_df is None or secondary_df.empty:
|
||||
logger.warning(f" {secondary_symbol} 1m: No historical data, using latest as fallback")
|
||||
secondary_df = self.data_provider.get_historical_data(
|
||||
symbol=secondary_symbol,
|
||||
timeframe='1m',
|
||||
limit=candles_per_timeframe
|
||||
)
|
||||
|
||||
# Store secondary symbol data
|
||||
if secondary_df is not None and not secondary_df.empty:
|
||||
market_state['secondary_timeframes']['1m'] = {
|
||||
'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': secondary_df['open'].tolist(),
|
||||
'high': secondary_df['high'].tolist(),
|
||||
'low': secondary_df['low'].tolist(),
|
||||
'close': secondary_df['close'].tolist(),
|
||||
'volume': secondary_df['volume'].tolist()
|
||||
}
|
||||
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles")
|
||||
else:
|
||||
logger.warning(f" {secondary_symbol} 1m: No data available")
|
||||
|
||||
# Verify we have data
|
||||
if market_state['timeframes']:
|
||||
logger.info(f" Fetched market state with {len(market_state['timeframes'])} timeframes")
|
||||
total_primary = sum(len(tf_data.get('timestamps', [])) for tf_data in market_state['timeframes'].values())
|
||||
total_secondary = sum(len(tf_data.get('timestamps', [])) for tf_data in market_state['secondary_timeframes'].values())
|
||||
logger.info(f" [OK] Fetched {len(market_state['timeframes'])} primary timeframes ({total_primary} total candles)")
|
||||
logger.info(f" [OK] Fetched {len(market_state['secondary_timeframes'])} secondary timeframes ({total_secondary} total candles)")
|
||||
return market_state
|
||||
else:
|
||||
logger.warning(f" No market data fetched for any timeframe")
|
||||
@@ -483,7 +568,7 @@ class RealTrainingAdapter:
|
||||
}
|
||||
|
||||
training_data.append(entry_sample)
|
||||
logger.debug(f" Entry sample: {entry_sample['direction']} @ {entry_sample['entry_price']}")
|
||||
logger.info(f" Test case {i+1}: ENTRY sample - {entry_sample['direction']} @ {entry_sample['entry_price']}")
|
||||
|
||||
# Create HOLD samples (every candle while position is open)
|
||||
# This teaches the model to maintain the position until exit
|
||||
@@ -494,7 +579,8 @@ class RealTrainingAdapter:
|
||||
)
|
||||
|
||||
training_data.extend(hold_samples)
|
||||
logger.debug(f" Added {len(hold_samples)} HOLD samples (during position)")
|
||||
if hold_samples:
|
||||
logger.info(f" Test case {i+1}: Added {len(hold_samples)} HOLD samples (during position)")
|
||||
|
||||
# Create EXIT sample (where model SHOULD exit trade)
|
||||
exit_timestamp = test_case.get('annotation_metadata', {}).get('exit_timestamp')
|
||||
@@ -511,10 +597,11 @@ class RealTrainingAdapter:
|
||||
'repetitions': training_repetitions
|
||||
}
|
||||
training_data.append(exit_sample)
|
||||
logger.debug(f" Exit sample @ {exit_sample['exit_price']} ({exit_sample['profit_loss_pct']:.2f}%)")
|
||||
logger.info(f" Test case {i+1}: EXIT sample @ {exit_sample['exit_price']} ({exit_sample['profit_loss_pct']:.2f}%)")
|
||||
|
||||
# Create NEGATIVE samples (where model should NOT trade)
|
||||
# These are candles before and after the signal
|
||||
# These are candles before and after the signal (±15 candles)
|
||||
# This teaches the model to recognize when NOT to enter
|
||||
negative_samples = self._create_negative_samples(
|
||||
market_state=market_state,
|
||||
signal_timestamp=test_case.get('timestamp'),
|
||||
@@ -523,7 +610,12 @@ class RealTrainingAdapter:
|
||||
)
|
||||
|
||||
training_data.extend(negative_samples)
|
||||
logger.debug(f" ➕ Added {len(negative_samples)} negative samples (±{negative_samples_window} candles)")
|
||||
if negative_samples:
|
||||
logger.info(f" Test case {i+1}: Added {len(negative_samples)} NO_TRADE samples (±{negative_samples_window} candles)")
|
||||
# Show breakdown of before/after
|
||||
before_count = sum(1 for s in negative_samples if 'before' in str(s.get('timestamp', '')))
|
||||
after_count = len(negative_samples) - before_count
|
||||
logger.info(f" -> {before_count} beforesignal, {after_count} after signal")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" Error preparing test case {i+1}: {e}")
|
||||
@@ -1222,16 +1314,21 @@ class RealTrainingAdapter:
|
||||
|
||||
actions = torch.tensor([action], dtype=torch.long)
|
||||
|
||||
# Future price target
|
||||
# Future price target - NORMALIZED
|
||||
# Model predicts price change ratio, not absolute price
|
||||
entry_price = training_sample.get('entry_price')
|
||||
exit_price = training_sample.get('exit_price')
|
||||
current_price = closes_for_tech[-1] # Most recent close price
|
||||
|
||||
if exit_price and entry_price:
|
||||
future_price = exit_price
|
||||
# Normalize: (exit_price - current_price) / current_price
|
||||
# This gives the expected price change as a ratio
|
||||
future_price_ratio = (exit_price - current_price) / current_price
|
||||
else:
|
||||
future_price = closes[-1] # Current price for HOLD
|
||||
# For HOLD samples, expect no price change
|
||||
future_price_ratio = 0.0
|
||||
|
||||
future_prices = torch.tensor([future_price], dtype=torch.float32)
|
||||
future_prices = torch.tensor([future_price_ratio], dtype=torch.float32)
|
||||
|
||||
# Trade success (1.0 if profitable, 0.0 otherwise)
|
||||
# Shape must be [batch_size, 1] to match confidence head output
|
||||
@@ -1321,7 +1418,7 @@ class RealTrainingAdapter:
|
||||
num_batches += 1
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
logger.debug(f" Batch {i + 1}/{len(converted_batches)}, Loss: {result.get('total_loss', 0.0):.6f}")
|
||||
logger.info(f" Batch {i + 1}/{len(converted_batches)}, Loss: {result.get('total_loss', 0.0):.6f}, Accuracy: {result.get('accuracy', 0.0):.2%}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" Error in batch {i + 1}: {e}")
|
||||
|
||||
Reference in New Issue
Block a user