device fix , TZ fix
This commit is contained in:
@ -1478,7 +1478,11 @@ class DataProvider:
|
||||
# Check for cached data and determine what we need to fetch
|
||||
cached_data = self._load_monthly_data_from_cache(symbol)
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
import pytz
|
||||
utc = pytz.UTC
|
||||
sofia_tz = pytz.timezone('Europe/Sofia')
|
||||
|
||||
end_time = datetime.utcnow().replace(tzinfo=utc).astimezone(sofia_tz)
|
||||
start_time = end_time - timedelta(days=30)
|
||||
|
||||
if cached_data is not None and not cached_data.empty:
|
||||
@ -1496,6 +1500,12 @@ class DataProvider:
|
||||
# Check if we need to fill gaps
|
||||
gap_start = cache_end + timedelta(minutes=1)
|
||||
|
||||
# Ensure gap_start has same timezone as end_time for comparison
|
||||
if gap_start.tzinfo is None:
|
||||
gap_start = sofia_tz.localize(gap_start)
|
||||
elif gap_start.tzinfo != sofia_tz:
|
||||
gap_start = gap_start.astimezone(sofia_tz)
|
||||
|
||||
if gap_start < end_time:
|
||||
# Need to fill gap from cache_end to now
|
||||
logger.info(f"Filling gap from {gap_start} to {end_time}")
|
||||
@ -1573,8 +1583,10 @@ class DataProvider:
|
||||
'taker_buy_quote', 'ignore'
|
||||
])
|
||||
|
||||
# Process columns
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||
# Process columns with proper timezone handling
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
# Convert from UTC to Europe/Sofia timezone to match cached data
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
df[col] = df[col].astype(float)
|
||||
|
||||
@ -1644,8 +1656,10 @@ class DataProvider:
|
||||
'taker_buy_quote', 'ignore'
|
||||
])
|
||||
|
||||
# Process columns
|
||||
batch_df['timestamp'] = pd.to_datetime(batch_df['timestamp'], unit='ms')
|
||||
# Process columns with proper timezone handling
|
||||
batch_df['timestamp'] = pd.to_datetime(batch_df['timestamp'], unit='ms', utc=True)
|
||||
# Convert from UTC to Europe/Sofia timezone to match cached data
|
||||
batch_df['timestamp'] = batch_df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
batch_df[col] = batch_df[col].astype(float)
|
||||
|
||||
|
@ -2377,11 +2377,10 @@ class TradingOrchestrator:
|
||||
return None
|
||||
|
||||
async def _train_cnn_model(self, model, model_name: str, record: Dict, prediction: Dict, reward: float) -> bool:
|
||||
"""Train CNN model with training samples"""
|
||||
"""Train CNN model directly (no adapter)"""
|
||||
try:
|
||||
# Direct CNN model training (no adapter)
|
||||
if hasattr(self, 'cnn_model') and self.cnn_model and 'cnn' in model_name.lower():
|
||||
try:
|
||||
# Direct CNN model training (no adapter)
|
||||
if hasattr(self, 'cnn_model') and self.cnn_model and 'cnn' in model_name.lower():
|
||||
symbol = record.get('symbol', 'ETH/USDT')
|
||||
actual_action = prediction['action']
|
||||
|
||||
@ -2441,47 +2440,36 @@ class TradingOrchestrator:
|
||||
else:
|
||||
logger.warning(f"No model input available for CNN training")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in direct CNN training: {e}")
|
||||
return False
|
||||
|
||||
# Try direct model training methods
|
||||
# Try model interface training methods
|
||||
elif hasattr(model, 'add_training_sample'):
|
||||
symbol = record.get('symbol', 'ETH/USDT')
|
||||
actual_action = prediction['action']
|
||||
model.add_training_sample(symbol, actual_action, reward)
|
||||
logger.debug(f"Added training sample to {model_name}: action={actual_action}, reward={reward:.3f}")
|
||||
|
||||
# Trigger training if batch size is met
|
||||
if hasattr(model, 'train') and hasattr(model, 'training_data') and hasattr(model, 'batch_size'):
|
||||
if len(model.training_data) >= model.batch_size:
|
||||
# If model has train method, trigger training
|
||||
if hasattr(model, 'train') and callable(getattr(model, 'train')):
|
||||
try:
|
||||
training_start_time = time.time()
|
||||
training_results = model.train(epochs=1)
|
||||
training_duration_ms = (time.time() - training_start_time) * 1000
|
||||
|
||||
if training_results and 'loss' in training_results:
|
||||
current_loss = training_results['loss']
|
||||
accuracy = training_results.get('accuracy', 0.0)
|
||||
|
||||
# Validate training results
|
||||
if accuracy >= 0.99:
|
||||
logger.warning(f"CNN training shows suspiciously high accuracy: {accuracy:.4f} - possible overfitting")
|
||||
else:
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
|
||||
logger.debug(f"CNN training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms")
|
||||
return True
|
||||
logger.debug(f"Model {model_name} training completed: loss={current_loss:.4f}")
|
||||
else:
|
||||
# Still update training statistics even if no loss returned
|
||||
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
|
||||
return True # Sample added successfully
|
||||
except Exception as e:
|
||||
logger.error(f"Error training {model_name}: {e}")
|
||||
|
||||
return True
|
||||
|
||||
# Try basic training method for EnhancedCNN
|
||||
# Basic acknowledgment for other training methods
|
||||
elif hasattr(model, 'train'):
|
||||
logger.debug(f"Using basic train method for {model_name}")
|
||||
# For now, just acknowledge that training was attempted
|
||||
logger.debug(f"CNN model {model_name} training acknowledged (basic train method available)")
|
||||
return True
|
||||
|
||||
@ -2622,7 +2610,8 @@ class TradingOrchestrator:
|
||||
|
||||
# Convert to tensor and ensure proper device placement
|
||||
device = next(self.cnn_model.parameters()).device
|
||||
features_tensor = torch.tensor(features, dtype=torch.float32, device=device)
|
||||
import torch as torch_module # Explicit import to avoid scoping issues
|
||||
features_tensor = torch_module.tensor(features, dtype=torch_module.float32, device=device)
|
||||
|
||||
# Ensure batch dimension
|
||||
if features_tensor.dim() == 1:
|
||||
@ -2632,12 +2621,12 @@ class TradingOrchestrator:
|
||||
self.cnn_model.eval()
|
||||
|
||||
# Get prediction from CNN model
|
||||
with torch.no_grad():
|
||||
with torch_module.no_grad():
|
||||
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.cnn_model(features_tensor)
|
||||
|
||||
# Convert to probabilities using softmax
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action_idx = torch.argmax(action_probs, dim=1).item()
|
||||
action_probs = torch_module.softmax(q_values, dim=1)
|
||||
action_idx = torch_module.argmax(action_probs, dim=1).item()
|
||||
confidence = float(action_probs[0, action_idx].item())
|
||||
|
||||
# Map action index to action string
|
||||
@ -2679,9 +2668,9 @@ class TradingOrchestrator:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Fallback to direct model inference using BaseDataInput (unified approach)
|
||||
# Remove this fallback - direct CNN inference should work above
|
||||
if not predictions:
|
||||
logger.warning(f"CNN adapter failed for {symbol}, trying direct model inference with BaseDataInput")
|
||||
logger.debug(f"No CNN predictions generated for {symbol} - this is expected if CNN model is not properly initialized")
|
||||
|
||||
try:
|
||||
# Use the already available base_data (no need to rebuild)
|
||||
@ -2694,10 +2683,9 @@ class TradingOrchestrator:
|
||||
|
||||
# Use the model's act method with unified input
|
||||
if hasattr(model.model, 'act'):
|
||||
# Convert to tensor format expected by enhanced_cnn
|
||||
import torch
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
features_tensor = torch.tensor(feature_vector, dtype=torch.float32, device=device)
|
||||
# Convert to tensor format expected by enhanced_cnn
|
||||
device = torch_module.device('cuda' if torch_module.cuda.is_available() else 'cpu')
|
||||
features_tensor = torch_module.tensor(feature_vector, dtype=torch_module.float32, device=device)
|
||||
|
||||
# Call the model's act method
|
||||
action_idx, confidence, action_probs = model.model.act(features_tensor, explore=False)
|
||||
@ -3263,9 +3251,6 @@ class TradingOrchestrator:
|
||||
if not self.decision_fusion_enabled:
|
||||
return
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Create decision fusion network
|
||||
class DecisionFusionNet(nn.Module):
|
||||
def __init__(self, input_size=32, hidden_size=64):
|
||||
|
Reference in New Issue
Block a user