code structure
This commit is contained in:
@ -1,6 +0,0 @@
|
||||
# Trading environments for reinforcement learning
|
||||
# This module contains environments for training trading agents
|
||||
|
||||
from NN.environments.trading_env import TradingEnvironment
|
||||
|
||||
__all__ = ['TradingEnvironment']
|
@ -1,532 +0,0 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Dict, Tuple, List, Any, Optional
|
||||
import logging
|
||||
import gym
|
||||
from gym import spaces
|
||||
import random
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingEnvironment(gym.Env):
|
||||
"""
|
||||
Trading environment implementing gym interface for reinforcement learning
|
||||
|
||||
2-Action System:
|
||||
- 0: SELL (or close long position)
|
||||
- 1: BUY (or close short position)
|
||||
|
||||
Intelligent Position Management:
|
||||
- When neutral: Actions enter positions
|
||||
- When positioned: Actions can close or flip positions
|
||||
- Different thresholds for entry vs exit decisions
|
||||
|
||||
State:
|
||||
- OHLCV data from multiple timeframes
|
||||
- Technical indicators
|
||||
- Position data and unrealized PnL
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_interface,
|
||||
initial_balance: float = 10000.0,
|
||||
transaction_fee: float = 0.0002,
|
||||
window_size: int = 20,
|
||||
max_position: float = 1.0,
|
||||
reward_scaling: float = 1.0,
|
||||
entry_threshold: float = 0.6, # Higher threshold for entering positions
|
||||
exit_threshold: float = 0.3, # Lower threshold for exiting positions
|
||||
):
|
||||
"""
|
||||
Initialize the trading environment with 2-action system.
|
||||
|
||||
Args:
|
||||
data_interface: DataInterface instance to get market data
|
||||
initial_balance: Initial balance in the base currency
|
||||
transaction_fee: Fee for each transaction as a fraction of trade value
|
||||
window_size: Number of candles in the observation window
|
||||
max_position: Maximum position size as a fraction of balance
|
||||
reward_scaling: Scale factor for rewards
|
||||
entry_threshold: Confidence threshold for entering new positions
|
||||
exit_threshold: Confidence threshold for exiting positions
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.data_interface = data_interface
|
||||
self.initial_balance = initial_balance
|
||||
self.transaction_fee = transaction_fee
|
||||
self.window_size = window_size
|
||||
self.max_position = max_position
|
||||
self.reward_scaling = reward_scaling
|
||||
self.entry_threshold = entry_threshold
|
||||
self.exit_threshold = exit_threshold
|
||||
|
||||
# Load data for primary timeframe (assuming the first one is primary)
|
||||
self.timeframe = self.data_interface.timeframes[0]
|
||||
self.reset_data()
|
||||
|
||||
# Define action and observation spaces for 2-action system
|
||||
self.action_space = spaces.Discrete(2) # 0=SELL, 1=BUY
|
||||
|
||||
# For observation space, we consider multiple timeframes with OHLCV data
|
||||
# and additional features like technical indicators, position info, etc.
|
||||
n_timeframes = len(self.data_interface.timeframes)
|
||||
n_features = 5 # OHLCV data by default
|
||||
|
||||
# Add additional features for position, balance, unrealized_pnl, etc.
|
||||
additional_features = 5 # position, balance, unrealized_pnl, entry_price, position_duration
|
||||
|
||||
# Calculate total feature dimension
|
||||
total_features = (n_timeframes * n_features * self.window_size) + additional_features
|
||||
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(total_features,), dtype=np.float32
|
||||
)
|
||||
|
||||
# Use tuple for state_shape that EnhancedCNN expects
|
||||
self.state_shape = (total_features,)
|
||||
|
||||
# Position tracking for 2-action system
|
||||
self.position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
||||
self.entry_price = 0.0 # Price at which position was entered
|
||||
self.entry_step = 0 # Step at which position was entered
|
||||
|
||||
# Initialize state
|
||||
self.reset()
|
||||
|
||||
def reset_data(self):
|
||||
"""Reset data and generate a new set of price data for training"""
|
||||
# Get data for each timeframe
|
||||
self.data = {}
|
||||
for tf in self.data_interface.timeframes:
|
||||
df = self.data_interface.dataframes[tf]
|
||||
if df is not None and not df.empty:
|
||||
self.data[tf] = df
|
||||
|
||||
if not self.data:
|
||||
raise ValueError("No data available for training")
|
||||
|
||||
# Use the primary timeframe for step count
|
||||
self.prices = self.data[self.timeframe]['close'].values
|
||||
self.timestamps = self.data[self.timeframe].index.values
|
||||
self.max_steps = len(self.prices) - self.window_size - 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset the environment to initial state"""
|
||||
# Reset trading variables
|
||||
self.balance = self.initial_balance
|
||||
self.trades = []
|
||||
self.rewards = []
|
||||
|
||||
# Reset step counter
|
||||
self.current_step = self.window_size
|
||||
|
||||
# Get initial observation
|
||||
observation = self._get_observation()
|
||||
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Take a step in the environment using 2-action system with intelligent position management.
|
||||
|
||||
Args:
|
||||
action: Action to take (0: SELL, 1: BUY)
|
||||
|
||||
Returns:
|
||||
tuple: (observation, reward, done, info)
|
||||
"""
|
||||
# Get current state before taking action
|
||||
prev_balance = self.balance
|
||||
prev_position = self.position
|
||||
prev_price = self.prices[self.current_step]
|
||||
|
||||
# Take action with intelligent position management
|
||||
info = {}
|
||||
reward = 0
|
||||
last_position_info = None
|
||||
|
||||
# Get current price
|
||||
current_price = self.prices[self.current_step]
|
||||
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
|
||||
|
||||
# Implement 2-action system with position management
|
||||
if action == 0: # SELL action
|
||||
if self.position == 0: # No position - enter short
|
||||
self._open_position(-1.0 * self.max_position, current_price)
|
||||
logger.info(f"ENTER SHORT at step {self.current_step}, price: {current_price:.4f}")
|
||||
reward = -self.transaction_fee # Entry cost
|
||||
|
||||
elif self.position > 0: # Long position - close it
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
logger.info(f"CLOSE LONG at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||
|
||||
elif self.position < 0: # Already short - potentially flip to long if very strong signal
|
||||
# For now, just hold the short position (no action)
|
||||
pass
|
||||
|
||||
elif action == 1: # BUY action
|
||||
if self.position == 0: # No position - enter long
|
||||
self._open_position(1.0 * self.max_position, current_price)
|
||||
logger.info(f"ENTER LONG at step {self.current_step}, price: {current_price:.4f}")
|
||||
reward = -self.transaction_fee # Entry cost
|
||||
|
||||
elif self.position < 0: # Short position - close it
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
logger.info(f"CLOSE SHORT at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||
|
||||
elif self.position > 0: # Already long - potentially flip to short if very strong signal
|
||||
# For now, just hold the long position (no action)
|
||||
pass
|
||||
|
||||
# Calculate unrealized PnL and add to reward if holding position
|
||||
if self.position != 0:
|
||||
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
|
||||
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
|
||||
|
||||
# Apply time-based holding penalty to encourage decisive actions
|
||||
position_duration = self.current_step - self.entry_step
|
||||
holding_penalty = min(position_duration * 0.0001, 0.01) # Max 1% penalty
|
||||
reward -= holding_penalty
|
||||
|
||||
# Reward staying neutral when uncertain (no clear setup)
|
||||
else:
|
||||
reward += 0.0001 # Small reward for not trading without clear signals
|
||||
|
||||
# Move to next step
|
||||
self.current_step += 1
|
||||
|
||||
# Get new observation
|
||||
observation = self._get_observation()
|
||||
|
||||
# Check if episode is done
|
||||
done = self.current_step >= len(self.prices) - 1
|
||||
|
||||
# If done, close any remaining positions
|
||||
if done and self.position != 0:
|
||||
final_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += final_pnl * self.reward_scaling
|
||||
info['final_pnl'] = final_pnl
|
||||
info['final_balance'] = self.balance
|
||||
logger.info(f"Episode ended. Final balance: {self.balance:.4f}, Return: {(self.balance/self.initial_balance-1)*100:.2f}%")
|
||||
|
||||
# Track trade result if position changed or position was closed
|
||||
if prev_position != self.position or last_position_info is not None:
|
||||
# Calculate realized PnL if position was closed
|
||||
realized_pnl = 0
|
||||
position_info = {}
|
||||
|
||||
if last_position_info is not None:
|
||||
# Use the position information from closing
|
||||
realized_pnl = last_position_info['pnl']
|
||||
position_info = last_position_info
|
||||
else:
|
||||
# Calculate manually based on balance change
|
||||
realized_pnl = self.balance - prev_balance if prev_position != 0 else 0
|
||||
|
||||
# Record detailed trade information
|
||||
trade_result = {
|
||||
'step': self.current_step,
|
||||
'timestamp': self.timestamps[self.current_step],
|
||||
'action': action,
|
||||
'action_name': ['SELL', 'BUY'][action],
|
||||
'price': current_price,
|
||||
'position_changed': prev_position != self.position,
|
||||
'prev_position': prev_position,
|
||||
'new_position': self.position,
|
||||
'position_size': abs(self.position) if self.position != 0 else abs(prev_position),
|
||||
'entry_price': position_info.get('entry_price', self.entry_price),
|
||||
'exit_price': position_info.get('exit_price', current_price),
|
||||
'realized_pnl': realized_pnl,
|
||||
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0,
|
||||
'pnl': realized_pnl, # Total PnL (realized for this step)
|
||||
'balance_before': prev_balance,
|
||||
'balance_after': self.balance,
|
||||
'trade_fee': position_info.get('fee', abs(self.position - prev_position) * current_price * self.transaction_fee)
|
||||
}
|
||||
info['trade_result'] = trade_result
|
||||
self.trades.append(trade_result)
|
||||
|
||||
# Log trade details
|
||||
logger.info(f"Trade executed - Action: {['SELL', 'BUY'][action]}, "
|
||||
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
|
||||
f"Balance: {self.balance:.4f}")
|
||||
|
||||
# Store reward
|
||||
self.rewards.append(reward)
|
||||
|
||||
# Update info dict with current state
|
||||
info.update({
|
||||
'step': self.current_step,
|
||||
'price': current_price,
|
||||
'prev_price': prev_price,
|
||||
'price_change': (current_price - prev_price) / prev_price if prev_price != 0 else 0,
|
||||
'balance': self.balance,
|
||||
'position': self.position,
|
||||
'entry_price': self.entry_price,
|
||||
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0.0,
|
||||
'total_trades': len(self.trades),
|
||||
'total_pnl': self.total_pnl,
|
||||
'return_pct': (self.balance/self.initial_balance-1)*100
|
||||
})
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def _calculate_unrealized_pnl(self, current_price):
|
||||
"""Calculate unrealized PnL for current position"""
|
||||
if self.position == 0 or self.entry_price == 0:
|
||||
return 0.0
|
||||
|
||||
if self.position > 0: # Long position
|
||||
return self.position * (current_price / self.entry_price - 1.0)
|
||||
else: # Short position
|
||||
return -self.position * (1.0 - current_price / self.entry_price)
|
||||
|
||||
def _open_position(self, position_size: float, entry_price: float):
|
||||
"""Open a new position"""
|
||||
self.position = position_size
|
||||
self.entry_price = entry_price
|
||||
self.entry_step = self.current_step
|
||||
|
||||
# Calculate position value
|
||||
position_value = abs(position_size) * entry_price
|
||||
|
||||
# Apply transaction fee
|
||||
fee = position_value * self.transaction_fee
|
||||
self.balance -= fee
|
||||
|
||||
logger.info(f"Opened position: {position_size:.4f} at {entry_price:.4f}, fee: {fee:.4f}")
|
||||
|
||||
def _close_position(self, exit_price: float) -> Tuple[float, Dict]:
|
||||
"""Close current position and return PnL"""
|
||||
if self.position == 0:
|
||||
return 0.0, {}
|
||||
|
||||
# Calculate PnL
|
||||
if self.position > 0: # Long position
|
||||
pnl = (exit_price - self.entry_price) / self.entry_price
|
||||
else: # Short position
|
||||
pnl = (self.entry_price - exit_price) / self.entry_price
|
||||
|
||||
# Apply transaction fees (entry + exit)
|
||||
position_value = abs(self.position) * exit_price
|
||||
exit_fee = position_value * self.transaction_fee
|
||||
total_fees = exit_fee # Entry fee already applied when opening
|
||||
|
||||
# Net PnL after fees
|
||||
net_pnl = pnl - (total_fees / (abs(self.position) * self.entry_price))
|
||||
|
||||
# Update balance
|
||||
self.balance *= (1 + net_pnl)
|
||||
self.total_pnl += net_pnl
|
||||
|
||||
# Track trade
|
||||
position_info = {
|
||||
'position_size': self.position,
|
||||
'entry_price': self.entry_price,
|
||||
'exit_price': exit_price,
|
||||
'pnl': net_pnl,
|
||||
'duration': self.current_step - self.entry_step,
|
||||
'entry_step': self.entry_step,
|
||||
'exit_step': self.current_step
|
||||
}
|
||||
|
||||
self.trades.append(position_info)
|
||||
|
||||
# Update trade statistics
|
||||
if net_pnl > 0:
|
||||
self.winning_trades += 1
|
||||
else:
|
||||
self.losing_trades += 1
|
||||
|
||||
logger.info(f"Closed position: {self.position:.4f}, PnL: {net_pnl:.4f}, Duration: {position_info['duration']} steps")
|
||||
|
||||
# Reset position
|
||||
self.position = 0.0
|
||||
self.entry_price = 0.0
|
||||
self.entry_step = 0
|
||||
|
||||
return net_pnl, position_info
|
||||
|
||||
def _get_observation(self):
|
||||
"""
|
||||
Get the current observation.
|
||||
|
||||
Returns:
|
||||
np.array: The observation vector
|
||||
"""
|
||||
observations = []
|
||||
|
||||
# Get data from each timeframe
|
||||
for tf in self.data_interface.timeframes:
|
||||
if tf in self.data:
|
||||
# Get the window of data for this timeframe
|
||||
df = self.data[tf]
|
||||
start_idx = self._align_timeframe_index(tf)
|
||||
|
||||
if start_idx is not None and start_idx >= 0 and start_idx + self.window_size <= len(df):
|
||||
window = df.iloc[start_idx:start_idx + self.window_size]
|
||||
|
||||
# Extract OHLCV data
|
||||
ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Normalize OHLCV data
|
||||
last_close = ohlcv[-1, 3] # Last close price
|
||||
ohlcv_normalized = np.zeros_like(ohlcv)
|
||||
ohlcv_normalized[:, 0] = ohlcv[:, 0] / last_close - 1.0 # open
|
||||
ohlcv_normalized[:, 1] = ohlcv[:, 1] / last_close - 1.0 # high
|
||||
ohlcv_normalized[:, 2] = ohlcv[:, 2] / last_close - 1.0 # low
|
||||
ohlcv_normalized[:, 3] = ohlcv[:, 3] / last_close - 1.0 # close
|
||||
|
||||
# Normalize volume (relative to moving average of volume)
|
||||
if 'volume' in window.columns:
|
||||
volume_ma = ohlcv[:, 4].mean()
|
||||
if volume_ma > 0:
|
||||
ohlcv_normalized[:, 4] = ohlcv[:, 4] / volume_ma - 1.0
|
||||
else:
|
||||
ohlcv_normalized[:, 4] = 0.0
|
||||
else:
|
||||
ohlcv_normalized[:, 4] = 0.0
|
||||
|
||||
# Flatten and add to observations
|
||||
observations.append(ohlcv_normalized.flatten())
|
||||
else:
|
||||
# Fill with zeros if not enough data
|
||||
observations.append(np.zeros(self.window_size * 5))
|
||||
|
||||
# Add position and balance information
|
||||
current_price = self.prices[self.current_step]
|
||||
position_info = np.array([
|
||||
self.position / self.max_position, # Normalized position (-1 to 1)
|
||||
self.balance / self.initial_balance - 1.0, # Normalized balance change
|
||||
self._calculate_unrealized_pnl(current_price) # Unrealized PnL
|
||||
])
|
||||
|
||||
observations.append(position_info)
|
||||
|
||||
# Concatenate all observations
|
||||
observation = np.concatenate(observations)
|
||||
return observation
|
||||
|
||||
def _align_timeframe_index(self, timeframe):
|
||||
"""
|
||||
Align the index of a higher timeframe with the current step in the primary timeframe.
|
||||
|
||||
Args:
|
||||
timeframe: The timeframe to align
|
||||
|
||||
Returns:
|
||||
int: The starting index in the higher timeframe
|
||||
"""
|
||||
if timeframe == self.timeframe:
|
||||
return self.current_step - self.window_size
|
||||
|
||||
# Get timestamps for current primary timeframe step
|
||||
primary_ts = self.timestamps[self.current_step]
|
||||
|
||||
# Find closest index in the higher timeframe
|
||||
higher_ts = self.data[timeframe].index.values
|
||||
idx = np.searchsorted(higher_ts, primary_ts)
|
||||
|
||||
# Adjust to get the starting index
|
||||
start_idx = max(0, idx - self.window_size)
|
||||
return start_idx
|
||||
|
||||
def get_last_positions(self, n=5):
|
||||
"""
|
||||
Get detailed information about the last n positions.
|
||||
|
||||
Args:
|
||||
n: Number of last positions to return
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing position details
|
||||
"""
|
||||
if not self.trades:
|
||||
return []
|
||||
|
||||
# Filter trades to only include those that closed positions
|
||||
position_trades = [t for t in self.trades if t.get('realized_pnl', 0) != 0 or (t.get('prev_position', 0) != 0 and t.get('new_position', 0) == 0)]
|
||||
|
||||
positions = []
|
||||
last_n_trades = position_trades[-n:] if len(position_trades) >= n else position_trades
|
||||
|
||||
for trade in last_n_trades:
|
||||
position_info = {
|
||||
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
|
||||
'action': trade.get('action_name', ['SELL', 'BUY'][trade['action']]),
|
||||
'entry_price': trade.get('entry_price', 0.0),
|
||||
'exit_price': trade.get('exit_price', trade['price']),
|
||||
'position_size': trade.get('position_size', self.max_position),
|
||||
'realized_pnl': trade.get('realized_pnl', 0.0),
|
||||
'fee': trade.get('trade_fee', 0.0),
|
||||
'pnl': trade.get('pnl', 0.0),
|
||||
'pnl_percentage': (trade.get('pnl', 0.0) / self.initial_balance) * 100,
|
||||
'balance_before': trade.get('balance_before', 0.0),
|
||||
'balance_after': trade.get('balance_after', 0.0),
|
||||
'duration': trade.get('duration', 'N/A')
|
||||
}
|
||||
positions.append(position_info)
|
||||
|
||||
return positions
|
||||
|
||||
def render(self, mode='human'):
|
||||
"""Render the environment"""
|
||||
current_step = self.current_step
|
||||
current_price = self.prices[current_step]
|
||||
|
||||
# Display basic information
|
||||
print(f"\nTrading Environment Status:")
|
||||
print(f"============================")
|
||||
print(f"Step: {current_step}/{len(self.prices)-1}")
|
||||
print(f"Current Price: {current_price:.4f}")
|
||||
print(f"Current Balance: {self.balance:.4f}")
|
||||
print(f"Current Position: {self.position:.4f}")
|
||||
|
||||
if self.position != 0:
|
||||
unrealized_pnl = self._calculate_unrealized_pnl(current_price)
|
||||
print(f"Entry Price: {self.entry_price:.4f}")
|
||||
print(f"Unrealized PnL: {unrealized_pnl:.4f} ({unrealized_pnl/self.balance*100:.2f}%)")
|
||||
|
||||
print(f"Total PnL: {self.total_pnl:.4f} ({self.total_pnl/self.initial_balance*100:.2f}%)")
|
||||
print(f"Total Trades: {len(self.trades)}")
|
||||
|
||||
if len(self.trades) > 0:
|
||||
win_trades = [t for t in self.trades if t.get('realized_pnl', 0) > 0]
|
||||
win_count = len(win_trades)
|
||||
# Count trades that closed positions (not just changed them)
|
||||
closed_positions = [t for t in self.trades if t.get('realized_pnl', 0) != 0]
|
||||
closed_count = len(closed_positions)
|
||||
win_rate = win_count / closed_count if closed_count > 0 else 0
|
||||
print(f"Positions Closed: {closed_count}")
|
||||
print(f"Winning Positions: {win_count}")
|
||||
print(f"Win Rate: {win_rate:.2f}")
|
||||
|
||||
# Display last 5 positions
|
||||
print("\nLast 5 Positions:")
|
||||
print("================")
|
||||
last_positions = self.get_last_positions(5)
|
||||
|
||||
if not last_positions:
|
||||
print("No closed positions yet.")
|
||||
|
||||
for pos in last_positions:
|
||||
print(f"Time: {pos['timestamp']}")
|
||||
print(f"Action: {pos['action']}")
|
||||
print(f"Entry: {pos['entry_price']:.4f}, Exit: {pos['exit_price']:.4f}")
|
||||
print(f"Size: {pos['position_size']:.4f}")
|
||||
print(f"PnL: {pos['realized_pnl']:.4f} ({pos['pnl_percentage']:.2f}%)")
|
||||
print(f"Fee: {pos['fee']:.4f}")
|
||||
print(f"Balance: {pos['balance_before']:.4f} -> {pos['balance_after']:.4f}")
|
||||
print("----------------")
|
||||
|
||||
return
|
||||
|
||||
def close(self):
|
||||
"""Close the environment"""
|
||||
pass
|
@ -26,6 +26,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
# Import checkpoint management
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, save_checkpoint
|
||||
CHECKPOINT_MANAGER_AVAILABLE = True
|
||||
except ImportError:
|
||||
CHECKPOINT_MANAGER_AVAILABLE = False
|
||||
logger.warning("Checkpoint manager not available. Model persistence will be disabled.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedRealtimeTrainingSystem:
|
||||
@ -50,6 +58,12 @@ class EnhancedRealtimeTrainingSystem:
|
||||
# Experience buffers
|
||||
self.experience_buffer = deque(maxlen=self.training_config['memory_size'])
|
||||
self.validation_buffer = deque(maxlen=1000)
|
||||
|
||||
# Training counters - CRITICAL for checkpoint management
|
||||
self.training_iteration = 0
|
||||
self.dqn_training_count = 0
|
||||
self.cnn_training_count = 0
|
||||
self.cob_training_count = 0
|
||||
self.priority_buffer = deque(maxlen=2000) # High-priority experiences
|
||||
|
||||
# Performance tracking
|
||||
@ -1071,6 +1085,10 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
self.dqn_training_count += 1
|
||||
|
||||
# Save checkpoint after training
|
||||
if training_iterations > 0 and avg_loss > 0:
|
||||
self._save_model_checkpoint('dqn_agent', rl_agent, avg_loss)
|
||||
|
||||
# Log progress every 10 training sessions
|
||||
if self.dqn_training_count % 10 == 0:
|
||||
logger.info(f"DQN TRAINING: Session {self.dqn_training_count}, "
|
||||
@ -2523,4 +2541,56 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error estimating price change: {e}")
|
||||
return 0.0
|
||||
return 0.0 d
|
||||
ef _save_model_checkpoint(self, model_name: str, model_obj, loss: float):
|
||||
"""
|
||||
Save model checkpoint after training if performance improved
|
||||
|
||||
This is CRITICAL for preserving training progress across restarts.
|
||||
"""
|
||||
try:
|
||||
if not CHECKPOINT_MANAGER_AVAILABLE:
|
||||
return
|
||||
|
||||
# Get checkpoint manager
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
if not checkpoint_manager:
|
||||
return
|
||||
|
||||
# Prepare performance metrics
|
||||
performance_metrics = {
|
||||
'loss': loss,
|
||||
'training_samples': len(self.experience_buffer),
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Prepare training metadata
|
||||
training_metadata = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'training_iteration': self.training_iteration,
|
||||
'model_type': model_name
|
||||
}
|
||||
|
||||
# Determine model type based on model name
|
||||
model_type = model_name
|
||||
if 'dqn' in model_name.lower():
|
||||
model_type = 'dqn'
|
||||
elif 'cnn' in model_name.lower():
|
||||
model_type = 'cnn'
|
||||
elif 'cob' in model_name.lower():
|
||||
model_type = 'cob_rl'
|
||||
|
||||
# Save checkpoint
|
||||
checkpoint_path = save_checkpoint(
|
||||
model=model_obj,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
logger.info(f"💾 Saved checkpoint for {model_name}: {checkpoint_path} (loss: {loss:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
@ -2246,9 +2246,144 @@ class TradingOrchestrator:
|
||||
try:
|
||||
model_obj = None
|
||||
current_loss = None
|
||||
model_type = model_name
|
||||
|
||||
# Get model object and calculate current performance
|
||||
if model_name == 'dqn' and self.rl_agent:
|
||||
model_obj = self.rl_agent
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states['dqn'].get('current_loss')
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states['dqn']['current_loss'] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states['dqn']['initial_loss'] is None:
|
||||
self.model_states['dqn']['initial_loss'] = current_loss
|
||||
if self.model_states['dqn']['best_loss'] is None or current_loss < self.model_states['dqn']['best_loss']:
|
||||
self.model_states['dqn']['best_loss'] = current_loss
|
||||
|
||||
elif model_name == 'cnn' and self.cnn_model:
|
||||
model_obj = self.cnn_model
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states['cnn'].get('current_loss')
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states['cnn']['current_loss'] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states['cnn']['initial_loss'] is None:
|
||||
self.model_states['cnn']['initial_loss'] = current_loss
|
||||
if self.model_states['cnn']['best_loss'] is None or current_loss < self.model_states['cnn']['best_loss']:
|
||||
self.model_states['cnn']['best_loss'] = current_loss
|
||||
|
||||
elif model_name == 'cob_rl' and self.cob_rl_agent:
|
||||
model_obj = self.cob_rl_agent
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states['cob_rl'].get('current_loss')
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states['cob_rl']['current_loss'] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states['cob_rl']['initial_loss'] is None:
|
||||
self.model_states['cob_rl']['initial_loss'] = current_loss
|
||||
if self.model_states['cob_rl']['best_loss'] is None or current_loss < self.model_states['cob_rl']['best_loss']:
|
||||
self.model_states['cob_rl']['best_loss'] = current_loss
|
||||
|
||||
elif model_name == 'extrema' and hasattr(self, 'extrema_trainer') and self.extrema_trainer:
|
||||
model_obj = self.extrema_trainer
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states['extrema'].get('current_loss')
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states['extrema']['current_loss'] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states['extrema']['initial_loss'] is None:
|
||||
self.model_states['extrema']['initial_loss'] = current_loss
|
||||
if self.model_states['extrema']['best_loss'] is None or current_loss < self.model_states['extrema']['best_loss']:
|
||||
self.model_states['extrema']['best_loss'] = current_loss
|
||||
|
||||
# Skip if we couldn't get a model object
|
||||
if model_obj is None:
|
||||
continue
|
||||
|
||||
# Prepare performance metrics for checkpoint
|
||||
performance_metrics = {
|
||||
'loss': current_loss,
|
||||
'accuracy': performance_score, # Use confidence as a proxy for accuracy
|
||||
}
|
||||
|
||||
# Prepare training metadata
|
||||
training_metadata = {
|
||||
'training_iteration': self.training_iterations,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Save checkpoint using checkpoint manager
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
checkpoint_metadata = save_checkpoint(
|
||||
model=model_obj,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if checkpoint_metadata:
|
||||
logger.info(f"Saved checkpoint for {model_name}: {checkpoint_metadata.checkpoint_id} (loss={current_loss:.4f})")
|
||||
|
||||
# Also save periodically based on training iterations
|
||||
if self.training_iterations % 100 == 0:
|
||||
# Force save every 100 training iterations regardless of performance
|
||||
checkpoint_metadata = save_checkpoint(
|
||||
model=model_obj,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata,
|
||||
force_save=True
|
||||
)
|
||||
if checkpoint_metadata:
|
||||
logger.info(f"Periodic checkpoint saved for {model_name}: {checkpoint_metadata.checkpoint_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _save_training_checkpoints: {e}")
|
||||
|
||||
def _initialize_checkpoint_manager(self):
|
||||
"""Initialize the checkpoint manager for model persistence"""
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
|
||||
# Initialize model states dictionary to track performance
|
||||
self.model_states = {
|
||||
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||
'extrema': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False}
|
||||
}
|
||||
|
||||
logger.info("Checkpoint manager initialized for model persistence")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing checkpoint manager: {e}")
|
||||
self.checkpoint_manager = None
|
||||
model_obj = self.rl_agent
|
||||
# Use negative performance score as loss (higher confidence = lower loss)
|
||||
current_loss = 1.0 - performance_score
|
||||
|
@ -1 +0,0 @@
|
||||
|
@ -1,49 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fix Dashboard Metrics Script
|
||||
|
||||
This script fixes the incomplete code in the update_metrics function
|
||||
of the web/clean_dashboard.py file.
|
||||
"""
|
||||
|
||||
import re
|
||||
import os
|
||||
|
||||
def fix_dashboard_metrics():
|
||||
"""Fix the incomplete code in the update_metrics function"""
|
||||
file_path = 'web/clean_dashboard.py'
|
||||
|
||||
# Read the file content
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
|
||||
# Find and replace the incomplete code
|
||||
pattern = r"# Add unrealized P&L from current position \(adjustable leverage\)\s+if self\.curr"
|
||||
replacement = """# Add unrealized P&L from current position (adjustable leverage)
|
||||
if self.current_position and current_price:
|
||||
side = self.current_position.get('side', 'UNKNOWN')
|
||||
size = self.current_position.get('size', 0)
|
||||
entry_price = self.current_position.get('price', 0)
|
||||
|
||||
if entry_price and size > 0:
|
||||
# Calculate unrealized P&L with current leverage
|
||||
if side.upper() == 'LONG' or side.upper() == 'BUY':
|
||||
raw_pnl_per_unit = current_price - entry_price
|
||||
else: # SHORT or SELL
|
||||
raw_pnl_per_unit = entry_price - current_price
|
||||
|
||||
# Apply current leverage to unrealized P&L
|
||||
leveraged_unrealized_pnl = raw_pnl_per_unit * size * self.current_leverage
|
||||
total_session_pnl += leveraged_unrealized_pnl"""
|
||||
|
||||
# Replace the pattern
|
||||
fixed_content = re.sub(pattern, replacement, content)
|
||||
|
||||
# Write the fixed content back to the file
|
||||
with open(file_path, 'w', encoding='utf-8') as file:
|
||||
file.write(fixed_content)
|
||||
|
||||
print(f"Fixed dashboard metrics in {file_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
fix_dashboard_metrics()
|
@ -1,283 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fix RL Training Issues - Comprehensive Solution
|
||||
|
||||
This script addresses the critical RL training audit issues:
|
||||
1. MASSIVE INPUT DATA GAP (99.25% Missing) - Implements full 13,400 feature state
|
||||
2. Disconnected Training Pipeline - Fixes data flow between components
|
||||
3. Missing Enhanced State Builder - Connects orchestrator to dashboard
|
||||
4. Reward Calculation Issues - Ensures enhanced pivot-based rewards
|
||||
5. Williams Market Structure Integration - Proper feature extraction
|
||||
6. Real-time Data Integration - Live market data to RL
|
||||
|
||||
Usage:
|
||||
python fix_rl_training_issues.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def fix_orchestrator_missing_methods():
|
||||
"""Fix missing methods in enhanced orchestrator"""
|
||||
try:
|
||||
logger.info("Checking enhanced orchestrator...")
|
||||
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
|
||||
# Test if methods exist
|
||||
test_orchestrator = EnhancedTradingOrchestrator()
|
||||
|
||||
methods_to_check = [
|
||||
'_get_symbol_correlation',
|
||||
'build_comprehensive_rl_state',
|
||||
'calculate_enhanced_pivot_reward'
|
||||
]
|
||||
|
||||
missing_methods = []
|
||||
for method in methods_to_check:
|
||||
if not hasattr(test_orchestrator, method):
|
||||
missing_methods.append(method)
|
||||
|
||||
if missing_methods:
|
||||
logger.error(f"Missing methods in enhanced orchestrator: {missing_methods}")
|
||||
return False
|
||||
else:
|
||||
logger.info("✅ All required methods present in enhanced orchestrator")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking orchestrator: {e}")
|
||||
return False
|
||||
|
||||
def test_comprehensive_state_building():
|
||||
"""Test comprehensive RL state building"""
|
||||
try:
|
||||
logger.info("Testing comprehensive state building...")
|
||||
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Create test instances
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Test comprehensive state building
|
||||
state = orchestrator.build_comprehensive_rl_state('ETH/USDT')
|
||||
|
||||
if state is not None:
|
||||
logger.info(f"✅ Comprehensive state built: {len(state)} features")
|
||||
|
||||
if len(state) == 13400:
|
||||
logger.info("✅ PERFECT: Exactly 13,400 features as required!")
|
||||
else:
|
||||
logger.warning(f"⚠️ Expected 13,400 features, got {len(state)}")
|
||||
|
||||
# Check feature distribution
|
||||
import numpy as np
|
||||
non_zero = np.count_nonzero(state)
|
||||
logger.info(f"Non-zero features: {non_zero} ({non_zero/len(state)*100:.1f}%)")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error("❌ Comprehensive state building failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing state building: {e}")
|
||||
return False
|
||||
|
||||
def test_enhanced_reward_calculation():
|
||||
"""Test enhanced reward calculation"""
|
||||
try:
|
||||
logger.info("Testing enhanced reward calculation...")
|
||||
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
orchestrator = EnhancedTradingOrchestrator()
|
||||
|
||||
# Test data
|
||||
trade_decision = {
|
||||
'action': 'BUY',
|
||||
'confidence': 0.75,
|
||||
'price': 2500.0,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
trade_outcome = {
|
||||
'net_pnl': 50.0,
|
||||
'exit_price': 2550.0,
|
||||
'duration': timedelta(minutes=15)
|
||||
}
|
||||
|
||||
market_data = {
|
||||
'volatility': 0.03,
|
||||
'order_flow_direction': 'bullish',
|
||||
'order_flow_strength': 0.8
|
||||
}
|
||||
|
||||
# Test enhanced reward
|
||||
enhanced_reward = orchestrator.calculate_enhanced_pivot_reward(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
logger.info(f"✅ Enhanced reward calculated: {enhanced_reward:.3f}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing reward calculation: {e}")
|
||||
return False
|
||||
|
||||
def test_williams_integration():
|
||||
"""Test Williams market structure integration"""
|
||||
try:
|
||||
logger.info("Testing Williams market structure integration...")
|
||||
|
||||
from training.williams_market_structure import extract_pivot_features, analyze_pivot_context
|
||||
from core.data_provider import DataProvider
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# Create test data
|
||||
test_data = {
|
||||
'open': np.random.uniform(2400, 2600, 100),
|
||||
'high': np.random.uniform(2500, 2700, 100),
|
||||
'low': np.random.uniform(2300, 2500, 100),
|
||||
'close': np.random.uniform(2400, 2600, 100),
|
||||
'volume': np.random.uniform(1000, 5000, 100)
|
||||
}
|
||||
df = pd.DataFrame(test_data)
|
||||
|
||||
# Test pivot features
|
||||
pivot_features = extract_pivot_features(df)
|
||||
|
||||
if pivot_features is not None:
|
||||
logger.info(f"✅ Williams pivot features extracted: {len(pivot_features)} features")
|
||||
|
||||
# Test pivot context analysis
|
||||
market_data = {'ohlcv_data': df}
|
||||
context = analyze_pivot_context(market_data, datetime.now(), 'BUY')
|
||||
|
||||
if context is not None:
|
||||
logger.info("✅ Williams pivot context analysis working")
|
||||
return True
|
||||
else:
|
||||
logger.warning("⚠️ Pivot context analysis returned None")
|
||||
return False
|
||||
else:
|
||||
logger.error("❌ Williams pivot feature extraction failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing Williams integration: {e}")
|
||||
return False
|
||||
|
||||
def test_dashboard_integration():
|
||||
"""Test dashboard integration with enhanced features"""
|
||||
try:
|
||||
logger.info("Testing dashboard integration...")
|
||||
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Create components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider=data_provider)
|
||||
executor = TradingExecutor()
|
||||
|
||||
# Create dashboard
|
||||
dashboard = TradingDashboard(
|
||||
data_provider=data_provider,
|
||||
orchestrator=orchestrator,
|
||||
trading_executor=executor
|
||||
)
|
||||
|
||||
# Check if dashboard has access to enhanced features
|
||||
has_comprehensive_builder = hasattr(dashboard, '_build_comprehensive_rl_state')
|
||||
has_enhanced_orchestrator = hasattr(dashboard.orchestrator, 'build_comprehensive_rl_state')
|
||||
|
||||
if has_comprehensive_builder and has_enhanced_orchestrator:
|
||||
logger.info("✅ Dashboard properly integrated with enhanced features")
|
||||
return True
|
||||
else:
|
||||
logger.warning("⚠️ Dashboard missing some enhanced features")
|
||||
logger.info(f"Comprehensive builder: {has_comprehensive_builder}")
|
||||
logger.info(f"Enhanced orchestrator: {has_enhanced_orchestrator}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing dashboard integration: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main function to run all fixes and tests"""
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("COMPREHENSIVE RL TRAINING FIX - AUDIT ISSUE RESOLUTION")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Track results
|
||||
test_results = {}
|
||||
|
||||
# Run all tests
|
||||
tests = [
|
||||
("Enhanced Orchestrator Methods", fix_orchestrator_missing_methods),
|
||||
("Comprehensive State Building", test_comprehensive_state_building),
|
||||
("Enhanced Reward Calculation", test_enhanced_reward_calculation),
|
||||
("Williams Market Structure", test_williams_integration),
|
||||
("Dashboard Integration", test_dashboard_integration)
|
||||
]
|
||||
|
||||
for test_name, test_func in tests:
|
||||
logger.info(f"\n🔧 {test_name}...")
|
||||
try:
|
||||
result = test_func()
|
||||
test_results[test_name] = result
|
||||
except Exception as e:
|
||||
logger.error(f"❌ {test_name} failed: {e}")
|
||||
test_results[test_name] = False
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "=" * 70)
|
||||
logger.info("COMPREHENSIVE RL TRAINING FIX RESULTS")
|
||||
logger.info("=" * 70)
|
||||
|
||||
passed = sum(test_results.values())
|
||||
total = len(test_results)
|
||||
|
||||
for test_name, result in test_results.items():
|
||||
status = "✅ PASS" if result else "❌ FAIL"
|
||||
logger.info(f"{test_name}: {status}")
|
||||
|
||||
logger.info(f"\nOverall: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
logger.info("🎉 ALL RL TRAINING ISSUES FIXED!")
|
||||
logger.info("The system now supports:")
|
||||
logger.info(" - 13,400 comprehensive RL features")
|
||||
logger.info(" - Enhanced pivot-based rewards")
|
||||
logger.info(" - Williams market structure integration")
|
||||
logger.info(" - Proper data flow between components")
|
||||
logger.info(" - Real-time data integration")
|
||||
else:
|
||||
logger.warning("⚠️ Some issues remain - check logs above")
|
||||
|
||||
return 0 if passed == total else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
Reference in New Issue
Block a user