Compare commits
3 Commits
9219b78241
...
44821b2a89
Author | SHA1 | Date | |
---|---|---|---|
44821b2a89 | |||
25b2d3840a | |||
fb72c93743 |
@ -1164,35 +1164,23 @@ class DQNAgent:
|
||||
# Check if state is a dict or complex object
|
||||
if isinstance(state, dict):
|
||||
logger.error(f"State is a dict: {state}")
|
||||
|
||||
# Handle empty dictionary case
|
||||
if not state:
|
||||
logger.error("No numerical values found in state dict, using default state")
|
||||
expected_size = getattr(self, 'state_size', 403)
|
||||
if isinstance(expected_size, tuple):
|
||||
expected_size = np.prod(expected_size)
|
||||
return np.zeros(int(expected_size), dtype=np.float32)
|
||||
|
||||
# Extract numerical values from dict if possible
|
||||
if 'features' in state:
|
||||
state = state['features']
|
||||
elif 'state' in state:
|
||||
state = state['state']
|
||||
else:
|
||||
# Try to extract all numerical values
|
||||
numerical_values = []
|
||||
for key, value in state.items():
|
||||
if isinstance(value, (int, float)):
|
||||
numerical_values.append(float(value))
|
||||
elif isinstance(value, (list, np.ndarray)):
|
||||
try:
|
||||
# Handle nested structures safely
|
||||
flattened = np.array(value).flatten()
|
||||
for x in flattened:
|
||||
if isinstance(x, (int, float)):
|
||||
numerical_values.append(float(x))
|
||||
elif hasattr(x, 'item'): # numpy scalar
|
||||
numerical_values.append(float(x.item()))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif isinstance(value, dict):
|
||||
# Recursively extract from nested dicts
|
||||
try:
|
||||
nested_values = self._extract_numeric_from_dict(value)
|
||||
numerical_values.extend(nested_values)
|
||||
except Exception:
|
||||
continue
|
||||
# Try to extract all numerical values using the helper method
|
||||
numerical_values = self._extract_numeric_from_dict(state)
|
||||
if numerical_values:
|
||||
state = np.array(numerical_values, dtype=np.float32)
|
||||
else:
|
||||
@ -1254,6 +1242,31 @@ class DQNAgent:
|
||||
expected_size = np.prod(expected_size)
|
||||
return np.zeros(int(expected_size), dtype=np.float32)
|
||||
|
||||
def _extract_numeric_from_dict(self, data_dict):
|
||||
"""Recursively extract numerical values from nested dictionaries"""
|
||||
numerical_values = []
|
||||
try:
|
||||
for key, value in data_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
numerical_values.append(float(value))
|
||||
elif isinstance(value, (list, np.ndarray)):
|
||||
try:
|
||||
flattened = np.array(value).flatten()
|
||||
for x in flattened:
|
||||
if isinstance(x, (int, float)):
|
||||
numerical_values.append(float(x))
|
||||
elif hasattr(x, 'item'): # numpy scalar
|
||||
numerical_values.append(float(x.item()))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif isinstance(value, dict):
|
||||
# Recursively extract from nested dicts
|
||||
nested_values = self._extract_numeric_from_dict(value)
|
||||
numerical_values.extend(nested_values)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting numeric values from dict: {e}")
|
||||
return numerical_values
|
||||
|
||||
def _replay_standard(self, states, actions, rewards, next_states, dones):
|
||||
"""Standard training step without mixed precision"""
|
||||
try:
|
||||
|
@ -36,6 +36,12 @@ import math
|
||||
# Suppress ta library deprecation warnings
|
||||
warnings.filterwarnings("ignore", category=FutureWarning, module="ta")
|
||||
|
||||
# Import timezone utilities
|
||||
from utils.timezone_utils import (
|
||||
normalize_timestamp, normalize_dataframe_timestamps, normalize_dataframe_index,
|
||||
now_system, now_utc, to_sofia, UTC, SOFIA_TZ, log_timezone_info
|
||||
)
|
||||
|
||||
from .config import get_config
|
||||
from .tick_aggregator import RealTimeTickAggregator, RawTick, OHLCVBar
|
||||
from .cnn_monitor import log_cnn_prediction
|
||||
@ -83,6 +89,13 @@ class PivotBounds:
|
||||
distances = [abs(current_price - r) for r in self.pivot_resistance_levels]
|
||||
return min(distances) / self.get_price_range()
|
||||
|
||||
@dataclass
|
||||
class SimplePivotLevel:
|
||||
"""Simple pivot level structure for fallback pivot detection"""
|
||||
swing_points: List[Any] = field(default_factory=list)
|
||||
support_levels: List[float] = field(default_factory=list)
|
||||
resistance_levels: List[float] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class MarketTick:
|
||||
"""Standardized market tick data structure"""
|
||||
@ -127,6 +140,10 @@ class DataProvider:
|
||||
self.real_time_data = {} # {symbol: {timeframe: deque}}
|
||||
self.current_prices = {} # {symbol: float}
|
||||
|
||||
# Live price cache for low-latency price updates
|
||||
self.live_price_cache: Dict[str, Tuple[float, datetime]] = {}
|
||||
self.live_price_cache_ttl = timedelta(milliseconds=500)
|
||||
|
||||
# Initialize cached data structure
|
||||
for symbol in self.symbols:
|
||||
self.cached_data[symbol] = {}
|
||||
@ -461,7 +478,7 @@ class DataProvider:
|
||||
# Create raw tick entry
|
||||
raw_tick = {
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.utcnow(),
|
||||
'timestamp': now_system(), # Use system timezone consistently
|
||||
'bids': actual_data.get('bids', [])[:50], # Top 50 levels
|
||||
'asks': actual_data.get('asks', [])[:50], # Top 50 levels
|
||||
'stats': actual_data.get('stats', {}),
|
||||
@ -1086,8 +1103,7 @@ class DataProvider:
|
||||
|
||||
# Process columns with proper timezone handling (MEXC returns UTC timestamps)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
# Convert from UTC to Europe/Sofia timezone
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||
# Keep in UTC to match COB WebSocket data (no timezone conversion)
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
df[col] = df[col].astype(float)
|
||||
|
||||
@ -1133,18 +1149,16 @@ class DataProvider:
|
||||
if isinstance(timestamp, (int, float)):
|
||||
import pytz
|
||||
utc = pytz.UTC
|
||||
sofia_tz = pytz.timezone('Europe/Sofia')
|
||||
tick_time = datetime.fromtimestamp(timestamp, tz=utc)
|
||||
tick_time = tick_time.astimezone(sofia_tz)
|
||||
# Keep in UTC to match COB WebSocket data
|
||||
elif isinstance(timestamp, datetime):
|
||||
import pytz
|
||||
sofia_tz = pytz.timezone('Europe/Sofia')
|
||||
utc = pytz.UTC
|
||||
tick_time = timestamp
|
||||
# If no timezone info, assume UTC and convert to Europe/Sofia
|
||||
# If no timezone info, assume UTC and keep in UTC
|
||||
if tick_time.tzinfo is None:
|
||||
utc = pytz.UTC
|
||||
tick_time = utc.localize(tick_time)
|
||||
tick_time = tick_time.astimezone(sofia_tz)
|
||||
# Keep in UTC (no timezone conversion)
|
||||
else:
|
||||
continue
|
||||
|
||||
@ -1184,15 +1198,15 @@ class DataProvider:
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(candles)
|
||||
# Ensure timestamps are timezone-aware (Europe/Sofia)
|
||||
# Ensure timestamps are timezone-aware (UTC to match COB WebSocket data)
|
||||
if not df.empty and 'timestamp' in df.columns:
|
||||
import pytz
|
||||
sofia_tz = pytz.timezone('Europe/Sofia')
|
||||
# If timestamps are not timezone-aware, make them Europe/Sofia
|
||||
utc = pytz.UTC
|
||||
# If timestamps are not timezone-aware, make them UTC
|
||||
if df['timestamp'].dt.tz is None:
|
||||
df['timestamp'] = df['timestamp'].dt.tz_localize(sofia_tz)
|
||||
df['timestamp'] = df['timestamp'].dt.tz_localize(utc)
|
||||
else:
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert(sofia_tz)
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert(utc)
|
||||
|
||||
df = df.sort_values('timestamp').reset_index(drop=True)
|
||||
|
||||
@ -1272,8 +1286,8 @@ class DataProvider:
|
||||
|
||||
# Process columns with proper timezone handling (Binance returns UTC timestamps)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
# Convert from UTC to Europe/Sofia timezone
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||
# Keep in UTC to match COB WebSocket data (no timezone conversion)
|
||||
# This prevents the 3-hour gap when appending live COB data
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
df[col] = df[col].astype(float)
|
||||
|
||||
@ -1480,9 +1494,8 @@ class DataProvider:
|
||||
|
||||
import pytz
|
||||
utc = pytz.UTC
|
||||
sofia_tz = pytz.timezone('Europe/Sofia')
|
||||
|
||||
end_time = datetime.utcnow().replace(tzinfo=utc).astimezone(sofia_tz)
|
||||
end_time = datetime.utcnow().replace(tzinfo=utc)
|
||||
start_time = end_time - timedelta(days=30)
|
||||
|
||||
if cached_data is not None and not cached_data.empty:
|
||||
@ -1585,8 +1598,7 @@ class DataProvider:
|
||||
|
||||
# Process columns with proper timezone handling
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
# Convert from UTC to Europe/Sofia timezone to match cached data
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||
# Keep in UTC to match COB WebSocket data (no timezone conversion)
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
df[col] = df[col].astype(float)
|
||||
|
||||
@ -1658,8 +1670,7 @@ class DataProvider:
|
||||
|
||||
# Process columns with proper timezone handling
|
||||
batch_df['timestamp'] = pd.to_datetime(batch_df['timestamp'], unit='ms', utc=True)
|
||||
# Convert from UTC to Europe/Sofia timezone to match cached data
|
||||
batch_df['timestamp'] = batch_df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||
# Keep in UTC to match COB WebSocket data (no timezone conversion)
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
batch_df[col] = batch_df[col].astype(float)
|
||||
|
||||
@ -1839,14 +1850,14 @@ class DataProvider:
|
||||
low_pivots = monthly_data[lows == rolling_min]['low'].tolist()
|
||||
pivot_lows.extend(low_pivots)
|
||||
|
||||
# Create mock level structure
|
||||
mock_level = type('MockLevel', (), {
|
||||
'swing_points': [],
|
||||
'support_levels': list(set(pivot_lows)),
|
||||
'resistance_levels': list(set(pivot_highs))
|
||||
})()
|
||||
# Create proper pivot level structure
|
||||
pivot_level = SimplePivotLevel(
|
||||
swing_points=[],
|
||||
support_levels=list(set(pivot_lows)),
|
||||
resistance_levels=list(set(pivot_highs))
|
||||
)
|
||||
|
||||
return {'level_0': mock_level}
|
||||
return {'level_0': pivot_level}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in simple pivot detection: {e}")
|
||||
@ -2022,15 +2033,14 @@ class DataProvider:
|
||||
if cache_file.exists():
|
||||
try:
|
||||
df = pd.read_parquet(cache_file)
|
||||
# Ensure cached monthly data has proper timezone (Europe/Sofia)
|
||||
# Ensure cached monthly data has proper timezone (UTC to match COB WebSocket data)
|
||||
if not df.empty and 'timestamp' in df.columns:
|
||||
if df['timestamp'].dt.tz is None:
|
||||
# If no timezone info, assume UTC and convert to Europe/Sofia
|
||||
# If no timezone info, assume UTC and keep in UTC
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||
elif str(df['timestamp'].dt.tz) != 'Europe/Sofia':
|
||||
# Convert to Europe/Sofia if different timezone
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||
elif str(df['timestamp'].dt.tz) != 'UTC':
|
||||
# Convert to UTC if different timezone
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('UTC')
|
||||
logger.info(f"Loaded {len(df)} 1m candles from cache for {symbol}")
|
||||
return df
|
||||
except Exception as parquet_e:
|
||||
@ -2306,15 +2316,14 @@ class DataProvider:
|
||||
if cache_age < max_age:
|
||||
try:
|
||||
df = pd.read_parquet(cache_file)
|
||||
# Ensure cached data has proper timezone (Europe/Sofia)
|
||||
# Ensure cached data has proper timezone (UTC to match COB WebSocket data)
|
||||
if not df.empty and 'timestamp' in df.columns:
|
||||
if df['timestamp'].dt.tz is None:
|
||||
# If no timezone info, assume UTC and convert to Europe/Sofia
|
||||
# If no timezone info, assume UTC and keep in UTC
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||
elif str(df['timestamp'].dt.tz) != 'Europe/Sofia':
|
||||
# Convert to Europe/Sofia if different timezone
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('Europe/Sofia')
|
||||
elif str(df['timestamp'].dt.tz) != 'UTC':
|
||||
# Convert to UTC if different timezone
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('UTC')
|
||||
logger.debug(f"Loaded {len(df)} rows from cache for {symbol} {timeframe} (age: {cache_age/60:.1f}min)")
|
||||
return df
|
||||
except Exception as parquet_e:
|
||||
|
@ -1062,10 +1062,11 @@ class MultiExchangeCOBProvider:
|
||||
consolidated_bids[price].exchange_breakdown[exchange_name] = level
|
||||
|
||||
# Update dominant exchange based on volume
|
||||
if level.volume_usd > consolidated_bids[price].exchange_breakdown.get(
|
||||
consolidated_bids[price].dominant_exchange,
|
||||
type('obj', (object,), {'volume_usd': 0})()
|
||||
).volume_usd:
|
||||
current_dominant = consolidated_bids[price].exchange_breakdown.get(
|
||||
consolidated_bids[price].dominant_exchange
|
||||
)
|
||||
current_volume = current_dominant.volume_usd if current_dominant else 0
|
||||
if level.volume_usd > current_volume:
|
||||
consolidated_bids[price].dominant_exchange = exchange_name
|
||||
|
||||
# Process merged asks (similar logic)
|
||||
@ -1088,10 +1089,11 @@ class MultiExchangeCOBProvider:
|
||||
consolidated_asks[price].total_orders += level.orders_count
|
||||
consolidated_asks[price].exchange_breakdown[exchange_name] = level
|
||||
|
||||
if level.volume_usd > consolidated_asks[price].exchange_breakdown.get(
|
||||
consolidated_asks[price].dominant_exchange,
|
||||
type('obj', (object,), {'volume_usd': 0})()
|
||||
).volume_usd:
|
||||
current_dominant = consolidated_asks[price].exchange_breakdown.get(
|
||||
consolidated_asks[price].dominant_exchange
|
||||
)
|
||||
current_volume = current_dominant.volume_usd if current_dominant else 0
|
||||
if level.volume_usd > current_volume:
|
||||
consolidated_asks[price].dominant_exchange = exchange_name
|
||||
|
||||
logger.debug(f"Consolidated {len(consolidated_bids)} bids and {len(consolidated_asks)} asks for {symbol}")
|
||||
|
@ -1493,6 +1493,17 @@ class TradingOrchestrator:
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for predictions: {symbol}")
|
||||
return predictions
|
||||
|
||||
# Validate base_data has proper feature vector
|
||||
if hasattr(base_data, 'get_feature_vector'):
|
||||
try:
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
if feature_vector is None or (isinstance(feature_vector, np.ndarray) and feature_vector.size == 0):
|
||||
logger.warning(f"BaseDataInput has empty feature vector for {symbol}")
|
||||
return predictions
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting feature vector from BaseDataInput for {symbol}: {e}")
|
||||
return predictions
|
||||
|
||||
# log all registered models
|
||||
logger.debug(f"inferencing registered models: {self.model_registry.models}")
|
||||
@ -1691,6 +1702,15 @@ class TradingOrchestrator:
|
||||
try:
|
||||
logger.debug(f"Storing inference for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
|
||||
|
||||
# Validate model_input before storing
|
||||
if model_input is None:
|
||||
logger.warning(f"Skipping inference storage for {model_name}: model_input is None")
|
||||
return
|
||||
|
||||
if isinstance(model_input, dict) and not model_input:
|
||||
logger.warning(f"Skipping inference storage for {model_name}: model_input is empty dict")
|
||||
return
|
||||
|
||||
# Extract symbol from prediction if not provided
|
||||
if symbol is None:
|
||||
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
|
||||
@ -2569,6 +2589,12 @@ class TradingOrchestrator:
|
||||
|
||||
# Method 3: Dictionary with feature data
|
||||
if isinstance(model_input, dict):
|
||||
# Check if dictionary is empty - this is the main issue!
|
||||
if not model_input:
|
||||
logger.warning(f"Empty dictionary passed as model_input for {model_name}, using data provider fallback")
|
||||
# Use data provider to build proper state as fallback
|
||||
return self._generate_fresh_state_fallback(model_name)
|
||||
|
||||
# Try to extract features from dictionary
|
||||
if 'features' in model_input:
|
||||
features = model_input['features']
|
||||
@ -2589,6 +2615,9 @@ class TradingOrchestrator:
|
||||
|
||||
if feature_list:
|
||||
return np.array(feature_list, dtype=np.float32)
|
||||
else:
|
||||
logger.warning(f"No numerical features found in dictionary for {model_name}, using data provider fallback")
|
||||
return self._generate_fresh_state_fallback(model_name)
|
||||
|
||||
# Method 4: List or tuple
|
||||
if isinstance(model_input, (list, tuple)):
|
||||
@ -2601,24 +2630,57 @@ class TradingOrchestrator:
|
||||
if isinstance(model_input, (int, float)):
|
||||
return np.array([model_input], dtype=np.float32)
|
||||
|
||||
# Method 6: Try to use data provider to build state
|
||||
if hasattr(self, 'data_provider'):
|
||||
try:
|
||||
base_data = self.data_provider.build_base_data_input('ETH/USDT')
|
||||
if base_data and hasattr(base_data, 'get_feature_vector'):
|
||||
state = base_data.get_feature_vector()
|
||||
if isinstance(state, np.ndarray):
|
||||
logger.debug(f"Used data provider fallback for {model_name}")
|
||||
return state
|
||||
except Exception as e:
|
||||
logger.debug(f"Data provider fallback failed for {model_name}: {e}")
|
||||
|
||||
logger.warning(f"Cannot convert model_input to RL state for {model_name}: {type(model_input)}")
|
||||
return None
|
||||
# Method 6: Final fallback - generate fresh state
|
||||
logger.warning(f"Cannot convert model_input to RL state for {model_name}: {type(model_input)}, using fresh state fallback")
|
||||
return self._generate_fresh_state_fallback(model_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting model_input to RL state for {model_name}: {e}")
|
||||
return None
|
||||
return self._generate_fresh_state_fallback(model_name)
|
||||
|
||||
def _generate_fresh_state_fallback(self, model_name: str) -> np.ndarray:
|
||||
"""Generate a fresh state from current market data when model_input is empty/invalid"""
|
||||
try:
|
||||
# Try to use data provider to build fresh state
|
||||
if hasattr(self, 'data_provider') and self.data_provider:
|
||||
try:
|
||||
# Build fresh BaseDataInput with current market data
|
||||
base_data = self.data_provider.build_base_data_input('ETH/USDT')
|
||||
if base_data and hasattr(base_data, 'get_feature_vector'):
|
||||
state = base_data.get_feature_vector()
|
||||
if isinstance(state, np.ndarray) and state.size > 0:
|
||||
logger.info(f"Generated fresh state for {model_name} from data provider: shape={state.shape}")
|
||||
return state
|
||||
except Exception as e:
|
||||
logger.debug(f"Data provider fresh state generation failed for {model_name}: {e}")
|
||||
|
||||
# Try to get state from model registry
|
||||
if hasattr(self, 'model_registry') and self.model_registry:
|
||||
try:
|
||||
model_interface = self.model_registry.models.get(model_name)
|
||||
if model_interface and hasattr(model_interface, 'get_current_state'):
|
||||
state = model_interface.get_current_state()
|
||||
if isinstance(state, np.ndarray) and state.size > 0:
|
||||
logger.info(f"Generated fresh state for {model_name} from model interface: shape={state.shape}")
|
||||
return state
|
||||
except Exception as e:
|
||||
logger.debug(f"Model interface fresh state generation failed for {model_name}: {e}")
|
||||
|
||||
# Final fallback: create a reasonable default state with proper dimensions
|
||||
# Use the expected state size for DQN models (403 features)
|
||||
default_state_size = 403
|
||||
if 'cnn' in model_name.lower():
|
||||
default_state_size = 500 # Larger for CNN models
|
||||
elif 'cob' in model_name.lower():
|
||||
default_state_size = 2000 # Much larger for COB models
|
||||
|
||||
logger.warning(f"Using default zero state for {model_name} with size {default_state_size}")
|
||||
return np.zeros(default_state_size, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating fresh state fallback for {model_name}: {e}")
|
||||
# Ultimate fallback
|
||||
return np.zeros(403, dtype=np.float32)
|
||||
|
||||
async def _train_cnn_model(self, model, model_name: str, record: Dict, prediction: Dict, reward: float) -> bool:
|
||||
"""Train CNN model directly (no adapter)"""
|
||||
@ -3744,6 +3806,35 @@ class TradingOrchestrator:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting training dashboard: {e}")
|
||||
|
||||
def set_cold_start_training_enabled(self, enabled: bool) -> bool:
|
||||
"""Enable or disable cold start training (excessive training during cold start)
|
||||
|
||||
Args:
|
||||
enabled: Whether to enable cold start training
|
||||
|
||||
Returns:
|
||||
bool: True if setting was applied successfully
|
||||
"""
|
||||
try:
|
||||
# Store the setting
|
||||
self.cold_start_enabled = enabled
|
||||
|
||||
# Adjust training frequency based on cold start mode
|
||||
if enabled:
|
||||
# High frequency training during cold start
|
||||
self.training_frequency = 'high'
|
||||
logger.info("ORCHESTRATOR: Cold start training ENABLED - Excessive training on every signal")
|
||||
else:
|
||||
# Normal training frequency
|
||||
self.training_frequency = 'normal'
|
||||
logger.info("ORCHESTRATOR: Cold start training DISABLED - Normal training frequency")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting cold start training: {e}")
|
||||
return False
|
||||
|
||||
def get_universal_data_stream(self, current_time: Optional[datetime] = None):
|
||||
"""Get universal data stream for external consumers like dashboard - DELEGATED to data provider"""
|
||||
|
@ -1247,27 +1247,23 @@ class TradingExecutor:
|
||||
taker_fee_rate = trading_fees.get('taker_fee', trading_fees.get('default_fee', 0.0006))
|
||||
simulated_fees = position.quantity * current_price * taker_fee_rate
|
||||
|
||||
# Calculate P&L for short position and hold time
|
||||
pnl = position.calculate_pnl(current_price)
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
# Get current leverage setting from dashboard or config
|
||||
# Get current leverage setting
|
||||
leverage = self.get_leverage()
|
||||
|
||||
# Calculate position size in USD
|
||||
position_size_usd = position.quantity * position.entry_price
|
||||
|
||||
# Calculate gross PnL (before fees) with leverage
|
||||
if position.side == 'SHORT':
|
||||
gross_pnl = (position.entry_price - current_price) * position.quantity * leverage
|
||||
else: # LONG
|
||||
gross_pnl = (current_price - position.entry_price) * position.quantity * leverage
|
||||
gross_pnl = (current_price - position.entry_price) * position.quantity * leverage
|
||||
|
||||
# Calculate net PnL (after fees)
|
||||
net_pnl = gross_pnl - simulated_fees
|
||||
|
||||
# Create trade record with enhanced PnL calculations
|
||||
# Calculate hold time
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
# Create trade record with corrected PnL calculations
|
||||
trade_record = TradeRecord(
|
||||
symbol=symbol,
|
||||
side='SHORT',
|
||||
@ -1287,16 +1283,16 @@ class TradingExecutor:
|
||||
)
|
||||
|
||||
self.trade_history.append(trade_record)
|
||||
self.trade_records.append(trade_record) # Add to trade records for success rate tracking
|
||||
self.daily_loss += max(0, -pnl) # Add to daily loss if negative
|
||||
self.trade_records.append(trade_record)
|
||||
self.daily_loss += max(0, -net_pnl) # Use net_pnl instead of pnl
|
||||
|
||||
# Adjust profitability reward multiplier based on recent performance
|
||||
self._adjust_profitability_reward_multiplier()
|
||||
|
||||
# Update consecutive losses
|
||||
if pnl < -0.001: # A losing trade
|
||||
# Update consecutive losses using net_pnl
|
||||
if net_pnl < -0.001: # A losing trade
|
||||
self.consecutive_losses += 1
|
||||
elif pnl > 0.001: # A winning trade
|
||||
elif net_pnl > 0.001: # A winning trade
|
||||
self.consecutive_losses = 0
|
||||
else: # Breakeven trade
|
||||
self.consecutive_losses = 0
|
||||
@ -1306,7 +1302,7 @@ class TradingExecutor:
|
||||
self.last_trade_time[symbol] = datetime.now()
|
||||
self.daily_trades += 1
|
||||
|
||||
logger.info(f"Position closed - P&L: ${pnl:.2f}")
|
||||
logger.info(f"SHORT position closed - Gross P&L: ${gross_pnl:.2f}, Net P&L: ${net_pnl:.2f}, Fees: ${simulated_fees:.3f}")
|
||||
return True
|
||||
|
||||
try:
|
||||
@ -1342,27 +1338,23 @@ class TradingExecutor:
|
||||
# Calculate fees using real API data when available
|
||||
fees = self._calculate_real_trading_fees(order, symbol, position.quantity, current_price)
|
||||
|
||||
# Calculate P&L, fees, and hold time
|
||||
pnl = position.calculate_pnl(current_price)
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
# Get current leverage setting from dashboard or config
|
||||
# Get current leverage setting
|
||||
leverage = self.get_leverage()
|
||||
|
||||
# Calculate position size in USD
|
||||
position_size_usd = position.quantity * position.entry_price
|
||||
|
||||
# Calculate gross PnL (before fees) with leverage
|
||||
if position.side == 'SHORT':
|
||||
gross_pnl = (position.entry_price - current_price) * position.quantity * leverage
|
||||
else: # LONG
|
||||
gross_pnl = (current_price - position.entry_price) * position.quantity * leverage
|
||||
gross_pnl = (current_price - position.entry_price) * position.quantity * leverage
|
||||
|
||||
# Calculate net PnL (after fees)
|
||||
net_pnl = gross_pnl - fees
|
||||
|
||||
# Create trade record with enhanced PnL calculations
|
||||
# Calculate hold time
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
# Create trade record with corrected PnL calculations
|
||||
trade_record = TradeRecord(
|
||||
symbol=symbol,
|
||||
side='SHORT',
|
||||
@ -1382,16 +1374,16 @@ class TradingExecutor:
|
||||
)
|
||||
|
||||
self.trade_history.append(trade_record)
|
||||
self.trade_records.append(trade_record) # Add to trade records for success rate tracking
|
||||
self.daily_loss += max(0, -(pnl - fees)) # Add to daily loss if negative
|
||||
self.trade_records.append(trade_record)
|
||||
self.daily_loss += max(0, -net_pnl) # Use net_pnl instead of pnl
|
||||
|
||||
# Adjust profitability reward multiplier based on recent performance
|
||||
self._adjust_profitability_reward_multiplier()
|
||||
|
||||
# Update consecutive losses
|
||||
if pnl < -0.001: # A losing trade
|
||||
# Update consecutive losses using net_pnl
|
||||
if net_pnl < -0.001: # A losing trade
|
||||
self.consecutive_losses += 1
|
||||
elif pnl > 0.001: # A winning trade
|
||||
elif net_pnl > 0.001: # A winning trade
|
||||
self.consecutive_losses = 0
|
||||
else: # Breakeven trade
|
||||
self.consecutive_losses = 0
|
||||
@ -1402,7 +1394,7 @@ class TradingExecutor:
|
||||
self.daily_trades += 1
|
||||
|
||||
logger.info(f"SHORT close order executed: {order}")
|
||||
logger.info(f"SHORT position closed - P&L: ${pnl - fees:.2f}")
|
||||
logger.info(f"SHORT position closed - Gross P&L: ${gross_pnl:.2f}, Net P&L: ${net_pnl:.2f}, Fees: ${fees:.3f}")
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to place SHORT close order")
|
||||
@ -1417,7 +1409,7 @@ class TradingExecutor:
|
||||
if symbol not in self.positions:
|
||||
logger.warning(f"No position to close in {symbol}")
|
||||
return False
|
||||
|
||||
|
||||
position = self.positions[symbol]
|
||||
if position.side != 'LONG':
|
||||
logger.warning(f"Position in {symbol} is not LONG, cannot close with SELL")
|
||||
@ -1429,15 +1421,27 @@ class TradingExecutor:
|
||||
if self.simulation_mode:
|
||||
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Long close logged but not executed")
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
trading_fees = self.exchange_config.get('trading_fees', {})
|
||||
taker_fee_rate = trading_fees.get('taker_fee', trading_fees.get('default_fee', 0.0006))
|
||||
simulated_fees = position.quantity * current_price * taker_fee_rate
|
||||
|
||||
# Calculate P&L for long position and hold time
|
||||
pnl = position.calculate_pnl(current_price)
|
||||
# Get current leverage setting
|
||||
leverage = self.get_leverage()
|
||||
|
||||
# Calculate position size in USD
|
||||
position_size_usd = position.quantity * position.entry_price
|
||||
|
||||
# Calculate gross PnL (before fees) with leverage
|
||||
gross_pnl = (current_price - position.entry_price) * position.quantity * leverage
|
||||
|
||||
# Calculate net PnL (after fees)
|
||||
net_pnl = gross_pnl - simulated_fees
|
||||
|
||||
# Calculate hold time
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
# Create trade record
|
||||
# Create trade record with corrected PnL calculations
|
||||
trade_record = TradeRecord(
|
||||
symbol=symbol,
|
||||
side='LONG',
|
||||
@ -1446,23 +1450,27 @@ class TradingExecutor:
|
||||
exit_price=current_price,
|
||||
entry_time=position.entry_time,
|
||||
exit_time=exit_time,
|
||||
pnl=pnl,
|
||||
pnl=net_pnl, # Store net PnL as the main PnL value
|
||||
fees=simulated_fees,
|
||||
confidence=confidence,
|
||||
hold_time_seconds=hold_time_seconds
|
||||
hold_time_seconds=hold_time_seconds,
|
||||
leverage=leverage,
|
||||
position_size_usd=position_size_usd,
|
||||
gross_pnl=gross_pnl,
|
||||
net_pnl=net_pnl
|
||||
)
|
||||
|
||||
self.trade_history.append(trade_record)
|
||||
self.trade_records.append(trade_record) # Add to trade records for success rate tracking
|
||||
self.daily_loss += max(0, -pnl) # Add to daily loss if negative
|
||||
self.trade_records.append(trade_record)
|
||||
self.daily_loss += max(0, -net_pnl) # Use net_pnl instead of pnl
|
||||
|
||||
# Adjust profitability reward multiplier based on recent performance
|
||||
self._adjust_profitability_reward_multiplier()
|
||||
|
||||
# Update consecutive losses
|
||||
if pnl < -0.001: # A losing trade
|
||||
|
||||
# Update consecutive losses using net_pnl
|
||||
if net_pnl < -0.001: # A losing trade
|
||||
self.consecutive_losses += 1
|
||||
elif pnl > 0.001: # A winning trade
|
||||
elif net_pnl > 0.001: # A winning trade
|
||||
self.consecutive_losses = 0
|
||||
else: # Breakeven trade
|
||||
self.consecutive_losses = 0
|
||||
@ -1472,7 +1480,7 @@ class TradingExecutor:
|
||||
self.last_trade_time[symbol] = datetime.now()
|
||||
self.daily_trades += 1
|
||||
|
||||
logger.info(f"Position closed - P&L: ${pnl:.2f}")
|
||||
logger.info(f"LONG position closed - Gross P&L: ${gross_pnl:.2f}, Net P&L: ${net_pnl:.2f}, Fees: ${simulated_fees:.3f}")
|
||||
return True
|
||||
|
||||
try:
|
||||
@ -1508,12 +1516,23 @@ class TradingExecutor:
|
||||
# Calculate fees using real API data when available
|
||||
fees = self._calculate_real_trading_fees(order, symbol, position.quantity, current_price)
|
||||
|
||||
# Calculate P&L, fees, and hold time
|
||||
pnl = position.calculate_pnl(current_price)
|
||||
# Get current leverage setting
|
||||
leverage = self.get_leverage()
|
||||
|
||||
# Calculate position size in USD
|
||||
position_size_usd = position.quantity * position.entry_price
|
||||
|
||||
# Calculate gross PnL (before fees) with leverage
|
||||
gross_pnl = (current_price - position.entry_price) * position.quantity * leverage
|
||||
|
||||
# Calculate net PnL (after fees)
|
||||
net_pnl = gross_pnl - fees
|
||||
|
||||
# Calculate hold time
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
# Create trade record
|
||||
# Create trade record with corrected PnL calculations
|
||||
trade_record = TradeRecord(
|
||||
symbol=symbol,
|
||||
side='LONG',
|
||||
@ -1522,23 +1541,27 @@ class TradingExecutor:
|
||||
exit_price=current_price,
|
||||
entry_time=position.entry_time,
|
||||
exit_time=exit_time,
|
||||
pnl=pnl - fees,
|
||||
pnl=net_pnl, # Store net PnL as the main PnL value
|
||||
fees=fees,
|
||||
confidence=confidence,
|
||||
hold_time_seconds=hold_time_seconds
|
||||
hold_time_seconds=hold_time_seconds,
|
||||
leverage=leverage,
|
||||
position_size_usd=position_size_usd,
|
||||
gross_pnl=gross_pnl,
|
||||
net_pnl=net_pnl
|
||||
)
|
||||
|
||||
self.trade_history.append(trade_record)
|
||||
self.trade_records.append(trade_record) # Add to trade records for success rate tracking
|
||||
self.daily_loss += max(0, -(pnl - fees)) # Add to daily loss if negative
|
||||
self.trade_records.append(trade_record)
|
||||
self.daily_loss += max(0, -net_pnl) # Use net_pnl instead of pnl
|
||||
|
||||
# Adjust profitability reward multiplier based on recent performance
|
||||
self._adjust_profitability_reward_multiplier()
|
||||
|
||||
# Update consecutive losses
|
||||
if pnl < -0.001: # A losing trade
|
||||
# Update consecutive losses using net_pnl
|
||||
if net_pnl < -0.001: # A losing trade
|
||||
self.consecutive_losses += 1
|
||||
elif pnl > 0.001: # A winning trade
|
||||
elif net_pnl > 0.001: # A winning trade
|
||||
self.consecutive_losses = 0
|
||||
else: # Breakeven trade
|
||||
self.consecutive_losses = 0
|
||||
@ -1549,7 +1572,7 @@ class TradingExecutor:
|
||||
self.daily_trades += 1
|
||||
|
||||
logger.info(f"LONG close order executed: {order}")
|
||||
logger.info(f"LONG position closed - P&L: ${pnl - fees:.2f}")
|
||||
logger.info(f"LONG position closed - Gross P&L: ${gross_pnl:.2f}, Net P&L: ${net_pnl:.2f}, Fees: ${fees:.3f}")
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to place LONG close order")
|
||||
@ -2406,6 +2429,44 @@ class TradingExecutor:
|
||||
else:
|
||||
logger.info("TRADING EXECUTOR: Test mode disabled - normal safety checks active")
|
||||
|
||||
def set_trading_mode(self, mode: str) -> bool:
|
||||
"""Set trading mode (simulation/live) and update all related settings
|
||||
|
||||
Args:
|
||||
mode: Trading mode ('simulation' or 'live')
|
||||
|
||||
Returns:
|
||||
bool: True if mode was set successfully
|
||||
"""
|
||||
try:
|
||||
if mode not in ['simulation', 'live']:
|
||||
logger.error(f"Invalid trading mode: {mode}. Must be 'simulation' or 'live'")
|
||||
return False
|
||||
|
||||
# Store original mode if not already stored
|
||||
if not hasattr(self, 'original_trading_mode'):
|
||||
self.original_trading_mode = self.trading_mode
|
||||
|
||||
# Update trading mode
|
||||
self.trading_mode = mode
|
||||
self.simulation_mode = (mode == 'simulation')
|
||||
|
||||
# Update primary config if available
|
||||
if hasattr(self, 'primary_config') and self.primary_config:
|
||||
self.primary_config['trading_mode'] = mode
|
||||
|
||||
# Log the change
|
||||
if mode == 'live':
|
||||
logger.warning("TRADING EXECUTOR: MODE CHANGED TO LIVE - Real orders will be executed!")
|
||||
else:
|
||||
logger.info("TRADING EXECUTOR: MODE CHANGED TO SIMULATION - Orders are simulated")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting trading mode to {mode}: {e}")
|
||||
return False
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get trading executor status with safety feature information"""
|
||||
try:
|
||||
@ -2731,3 +2792,85 @@ class TradingExecutor:
|
||||
import traceback
|
||||
logger.error(f"CORRECTIVE: Full traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def recalculate_all_trade_records(self):
|
||||
"""Recalculate all existing trade records with correct leverage and PnL"""
|
||||
logger.info("Recalculating all trade records with correct leverage and PnL...")
|
||||
|
||||
updated_count = 0
|
||||
for i, trade in enumerate(self.trade_history):
|
||||
try:
|
||||
# Get current leverage setting
|
||||
leverage = self.get_leverage()
|
||||
|
||||
# Calculate position size in USD
|
||||
position_size_usd = trade.entry_price * trade.quantity
|
||||
|
||||
# Calculate gross PnL (before fees) with leverage
|
||||
if trade.side == 'LONG':
|
||||
gross_pnl = (trade.exit_price - trade.entry_price) * trade.quantity * leverage
|
||||
else: # SHORT
|
||||
gross_pnl = (trade.entry_price - trade.exit_price) * trade.quantity * leverage
|
||||
|
||||
# Calculate fees (0.1% open + 0.1% close = 0.2% total)
|
||||
entry_value = trade.entry_price * trade.quantity
|
||||
exit_value = trade.exit_price * trade.quantity
|
||||
fees = (entry_value + exit_value) * 0.001
|
||||
|
||||
# Calculate net PnL (after fees)
|
||||
net_pnl = gross_pnl - fees
|
||||
|
||||
# Update trade record with corrected values
|
||||
trade.leverage = leverage
|
||||
trade.position_size_usd = position_size_usd
|
||||
trade.gross_pnl = gross_pnl
|
||||
trade.net_pnl = net_pnl
|
||||
trade.pnl = net_pnl # Main PnL field
|
||||
trade.fees = fees
|
||||
|
||||
updated_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recalculating trade record {i}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Updated {updated_count} trade records with correct leverage and PnL calculations")
|
||||
|
||||
# Also update trade_records list if it exists
|
||||
if hasattr(self, 'trade_records') and self.trade_records:
|
||||
logger.info("Updating trade_records list...")
|
||||
for i, trade in enumerate(self.trade_records):
|
||||
try:
|
||||
# Get current leverage setting
|
||||
leverage = self.get_leverage()
|
||||
|
||||
# Calculate position size in USD
|
||||
position_size_usd = trade.entry_price * trade.quantity
|
||||
|
||||
# Calculate gross PnL (before fees) with leverage
|
||||
if trade.side == 'LONG':
|
||||
gross_pnl = (trade.exit_price - trade.entry_price) * trade.quantity * leverage
|
||||
else: # SHORT
|
||||
gross_pnl = (trade.entry_price - trade.exit_price) * trade.quantity * leverage
|
||||
|
||||
# Calculate fees (0.1% open + 0.1% close = 0.2% total)
|
||||
entry_value = trade.entry_price * trade.quantity
|
||||
exit_value = trade.exit_price * trade.quantity
|
||||
fees = (entry_value + exit_value) * 0.001
|
||||
|
||||
# Calculate net PnL (after fees)
|
||||
net_pnl = gross_pnl - fees
|
||||
|
||||
# Update trade record with corrected values
|
||||
trade.leverage = leverage
|
||||
trade.position_size_usd = position_size_usd
|
||||
trade.gross_pnl = gross_pnl
|
||||
trade.net_pnl = net_pnl
|
||||
trade.pnl = net_pnl # Main PnL field
|
||||
trade.fees = fees
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recalculating trade_records entry {i}: {e}")
|
||||
continue
|
||||
|
||||
logger.info("Trade record recalculation completed")
|
||||
|
140
fix_order_history_calculations.py
Normal file
140
fix_order_history_calculations.py
Normal file
@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fix Order History Calculations
|
||||
|
||||
This script fixes the PnL calculations in the order history to ensure
|
||||
correct leverage application and consistent fee calculations.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.trading_executor import TradingExecutor, Position, TradeRecord
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_pnl_calculations():
|
||||
"""Test the corrected PnL calculations"""
|
||||
logger.info("Testing corrected PnL calculations...")
|
||||
|
||||
# Create trading executor
|
||||
executor = TradingExecutor()
|
||||
|
||||
# Set leverage to 50x
|
||||
executor.set_leverage(50.0)
|
||||
current_leverage = executor.get_leverage()
|
||||
logger.info(f"Current leverage: {current_leverage}x")
|
||||
|
||||
# Test case from your example
|
||||
# LONG $11.20 $3889.77 $3883.99 277 $-0.05 $0.007
|
||||
entry_price = 3889.77
|
||||
exit_price = 3883.99
|
||||
position_size_usd = 11.20
|
||||
quantity = position_size_usd / entry_price # Calculate actual quantity
|
||||
|
||||
logger.info(f"Test case:")
|
||||
logger.info(f" Position Size: ${position_size_usd}")
|
||||
logger.info(f" Entry Price: ${entry_price}")
|
||||
logger.info(f" Exit Price: ${exit_price}")
|
||||
logger.info(f" Quantity: {quantity:.6f}")
|
||||
logger.info(f" Leverage: {current_leverage}x")
|
||||
|
||||
# Calculate with corrected method
|
||||
gross_pnl = (exit_price - entry_price) * quantity * current_leverage
|
||||
fees = (entry_price * quantity + exit_price * quantity) * 0.001
|
||||
net_pnl = gross_pnl - fees
|
||||
|
||||
logger.info(f"Corrected calculations:")
|
||||
logger.info(f" Gross PnL: ${gross_pnl:.2f}")
|
||||
logger.info(f" Fees: ${fees:.3f}")
|
||||
logger.info(f" Net PnL: ${net_pnl:.2f}")
|
||||
|
||||
# Expected calculation for $11.20 position with 50x leverage
|
||||
expected_gross_pnl = (3883.99 - 3889.77) * (11.20 / 3889.77) * 50
|
||||
expected_fees = (11.20 + (11.20 * 3883.99 / 3889.77)) * 0.001
|
||||
expected_net_pnl = expected_gross_pnl - expected_fees
|
||||
|
||||
logger.info(f"Expected calculations:")
|
||||
logger.info(f" Expected Gross PnL: ${expected_gross_pnl:.2f}")
|
||||
logger.info(f" Expected Fees: ${expected_fees:.3f}")
|
||||
logger.info(f" Expected Net PnL: ${expected_net_pnl:.2f}")
|
||||
|
||||
# Compare with your reported values
|
||||
reported_pnl = -0.05
|
||||
reported_fees = 0.007
|
||||
|
||||
logger.info(f"Your reported values:")
|
||||
logger.info(f" Reported PnL: ${reported_pnl:.2f}")
|
||||
logger.info(f" Reported Fees: ${reported_fees:.3f}")
|
||||
|
||||
# Calculate difference
|
||||
pnl_diff = abs(net_pnl - reported_pnl)
|
||||
logger.info(f"Difference in PnL: ${pnl_diff:.2f}")
|
||||
|
||||
if pnl_diff > 1.0:
|
||||
logger.warning("Significant difference detected! The calculations were incorrect.")
|
||||
else:
|
||||
logger.info("Calculations are now correct!")
|
||||
|
||||
def fix_existing_trade_records():
|
||||
"""Fix existing trade records in the trading executor"""
|
||||
logger.info("Fixing existing trade records...")
|
||||
|
||||
try:
|
||||
# Create trading executor
|
||||
executor = TradingExecutor()
|
||||
|
||||
# Call the recalculation method
|
||||
executor.recalculate_all_trade_records()
|
||||
|
||||
# Display some sample results
|
||||
trade_history = executor.get_trade_history()
|
||||
if trade_history:
|
||||
logger.info(f"Found {len(trade_history)} trade records")
|
||||
|
||||
# Show first few trades
|
||||
for i, trade in enumerate(trade_history[:3]):
|
||||
logger.info(f"Trade {i+1}:")
|
||||
logger.info(f" Side: {trade.side}")
|
||||
logger.info(f" Entry: ${trade.entry_price:.2f}")
|
||||
logger.info(f" Exit: ${trade.exit_price:.2f}")
|
||||
logger.info(f" Quantity: {trade.quantity:.6f}")
|
||||
logger.info(f" Leverage: {trade.leverage}x")
|
||||
logger.info(f" Gross PnL: ${trade.gross_pnl:.2f}")
|
||||
logger.info(f" Net PnL: ${trade.net_pnl:.2f}")
|
||||
logger.info(f" Fees: ${trade.fees:.3f}")
|
||||
logger.info("")
|
||||
else:
|
||||
logger.info("No trade records found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fixing trade records: {e}")
|
||||
|
||||
def main():
|
||||
"""Main function to test and fix order history calculations"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("FIXING ORDER HISTORY CALCULATIONS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Test the calculations
|
||||
test_pnl_calculations()
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
|
||||
# Fix existing trade records
|
||||
fix_existing_trade_records()
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("Order history calculations have been fixed!")
|
||||
logger.info("All future trades will use the corrected PnL calculations.")
|
||||
logger.info("Existing trade records have been updated with correct leverage and fees.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
34
main.py
34
main.py
@ -65,16 +65,27 @@ async def run_web_dashboard():
|
||||
except Exception as e:
|
||||
logger.warning(f"[WARNING] Real-time streaming failed: {e}")
|
||||
|
||||
# Verify data connection
|
||||
# Verify data connection with retry mechanism
|
||||
logger.info("[DATA] Verifying live data connection...")
|
||||
symbol = config.get('symbols', ['ETH/USDT'])[0]
|
||||
test_df = data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if test_df is not None and len(test_df) > 0:
|
||||
logger.info("[SUCCESS] Data connection verified")
|
||||
logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation")
|
||||
else:
|
||||
logger.error("[ERROR] Data connection failed - no live data available")
|
||||
return
|
||||
|
||||
# Wait for data provider to initialize and fetch initial data
|
||||
max_retries = 10
|
||||
retry_delay = 2
|
||||
|
||||
for attempt in range(max_retries):
|
||||
test_df = data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if test_df is not None and len(test_df) > 0:
|
||||
logger.info("[SUCCESS] Data connection verified")
|
||||
logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation")
|
||||
break
|
||||
else:
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"[DATA] Waiting for data provider to initialize... (attempt {attempt + 1}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.warning("[WARNING] Data connection verification failed, but continuing with system startup")
|
||||
logger.warning("The system will attempt to fetch data as needed during operation")
|
||||
|
||||
# Load model registry for integrated pipeline
|
||||
try:
|
||||
@ -122,6 +133,7 @@ async def run_web_dashboard():
|
||||
logger.info("Starting training loop...")
|
||||
|
||||
# Start the training loop
|
||||
logger.info("About to start training loop...")
|
||||
await start_training_loop(orchestrator, trading_executor)
|
||||
|
||||
except Exception as e:
|
||||
@ -207,6 +219,8 @@ async def start_training_loop(orchestrator, trading_executor):
|
||||
logger.info("STARTING ENHANCED TRAINING LOOP WITH COB INTEGRATION")
|
||||
logger.info("=" * 70)
|
||||
|
||||
logger.info("Training loop function entered successfully")
|
||||
|
||||
# Initialize checkpoint management for training loop
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
training_integration = get_training_integration()
|
||||
@ -222,8 +236,10 @@ async def start_training_loop(orchestrator, trading_executor):
|
||||
|
||||
try:
|
||||
# Start real-time processing (Basic orchestrator doesn't have this method)
|
||||
logger.info("Checking for real-time processing capabilities...")
|
||||
try:
|
||||
if hasattr(orchestrator, 'start_realtime_processing'):
|
||||
logger.info("Starting real-time processing...")
|
||||
await orchestrator.start_realtime_processing()
|
||||
logger.info("Real-time processing started")
|
||||
else:
|
||||
@ -231,6 +247,8 @@ async def start_training_loop(orchestrator, trading_executor):
|
||||
except Exception as e:
|
||||
logger.warning(f"Real-time processing not available: {e}")
|
||||
|
||||
logger.info("About to enter main training loop...")
|
||||
|
||||
# Main training loop
|
||||
iteration = 0
|
||||
while True:
|
||||
|
@ -491,4 +491,57 @@ class CheckpointManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all checkpoints: {e}")
|
||||
return []
|
||||
return []
|
||||
|
||||
def get_checkpoint_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about all checkpoints
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Statistics about checkpoints
|
||||
"""
|
||||
try:
|
||||
stats = {
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
||||
|
||||
# Iterate through all model directories
|
||||
for model_dir in os.listdir(self.checkpoint_dir):
|
||||
model_path = os.path.join(self.checkpoint_dir, model_dir)
|
||||
if not os.path.isdir(model_path):
|
||||
continue
|
||||
|
||||
# Count checkpoints for this model
|
||||
checkpoint_files = glob.glob(os.path.join(model_path, f"{model_dir}_*.pt"))
|
||||
model_checkpoints = len(checkpoint_files)
|
||||
|
||||
# Calculate total size for this model
|
||||
model_size_mb = 0.0
|
||||
for checkpoint_file in checkpoint_files:
|
||||
try:
|
||||
size_bytes = os.path.getsize(checkpoint_file)
|
||||
model_size_mb += size_bytes / (1024 * 1024) # Convert to MB
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
stats['models'][model_dir] = {
|
||||
'checkpoints': model_checkpoints,
|
||||
'size_mb': round(model_size_mb, 2)
|
||||
}
|
||||
|
||||
stats['total_checkpoints'] += model_checkpoints
|
||||
stats['total_size_mb'] += model_size_mb
|
||||
|
||||
stats['total_size_mb'] = round(stats['total_size_mb'], 2)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint stats: {e}")
|
||||
return {
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
252
utils/timezone_utils.py
Normal file
252
utils/timezone_utils.py
Normal file
@ -0,0 +1,252 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Centralized timezone utilities for the trading system
|
||||
|
||||
This module provides consistent timezone handling across all components:
|
||||
- All external data (Binance, MEXC) comes in UTC
|
||||
- All internal processing uses Europe/Sofia timezone
|
||||
- All timestamps stored in database are timezone-aware
|
||||
- All NN model inputs use consistent timezone
|
||||
"""
|
||||
|
||||
import pytz
|
||||
import pandas as pd
|
||||
from datetime import datetime, timezone
|
||||
from typing import Union, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define timezone constants
|
||||
UTC = pytz.UTC
|
||||
SOFIA_TZ = pytz.timezone('Europe/Sofia')
|
||||
SYSTEM_TIMEZONE = SOFIA_TZ # Our system's primary timezone
|
||||
|
||||
def get_system_timezone():
|
||||
"""Get the system's primary timezone (Europe/Sofia)"""
|
||||
return SYSTEM_TIMEZONE
|
||||
|
||||
def get_utc_timezone():
|
||||
"""Get UTC timezone"""
|
||||
return UTC
|
||||
|
||||
def now_utc() -> datetime:
|
||||
"""Get current time in UTC"""
|
||||
return datetime.now(UTC)
|
||||
|
||||
def now_sofia() -> datetime:
|
||||
"""Get current time in Sofia timezone"""
|
||||
return datetime.now(SOFIA_TZ)
|
||||
|
||||
def now_system() -> datetime:
|
||||
"""Get current time in system timezone (Sofia)"""
|
||||
return now_sofia()
|
||||
|
||||
def to_utc(dt: Union[datetime, pd.Timestamp]) -> datetime:
|
||||
"""Convert datetime to UTC timezone"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if isinstance(dt, pd.Timestamp):
|
||||
dt = dt.to_pydatetime()
|
||||
|
||||
if dt.tzinfo is None:
|
||||
# Assume it's in system timezone if no timezone info
|
||||
dt = SYSTEM_TIMEZONE.localize(dt)
|
||||
|
||||
return dt.astimezone(UTC)
|
||||
|
||||
def to_sofia(dt: Union[datetime, pd.Timestamp]) -> datetime:
|
||||
"""Convert datetime to Sofia timezone"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if isinstance(dt, pd.Timestamp):
|
||||
dt = dt.to_pydatetime()
|
||||
|
||||
if dt.tzinfo is None:
|
||||
# Assume it's UTC if no timezone info (common for external data)
|
||||
dt = UTC.localize(dt)
|
||||
|
||||
return dt.astimezone(SOFIA_TZ)
|
||||
|
||||
def to_system_timezone(dt: Union[datetime, pd.Timestamp]) -> datetime:
|
||||
"""Convert datetime to system timezone (Sofia)"""
|
||||
return to_sofia(dt)
|
||||
|
||||
def normalize_timestamp(timestamp: Union[int, float, str, datetime, pd.Timestamp],
|
||||
source_tz: Optional[pytz.BaseTzInfo] = None) -> datetime:
|
||||
"""
|
||||
Normalize various timestamp formats to system timezone (Sofia)
|
||||
|
||||
Args:
|
||||
timestamp: Timestamp in various formats
|
||||
source_tz: Source timezone (defaults to UTC for external data)
|
||||
|
||||
Returns:
|
||||
datetime: Timezone-aware datetime in system timezone
|
||||
"""
|
||||
if timestamp is None:
|
||||
return now_system()
|
||||
|
||||
# Default source timezone is UTC (most external APIs use UTC)
|
||||
if source_tz is None:
|
||||
source_tz = UTC
|
||||
|
||||
try:
|
||||
# Handle different timestamp formats
|
||||
if isinstance(timestamp, (int, float)):
|
||||
# Unix timestamp (assume seconds, convert to milliseconds if needed)
|
||||
if timestamp > 1e10: # Milliseconds
|
||||
timestamp = timestamp / 1000
|
||||
dt = datetime.fromtimestamp(timestamp, tz=source_tz)
|
||||
|
||||
elif isinstance(timestamp, str):
|
||||
# String timestamp
|
||||
dt = pd.to_datetime(timestamp)
|
||||
if dt.tzinfo is None:
|
||||
dt = source_tz.localize(dt)
|
||||
|
||||
elif isinstance(timestamp, pd.Timestamp):
|
||||
dt = timestamp.to_pydatetime()
|
||||
if dt.tzinfo is None:
|
||||
dt = source_tz.localize(dt)
|
||||
|
||||
elif isinstance(timestamp, datetime):
|
||||
dt = timestamp
|
||||
if dt.tzinfo is None:
|
||||
dt = source_tz.localize(dt)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown timestamp format: {type(timestamp)}")
|
||||
return now_system()
|
||||
|
||||
# Convert to system timezone
|
||||
return dt.astimezone(SYSTEM_TIMEZONE)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing timestamp {timestamp}: {e}")
|
||||
return now_system()
|
||||
|
||||
def normalize_dataframe_timestamps(df: pd.DataFrame,
|
||||
timestamp_col: str = 'timestamp',
|
||||
source_tz: Optional[pytz.BaseTzInfo] = None) -> pd.DataFrame:
|
||||
"""
|
||||
Normalize timestamps in a DataFrame to system timezone
|
||||
|
||||
Args:
|
||||
df: DataFrame with timestamp column
|
||||
timestamp_col: Name of timestamp column
|
||||
source_tz: Source timezone (defaults to UTC)
|
||||
|
||||
Returns:
|
||||
DataFrame with normalized timestamps
|
||||
"""
|
||||
if df.empty or timestamp_col not in df.columns:
|
||||
return df
|
||||
|
||||
if source_tz is None:
|
||||
source_tz = UTC
|
||||
|
||||
try:
|
||||
# Convert to datetime if not already
|
||||
if not pd.api.types.is_datetime64_any_dtype(df[timestamp_col]):
|
||||
df[timestamp_col] = pd.to_datetime(df[timestamp_col])
|
||||
|
||||
# Handle timezone
|
||||
if df[timestamp_col].dt.tz is None:
|
||||
# Localize to source timezone first
|
||||
df[timestamp_col] = df[timestamp_col].dt.tz_localize(source_tz)
|
||||
|
||||
# Convert to system timezone
|
||||
df[timestamp_col] = df[timestamp_col].dt.tz_convert(SYSTEM_TIMEZONE)
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing DataFrame timestamps: {e}")
|
||||
return df
|
||||
|
||||
def normalize_dataframe_index(df: pd.DataFrame,
|
||||
source_tz: Optional[pytz.BaseTzInfo] = None) -> pd.DataFrame:
|
||||
"""
|
||||
Normalize DataFrame index timestamps to system timezone
|
||||
|
||||
Args:
|
||||
df: DataFrame with datetime index
|
||||
source_tz: Source timezone (defaults to UTC)
|
||||
|
||||
Returns:
|
||||
DataFrame with normalized index
|
||||
"""
|
||||
if df.empty or not isinstance(df.index, pd.DatetimeIndex):
|
||||
return df
|
||||
|
||||
if source_tz is None:
|
||||
source_tz = UTC
|
||||
|
||||
try:
|
||||
# Handle timezone
|
||||
if df.index.tz is None:
|
||||
# Localize to source timezone first
|
||||
df.index = df.index.tz_localize(source_tz)
|
||||
|
||||
# Convert to system timezone
|
||||
df.index = df.index.tz_convert(SYSTEM_TIMEZONE)
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing DataFrame index: {e}")
|
||||
return df
|
||||
|
||||
def format_timestamp_for_display(dt: datetime, format_str: str = '%H:%M:%S') -> str:
|
||||
"""
|
||||
Format timestamp for display in system timezone
|
||||
|
||||
Args:
|
||||
dt: Datetime to format
|
||||
format_str: Format string
|
||||
|
||||
Returns:
|
||||
Formatted timestamp string
|
||||
"""
|
||||
if dt is None:
|
||||
return now_system().strftime(format_str)
|
||||
|
||||
try:
|
||||
# Convert to system timezone if needed
|
||||
if isinstance(dt, datetime):
|
||||
if dt.tzinfo is None:
|
||||
dt = UTC.localize(dt)
|
||||
dt = dt.astimezone(SYSTEM_TIMEZONE)
|
||||
|
||||
return dt.strftime(format_str)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting timestamp {dt}: {e}")
|
||||
return now_system().strftime(format_str)
|
||||
|
||||
def get_timezone_offset_hours() -> float:
|
||||
"""Get current timezone offset from UTC in hours"""
|
||||
now = now_system()
|
||||
utc_now = now_utc()
|
||||
offset_seconds = (now - utc_now.replace(tzinfo=None)).total_seconds()
|
||||
return offset_seconds / 3600
|
||||
|
||||
def is_market_hours() -> bool:
|
||||
"""Check if it's currently market hours (24/7 for crypto, but useful for logging)"""
|
||||
# Crypto markets are 24/7, but this can be useful for other purposes
|
||||
return True
|
||||
|
||||
def log_timezone_info():
|
||||
"""Log current timezone information for debugging"""
|
||||
now_utc_time = now_utc()
|
||||
now_sofia_time = now_sofia()
|
||||
offset_hours = get_timezone_offset_hours()
|
||||
|
||||
logger.info(f"Timezone Info:")
|
||||
logger.info(f" UTC Time: {now_utc_time}")
|
||||
logger.info(f" Sofia Time: {now_sofia_time}")
|
||||
logger.info(f" Offset: {offset_hours:+.1f} hours from UTC")
|
||||
logger.info(f" System Timezone: {SYSTEM_TIMEZONE}")
|
@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingIntegration:
|
||||
def __init__(self, enable_wandb: bool = True):
|
||||
self.enable_wandb = enable_wandb
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
|
||||
|
||||
@ -55,9 +56,13 @@ class TrainingIntegration:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
# Save the model first to get the path
|
||||
model_path = f"models/{model_name}_temp.pt"
|
||||
torch.save(cnn_model.state_dict(), model_path)
|
||||
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model=cnn_model,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
model_type='cnn',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
@ -114,9 +119,13 @@ class TrainingIntegration:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
# Save the model first to get the path
|
||||
model_path = f"models/{model_name}_temp.pt"
|
||||
torch.save(rl_agent.state_dict() if hasattr(rl_agent, 'state_dict') else rl_agent, model_path)
|
||||
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model=rl_agent,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
model_type='rl',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
|
@ -62,7 +62,12 @@ logging.getLogger('dash').setLevel(logging.WARNING)
|
||||
logging.getLogger('dash.dash').setLevel(logging.WARNING)
|
||||
|
||||
# Import core components
|
||||
from core.config import get_config
|
||||
try:
|
||||
from core.config import get_config
|
||||
except ImportError:
|
||||
# Fallback if config module is not available
|
||||
def get_config():
|
||||
return {}
|
||||
from core.data_provider import DataProvider
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
@ -117,12 +122,17 @@ class CleanTradingDashboard:
|
||||
"""Clean, modular trading dashboard implementation"""
|
||||
|
||||
def __init__(self, data_provider=None, orchestrator: Optional[Any] = None, trading_executor: Optional[TradingExecutor] = None):
|
||||
self.config = get_config()
|
||||
# Load configuration safely
|
||||
try:
|
||||
self.config = get_config()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading config, using empty config: {e}")
|
||||
self.config = {}
|
||||
|
||||
# Removed batch counter - now using proper interval separation for performance
|
||||
|
||||
# Initialize components
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.data_provider = data_provider or StandardizedDataProvider()
|
||||
self.trading_executor = trading_executor or TradingExecutor()
|
||||
|
||||
# Initialize unified orchestrator with full ML capabilities
|
||||
@ -174,10 +184,35 @@ class CleanTradingDashboard:
|
||||
self.standardized_cnn = None
|
||||
self._initialize_standardized_cnn()
|
||||
|
||||
# Initialize trading mode and cold start settings from config
|
||||
self.trading_mode_live = False # Default to simulation mode
|
||||
self.cold_start_enabled = True # Default to cold start enabled
|
||||
|
||||
# Load config values if available
|
||||
try:
|
||||
if hasattr(self, 'config') and self.config:
|
||||
# Check if trading mode is live based on config
|
||||
exchanges = self.config.get('exchanges', {})
|
||||
if exchanges:
|
||||
for exchange_name, exchange_config in exchanges.items():
|
||||
if exchange_config.get('enabled', False):
|
||||
trading_mode = exchange_config.get('trading_mode', 'simulation')
|
||||
if trading_mode == 'live':
|
||||
self.trading_mode_live = True
|
||||
break
|
||||
|
||||
# Check cold start setting
|
||||
cold_start_config = self.config.get('cold_start', {})
|
||||
self.cold_start_enabled = cold_start_config.get('enabled', True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading config settings, using defaults: {e}")
|
||||
# Keep default values
|
||||
|
||||
# Initialize layout and component managers
|
||||
self.layout_manager = DashboardLayoutManager(
|
||||
starting_balance=self._get_initial_balance(),
|
||||
trading_executor=self.trading_executor
|
||||
trading_executor=self.trading_executor,
|
||||
dashboard=self
|
||||
)
|
||||
self.component_manager = DashboardComponentManager()
|
||||
|
||||
@ -206,6 +241,19 @@ class CleanTradingDashboard:
|
||||
# ENHANCED: Model control toggles - separate inference and training
|
||||
self.dqn_inference_enabled = True # Default: enabled
|
||||
self.dqn_training_enabled = True # Default: enabled
|
||||
|
||||
# Trading mode and cold start settings from config
|
||||
from core.config import get_config
|
||||
config = get_config()
|
||||
|
||||
# Initialize trading mode from config (default to simulation)
|
||||
default_trading_mode = config.get('exchanges', {}).get('bybit', {}).get('trading_mode', 'simulation')
|
||||
self.trading_mode_live = (default_trading_mode == 'live')
|
||||
|
||||
# Initialize cold start from config (default to enabled)
|
||||
self.cold_start_enabled = config.get('cold_start', {}).get('enabled', True)
|
||||
|
||||
logger.info(f"Dashboard initialized - Trading Mode: {'LIVE' if self.trading_mode_live else 'SIM'}, Cold Start: {'ON' if self.cold_start_enabled else 'OFF'}")
|
||||
self.cnn_inference_enabled = True
|
||||
self.cnn_training_enabled = True
|
||||
|
||||
@ -611,13 +659,51 @@ class CleanTradingDashboard:
|
||||
logger.error(f"Error getting model status: {e}")
|
||||
return {'loaded_models': {}, 'total_models': 0, 'system_status': 'ERROR'}
|
||||
|
||||
def _convert_utc_to_local(self, utc_timestamp):
|
||||
"""Convert UTC timestamp to local timezone for display"""
|
||||
try:
|
||||
if utc_timestamp is None:
|
||||
return datetime.now()
|
||||
|
||||
# Handle different input types
|
||||
if isinstance(utc_timestamp, str):
|
||||
try:
|
||||
utc_timestamp = pd.to_datetime(utc_timestamp)
|
||||
except:
|
||||
return datetime.now()
|
||||
|
||||
# If it's already a datetime object
|
||||
if isinstance(utc_timestamp, datetime):
|
||||
# If it has timezone info and is UTC, convert to local
|
||||
if utc_timestamp.tzinfo is not None:
|
||||
if str(utc_timestamp.tzinfo) == 'UTC':
|
||||
# Convert UTC to local timezone
|
||||
local_timestamp = utc_timestamp.replace(tzinfo=timezone.utc).astimezone()
|
||||
return local_timestamp.replace(tzinfo=None) # Remove timezone info for display
|
||||
else:
|
||||
# Already has timezone, convert to local
|
||||
return utc_timestamp.astimezone().replace(tzinfo=None)
|
||||
else:
|
||||
# No timezone info, assume it's already local
|
||||
return utc_timestamp
|
||||
|
||||
# Fallback
|
||||
return datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error converting UTC to local time: {e}")
|
||||
return datetime.now()
|
||||
|
||||
def _safe_strftime(self, timestamp_val, format_str='%H:%M:%S'):
|
||||
"""Safely format timestamp, handling both string and datetime objects"""
|
||||
try:
|
||||
if isinstance(timestamp_val, str):
|
||||
return timestamp_val
|
||||
elif hasattr(timestamp_val, 'strftime'):
|
||||
return timestamp_val.strftime(format_str)
|
||||
# Convert to local time first
|
||||
local_timestamp = self._convert_utc_to_local(timestamp_val)
|
||||
|
||||
if isinstance(local_timestamp, str):
|
||||
return local_timestamp
|
||||
elif hasattr(local_timestamp, 'strftime'):
|
||||
return local_timestamp.strftime(format_str)
|
||||
else:
|
||||
return datetime.now().strftime(format_str)
|
||||
except Exception as e:
|
||||
@ -1135,9 +1221,13 @@ class CleanTradingDashboard:
|
||||
def handle_clear_session(n_clicks):
|
||||
"""Handle clear session button"""
|
||||
if n_clicks:
|
||||
self._clear_session()
|
||||
# Return a visual confirmation that the session was cleared
|
||||
return [html.I(className="fas fa-check me-1 text-success"), "Cleared"]
|
||||
try:
|
||||
self._clear_session()
|
||||
# Return a visual confirmation that the session was cleared
|
||||
return [html.I(className="fas fa-check me-1 text-success"), "Session Cleared!"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in clear session callback: {e}")
|
||||
return [html.I(className="fas fa-exclamation-triangle me-1 text-warning"), "Clear Failed"]
|
||||
return [html.I(className="fas fa-trash me-1"), "Clear Session"]
|
||||
|
||||
@self.app.callback(
|
||||
@ -1154,6 +1244,93 @@ class CleanTradingDashboard:
|
||||
else:
|
||||
return [html.I(className="fas fa-exclamation-triangle me-1"), "Store Failed"]
|
||||
return [html.I(className="fas fa-save me-1"), "Store All Models"]
|
||||
|
||||
# Trading Mode Toggle
|
||||
@self.app.callback(
|
||||
Output('trading-mode-display', 'children'),
|
||||
Output('trading-mode-display', 'className'),
|
||||
[Input('trading-mode-switch', 'value')]
|
||||
)
|
||||
def update_trading_mode(switch_value):
|
||||
"""Update trading mode display and apply changes"""
|
||||
try:
|
||||
is_live = 'live' in (switch_value or [])
|
||||
self.trading_mode_live = is_live
|
||||
|
||||
# Update trading executor mode if available
|
||||
if hasattr(self, 'trading_executor') and self.trading_executor:
|
||||
if hasattr(self.trading_executor, 'set_trading_mode'):
|
||||
# Use the new set_trading_mode method
|
||||
success = self.trading_executor.set_trading_mode('live' if is_live else 'simulation')
|
||||
if success:
|
||||
logger.info(f"TRADING MODE: {'LIVE' if is_live else 'SIMULATION'} - Mode updated successfully")
|
||||
else:
|
||||
logger.error(f"Failed to update trading mode to {'LIVE' if is_live else 'SIMULATION'}")
|
||||
else:
|
||||
# Fallback to direct property setting
|
||||
if is_live:
|
||||
self.trading_executor.trading_mode = 'live'
|
||||
self.trading_executor.simulation_mode = False
|
||||
logger.info("TRADING MODE: LIVE - Real orders will be executed!")
|
||||
else:
|
||||
self.trading_executor.trading_mode = 'simulation'
|
||||
self.trading_executor.simulation_mode = True
|
||||
logger.info("TRADING MODE: SIMULATION - Orders are simulated")
|
||||
|
||||
# Return display text and styling
|
||||
if is_live:
|
||||
return "LIVE", "fw-bold text-danger"
|
||||
else:
|
||||
return "SIM", "fw-bold text-warning"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating trading mode: {e}")
|
||||
return "ERROR", "fw-bold text-danger"
|
||||
|
||||
# Cold Start Toggle
|
||||
@self.app.callback(
|
||||
Output('cold-start-display', 'children'),
|
||||
Output('cold-start-display', 'className'),
|
||||
[Input('cold-start-switch', 'value')]
|
||||
)
|
||||
def update_cold_start(switch_value):
|
||||
"""Update cold start training mode"""
|
||||
try:
|
||||
is_enabled = 'enabled' in (switch_value or [])
|
||||
self.cold_start_enabled = is_enabled
|
||||
|
||||
# Update orchestrator cold start mode if available
|
||||
if hasattr(self, 'orchestrator') and self.orchestrator:
|
||||
if hasattr(self.orchestrator, 'set_cold_start_training_enabled'):
|
||||
# Use the new set_cold_start_training_enabled method
|
||||
success = self.orchestrator.set_cold_start_training_enabled(is_enabled)
|
||||
if success:
|
||||
logger.info(f"COLD START: {'ON' if is_enabled else 'OFF'} - Training mode updated successfully")
|
||||
else:
|
||||
logger.error(f"Failed to update cold start training to {'ON' if is_enabled else 'OFF'}")
|
||||
else:
|
||||
# Fallback to direct property setting
|
||||
if hasattr(self.orchestrator, 'cold_start_enabled'):
|
||||
self.orchestrator.cold_start_enabled = is_enabled
|
||||
|
||||
# Update training frequency based on cold start mode
|
||||
if hasattr(self.orchestrator, 'training_frequency'):
|
||||
if is_enabled:
|
||||
self.orchestrator.training_frequency = 'high' # Train on every signal
|
||||
logger.info("COLD START: ON - Excessive training enabled")
|
||||
else:
|
||||
self.orchestrator.training_frequency = 'normal' # Normal training
|
||||
logger.info("COLD START: OFF - Normal training frequency")
|
||||
|
||||
# Return display text and styling
|
||||
if is_enabled:
|
||||
return "ON", "fw-bold text-success"
|
||||
else:
|
||||
return "OFF", "fw-bold text-secondary"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating cold start mode: {e}")
|
||||
return "ERROR", "fw-bold text-danger"
|
||||
|
||||
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price for symbol - ONLY using our data providers"""
|
||||
@ -5449,6 +5626,16 @@ class CleanTradingDashboard:
|
||||
if hasattr(self, 'dashboard_cache'):
|
||||
self.dashboard_cache = {}
|
||||
|
||||
# Clear any success rate or performance caches
|
||||
if hasattr(self, '_performance_cache'):
|
||||
self._performance_cache = {}
|
||||
|
||||
if hasattr(self, '_success_rate_cache'):
|
||||
self._success_rate_cache = {}
|
||||
|
||||
if hasattr(self, '_win_rate_cache'):
|
||||
self._win_rate_cache = {}
|
||||
|
||||
# Clear persistent trade log files
|
||||
self._clear_trade_logs()
|
||||
|
||||
@ -5463,10 +5650,17 @@ class CleanTradingDashboard:
|
||||
# Force refresh of dashboard components
|
||||
self._force_dashboard_refresh()
|
||||
|
||||
logger.info("✅ Session data and trade logs cleared successfully")
|
||||
logger.info("=" * 60)
|
||||
logger.info("✅ SESSION CLEAR COMPLETED SUCCESSFULLY")
|
||||
logger.info("=" * 60)
|
||||
logger.info("📊 Session P&L reset to $0.00")
|
||||
logger.info("📈 All positions closed")
|
||||
logger.info("📋 Trade history cleared")
|
||||
logger.info("🎯 Success rate calculations reset")
|
||||
logger.info("📈 Model performance metrics reset")
|
||||
logger.info("🔄 All caches cleared")
|
||||
logger.info("📁 Trade log files cleared")
|
||||
logger.info("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error clearing session: {e}")
|
||||
@ -5578,47 +5772,129 @@ class CleanTradingDashboard:
|
||||
# Use the orchestrator's built-in clear method if available
|
||||
if hasattr(self.orchestrator, 'clear_session_data'):
|
||||
self.orchestrator.clear_session_data()
|
||||
logger.info("✅ Used orchestrator's built-in clear_session_data method")
|
||||
else:
|
||||
# Fallback to manual clearing
|
||||
if hasattr(self.orchestrator, 'recent_decisions'):
|
||||
self.orchestrator.recent_decisions = {}
|
||||
logger.info("✅ Cleared recent_decisions")
|
||||
|
||||
if hasattr(self.orchestrator, 'recent_dqn_predictions'):
|
||||
for symbol in self.orchestrator.recent_dqn_predictions:
|
||||
self.orchestrator.recent_dqn_predictions[symbol].clear()
|
||||
logger.info("✅ Cleared recent_dqn_predictions")
|
||||
|
||||
if hasattr(self.orchestrator, 'recent_cnn_predictions'):
|
||||
for symbol in self.orchestrator.recent_cnn_predictions:
|
||||
self.orchestrator.recent_cnn_predictions[symbol].clear()
|
||||
logger.info("✅ Cleared recent_cnn_predictions")
|
||||
|
||||
if hasattr(self.orchestrator, 'prediction_accuracy_history'):
|
||||
for symbol in self.orchestrator.prediction_accuracy_history:
|
||||
self.orchestrator.prediction_accuracy_history[symbol].clear()
|
||||
logger.info("✅ Cleared prediction_accuracy_history")
|
||||
|
||||
logger.info("Orchestrator state cleared (fallback method)")
|
||||
|
||||
# Clear model performance tracking (critical for success rate calculations)
|
||||
if hasattr(self.orchestrator, 'model_performance'):
|
||||
# Reset all model performance metrics
|
||||
for model_name in self.orchestrator.model_performance:
|
||||
self.orchestrator.model_performance[model_name] = {
|
||||
'correct': 0,
|
||||
'total': 0,
|
||||
'accuracy': 0.0,
|
||||
'price_predictions': {'total': 0, 'accurate': 0, 'avg_error': 0.0}
|
||||
}
|
||||
logger.info("✅ Reset model_performance tracking (accuracy calculations)")
|
||||
|
||||
# Clear model statistics if they exist
|
||||
if hasattr(self.orchestrator, 'model_statistics'):
|
||||
for model_name in self.orchestrator.model_statistics:
|
||||
if hasattr(self.orchestrator.model_statistics[model_name], 'accuracy'):
|
||||
self.orchestrator.model_statistics[model_name].accuracy = None
|
||||
if hasattr(self.orchestrator.model_statistics[model_name], 'correct'):
|
||||
self.orchestrator.model_statistics[model_name].correct = 0
|
||||
if hasattr(self.orchestrator.model_statistics[model_name], 'total'):
|
||||
self.orchestrator.model_statistics[model_name].total = 0
|
||||
logger.info("✅ Reset model_statistics accuracy tracking")
|
||||
|
||||
# Clear any cached performance metrics
|
||||
if hasattr(self.orchestrator, '_cached_performance'):
|
||||
self.orchestrator._cached_performance = {}
|
||||
|
||||
if hasattr(self.orchestrator, '_last_performance_update'):
|
||||
self.orchestrator._last_performance_update = {}
|
||||
|
||||
logger.info("✅ Orchestrator state and performance metrics cleared completely")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing orchestrator state: {e}")
|
||||
|
||||
def _clear_trading_executor_state(self):
|
||||
"""Clear trading executor state and positions"""
|
||||
try:
|
||||
# Clear positions and orders
|
||||
if hasattr(self.trading_executor, 'current_positions'):
|
||||
self.trading_executor.current_positions = {}
|
||||
|
||||
if hasattr(self.trading_executor, 'positions'):
|
||||
self.trading_executor.positions = {}
|
||||
|
||||
if hasattr(self.trading_executor, 'open_orders'):
|
||||
self.trading_executor.open_orders = {}
|
||||
|
||||
# Clear trade history and records (critical for success rate calculations)
|
||||
if hasattr(self.trading_executor, 'trade_history'):
|
||||
self.trading_executor.trade_history = []
|
||||
logger.info("✅ Cleared trade_history")
|
||||
|
||||
if hasattr(self.trading_executor, 'trade_records'):
|
||||
self.trading_executor.trade_records = []
|
||||
logger.info("✅ Cleared trade_records (used for success rate)")
|
||||
|
||||
# Clear P&L and fee tracking
|
||||
if hasattr(self.trading_executor, 'session_pnl'):
|
||||
self.trading_executor.session_pnl = 0.0
|
||||
|
||||
if hasattr(self.trading_executor, 'total_fees'):
|
||||
self.trading_executor.total_fees = 0.0
|
||||
|
||||
if hasattr(self.trading_executor, 'open_orders'):
|
||||
self.trading_executor.open_orders = {}
|
||||
if hasattr(self.trading_executor, 'daily_pnl'):
|
||||
self.trading_executor.daily_pnl = 0.0
|
||||
|
||||
logger.info("Trading executor state cleared")
|
||||
if hasattr(self.trading_executor, 'daily_loss'):
|
||||
self.trading_executor.daily_loss = 0.0
|
||||
|
||||
if hasattr(self.trading_executor, 'daily_trades'):
|
||||
self.trading_executor.daily_trades = 0
|
||||
|
||||
# Clear consecutive loss tracking (affects success rate calculations)
|
||||
if hasattr(self.trading_executor, 'consecutive_losses'):
|
||||
self.trading_executor.consecutive_losses = 0
|
||||
logger.info("✅ Reset consecutive_losses counter")
|
||||
|
||||
# Reset safety feature state
|
||||
if hasattr(self.trading_executor, 'safety_triggered'):
|
||||
self.trading_executor.safety_triggered = False
|
||||
logger.info("✅ Reset safety_triggered flag")
|
||||
|
||||
# Reset profitability multiplier to default
|
||||
if hasattr(self.trading_executor, 'profitability_reward_multiplier'):
|
||||
self.trading_executor.profitability_reward_multiplier = getattr(
|
||||
self.trading_executor, 'default_profitability_multiplier', 1.0
|
||||
)
|
||||
logger.info("✅ Reset profitability_reward_multiplier")
|
||||
|
||||
# Clear any cached statistics
|
||||
if hasattr(self.trading_executor, '_cached_stats'):
|
||||
self.trading_executor._cached_stats = {}
|
||||
|
||||
if hasattr(self.trading_executor, '_last_stats_update'):
|
||||
self.trading_executor._last_stats_update = None
|
||||
|
||||
logger.info("✅ Trading executor state cleared completely")
|
||||
logger.info("📊 Success rate calculations will start fresh")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing trading executor state: {e}")
|
||||
@ -6056,6 +6332,7 @@ class CleanTradingDashboard:
|
||||
|
||||
# Fallback: create BaseDataInput from available data
|
||||
from core.data_models import BaseDataInput, OHLCVBar, COBData
|
||||
import random
|
||||
|
||||
# Get OHLCV data for different timeframes - ensure we have enough data
|
||||
ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300)
|
||||
@ -6073,7 +6350,6 @@ class CleanTradingDashboard:
|
||||
if len(bars) > 0:
|
||||
last_bar = bars[-1]
|
||||
# Add small random variation to prevent identical data
|
||||
import random
|
||||
for i in range(target_count - len(bars)):
|
||||
# Create slight variations of the last bar
|
||||
variation = random.uniform(-0.001, 0.001) # 0.1% variation
|
||||
@ -6090,7 +6366,6 @@ class CleanTradingDashboard:
|
||||
bars.append(new_bar)
|
||||
else:
|
||||
# Create realistic dummy bars with variation
|
||||
from core.data_models import OHLCVBar
|
||||
base_price = 3500.0
|
||||
for i in range(target_count):
|
||||
# Add realistic price movement
|
||||
@ -8725,6 +9000,14 @@ def signal_handler(sig, frame):
|
||||
self.shutdown() # Assuming a shutdown method exists or add one
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
# Only set signal handlers if we're in the main thread
|
||||
try:
|
||||
import threading
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
else:
|
||||
print("Warning: Signal handlers can only be set in main thread, skipping...")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not set signal handlers: {e}")
|
||||
|
||||
|
@ -10,9 +10,10 @@ from datetime import datetime
|
||||
class DashboardLayoutManager:
|
||||
"""Manages dashboard layout and structure"""
|
||||
|
||||
def __init__(self, starting_balance: float = 100.0, trading_executor=None):
|
||||
def __init__(self, starting_balance: float = 100.0, trading_executor=None, dashboard=None):
|
||||
self.starting_balance = starting_balance
|
||||
self.trading_executor = trading_executor
|
||||
self.dashboard = dashboard
|
||||
|
||||
def create_main_layout(self):
|
||||
"""Create the main dashboard layout"""
|
||||
@ -153,6 +154,48 @@ class DashboardLayoutManager:
|
||||
"Session Controls"
|
||||
], className="card-title mb-2"),
|
||||
|
||||
# Trading Agent Mode Toggle
|
||||
html.Div([
|
||||
html.Label([
|
||||
html.I(className="fas fa-robot me-1"),
|
||||
"Trading Agent: ",
|
||||
html.Span(
|
||||
id="trading-mode-display",
|
||||
children="LIVE" if getattr(self.dashboard, 'trading_mode_live', False) else "SIM",
|
||||
className="fw-bold text-danger" if getattr(self.dashboard, 'trading_mode_live', False) else "fw-bold text-warning"
|
||||
)
|
||||
], className="form-label small mb-1"),
|
||||
dcc.Checklist(
|
||||
id='trading-mode-switch',
|
||||
options=[{'label': '', 'value': 'live'}],
|
||||
value=['live'] if getattr(self.dashboard, 'trading_mode_live', False) else [],
|
||||
className="form-check-input"
|
||||
),
|
||||
html.Small("SIM = Simulation Mode, LIVE = Real Trading", className="text-muted d-block")
|
||||
], className="mb-2"),
|
||||
|
||||
# Cold Start Training Toggle
|
||||
html.Div([
|
||||
html.Label([
|
||||
html.I(className="fas fa-fire me-1"),
|
||||
"Cold Start Training: ",
|
||||
html.Span(
|
||||
id="cold-start-display",
|
||||
children="ON" if getattr(self.dashboard, 'cold_start_enabled', True) else "OFF",
|
||||
className="fw-bold text-success" if getattr(self.dashboard, 'cold_start_enabled', True) else "fw-bold text-secondary"
|
||||
)
|
||||
], className="form-label small mb-1"),
|
||||
dcc.Checklist(
|
||||
id='cold-start-switch',
|
||||
options=[{'label': '', 'value': 'enabled'}],
|
||||
value=['enabled'] if getattr(self.dashboard, 'cold_start_enabled', True) else [],
|
||||
className="form-check-input"
|
||||
),
|
||||
html.Small("Excessive training during cold start", className="text-muted d-block")
|
||||
], className="mb-2"),
|
||||
|
||||
html.Hr(className="my-2"),
|
||||
|
||||
# Leverage Control
|
||||
html.Div([
|
||||
html.Label([
|
||||
|
Reference in New Issue
Block a user