beef up DQN model, fix training issues

This commit is contained in:
Dobromir Popov
2025-07-27 20:48:44 +03:00
parent 1894d453c9
commit bd986f4534
6 changed files with 414 additions and 55 deletions

View File

@ -23,8 +23,9 @@ logger = logging.getLogger(__name__)
class DQNNetwork(nn.Module):
"""
Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
Massive Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
Handles 7850 input features from multi-timeframe, multi-asset data
TARGET: 50M parameters for enhanced learning capacity
"""
def __init__(self, input_dim: int, n_actions: int):
super(DQNNetwork, self).__init__()
@ -40,36 +41,102 @@ class DQNNetwork(nn.Module):
self.n_actions = n_actions
# Deep network architecture optimized for trading features
self.network = nn.Sequential(
# Input layer
nn.Linear(self.input_size, 2048),
nn.ReLU(),
nn.Dropout(0.3),
# MASSIVE network architecture optimized for trading features
# Target: ~50M parameters
self.feature_extractor = nn.Sequential(
# Initial feature extraction with massive width
nn.Linear(self.input_size, 8192), # 7850 -> 8192 = ~64M weights
nn.LayerNorm(8192),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
# Hidden layers with residual-like connections
# Deep feature processing layers
nn.Linear(8192, 6144), # 8192 -> 6144 = ~50M weights
nn.LayerNorm(6144),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(6144, 4096), # 6144 -> 4096 = ~25M weights
nn.LayerNorm(4096),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(4096, 3072), # 4096 -> 3072 = ~12M weights
nn.LayerNorm(3072),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(3072, 2048), # 3072 -> 2048 = ~6M weights
nn.LayerNorm(2048),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
)
# Market regime detection head
self.regime_head = nn.Sequential(
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.LayerNorm(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
# Output layer for Q-values
nn.Linear(128, n_actions)
nn.LayerNorm(512),
nn.ReLU(inplace=True),
nn.Linear(512, 4) # trending, ranging, volatile, mixed
)
# Price prediction head
self.price_head = nn.Sequential(
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(1024, 512),
nn.LayerNorm(512),
nn.ReLU(inplace=True),
nn.Linear(512, 3) # short, medium, long term price direction
)
# Volatility prediction head
self.volatility_head = nn.Sequential(
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(1024, 256),
nn.LayerNorm(256),
nn.ReLU(inplace=True),
nn.Linear(256, 1) # predicted volatility
)
# Main Q-value head (dueling architecture)
self.value_head = nn.Sequential(
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(1024, 512),
nn.LayerNorm(512),
nn.ReLU(inplace=True),
nn.Linear(512, 1) # State value
)
self.advantage_head = nn.Sequential(
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(1024, 512),
nn.LayerNorm(512),
nn.ReLU(inplace=True),
nn.Linear(512, n_actions) # Action advantages
)
# Initialize weights
self._initialize_weights()
# Log parameter count
total_params = sum(p.numel() for p in self.parameters())
logger.info(f"DQN Network initialized with {total_params:,} parameters (target: 50M)")
def _initialize_weights(self):
"""Initialize network weights using Xavier initialization"""
@ -78,6 +145,9 @@ class DQNNetwork(nn.Module):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
def forward(self, x):
"""Forward pass through the network"""
@ -87,7 +157,22 @@ class DQNNetwork(nn.Module):
elif x.dim() == 1:
x = x.unsqueeze(0) # Add batch dimension if needed
return self.network(x)
# Feature extraction
features = self.feature_extractor(x)
# Multiple prediction heads
regime_pred = self.regime_head(features)
price_pred = self.price_head(features)
volatility_pred = self.volatility_head(features)
# Dueling Q-network
value = self.value_head(features)
advantage = self.advantage_head(features)
# Combine value and advantage for Q-values
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values, regime_pred, price_pred, volatility_pred, features
def act(self, state, explore=True):
"""
@ -111,7 +196,7 @@ class DQNNetwork(nn.Module):
state = state.unsqueeze(0)
with torch.no_grad():
q_values = self.forward(state)
q_values, regime_pred, price_pred, volatility_pred, features = self.forward(state)
# Get action probabilities using softmax
action_probs = F.softmax(q_values, dim=1)
@ -1010,22 +1095,34 @@ class DQNAgent:
logger.warning("Empty batch in _replay_standard")
return 0.0
# Get current Q values using safe wrapper
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Ensure model is in training mode for gradients
self.policy_net.train()
# Get current Q values - use the updated forward method
q_values_output = self.policy_net(states)
if isinstance(q_values_output, tuple):
current_q_values_all = q_values_output[0] # Extract Q-values from tuple
else:
current_q_values_all = q_values_output
current_q_values = current_q_values_all.gather(1, actions.unsqueeze(1)).squeeze(1)
# Enhanced Double DQN implementation
with torch.no_grad():
if self.use_double_dqn:
# Double DQN: Use policy network to select actions, target network to evaluate
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
policy_output = self.policy_net(next_states)
policy_q_values = policy_output[0] if isinstance(policy_output, tuple) else policy_output
next_actions = policy_q_values.argmax(1)
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
target_output = self.target_net(next_states)
target_q_values_all = target_output[0] if isinstance(target_output, tuple) else target_output
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
else:
# Standard DQN: Use target network for both selection and evaluation
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
next_q_values = next_q_values.max(1)[0]
target_output = self.target_net(next_states)
target_q_values = target_output[0] if isinstance(target_output, tuple) else target_output
next_q_values = target_q_values.max(1)[0]
# Ensure tensor shapes are consistent
batch_size = states.shape[0]
@ -1043,26 +1140,15 @@ class DQNAgent:
# Compute loss for Q value - ensure tensors require gradients
if not current_q_values.requires_grad:
logger.warning("Current Q values do not require gradients")
# Force training mode
self.policy_net.train()
return 0.0
q_loss = self.criterion(current_q_values, target_q_values.detach())
# Initialize total loss with Q loss
# Use only Q-loss for now to ensure clean gradients
total_loss = q_loss
# Add auxiliary losses if available and valid
try:
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
# Create simple extrema targets based on Q-values
with torch.no_grad():
extrema_targets = torch.ones(current_extrema_pred.shape[0], dtype=torch.long, device=current_extrema_pred.device) * 2 # Default to "neither"
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
total_loss = total_loss + 0.1 * extrema_loss
except Exception as e:
logger.debug(f"Could not calculate auxiliary loss: {e}")
# Reset gradients
self.optimizer.zero_grad()

View File

@ -28,10 +28,14 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass, field
import ta
import warnings
from threading import Thread, Lock
from collections import deque
import math
# Suppress ta library deprecation warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="ta")
from .config import get_config
from .tick_aggregator import RealTimeTickAggregator, RawTick, OHLCVBar
from .cnn_monitor import log_cnn_prediction
@ -1127,17 +1131,20 @@ class DataProvider:
# Convert timestamp to datetime if needed
if isinstance(timestamp, (int, float)):
tick_time = datetime.fromtimestamp(timestamp, tz=pd.Timestamp.now().tz)
# If no timezone info, assume UTC and convert to Europe/Sofia
if tick_time.tzinfo is None:
tick_time = tick_time.replace(tzinfo=pd.Timestamp.now(tz='UTC').tz)
tick_time = tick_time.astimezone(pd.Timestamp.now(tz='Europe/Sofia').tz)
import pytz
utc = pytz.UTC
sofia_tz = pytz.timezone('Europe/Sofia')
tick_time = datetime.fromtimestamp(timestamp, tz=utc)
tick_time = tick_time.astimezone(sofia_tz)
elif isinstance(timestamp, datetime):
import pytz
sofia_tz = pytz.timezone('Europe/Sofia')
tick_time = timestamp
# If no timezone info, assume UTC and convert to Europe/Sofia
if tick_time.tzinfo is None:
tick_time = tick_time.replace(tzinfo=pd.Timestamp.now(tz='UTC').tz)
tick_time = tick_time.astimezone(pd.Timestamp.now(tz='Europe/Sofia').tz)
utc = pytz.UTC
tick_time = utc.localize(tick_time)
tick_time = tick_time.astimezone(sofia_tz)
else:
continue
@ -1177,6 +1184,16 @@ class DataProvider:
# Convert to DataFrame
df = pd.DataFrame(candles)
# Ensure timestamps are timezone-aware (Europe/Sofia)
if not df.empty and 'timestamp' in df.columns:
import pytz
sofia_tz = pytz.timezone('Europe/Sofia')
# If timestamps are not timezone-aware, make them Europe/Sofia
if df['timestamp'].dt.tz is None:
df['timestamp'] = df['timestamp'].dt.tz_localize(sofia_tz)
else:
df['timestamp'] = df['timestamp'].dt.tz_convert(sofia_tz)
df = df.sort_values('timestamp').reset_index(drop=True)
# Limit to requested number
@ -1991,6 +2008,15 @@ class DataProvider:
if cache_file.exists():
try:
df = pd.read_parquet(cache_file)
# Ensure cached monthly data has proper timezone (Europe/Sofia)
if not df.empty and 'timestamp' in df.columns:
if df['timestamp'].dt.tz is None:
# If no timezone info, assume UTC and convert to Europe/Sofia
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
elif str(df['timestamp'].dt.tz) != 'Europe/Sofia':
# Convert to Europe/Sofia if different timezone
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
logger.info(f"Loaded {len(df)} 1m candles from cache for {symbol}")
return df
except Exception as parquet_e:
@ -2266,6 +2292,15 @@ class DataProvider:
if cache_age < max_age:
try:
df = pd.read_parquet(cache_file)
# Ensure cached data has proper timezone (Europe/Sofia)
if not df.empty and 'timestamp' in df.columns:
if df['timestamp'].dt.tz is None:
# If no timezone info, assume UTC and convert to Europe/Sofia
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
elif str(df['timestamp'].dt.tz) != 'Europe/Sofia':
# Convert to Europe/Sofia if different timezone
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
logger.debug(f"Loaded {len(df)} rows from cache for {symbol} {timeframe} (age: {cache_age/60:.1f}min)")
return df
except Exception as parquet_e:

View File

@ -1672,7 +1672,7 @@ class TradingOrchestrator:
processing_time_ms=0.0, # We don't track this in orchestrator
memory_usage_mb=0.0, # We don't track this in orchestrator
input_features=input_features_array,
checkpoint_id=None,
checkpoint_id=None,f
metadata=inference_record.get('metadata', {})
)

Binary file not shown.

102
test_timezone_fix.py Normal file
View File

@ -0,0 +1,102 @@
#!/usr/bin/env python3
"""
Test Timezone Fix
This script tests that historical data timestamps are properly converted to Europe/Sofia timezone.
"""
import asyncio
import pandas as pd
from datetime import datetime
from core.data_provider import DataProvider
async def test_timezone_fix():
"""Test the timezone conversion fix"""
print("=== Testing Timezone Fix ===")
# Initialize data provider
print("1. Initializing data provider...")
data_provider = DataProvider()
# Wait for initialization
await asyncio.sleep(2)
# Test different timeframes
timeframes = ['1m', '1h', '1d']
symbol = 'ETH/USDT'
for timeframe in timeframes:
print(f"\n2. Testing {timeframe} data for {symbol}:")
# Get historical data
df = data_provider.get_historical_data(symbol, timeframe, limit=5)
if df is not None and not df.empty:
print(f" ✅ Got {len(df)} candles")
# Check timezone
if 'timestamp' in df.columns:
first_timestamp = df['timestamp'].iloc[0]
last_timestamp = df['timestamp'].iloc[-1]
print(f" First timestamp: {first_timestamp}")
print(f" Last timestamp: {last_timestamp}")
# Check if timezone is Europe/Sofia
if hasattr(first_timestamp, 'tz') and first_timestamp.tz is not None:
timezone_str = str(first_timestamp.tz)
if 'Europe/Sofia' in timezone_str or 'EET' in timezone_str or 'EEST' in timezone_str:
print(f" ✅ Timezone is correct: {timezone_str}")
else:
print(f" ❌ Timezone is incorrect: {timezone_str}")
else:
print(" ❌ No timezone information found")
# Show time difference from UTC
if hasattr(first_timestamp, 'utcoffset') and first_timestamp.utcoffset() is not None:
offset_hours = first_timestamp.utcoffset().total_seconds() / 3600
print(f" UTC offset: {offset_hours:+.0f} hours")
if offset_hours == 2 or offset_hours == 3: # EET (+2) or EEST (+3)
print(" ✅ UTC offset is correct for Europe/Sofia")
else:
print(f" ❌ UTC offset is incorrect: {offset_hours:+.0f} hours")
# Show sample data
print(" Sample data:")
for i in range(min(3, len(df))):
row = df.iloc[i]
print(f" {row['timestamp']}: O={row['open']:.2f} H={row['high']:.2f} L={row['low']:.2f} C={row['close']:.2f}")
else:
print(" ❌ No timestamp column found")
else:
print(f" ❌ No data available for {timeframe}")
# Test current time comparison
print(f"\n3. Current time comparison:")
current_utc = datetime.utcnow()
current_sofia = datetime.now()
print(f" Current UTC time: {current_utc}")
print(f" Current local time: {current_sofia}")
# Calculate expected offset
import pytz
sofia_tz = pytz.timezone('Europe/Sofia')
current_sofia_tz = datetime.now(sofia_tz)
offset_hours = current_sofia_tz.utcoffset().total_seconds() / 3600
print(f" Europe/Sofia current time: {current_sofia_tz}")
print(f" Current UTC offset: {offset_hours:+.0f} hours")
if offset_hours == 2:
print(" ✅ Currently in EET (Eastern European Time)")
elif offset_hours == 3:
print(" ✅ Currently in EEST (Eastern European Summer Time)")
else:
print(f" ❌ Unexpected offset: {offset_hours:+.0f} hours")
print("\n✅ Timezone fix test completed!")
if __name__ == "__main__":
asyncio.run(test_timezone_fix())

136
test_timezone_with_data.py Normal file
View File

@ -0,0 +1,136 @@
#!/usr/bin/env python3
"""
Test Timezone Fix with Data Fetching
This script tests timezone conversion by actually fetching data and checking timestamps.
"""
import asyncio
import pandas as pd
from datetime import datetime
from core.data_provider import DataProvider
async def test_timezone_with_data():
"""Test timezone conversion with actual data fetching"""
print("=== Testing Timezone Fix with Data Fetching ===")
# Initialize data provider
print("1. Initializing data provider...")
data_provider = DataProvider()
# Wait for initialization
await asyncio.sleep(2)
# Test direct Binance API call
print("\n2. Testing direct Binance API call:")
try:
# Call the internal Binance fetch method directly
df = data_provider._fetch_from_binance('ETH/USDT', '1h', 5)
if df is not None and not df.empty:
print(f" ✅ Got {len(df)} candles from Binance API")
# Check timezone
if 'timestamp' in df.columns:
first_timestamp = df['timestamp'].iloc[0]
last_timestamp = df['timestamp'].iloc[-1]
print(f" First timestamp: {first_timestamp}")
print(f" Last timestamp: {last_timestamp}")
# Check if timezone is Europe/Sofia
if hasattr(first_timestamp, 'tz') and first_timestamp.tz is not None:
timezone_str = str(first_timestamp.tz)
print(f" Timezone: {timezone_str}")
if 'Europe/Sofia' in timezone_str or 'EET' in timezone_str or 'EEST' in timezone_str:
print(f" ✅ Timezone is correct: {timezone_str}")
else:
print(f" ❌ Timezone is incorrect: {timezone_str}")
# Show UTC offset
if hasattr(first_timestamp, 'utcoffset') and first_timestamp.utcoffset() is not None:
offset_hours = first_timestamp.utcoffset().total_seconds() / 3600
print(f" UTC offset: {offset_hours:+.0f} hours")
if offset_hours == 2 or offset_hours == 3: # EET (+2) or EEST (+3)
print(" ✅ UTC offset is correct for Europe/Sofia")
else:
print(f" ❌ UTC offset is incorrect: {offset_hours:+.0f} hours")
# Compare with UTC time
print("\n Timestamp comparison:")
for i in range(min(2, len(df))):
row = df.iloc[i]
local_time = row['timestamp']
utc_time = local_time.astimezone(pd.Timestamp.now(tz='UTC').tz)
print(f" Local (Sofia): {local_time}")
print(f" UTC: {utc_time}")
print(f" Difference: {(local_time - utc_time).total_seconds() / 3600:+.0f} hours")
print()
else:
print(" ❌ No timestamp column found")
else:
print(" ❌ No data returned from Binance API")
except Exception as e:
print(f" ❌ Error fetching from Binance: {e}")
# Test MEXC API call as well
print("\n3. Testing MEXC API call:")
try:
df = data_provider._fetch_from_mexc('ETH/USDT', '1h', 3)
if df is not None and not df.empty:
print(f" ✅ Got {len(df)} candles from MEXC API")
# Check timezone
if 'timestamp' in df.columns:
first_timestamp = df['timestamp'].iloc[0]
print(f" First timestamp: {first_timestamp}")
# Check timezone
if hasattr(first_timestamp, 'tz') and first_timestamp.tz is not None:
timezone_str = str(first_timestamp.tz)
print(f" Timezone: {timezone_str}")
if 'Europe/Sofia' in timezone_str or 'EET' in timezone_str or 'EEST' in timezone_str:
print(f" ✅ MEXC timezone is correct: {timezone_str}")
else:
print(f" ❌ MEXC timezone is incorrect: {timezone_str}")
# Show UTC offset
if hasattr(first_timestamp, 'utcoffset') and first_timestamp.utcoffset() is not None:
offset_hours = first_timestamp.utcoffset().total_seconds() / 3600
print(f" UTC offset: {offset_hours:+.0f} hours")
else:
print(" ❌ No data returned from MEXC API")
except Exception as e:
print(f" ❌ Error fetching from MEXC: {e}")
# Show current timezone info
print(f"\n4. Current timezone information:")
import pytz
sofia_tz = pytz.timezone('Europe/Sofia')
current_sofia = datetime.now(sofia_tz)
current_utc = datetime.now(pytz.UTC)
print(f" Current Sofia time: {current_sofia}")
print(f" Current UTC time: {current_utc}")
print(f" Time difference: {(current_sofia - current_utc).total_seconds() / 3600:+.0f} hours")
# Check if it's summer time (EEST) or winter time (EET)
offset_hours = current_sofia.utcoffset().total_seconds() / 3600
if offset_hours == 3:
print(" ✅ Currently in EEST (Eastern European Summer Time)")
elif offset_hours == 2:
print(" ✅ Currently in EET (Eastern European Time)")
else:
print(f" ❌ Unexpected offset: {offset_hours:+.0f} hours")
print("\n✅ Timezone fix test with data completed!")
if __name__ == "__main__":
asyncio.run(test_timezone_with_data())