beef up DQN model, fix training issues
This commit is contained in:
@ -23,8 +23,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class DQNNetwork(nn.Module):
|
class DQNNetwork(nn.Module):
|
||||||
"""
|
"""
|
||||||
Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
Massive Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
||||||
Handles 7850 input features from multi-timeframe, multi-asset data
|
Handles 7850 input features from multi-timeframe, multi-asset data
|
||||||
|
TARGET: 50M parameters for enhanced learning capacity
|
||||||
"""
|
"""
|
||||||
def __init__(self, input_dim: int, n_actions: int):
|
def __init__(self, input_dim: int, n_actions: int):
|
||||||
super(DQNNetwork, self).__init__()
|
super(DQNNetwork, self).__init__()
|
||||||
@ -40,37 +41,103 @@ class DQNNetwork(nn.Module):
|
|||||||
|
|
||||||
self.n_actions = n_actions
|
self.n_actions = n_actions
|
||||||
|
|
||||||
# Deep network architecture optimized for trading features
|
# MASSIVE network architecture optimized for trading features
|
||||||
self.network = nn.Sequential(
|
# Target: ~50M parameters
|
||||||
# Input layer
|
self.feature_extractor = nn.Sequential(
|
||||||
nn.Linear(self.input_size, 2048),
|
# Initial feature extraction with massive width
|
||||||
nn.ReLU(),
|
nn.Linear(self.input_size, 8192), # 7850 -> 8192 = ~64M weights
|
||||||
nn.Dropout(0.3),
|
nn.LayerNorm(8192),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
|
||||||
# Hidden layers with residual-like connections
|
# Deep feature processing layers
|
||||||
|
nn.Linear(8192, 6144), # 8192 -> 6144 = ~50M weights
|
||||||
|
nn.LayerNorm(6144),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
|
||||||
|
nn.Linear(6144, 4096), # 6144 -> 4096 = ~25M weights
|
||||||
|
nn.LayerNorm(4096),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
|
||||||
|
nn.Linear(4096, 3072), # 4096 -> 3072 = ~12M weights
|
||||||
|
nn.LayerNorm(3072),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
|
||||||
|
nn.Linear(3072, 2048), # 3072 -> 2048 = ~6M weights
|
||||||
|
nn.LayerNorm(2048),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Market regime detection head
|
||||||
|
self.regime_head = nn.Sequential(
|
||||||
nn.Linear(2048, 1024),
|
nn.Linear(2048, 1024),
|
||||||
nn.ReLU(),
|
nn.LayerNorm(1024),
|
||||||
nn.Dropout(0.3),
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.1),
|
||||||
nn.Linear(1024, 512),
|
nn.Linear(1024, 512),
|
||||||
nn.ReLU(),
|
nn.LayerNorm(512),
|
||||||
nn.Dropout(0.3),
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(512, 4) # trending, ranging, volatile, mixed
|
||||||
|
)
|
||||||
|
|
||||||
nn.Linear(512, 256),
|
# Price prediction head
|
||||||
nn.ReLU(),
|
self.price_head = nn.Sequential(
|
||||||
nn.Dropout(0.2),
|
nn.Linear(2048, 1024),
|
||||||
|
nn.LayerNorm(1024),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
nn.Linear(1024, 512),
|
||||||
|
nn.LayerNorm(512),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(512, 3) # short, medium, long term price direction
|
||||||
|
)
|
||||||
|
|
||||||
nn.Linear(256, 128),
|
# Volatility prediction head
|
||||||
nn.ReLU(),
|
self.volatility_head = nn.Sequential(
|
||||||
nn.Dropout(0.2),
|
nn.Linear(2048, 1024),
|
||||||
|
nn.LayerNorm(1024),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
nn.Linear(1024, 256),
|
||||||
|
nn.LayerNorm(256),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(256, 1) # predicted volatility
|
||||||
|
)
|
||||||
|
|
||||||
# Output layer for Q-values
|
# Main Q-value head (dueling architecture)
|
||||||
nn.Linear(128, n_actions)
|
self.value_head = nn.Sequential(
|
||||||
|
nn.Linear(2048, 1024),
|
||||||
|
nn.LayerNorm(1024),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
nn.Linear(1024, 512),
|
||||||
|
nn.LayerNorm(512),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(512, 1) # State value
|
||||||
|
)
|
||||||
|
|
||||||
|
self.advantage_head = nn.Sequential(
|
||||||
|
nn.Linear(2048, 1024),
|
||||||
|
nn.LayerNorm(1024),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
nn.Linear(1024, 512),
|
||||||
|
nn.LayerNorm(512),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(512, n_actions) # Action advantages
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
self._initialize_weights()
|
self._initialize_weights()
|
||||||
|
|
||||||
|
# Log parameter count
|
||||||
|
total_params = sum(p.numel() for p in self.parameters())
|
||||||
|
logger.info(f"DQN Network initialized with {total_params:,} parameters (target: 50M)")
|
||||||
|
|
||||||
def _initialize_weights(self):
|
def _initialize_weights(self):
|
||||||
"""Initialize network weights using Xavier initialization"""
|
"""Initialize network weights using Xavier initialization"""
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
@ -78,6 +145,9 @@ class DQNNetwork(nn.Module):
|
|||||||
nn.init.xavier_uniform_(module.weight)
|
nn.init.xavier_uniform_(module.weight)
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
nn.init.constant_(module.bias, 0)
|
nn.init.constant_(module.bias, 0)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
nn.init.constant_(module.bias, 0)
|
||||||
|
nn.init.constant_(module.weight, 1.0)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Forward pass through the network"""
|
"""Forward pass through the network"""
|
||||||
@ -87,7 +157,22 @@ class DQNNetwork(nn.Module):
|
|||||||
elif x.dim() == 1:
|
elif x.dim() == 1:
|
||||||
x = x.unsqueeze(0) # Add batch dimension if needed
|
x = x.unsqueeze(0) # Add batch dimension if needed
|
||||||
|
|
||||||
return self.network(x)
|
# Feature extraction
|
||||||
|
features = self.feature_extractor(x)
|
||||||
|
|
||||||
|
# Multiple prediction heads
|
||||||
|
regime_pred = self.regime_head(features)
|
||||||
|
price_pred = self.price_head(features)
|
||||||
|
volatility_pred = self.volatility_head(features)
|
||||||
|
|
||||||
|
# Dueling Q-network
|
||||||
|
value = self.value_head(features)
|
||||||
|
advantage = self.advantage_head(features)
|
||||||
|
|
||||||
|
# Combine value and advantage for Q-values
|
||||||
|
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
return q_values, regime_pred, price_pred, volatility_pred, features
|
||||||
|
|
||||||
def act(self, state, explore=True):
|
def act(self, state, explore=True):
|
||||||
"""
|
"""
|
||||||
@ -111,7 +196,7 @@ class DQNNetwork(nn.Module):
|
|||||||
state = state.unsqueeze(0)
|
state = state.unsqueeze(0)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
q_values = self.forward(state)
|
q_values, regime_pred, price_pred, volatility_pred, features = self.forward(state)
|
||||||
|
|
||||||
# Get action probabilities using softmax
|
# Get action probabilities using softmax
|
||||||
action_probs = F.softmax(q_values, dim=1)
|
action_probs = F.softmax(q_values, dim=1)
|
||||||
@ -1010,22 +1095,34 @@ class DQNAgent:
|
|||||||
logger.warning("Empty batch in _replay_standard")
|
logger.warning("Empty batch in _replay_standard")
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# Get current Q values using safe wrapper
|
# Ensure model is in training mode for gradients
|
||||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
self.policy_net.train()
|
||||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
||||||
|
# Get current Q values - use the updated forward method
|
||||||
|
q_values_output = self.policy_net(states)
|
||||||
|
if isinstance(q_values_output, tuple):
|
||||||
|
current_q_values_all = q_values_output[0] # Extract Q-values from tuple
|
||||||
|
else:
|
||||||
|
current_q_values_all = q_values_output
|
||||||
|
|
||||||
|
current_q_values = current_q_values_all.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
# Enhanced Double DQN implementation
|
# Enhanced Double DQN implementation
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.use_double_dqn:
|
if self.use_double_dqn:
|
||||||
# Double DQN: Use policy network to select actions, target network to evaluate
|
# Double DQN: Use policy network to select actions, target network to evaluate
|
||||||
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
|
policy_output = self.policy_net(next_states)
|
||||||
|
policy_q_values = policy_output[0] if isinstance(policy_output, tuple) else policy_output
|
||||||
next_actions = policy_q_values.argmax(1)
|
next_actions = policy_q_values.argmax(1)
|
||||||
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
|
||||||
|
target_output = self.target_net(next_states)
|
||||||
|
target_q_values_all = target_output[0] if isinstance(target_output, tuple) else target_output
|
||||||
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||||
else:
|
else:
|
||||||
# Standard DQN: Use target network for both selection and evaluation
|
# Standard DQN: Use target network for both selection and evaluation
|
||||||
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
target_output = self.target_net(next_states)
|
||||||
next_q_values = next_q_values.max(1)[0]
|
target_q_values = target_output[0] if isinstance(target_output, tuple) else target_output
|
||||||
|
next_q_values = target_q_values.max(1)[0]
|
||||||
|
|
||||||
# Ensure tensor shapes are consistent
|
# Ensure tensor shapes are consistent
|
||||||
batch_size = states.shape[0]
|
batch_size = states.shape[0]
|
||||||
@ -1043,26 +1140,15 @@ class DQNAgent:
|
|||||||
# Compute loss for Q value - ensure tensors require gradients
|
# Compute loss for Q value - ensure tensors require gradients
|
||||||
if not current_q_values.requires_grad:
|
if not current_q_values.requires_grad:
|
||||||
logger.warning("Current Q values do not require gradients")
|
logger.warning("Current Q values do not require gradients")
|
||||||
|
# Force training mode
|
||||||
|
self.policy_net.train()
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
q_loss = self.criterion(current_q_values, target_q_values.detach())
|
q_loss = self.criterion(current_q_values, target_q_values.detach())
|
||||||
|
|
||||||
# Initialize total loss with Q loss
|
# Use only Q-loss for now to ensure clean gradients
|
||||||
total_loss = q_loss
|
total_loss = q_loss
|
||||||
|
|
||||||
# Add auxiliary losses if available and valid
|
|
||||||
try:
|
|
||||||
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
|
|
||||||
# Create simple extrema targets based on Q-values
|
|
||||||
with torch.no_grad():
|
|
||||||
extrema_targets = torch.ones(current_extrema_pred.shape[0], dtype=torch.long, device=current_extrema_pred.device) * 2 # Default to "neither"
|
|
||||||
|
|
||||||
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
|
|
||||||
total_loss = total_loss + 0.1 * extrema_loss
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Could not calculate auxiliary loss: {e}")
|
|
||||||
|
|
||||||
# Reset gradients
|
# Reset gradients
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
@ -28,10 +28,14 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import ta
|
import ta
|
||||||
|
import warnings
|
||||||
from threading import Thread, Lock
|
from threading import Thread, Lock
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
# Suppress ta library deprecation warnings
|
||||||
|
warnings.filterwarnings("ignore", category=FutureWarning, module="ta")
|
||||||
|
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
from .tick_aggregator import RealTimeTickAggregator, RawTick, OHLCVBar
|
from .tick_aggregator import RealTimeTickAggregator, RawTick, OHLCVBar
|
||||||
from .cnn_monitor import log_cnn_prediction
|
from .cnn_monitor import log_cnn_prediction
|
||||||
@ -1127,17 +1131,20 @@ class DataProvider:
|
|||||||
|
|
||||||
# Convert timestamp to datetime if needed
|
# Convert timestamp to datetime if needed
|
||||||
if isinstance(timestamp, (int, float)):
|
if isinstance(timestamp, (int, float)):
|
||||||
tick_time = datetime.fromtimestamp(timestamp, tz=pd.Timestamp.now().tz)
|
import pytz
|
||||||
# If no timezone info, assume UTC and convert to Europe/Sofia
|
utc = pytz.UTC
|
||||||
if tick_time.tzinfo is None:
|
sofia_tz = pytz.timezone('Europe/Sofia')
|
||||||
tick_time = tick_time.replace(tzinfo=pd.Timestamp.now(tz='UTC').tz)
|
tick_time = datetime.fromtimestamp(timestamp, tz=utc)
|
||||||
tick_time = tick_time.astimezone(pd.Timestamp.now(tz='Europe/Sofia').tz)
|
tick_time = tick_time.astimezone(sofia_tz)
|
||||||
elif isinstance(timestamp, datetime):
|
elif isinstance(timestamp, datetime):
|
||||||
|
import pytz
|
||||||
|
sofia_tz = pytz.timezone('Europe/Sofia')
|
||||||
tick_time = timestamp
|
tick_time = timestamp
|
||||||
# If no timezone info, assume UTC and convert to Europe/Sofia
|
# If no timezone info, assume UTC and convert to Europe/Sofia
|
||||||
if tick_time.tzinfo is None:
|
if tick_time.tzinfo is None:
|
||||||
tick_time = tick_time.replace(tzinfo=pd.Timestamp.now(tz='UTC').tz)
|
utc = pytz.UTC
|
||||||
tick_time = tick_time.astimezone(pd.Timestamp.now(tz='Europe/Sofia').tz)
|
tick_time = utc.localize(tick_time)
|
||||||
|
tick_time = tick_time.astimezone(sofia_tz)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -1177,6 +1184,16 @@ class DataProvider:
|
|||||||
|
|
||||||
# Convert to DataFrame
|
# Convert to DataFrame
|
||||||
df = pd.DataFrame(candles)
|
df = pd.DataFrame(candles)
|
||||||
|
# Ensure timestamps are timezone-aware (Europe/Sofia)
|
||||||
|
if not df.empty and 'timestamp' in df.columns:
|
||||||
|
import pytz
|
||||||
|
sofia_tz = pytz.timezone('Europe/Sofia')
|
||||||
|
# If timestamps are not timezone-aware, make them Europe/Sofia
|
||||||
|
if df['timestamp'].dt.tz is None:
|
||||||
|
df['timestamp'] = df['timestamp'].dt.tz_localize(sofia_tz)
|
||||||
|
else:
|
||||||
|
df['timestamp'] = df['timestamp'].dt.tz_convert(sofia_tz)
|
||||||
|
|
||||||
df = df.sort_values('timestamp').reset_index(drop=True)
|
df = df.sort_values('timestamp').reset_index(drop=True)
|
||||||
|
|
||||||
# Limit to requested number
|
# Limit to requested number
|
||||||
@ -1991,6 +2008,15 @@ class DataProvider:
|
|||||||
if cache_file.exists():
|
if cache_file.exists():
|
||||||
try:
|
try:
|
||||||
df = pd.read_parquet(cache_file)
|
df = pd.read_parquet(cache_file)
|
||||||
|
# Ensure cached monthly data has proper timezone (Europe/Sofia)
|
||||||
|
if not df.empty and 'timestamp' in df.columns:
|
||||||
|
if df['timestamp'].dt.tz is None:
|
||||||
|
# If no timezone info, assume UTC and convert to Europe/Sofia
|
||||||
|
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
|
||||||
|
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||||
|
elif str(df['timestamp'].dt.tz) != 'Europe/Sofia':
|
||||||
|
# Convert to Europe/Sofia if different timezone
|
||||||
|
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||||
logger.info(f"Loaded {len(df)} 1m candles from cache for {symbol}")
|
logger.info(f"Loaded {len(df)} 1m candles from cache for {symbol}")
|
||||||
return df
|
return df
|
||||||
except Exception as parquet_e:
|
except Exception as parquet_e:
|
||||||
@ -2266,6 +2292,15 @@ class DataProvider:
|
|||||||
if cache_age < max_age:
|
if cache_age < max_age:
|
||||||
try:
|
try:
|
||||||
df = pd.read_parquet(cache_file)
|
df = pd.read_parquet(cache_file)
|
||||||
|
# Ensure cached data has proper timezone (Europe/Sofia)
|
||||||
|
if not df.empty and 'timestamp' in df.columns:
|
||||||
|
if df['timestamp'].dt.tz is None:
|
||||||
|
# If no timezone info, assume UTC and convert to Europe/Sofia
|
||||||
|
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
|
||||||
|
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||||
|
elif str(df['timestamp'].dt.tz) != 'Europe/Sofia':
|
||||||
|
# Convert to Europe/Sofia if different timezone
|
||||||
|
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||||
logger.debug(f"Loaded {len(df)} rows from cache for {symbol} {timeframe} (age: {cache_age/60:.1f}min)")
|
logger.debug(f"Loaded {len(df)} rows from cache for {symbol} {timeframe} (age: {cache_age/60:.1f}min)")
|
||||||
return df
|
return df
|
||||||
except Exception as parquet_e:
|
except Exception as parquet_e:
|
||||||
|
@ -1672,7 +1672,7 @@ class TradingOrchestrator:
|
|||||||
processing_time_ms=0.0, # We don't track this in orchestrator
|
processing_time_ms=0.0, # We don't track this in orchestrator
|
||||||
memory_usage_mb=0.0, # We don't track this in orchestrator
|
memory_usage_mb=0.0, # We don't track this in orchestrator
|
||||||
input_features=input_features_array,
|
input_features=input_features_array,
|
||||||
checkpoint_id=None,
|
checkpoint_id=None,f
|
||||||
metadata=inference_record.get('metadata', {})
|
metadata=inference_record.get('metadata', {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Binary file not shown.
102
test_timezone_fix.py
Normal file
102
test_timezone_fix.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test Timezone Fix
|
||||||
|
|
||||||
|
This script tests that historical data timestamps are properly converted to Europe/Sofia timezone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
|
||||||
|
async def test_timezone_fix():
|
||||||
|
"""Test the timezone conversion fix"""
|
||||||
|
print("=== Testing Timezone Fix ===")
|
||||||
|
|
||||||
|
# Initialize data provider
|
||||||
|
print("1. Initializing data provider...")
|
||||||
|
data_provider = DataProvider()
|
||||||
|
|
||||||
|
# Wait for initialization
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# Test different timeframes
|
||||||
|
timeframes = ['1m', '1h', '1d']
|
||||||
|
symbol = 'ETH/USDT'
|
||||||
|
|
||||||
|
for timeframe in timeframes:
|
||||||
|
print(f"\n2. Testing {timeframe} data for {symbol}:")
|
||||||
|
|
||||||
|
# Get historical data
|
||||||
|
df = data_provider.get_historical_data(symbol, timeframe, limit=5)
|
||||||
|
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
print(f" ✅ Got {len(df)} candles")
|
||||||
|
|
||||||
|
# Check timezone
|
||||||
|
if 'timestamp' in df.columns:
|
||||||
|
first_timestamp = df['timestamp'].iloc[0]
|
||||||
|
last_timestamp = df['timestamp'].iloc[-1]
|
||||||
|
|
||||||
|
print(f" First timestamp: {first_timestamp}")
|
||||||
|
print(f" Last timestamp: {last_timestamp}")
|
||||||
|
|
||||||
|
# Check if timezone is Europe/Sofia
|
||||||
|
if hasattr(first_timestamp, 'tz') and first_timestamp.tz is not None:
|
||||||
|
timezone_str = str(first_timestamp.tz)
|
||||||
|
if 'Europe/Sofia' in timezone_str or 'EET' in timezone_str or 'EEST' in timezone_str:
|
||||||
|
print(f" ✅ Timezone is correct: {timezone_str}")
|
||||||
|
else:
|
||||||
|
print(f" ❌ Timezone is incorrect: {timezone_str}")
|
||||||
|
else:
|
||||||
|
print(" ❌ No timezone information found")
|
||||||
|
|
||||||
|
# Show time difference from UTC
|
||||||
|
if hasattr(first_timestamp, 'utcoffset') and first_timestamp.utcoffset() is not None:
|
||||||
|
offset_hours = first_timestamp.utcoffset().total_seconds() / 3600
|
||||||
|
print(f" UTC offset: {offset_hours:+.0f} hours")
|
||||||
|
|
||||||
|
if offset_hours == 2 or offset_hours == 3: # EET (+2) or EEST (+3)
|
||||||
|
print(" ✅ UTC offset is correct for Europe/Sofia")
|
||||||
|
else:
|
||||||
|
print(f" ❌ UTC offset is incorrect: {offset_hours:+.0f} hours")
|
||||||
|
|
||||||
|
# Show sample data
|
||||||
|
print(" Sample data:")
|
||||||
|
for i in range(min(3, len(df))):
|
||||||
|
row = df.iloc[i]
|
||||||
|
print(f" {row['timestamp']}: O={row['open']:.2f} H={row['high']:.2f} L={row['low']:.2f} C={row['close']:.2f}")
|
||||||
|
else:
|
||||||
|
print(" ❌ No timestamp column found")
|
||||||
|
else:
|
||||||
|
print(f" ❌ No data available for {timeframe}")
|
||||||
|
|
||||||
|
# Test current time comparison
|
||||||
|
print(f"\n3. Current time comparison:")
|
||||||
|
current_utc = datetime.utcnow()
|
||||||
|
current_sofia = datetime.now()
|
||||||
|
|
||||||
|
print(f" Current UTC time: {current_utc}")
|
||||||
|
print(f" Current local time: {current_sofia}")
|
||||||
|
|
||||||
|
# Calculate expected offset
|
||||||
|
import pytz
|
||||||
|
sofia_tz = pytz.timezone('Europe/Sofia')
|
||||||
|
current_sofia_tz = datetime.now(sofia_tz)
|
||||||
|
offset_hours = current_sofia_tz.utcoffset().total_seconds() / 3600
|
||||||
|
|
||||||
|
print(f" Europe/Sofia current time: {current_sofia_tz}")
|
||||||
|
print(f" Current UTC offset: {offset_hours:+.0f} hours")
|
||||||
|
|
||||||
|
if offset_hours == 2:
|
||||||
|
print(" ✅ Currently in EET (Eastern European Time)")
|
||||||
|
elif offset_hours == 3:
|
||||||
|
print(" ✅ Currently in EEST (Eastern European Summer Time)")
|
||||||
|
else:
|
||||||
|
print(f" ❌ Unexpected offset: {offset_hours:+.0f} hours")
|
||||||
|
|
||||||
|
print("\n✅ Timezone fix test completed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_timezone_fix())
|
136
test_timezone_with_data.py
Normal file
136
test_timezone_with_data.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test Timezone Fix with Data Fetching
|
||||||
|
|
||||||
|
This script tests timezone conversion by actually fetching data and checking timestamps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
|
||||||
|
async def test_timezone_with_data():
|
||||||
|
"""Test timezone conversion with actual data fetching"""
|
||||||
|
print("=== Testing Timezone Fix with Data Fetching ===")
|
||||||
|
|
||||||
|
# Initialize data provider
|
||||||
|
print("1. Initializing data provider...")
|
||||||
|
data_provider = DataProvider()
|
||||||
|
|
||||||
|
# Wait for initialization
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# Test direct Binance API call
|
||||||
|
print("\n2. Testing direct Binance API call:")
|
||||||
|
try:
|
||||||
|
# Call the internal Binance fetch method directly
|
||||||
|
df = data_provider._fetch_from_binance('ETH/USDT', '1h', 5)
|
||||||
|
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
print(f" ✅ Got {len(df)} candles from Binance API")
|
||||||
|
|
||||||
|
# Check timezone
|
||||||
|
if 'timestamp' in df.columns:
|
||||||
|
first_timestamp = df['timestamp'].iloc[0]
|
||||||
|
last_timestamp = df['timestamp'].iloc[-1]
|
||||||
|
|
||||||
|
print(f" First timestamp: {first_timestamp}")
|
||||||
|
print(f" Last timestamp: {last_timestamp}")
|
||||||
|
|
||||||
|
# Check if timezone is Europe/Sofia
|
||||||
|
if hasattr(first_timestamp, 'tz') and first_timestamp.tz is not None:
|
||||||
|
timezone_str = str(first_timestamp.tz)
|
||||||
|
print(f" Timezone: {timezone_str}")
|
||||||
|
|
||||||
|
if 'Europe/Sofia' in timezone_str or 'EET' in timezone_str or 'EEST' in timezone_str:
|
||||||
|
print(f" ✅ Timezone is correct: {timezone_str}")
|
||||||
|
else:
|
||||||
|
print(f" ❌ Timezone is incorrect: {timezone_str}")
|
||||||
|
|
||||||
|
# Show UTC offset
|
||||||
|
if hasattr(first_timestamp, 'utcoffset') and first_timestamp.utcoffset() is not None:
|
||||||
|
offset_hours = first_timestamp.utcoffset().total_seconds() / 3600
|
||||||
|
print(f" UTC offset: {offset_hours:+.0f} hours")
|
||||||
|
|
||||||
|
if offset_hours == 2 or offset_hours == 3: # EET (+2) or EEST (+3)
|
||||||
|
print(" ✅ UTC offset is correct for Europe/Sofia")
|
||||||
|
else:
|
||||||
|
print(f" ❌ UTC offset is incorrect: {offset_hours:+.0f} hours")
|
||||||
|
|
||||||
|
# Compare with UTC time
|
||||||
|
print("\n Timestamp comparison:")
|
||||||
|
for i in range(min(2, len(df))):
|
||||||
|
row = df.iloc[i]
|
||||||
|
local_time = row['timestamp']
|
||||||
|
utc_time = local_time.astimezone(pd.Timestamp.now(tz='UTC').tz)
|
||||||
|
|
||||||
|
print(f" Local (Sofia): {local_time}")
|
||||||
|
print(f" UTC: {utc_time}")
|
||||||
|
print(f" Difference: {(local_time - utc_time).total_seconds() / 3600:+.0f} hours")
|
||||||
|
print()
|
||||||
|
else:
|
||||||
|
print(" ❌ No timestamp column found")
|
||||||
|
else:
|
||||||
|
print(" ❌ No data returned from Binance API")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Error fetching from Binance: {e}")
|
||||||
|
|
||||||
|
# Test MEXC API call as well
|
||||||
|
print("\n3. Testing MEXC API call:")
|
||||||
|
try:
|
||||||
|
df = data_provider._fetch_from_mexc('ETH/USDT', '1h', 3)
|
||||||
|
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
print(f" ✅ Got {len(df)} candles from MEXC API")
|
||||||
|
|
||||||
|
# Check timezone
|
||||||
|
if 'timestamp' in df.columns:
|
||||||
|
first_timestamp = df['timestamp'].iloc[0]
|
||||||
|
print(f" First timestamp: {first_timestamp}")
|
||||||
|
|
||||||
|
# Check timezone
|
||||||
|
if hasattr(first_timestamp, 'tz') and first_timestamp.tz is not None:
|
||||||
|
timezone_str = str(first_timestamp.tz)
|
||||||
|
print(f" Timezone: {timezone_str}")
|
||||||
|
|
||||||
|
if 'Europe/Sofia' in timezone_str or 'EET' in timezone_str or 'EEST' in timezone_str:
|
||||||
|
print(f" ✅ MEXC timezone is correct: {timezone_str}")
|
||||||
|
else:
|
||||||
|
print(f" ❌ MEXC timezone is incorrect: {timezone_str}")
|
||||||
|
|
||||||
|
# Show UTC offset
|
||||||
|
if hasattr(first_timestamp, 'utcoffset') and first_timestamp.utcoffset() is not None:
|
||||||
|
offset_hours = first_timestamp.utcoffset().total_seconds() / 3600
|
||||||
|
print(f" UTC offset: {offset_hours:+.0f} hours")
|
||||||
|
else:
|
||||||
|
print(" ❌ No data returned from MEXC API")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Error fetching from MEXC: {e}")
|
||||||
|
|
||||||
|
# Show current timezone info
|
||||||
|
print(f"\n4. Current timezone information:")
|
||||||
|
import pytz
|
||||||
|
sofia_tz = pytz.timezone('Europe/Sofia')
|
||||||
|
current_sofia = datetime.now(sofia_tz)
|
||||||
|
current_utc = datetime.now(pytz.UTC)
|
||||||
|
|
||||||
|
print(f" Current Sofia time: {current_sofia}")
|
||||||
|
print(f" Current UTC time: {current_utc}")
|
||||||
|
print(f" Time difference: {(current_sofia - current_utc).total_seconds() / 3600:+.0f} hours")
|
||||||
|
|
||||||
|
# Check if it's summer time (EEST) or winter time (EET)
|
||||||
|
offset_hours = current_sofia.utcoffset().total_seconds() / 3600
|
||||||
|
if offset_hours == 3:
|
||||||
|
print(" ✅ Currently in EEST (Eastern European Summer Time)")
|
||||||
|
elif offset_hours == 2:
|
||||||
|
print(" ✅ Currently in EET (Eastern European Time)")
|
||||||
|
else:
|
||||||
|
print(f" ❌ Unexpected offset: {offset_hours:+.0f} hours")
|
||||||
|
|
||||||
|
print("\n✅ Timezone fix test with data completed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_timezone_with_data())
|
Reference in New Issue
Block a user