fetching data from the DB to train

This commit is contained in:
Dobromir Popov
2025-10-31 03:14:35 +02:00
parent 07150fd019
commit 6ac324289c
6 changed files with 1113 additions and 46 deletions

View File

@@ -279,6 +279,26 @@ class RealTrainingAdapter:
session.duration_seconds = time.time() - session.start_time
logger.error(f"Training session {training_id} failed after {session.duration_seconds:.2f}s")
def _get_secondary_symbol(self, primary_symbol: str) -> str:
"""
Determine secondary symbol based on primary symbol
Rules:
- ETH/USDT -> BTC/USDT
- SOL/USDT -> BTC/USDT
- BTC/USDT -> ETH/USDT
Args:
primary_symbol: Primary trading symbol
Returns:
Secondary symbol for correlation analysis
"""
if 'BTC' in primary_symbol:
return 'ETH/USDT'
else:
return 'BTC/USDT'
def _fetch_market_state_for_test_case(self, test_case: Dict) -> Dict:
"""
Fetch market state dynamically for a test case from DuckDB storage
@@ -314,28 +334,41 @@ class RealTrainingAdapter:
# Get training config
training_config = test_case.get('training_config', {})
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
context_window = training_config.get('context_window_minutes', 5)
negative_samples_window = training_config.get('negative_samples_window', 15) # ±15 candles
candles_per_timeframe = training_config.get('candles_per_timeframe', 600) # 600 candles per batch
# Calculate extended time range to include negative sampling window
# For 1m timeframe: ±15 candles = ±15 minutes
# Add buffer to ensure we have enough data
extended_window_minutes = max(context_window, negative_samples_window + 10)
# Determine secondary symbol based on primary symbol
# ETH/SOL -> BTC, BTC -> ETH
secondary_symbol = self._get_secondary_symbol(symbol)
logger.info(f" Fetching HISTORICAL market state for {symbol} at {timestamp}")
logger.info(f" Timeframes: {timeframes}, Extended window: ±{extended_window_minutes} minutes")
logger.info(f" (Includes ±{negative_samples_window} candles for negative sampling)")
logger.info(f" Primary symbol: {symbol} - Timeframes: {timeframes}")
logger.info(f" Secondary symbol: {secondary_symbol} - Timeframe: 1m")
logger.info(f" Candles per batch: {candles_per_timeframe}")
# Calculate time range for extended context window
# Calculate time range based on candles needed
# For 600 candles at 1m = 600 minutes = 10 hours
from datetime import timedelta
start_time = timestamp - timedelta(minutes=extended_window_minutes)
end_time = timestamp + timedelta(minutes=extended_window_minutes)
# Fetch data for each timeframe
# Calculate time window for each timeframe to get 600 candles
time_windows = {
'1s': timedelta(seconds=candles_per_timeframe), # 600 seconds = 10 minutes
'1m': timedelta(minutes=candles_per_timeframe), # 600 minutes = 10 hours
'1h': timedelta(hours=candles_per_timeframe), # 600 hours = 25 days
'1d': timedelta(days=candles_per_timeframe) # 600 days = ~1.6 years
}
# Use the largest window to ensure we have enough data for all timeframes
max_window = max(time_windows.values())
start_time = timestamp - max_window
end_time = timestamp
# Fetch data for primary symbol (all timeframes) and secondary symbol (1m only)
market_state = {
'symbol': symbol,
'timestamp': timestamp_str,
'timeframes': {}
'timeframes': {},
'secondary_symbol': secondary_symbol,
'secondary_timeframes': {}
}
# Try to get data from DuckDB storage first (historical data)
@@ -343,21 +376,11 @@ class RealTrainingAdapter:
if hasattr(self.data_provider, 'duckdb_storage'):
duckdb_storage = self.data_provider.duckdb_storage
# Fetch primary symbol data (all timeframes)
logger.info(f" Fetching primary symbol data: {symbol}")
for timeframe in timeframes:
df = None
# Calculate appropriate limit based on timeframe and window
# We want enough candles to cover the extended window plus negative samples
if timeframe == '1s':
limit = extended_window_minutes * 60 * 2 + 100 # 2x for safety + buffer
elif timeframe == '1m':
limit = extended_window_minutes * 2 + 50 # 2x for safety + buffer
elif timeframe == '1h':
limit = max(200, extended_window_minutes // 30) # At least 200 candles
elif timeframe == '1d':
limit = 200 # Fixed for daily
else:
limit = 300
limit = candles_per_timeframe # Always fetch 600 candles
# Try DuckDB storage first (has historical data)
if duckdb_storage:
@@ -410,12 +433,74 @@ class RealTrainingAdapter:
'close': df['close'].tolist(),
'volume': df['volume'].tolist()
}
logger.debug(f" {timeframe}: {len(df)} candles stored")
logger.info(f" {symbol} {timeframe}: {len(df)} candles")
else:
logger.warning(f" {timeframe}: No data available")
logger.warning(f" {symbol} {timeframe}: No data available")
# Fetch secondary symbol data (1m timeframe only, 600 candles)
logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)")
secondary_df = None
# Try DuckDB first
if duckdb_storage:
try:
secondary_df = duckdb_storage.get_ohlcv_data(
symbol=secondary_symbol,
timeframe='1m',
start_time=start_time,
end_time=end_time,
limit=candles_per_timeframe,
direction='latest'
)
if secondary_df is not None and not secondary_df.empty:
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from DuckDB")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: DuckDB query failed: {e}")
# Fallback to replay
if secondary_df is None or secondary_df.empty:
try:
replay_data = self.data_provider.get_historical_data_replay(
symbol=secondary_symbol,
start_time=start_time,
end_time=end_time,
timeframes=['1m']
)
secondary_df = replay_data.get('1m')
if secondary_df is not None and not secondary_df.empty:
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: Replay failed: {e}")
# Last resort: latest data
if secondary_df is None or secondary_df.empty:
logger.warning(f" {secondary_symbol} 1m: No historical data, using latest as fallback")
secondary_df = self.data_provider.get_historical_data(
symbol=secondary_symbol,
timeframe='1m',
limit=candles_per_timeframe
)
# Store secondary symbol data
if secondary_df is not None and not secondary_df.empty:
market_state['secondary_timeframes']['1m'] = {
'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': secondary_df['open'].tolist(),
'high': secondary_df['high'].tolist(),
'low': secondary_df['low'].tolist(),
'close': secondary_df['close'].tolist(),
'volume': secondary_df['volume'].tolist()
}
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles")
else:
logger.warning(f" {secondary_symbol} 1m: No data available")
# Verify we have data
if market_state['timeframes']:
logger.info(f" Fetched market state with {len(market_state['timeframes'])} timeframes")
total_primary = sum(len(tf_data.get('timestamps', [])) for tf_data in market_state['timeframes'].values())
total_secondary = sum(len(tf_data.get('timestamps', [])) for tf_data in market_state['secondary_timeframes'].values())
logger.info(f" [OK] Fetched {len(market_state['timeframes'])} primary timeframes ({total_primary} total candles)")
logger.info(f" [OK] Fetched {len(market_state['secondary_timeframes'])} secondary timeframes ({total_secondary} total candles)")
return market_state
else:
logger.warning(f" No market data fetched for any timeframe")
@@ -483,7 +568,7 @@ class RealTrainingAdapter:
}
training_data.append(entry_sample)
logger.debug(f" Entry sample: {entry_sample['direction']} @ {entry_sample['entry_price']}")
logger.info(f" Test case {i+1}: ENTRY sample - {entry_sample['direction']} @ {entry_sample['entry_price']}")
# Create HOLD samples (every candle while position is open)
# This teaches the model to maintain the position until exit
@@ -494,7 +579,8 @@ class RealTrainingAdapter:
)
training_data.extend(hold_samples)
logger.debug(f" Added {len(hold_samples)} HOLD samples (during position)")
if hold_samples:
logger.info(f" Test case {i+1}: Added {len(hold_samples)} HOLD samples (during position)")
# Create EXIT sample (where model SHOULD exit trade)
exit_timestamp = test_case.get('annotation_metadata', {}).get('exit_timestamp')
@@ -511,10 +597,11 @@ class RealTrainingAdapter:
'repetitions': training_repetitions
}
training_data.append(exit_sample)
logger.debug(f" Exit sample @ {exit_sample['exit_price']} ({exit_sample['profit_loss_pct']:.2f}%)")
logger.info(f" Test case {i+1}: EXIT sample @ {exit_sample['exit_price']} ({exit_sample['profit_loss_pct']:.2f}%)")
# Create NEGATIVE samples (where model should NOT trade)
# These are candles before and after the signal
# These are candles before and after the signal (±15 candles)
# This teaches the model to recognize when NOT to enter
negative_samples = self._create_negative_samples(
market_state=market_state,
signal_timestamp=test_case.get('timestamp'),
@@ -523,7 +610,12 @@ class RealTrainingAdapter:
)
training_data.extend(negative_samples)
logger.debug(f" Added {len(negative_samples)} negative samples (±{negative_samples_window} candles)")
if negative_samples:
logger.info(f" Test case {i+1}: Added {len(negative_samples)} NO_TRADE samples (±{negative_samples_window} candles)")
# Show breakdown of before/after
before_count = sum(1 for s in negative_samples if 'before' in str(s.get('timestamp', '')))
after_count = len(negative_samples) - before_count
logger.info(f" -> {before_count} beforesignal, {after_count} after signal")
except Exception as e:
logger.error(f" Error preparing test case {i+1}: {e}")
@@ -1222,16 +1314,21 @@ class RealTrainingAdapter:
actions = torch.tensor([action], dtype=torch.long)
# Future price target
# Future price target - NORMALIZED
# Model predicts price change ratio, not absolute price
entry_price = training_sample.get('entry_price')
exit_price = training_sample.get('exit_price')
current_price = closes_for_tech[-1] # Most recent close price
if exit_price and entry_price:
future_price = exit_price
# Normalize: (exit_price - current_price) / current_price
# This gives the expected price change as a ratio
future_price_ratio = (exit_price - current_price) / current_price
else:
future_price = closes[-1] # Current price for HOLD
# For HOLD samples, expect no price change
future_price_ratio = 0.0
future_prices = torch.tensor([future_price], dtype=torch.float32)
future_prices = torch.tensor([future_price_ratio], dtype=torch.float32)
# Trade success (1.0 if profitable, 0.0 otherwise)
# Shape must be [batch_size, 1] to match confidence head output
@@ -1321,7 +1418,7 @@ class RealTrainingAdapter:
num_batches += 1
if (i + 1) % 100 == 0:
logger.debug(f" Batch {i + 1}/{len(converted_batches)}, Loss: {result.get('total_loss', 0.0):.6f}")
logger.info(f" Batch {i + 1}/{len(converted_batches)}, Loss: {result.get('total_loss', 0.0):.6f}, Accuracy: {result.get('accuracy', 0.0):.2%}")
except Exception as e:
logger.error(f" Error in batch {i + 1}: {e}")