code structure
This commit is contained in:
@ -1,6 +0,0 @@
|
|||||||
# Trading environments for reinforcement learning
|
|
||||||
# This module contains environments for training trading agents
|
|
||||||
|
|
||||||
from NN.environments.trading_env import TradingEnvironment
|
|
||||||
|
|
||||||
__all__ = ['TradingEnvironment']
|
|
@ -1,532 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from typing import Dict, Tuple, List, Any, Optional
|
|
||||||
import logging
|
|
||||||
import gym
|
|
||||||
from gym import spaces
|
|
||||||
import random
|
|
||||||
|
|
||||||
# Configure logger
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class TradingEnvironment(gym.Env):
|
|
||||||
"""
|
|
||||||
Trading environment implementing gym interface for reinforcement learning
|
|
||||||
|
|
||||||
2-Action System:
|
|
||||||
- 0: SELL (or close long position)
|
|
||||||
- 1: BUY (or close short position)
|
|
||||||
|
|
||||||
Intelligent Position Management:
|
|
||||||
- When neutral: Actions enter positions
|
|
||||||
- When positioned: Actions can close or flip positions
|
|
||||||
- Different thresholds for entry vs exit decisions
|
|
||||||
|
|
||||||
State:
|
|
||||||
- OHLCV data from multiple timeframes
|
|
||||||
- Technical indicators
|
|
||||||
- Position data and unrealized PnL
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
data_interface,
|
|
||||||
initial_balance: float = 10000.0,
|
|
||||||
transaction_fee: float = 0.0002,
|
|
||||||
window_size: int = 20,
|
|
||||||
max_position: float = 1.0,
|
|
||||||
reward_scaling: float = 1.0,
|
|
||||||
entry_threshold: float = 0.6, # Higher threshold for entering positions
|
|
||||||
exit_threshold: float = 0.3, # Lower threshold for exiting positions
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the trading environment with 2-action system.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_interface: DataInterface instance to get market data
|
|
||||||
initial_balance: Initial balance in the base currency
|
|
||||||
transaction_fee: Fee for each transaction as a fraction of trade value
|
|
||||||
window_size: Number of candles in the observation window
|
|
||||||
max_position: Maximum position size as a fraction of balance
|
|
||||||
reward_scaling: Scale factor for rewards
|
|
||||||
entry_threshold: Confidence threshold for entering new positions
|
|
||||||
exit_threshold: Confidence threshold for exiting positions
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.data_interface = data_interface
|
|
||||||
self.initial_balance = initial_balance
|
|
||||||
self.transaction_fee = transaction_fee
|
|
||||||
self.window_size = window_size
|
|
||||||
self.max_position = max_position
|
|
||||||
self.reward_scaling = reward_scaling
|
|
||||||
self.entry_threshold = entry_threshold
|
|
||||||
self.exit_threshold = exit_threshold
|
|
||||||
|
|
||||||
# Load data for primary timeframe (assuming the first one is primary)
|
|
||||||
self.timeframe = self.data_interface.timeframes[0]
|
|
||||||
self.reset_data()
|
|
||||||
|
|
||||||
# Define action and observation spaces for 2-action system
|
|
||||||
self.action_space = spaces.Discrete(2) # 0=SELL, 1=BUY
|
|
||||||
|
|
||||||
# For observation space, we consider multiple timeframes with OHLCV data
|
|
||||||
# and additional features like technical indicators, position info, etc.
|
|
||||||
n_timeframes = len(self.data_interface.timeframes)
|
|
||||||
n_features = 5 # OHLCV data by default
|
|
||||||
|
|
||||||
# Add additional features for position, balance, unrealized_pnl, etc.
|
|
||||||
additional_features = 5 # position, balance, unrealized_pnl, entry_price, position_duration
|
|
||||||
|
|
||||||
# Calculate total feature dimension
|
|
||||||
total_features = (n_timeframes * n_features * self.window_size) + additional_features
|
|
||||||
|
|
||||||
self.observation_space = spaces.Box(
|
|
||||||
low=-np.inf, high=np.inf, shape=(total_features,), dtype=np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use tuple for state_shape that EnhancedCNN expects
|
|
||||||
self.state_shape = (total_features,)
|
|
||||||
|
|
||||||
# Position tracking for 2-action system
|
|
||||||
self.position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
|
||||||
self.entry_price = 0.0 # Price at which position was entered
|
|
||||||
self.entry_step = 0 # Step at which position was entered
|
|
||||||
|
|
||||||
# Initialize state
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset_data(self):
|
|
||||||
"""Reset data and generate a new set of price data for training"""
|
|
||||||
# Get data for each timeframe
|
|
||||||
self.data = {}
|
|
||||||
for tf in self.data_interface.timeframes:
|
|
||||||
df = self.data_interface.dataframes[tf]
|
|
||||||
if df is not None and not df.empty:
|
|
||||||
self.data[tf] = df
|
|
||||||
|
|
||||||
if not self.data:
|
|
||||||
raise ValueError("No data available for training")
|
|
||||||
|
|
||||||
# Use the primary timeframe for step count
|
|
||||||
self.prices = self.data[self.timeframe]['close'].values
|
|
||||||
self.timestamps = self.data[self.timeframe].index.values
|
|
||||||
self.max_steps = len(self.prices) - self.window_size - 1
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Reset the environment to initial state"""
|
|
||||||
# Reset trading variables
|
|
||||||
self.balance = self.initial_balance
|
|
||||||
self.trades = []
|
|
||||||
self.rewards = []
|
|
||||||
|
|
||||||
# Reset step counter
|
|
||||||
self.current_step = self.window_size
|
|
||||||
|
|
||||||
# Get initial observation
|
|
||||||
observation = self._get_observation()
|
|
||||||
|
|
||||||
return observation
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
"""
|
|
||||||
Take a step in the environment using 2-action system with intelligent position management.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: Action to take (0: SELL, 1: BUY)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (observation, reward, done, info)
|
|
||||||
"""
|
|
||||||
# Get current state before taking action
|
|
||||||
prev_balance = self.balance
|
|
||||||
prev_position = self.position
|
|
||||||
prev_price = self.prices[self.current_step]
|
|
||||||
|
|
||||||
# Take action with intelligent position management
|
|
||||||
info = {}
|
|
||||||
reward = 0
|
|
||||||
last_position_info = None
|
|
||||||
|
|
||||||
# Get current price
|
|
||||||
current_price = self.prices[self.current_step]
|
|
||||||
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
|
|
||||||
|
|
||||||
# Implement 2-action system with position management
|
|
||||||
if action == 0: # SELL action
|
|
||||||
if self.position == 0: # No position - enter short
|
|
||||||
self._open_position(-1.0 * self.max_position, current_price)
|
|
||||||
logger.info(f"ENTER SHORT at step {self.current_step}, price: {current_price:.4f}")
|
|
||||||
reward = -self.transaction_fee # Entry cost
|
|
||||||
|
|
||||||
elif self.position > 0: # Long position - close it
|
|
||||||
close_pnl, last_position_info = self._close_position(current_price)
|
|
||||||
reward += close_pnl * self.reward_scaling
|
|
||||||
logger.info(f"CLOSE LONG at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
|
||||||
|
|
||||||
elif self.position < 0: # Already short - potentially flip to long if very strong signal
|
|
||||||
# For now, just hold the short position (no action)
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif action == 1: # BUY action
|
|
||||||
if self.position == 0: # No position - enter long
|
|
||||||
self._open_position(1.0 * self.max_position, current_price)
|
|
||||||
logger.info(f"ENTER LONG at step {self.current_step}, price: {current_price:.4f}")
|
|
||||||
reward = -self.transaction_fee # Entry cost
|
|
||||||
|
|
||||||
elif self.position < 0: # Short position - close it
|
|
||||||
close_pnl, last_position_info = self._close_position(current_price)
|
|
||||||
reward += close_pnl * self.reward_scaling
|
|
||||||
logger.info(f"CLOSE SHORT at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
|
||||||
|
|
||||||
elif self.position > 0: # Already long - potentially flip to short if very strong signal
|
|
||||||
# For now, just hold the long position (no action)
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Calculate unrealized PnL and add to reward if holding position
|
|
||||||
if self.position != 0:
|
|
||||||
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
|
|
||||||
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
|
|
||||||
|
|
||||||
# Apply time-based holding penalty to encourage decisive actions
|
|
||||||
position_duration = self.current_step - self.entry_step
|
|
||||||
holding_penalty = min(position_duration * 0.0001, 0.01) # Max 1% penalty
|
|
||||||
reward -= holding_penalty
|
|
||||||
|
|
||||||
# Reward staying neutral when uncertain (no clear setup)
|
|
||||||
else:
|
|
||||||
reward += 0.0001 # Small reward for not trading without clear signals
|
|
||||||
|
|
||||||
# Move to next step
|
|
||||||
self.current_step += 1
|
|
||||||
|
|
||||||
# Get new observation
|
|
||||||
observation = self._get_observation()
|
|
||||||
|
|
||||||
# Check if episode is done
|
|
||||||
done = self.current_step >= len(self.prices) - 1
|
|
||||||
|
|
||||||
# If done, close any remaining positions
|
|
||||||
if done and self.position != 0:
|
|
||||||
final_pnl, last_position_info = self._close_position(current_price)
|
|
||||||
reward += final_pnl * self.reward_scaling
|
|
||||||
info['final_pnl'] = final_pnl
|
|
||||||
info['final_balance'] = self.balance
|
|
||||||
logger.info(f"Episode ended. Final balance: {self.balance:.4f}, Return: {(self.balance/self.initial_balance-1)*100:.2f}%")
|
|
||||||
|
|
||||||
# Track trade result if position changed or position was closed
|
|
||||||
if prev_position != self.position or last_position_info is not None:
|
|
||||||
# Calculate realized PnL if position was closed
|
|
||||||
realized_pnl = 0
|
|
||||||
position_info = {}
|
|
||||||
|
|
||||||
if last_position_info is not None:
|
|
||||||
# Use the position information from closing
|
|
||||||
realized_pnl = last_position_info['pnl']
|
|
||||||
position_info = last_position_info
|
|
||||||
else:
|
|
||||||
# Calculate manually based on balance change
|
|
||||||
realized_pnl = self.balance - prev_balance if prev_position != 0 else 0
|
|
||||||
|
|
||||||
# Record detailed trade information
|
|
||||||
trade_result = {
|
|
||||||
'step': self.current_step,
|
|
||||||
'timestamp': self.timestamps[self.current_step],
|
|
||||||
'action': action,
|
|
||||||
'action_name': ['SELL', 'BUY'][action],
|
|
||||||
'price': current_price,
|
|
||||||
'position_changed': prev_position != self.position,
|
|
||||||
'prev_position': prev_position,
|
|
||||||
'new_position': self.position,
|
|
||||||
'position_size': abs(self.position) if self.position != 0 else abs(prev_position),
|
|
||||||
'entry_price': position_info.get('entry_price', self.entry_price),
|
|
||||||
'exit_price': position_info.get('exit_price', current_price),
|
|
||||||
'realized_pnl': realized_pnl,
|
|
||||||
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0,
|
|
||||||
'pnl': realized_pnl, # Total PnL (realized for this step)
|
|
||||||
'balance_before': prev_balance,
|
|
||||||
'balance_after': self.balance,
|
|
||||||
'trade_fee': position_info.get('fee', abs(self.position - prev_position) * current_price * self.transaction_fee)
|
|
||||||
}
|
|
||||||
info['trade_result'] = trade_result
|
|
||||||
self.trades.append(trade_result)
|
|
||||||
|
|
||||||
# Log trade details
|
|
||||||
logger.info(f"Trade executed - Action: {['SELL', 'BUY'][action]}, "
|
|
||||||
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
|
|
||||||
f"Balance: {self.balance:.4f}")
|
|
||||||
|
|
||||||
# Store reward
|
|
||||||
self.rewards.append(reward)
|
|
||||||
|
|
||||||
# Update info dict with current state
|
|
||||||
info.update({
|
|
||||||
'step': self.current_step,
|
|
||||||
'price': current_price,
|
|
||||||
'prev_price': prev_price,
|
|
||||||
'price_change': (current_price - prev_price) / prev_price if prev_price != 0 else 0,
|
|
||||||
'balance': self.balance,
|
|
||||||
'position': self.position,
|
|
||||||
'entry_price': self.entry_price,
|
|
||||||
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0.0,
|
|
||||||
'total_trades': len(self.trades),
|
|
||||||
'total_pnl': self.total_pnl,
|
|
||||||
'return_pct': (self.balance/self.initial_balance-1)*100
|
|
||||||
})
|
|
||||||
|
|
||||||
return observation, reward, done, info
|
|
||||||
|
|
||||||
def _calculate_unrealized_pnl(self, current_price):
|
|
||||||
"""Calculate unrealized PnL for current position"""
|
|
||||||
if self.position == 0 or self.entry_price == 0:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
if self.position > 0: # Long position
|
|
||||||
return self.position * (current_price / self.entry_price - 1.0)
|
|
||||||
else: # Short position
|
|
||||||
return -self.position * (1.0 - current_price / self.entry_price)
|
|
||||||
|
|
||||||
def _open_position(self, position_size: float, entry_price: float):
|
|
||||||
"""Open a new position"""
|
|
||||||
self.position = position_size
|
|
||||||
self.entry_price = entry_price
|
|
||||||
self.entry_step = self.current_step
|
|
||||||
|
|
||||||
# Calculate position value
|
|
||||||
position_value = abs(position_size) * entry_price
|
|
||||||
|
|
||||||
# Apply transaction fee
|
|
||||||
fee = position_value * self.transaction_fee
|
|
||||||
self.balance -= fee
|
|
||||||
|
|
||||||
logger.info(f"Opened position: {position_size:.4f} at {entry_price:.4f}, fee: {fee:.4f}")
|
|
||||||
|
|
||||||
def _close_position(self, exit_price: float) -> Tuple[float, Dict]:
|
|
||||||
"""Close current position and return PnL"""
|
|
||||||
if self.position == 0:
|
|
||||||
return 0.0, {}
|
|
||||||
|
|
||||||
# Calculate PnL
|
|
||||||
if self.position > 0: # Long position
|
|
||||||
pnl = (exit_price - self.entry_price) / self.entry_price
|
|
||||||
else: # Short position
|
|
||||||
pnl = (self.entry_price - exit_price) / self.entry_price
|
|
||||||
|
|
||||||
# Apply transaction fees (entry + exit)
|
|
||||||
position_value = abs(self.position) * exit_price
|
|
||||||
exit_fee = position_value * self.transaction_fee
|
|
||||||
total_fees = exit_fee # Entry fee already applied when opening
|
|
||||||
|
|
||||||
# Net PnL after fees
|
|
||||||
net_pnl = pnl - (total_fees / (abs(self.position) * self.entry_price))
|
|
||||||
|
|
||||||
# Update balance
|
|
||||||
self.balance *= (1 + net_pnl)
|
|
||||||
self.total_pnl += net_pnl
|
|
||||||
|
|
||||||
# Track trade
|
|
||||||
position_info = {
|
|
||||||
'position_size': self.position,
|
|
||||||
'entry_price': self.entry_price,
|
|
||||||
'exit_price': exit_price,
|
|
||||||
'pnl': net_pnl,
|
|
||||||
'duration': self.current_step - self.entry_step,
|
|
||||||
'entry_step': self.entry_step,
|
|
||||||
'exit_step': self.current_step
|
|
||||||
}
|
|
||||||
|
|
||||||
self.trades.append(position_info)
|
|
||||||
|
|
||||||
# Update trade statistics
|
|
||||||
if net_pnl > 0:
|
|
||||||
self.winning_trades += 1
|
|
||||||
else:
|
|
||||||
self.losing_trades += 1
|
|
||||||
|
|
||||||
logger.info(f"Closed position: {self.position:.4f}, PnL: {net_pnl:.4f}, Duration: {position_info['duration']} steps")
|
|
||||||
|
|
||||||
# Reset position
|
|
||||||
self.position = 0.0
|
|
||||||
self.entry_price = 0.0
|
|
||||||
self.entry_step = 0
|
|
||||||
|
|
||||||
return net_pnl, position_info
|
|
||||||
|
|
||||||
def _get_observation(self):
|
|
||||||
"""
|
|
||||||
Get the current observation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.array: The observation vector
|
|
||||||
"""
|
|
||||||
observations = []
|
|
||||||
|
|
||||||
# Get data from each timeframe
|
|
||||||
for tf in self.data_interface.timeframes:
|
|
||||||
if tf in self.data:
|
|
||||||
# Get the window of data for this timeframe
|
|
||||||
df = self.data[tf]
|
|
||||||
start_idx = self._align_timeframe_index(tf)
|
|
||||||
|
|
||||||
if start_idx is not None and start_idx >= 0 and start_idx + self.window_size <= len(df):
|
|
||||||
window = df.iloc[start_idx:start_idx + self.window_size]
|
|
||||||
|
|
||||||
# Extract OHLCV data
|
|
||||||
ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values
|
|
||||||
|
|
||||||
# Normalize OHLCV data
|
|
||||||
last_close = ohlcv[-1, 3] # Last close price
|
|
||||||
ohlcv_normalized = np.zeros_like(ohlcv)
|
|
||||||
ohlcv_normalized[:, 0] = ohlcv[:, 0] / last_close - 1.0 # open
|
|
||||||
ohlcv_normalized[:, 1] = ohlcv[:, 1] / last_close - 1.0 # high
|
|
||||||
ohlcv_normalized[:, 2] = ohlcv[:, 2] / last_close - 1.0 # low
|
|
||||||
ohlcv_normalized[:, 3] = ohlcv[:, 3] / last_close - 1.0 # close
|
|
||||||
|
|
||||||
# Normalize volume (relative to moving average of volume)
|
|
||||||
if 'volume' in window.columns:
|
|
||||||
volume_ma = ohlcv[:, 4].mean()
|
|
||||||
if volume_ma > 0:
|
|
||||||
ohlcv_normalized[:, 4] = ohlcv[:, 4] / volume_ma - 1.0
|
|
||||||
else:
|
|
||||||
ohlcv_normalized[:, 4] = 0.0
|
|
||||||
else:
|
|
||||||
ohlcv_normalized[:, 4] = 0.0
|
|
||||||
|
|
||||||
# Flatten and add to observations
|
|
||||||
observations.append(ohlcv_normalized.flatten())
|
|
||||||
else:
|
|
||||||
# Fill with zeros if not enough data
|
|
||||||
observations.append(np.zeros(self.window_size * 5))
|
|
||||||
|
|
||||||
# Add position and balance information
|
|
||||||
current_price = self.prices[self.current_step]
|
|
||||||
position_info = np.array([
|
|
||||||
self.position / self.max_position, # Normalized position (-1 to 1)
|
|
||||||
self.balance / self.initial_balance - 1.0, # Normalized balance change
|
|
||||||
self._calculate_unrealized_pnl(current_price) # Unrealized PnL
|
|
||||||
])
|
|
||||||
|
|
||||||
observations.append(position_info)
|
|
||||||
|
|
||||||
# Concatenate all observations
|
|
||||||
observation = np.concatenate(observations)
|
|
||||||
return observation
|
|
||||||
|
|
||||||
def _align_timeframe_index(self, timeframe):
|
|
||||||
"""
|
|
||||||
Align the index of a higher timeframe with the current step in the primary timeframe.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timeframe: The timeframe to align
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The starting index in the higher timeframe
|
|
||||||
"""
|
|
||||||
if timeframe == self.timeframe:
|
|
||||||
return self.current_step - self.window_size
|
|
||||||
|
|
||||||
# Get timestamps for current primary timeframe step
|
|
||||||
primary_ts = self.timestamps[self.current_step]
|
|
||||||
|
|
||||||
# Find closest index in the higher timeframe
|
|
||||||
higher_ts = self.data[timeframe].index.values
|
|
||||||
idx = np.searchsorted(higher_ts, primary_ts)
|
|
||||||
|
|
||||||
# Adjust to get the starting index
|
|
||||||
start_idx = max(0, idx - self.window_size)
|
|
||||||
return start_idx
|
|
||||||
|
|
||||||
def get_last_positions(self, n=5):
|
|
||||||
"""
|
|
||||||
Get detailed information about the last n positions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
n: Number of last positions to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of dictionaries containing position details
|
|
||||||
"""
|
|
||||||
if not self.trades:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Filter trades to only include those that closed positions
|
|
||||||
position_trades = [t for t in self.trades if t.get('realized_pnl', 0) != 0 or (t.get('prev_position', 0) != 0 and t.get('new_position', 0) == 0)]
|
|
||||||
|
|
||||||
positions = []
|
|
||||||
last_n_trades = position_trades[-n:] if len(position_trades) >= n else position_trades
|
|
||||||
|
|
||||||
for trade in last_n_trades:
|
|
||||||
position_info = {
|
|
||||||
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
|
|
||||||
'action': trade.get('action_name', ['SELL', 'BUY'][trade['action']]),
|
|
||||||
'entry_price': trade.get('entry_price', 0.0),
|
|
||||||
'exit_price': trade.get('exit_price', trade['price']),
|
|
||||||
'position_size': trade.get('position_size', self.max_position),
|
|
||||||
'realized_pnl': trade.get('realized_pnl', 0.0),
|
|
||||||
'fee': trade.get('trade_fee', 0.0),
|
|
||||||
'pnl': trade.get('pnl', 0.0),
|
|
||||||
'pnl_percentage': (trade.get('pnl', 0.0) / self.initial_balance) * 100,
|
|
||||||
'balance_before': trade.get('balance_before', 0.0),
|
|
||||||
'balance_after': trade.get('balance_after', 0.0),
|
|
||||||
'duration': trade.get('duration', 'N/A')
|
|
||||||
}
|
|
||||||
positions.append(position_info)
|
|
||||||
|
|
||||||
return positions
|
|
||||||
|
|
||||||
def render(self, mode='human'):
|
|
||||||
"""Render the environment"""
|
|
||||||
current_step = self.current_step
|
|
||||||
current_price = self.prices[current_step]
|
|
||||||
|
|
||||||
# Display basic information
|
|
||||||
print(f"\nTrading Environment Status:")
|
|
||||||
print(f"============================")
|
|
||||||
print(f"Step: {current_step}/{len(self.prices)-1}")
|
|
||||||
print(f"Current Price: {current_price:.4f}")
|
|
||||||
print(f"Current Balance: {self.balance:.4f}")
|
|
||||||
print(f"Current Position: {self.position:.4f}")
|
|
||||||
|
|
||||||
if self.position != 0:
|
|
||||||
unrealized_pnl = self._calculate_unrealized_pnl(current_price)
|
|
||||||
print(f"Entry Price: {self.entry_price:.4f}")
|
|
||||||
print(f"Unrealized PnL: {unrealized_pnl:.4f} ({unrealized_pnl/self.balance*100:.2f}%)")
|
|
||||||
|
|
||||||
print(f"Total PnL: {self.total_pnl:.4f} ({self.total_pnl/self.initial_balance*100:.2f}%)")
|
|
||||||
print(f"Total Trades: {len(self.trades)}")
|
|
||||||
|
|
||||||
if len(self.trades) > 0:
|
|
||||||
win_trades = [t for t in self.trades if t.get('realized_pnl', 0) > 0]
|
|
||||||
win_count = len(win_trades)
|
|
||||||
# Count trades that closed positions (not just changed them)
|
|
||||||
closed_positions = [t for t in self.trades if t.get('realized_pnl', 0) != 0]
|
|
||||||
closed_count = len(closed_positions)
|
|
||||||
win_rate = win_count / closed_count if closed_count > 0 else 0
|
|
||||||
print(f"Positions Closed: {closed_count}")
|
|
||||||
print(f"Winning Positions: {win_count}")
|
|
||||||
print(f"Win Rate: {win_rate:.2f}")
|
|
||||||
|
|
||||||
# Display last 5 positions
|
|
||||||
print("\nLast 5 Positions:")
|
|
||||||
print("================")
|
|
||||||
last_positions = self.get_last_positions(5)
|
|
||||||
|
|
||||||
if not last_positions:
|
|
||||||
print("No closed positions yet.")
|
|
||||||
|
|
||||||
for pos in last_positions:
|
|
||||||
print(f"Time: {pos['timestamp']}")
|
|
||||||
print(f"Action: {pos['action']}")
|
|
||||||
print(f"Entry: {pos['entry_price']:.4f}, Exit: {pos['exit_price']:.4f}")
|
|
||||||
print(f"Size: {pos['position_size']:.4f}")
|
|
||||||
print(f"PnL: {pos['realized_pnl']:.4f} ({pos['pnl_percentage']:.2f}%)")
|
|
||||||
print(f"Fee: {pos['fee']:.4f}")
|
|
||||||
print(f"Balance: {pos['balance_before']:.4f} -> {pos['balance_after']:.4f}")
|
|
||||||
print("----------------")
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Close the environment"""
|
|
||||||
pass
|
|
@ -26,6 +26,14 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
|
# Import checkpoint management
|
||||||
|
try:
|
||||||
|
from utils.checkpoint_manager import get_checkpoint_manager, save_checkpoint
|
||||||
|
CHECKPOINT_MANAGER_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
CHECKPOINT_MANAGER_AVAILABLE = False
|
||||||
|
logger.warning("Checkpoint manager not available. Model persistence will be disabled.")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class EnhancedRealtimeTrainingSystem:
|
class EnhancedRealtimeTrainingSystem:
|
||||||
@ -50,6 +58,12 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
# Experience buffers
|
# Experience buffers
|
||||||
self.experience_buffer = deque(maxlen=self.training_config['memory_size'])
|
self.experience_buffer = deque(maxlen=self.training_config['memory_size'])
|
||||||
self.validation_buffer = deque(maxlen=1000)
|
self.validation_buffer = deque(maxlen=1000)
|
||||||
|
|
||||||
|
# Training counters - CRITICAL for checkpoint management
|
||||||
|
self.training_iteration = 0
|
||||||
|
self.dqn_training_count = 0
|
||||||
|
self.cnn_training_count = 0
|
||||||
|
self.cob_training_count = 0
|
||||||
self.priority_buffer = deque(maxlen=2000) # High-priority experiences
|
self.priority_buffer = deque(maxlen=2000) # High-priority experiences
|
||||||
|
|
||||||
# Performance tracking
|
# Performance tracking
|
||||||
@ -1071,6 +1085,10 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
|
|
||||||
self.dqn_training_count += 1
|
self.dqn_training_count += 1
|
||||||
|
|
||||||
|
# Save checkpoint after training
|
||||||
|
if training_iterations > 0 and avg_loss > 0:
|
||||||
|
self._save_model_checkpoint('dqn_agent', rl_agent, avg_loss)
|
||||||
|
|
||||||
# Log progress every 10 training sessions
|
# Log progress every 10 training sessions
|
||||||
if self.dqn_training_count % 10 == 0:
|
if self.dqn_training_count % 10 == 0:
|
||||||
logger.info(f"DQN TRAINING: Session {self.dqn_training_count}, "
|
logger.info(f"DQN TRAINING: Session {self.dqn_training_count}, "
|
||||||
@ -2523,4 +2541,56 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error estimating price change: {e}")
|
logger.debug(f"Error estimating price change: {e}")
|
||||||
return 0.0
|
return 0.0 d
|
||||||
|
ef _save_model_checkpoint(self, model_name: str, model_obj, loss: float):
|
||||||
|
"""
|
||||||
|
Save model checkpoint after training if performance improved
|
||||||
|
|
||||||
|
This is CRITICAL for preserving training progress across restarts.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not CHECKPOINT_MANAGER_AVAILABLE:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get checkpoint manager
|
||||||
|
checkpoint_manager = get_checkpoint_manager()
|
||||||
|
if not checkpoint_manager:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prepare performance metrics
|
||||||
|
performance_metrics = {
|
||||||
|
'loss': loss,
|
||||||
|
'training_samples': len(self.experience_buffer),
|
||||||
|
'timestamp': datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare training metadata
|
||||||
|
training_metadata = {
|
||||||
|
'timestamp': datetime.now().isoformat(),
|
||||||
|
'training_iteration': self.training_iteration,
|
||||||
|
'model_type': model_name
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine model type based on model name
|
||||||
|
model_type = model_name
|
||||||
|
if 'dqn' in model_name.lower():
|
||||||
|
model_type = 'dqn'
|
||||||
|
elif 'cnn' in model_name.lower():
|
||||||
|
model_type = 'cnn'
|
||||||
|
elif 'cob' in model_name.lower():
|
||||||
|
model_type = 'cob_rl'
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
checkpoint_path = save_checkpoint(
|
||||||
|
model=model_obj,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_type,
|
||||||
|
performance_metrics=performance_metrics,
|
||||||
|
training_metadata=training_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
if checkpoint_path:
|
||||||
|
logger.info(f"💾 Saved checkpoint for {model_name}: {checkpoint_path} (loss: {loss:.4f})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
@ -2246,9 +2246,144 @@ class TradingOrchestrator:
|
|||||||
try:
|
try:
|
||||||
model_obj = None
|
model_obj = None
|
||||||
current_loss = None
|
current_loss = None
|
||||||
|
model_type = model_name
|
||||||
|
|
||||||
# Get model object and calculate current performance
|
# Get model object and calculate current performance
|
||||||
if model_name == 'dqn' and self.rl_agent:
|
if model_name == 'dqn' and self.rl_agent:
|
||||||
|
model_obj = self.rl_agent
|
||||||
|
# Use current loss from model state or estimate from performance
|
||||||
|
current_loss = self.model_states['dqn'].get('current_loss')
|
||||||
|
if current_loss is None:
|
||||||
|
# Estimate loss from performance score (inverse relationship)
|
||||||
|
current_loss = max(0.001, 1.0 - performance_score)
|
||||||
|
|
||||||
|
# Update model state tracking
|
||||||
|
self.model_states['dqn']['current_loss'] = current_loss
|
||||||
|
|
||||||
|
# If this is the first loss value, set it as initial and best
|
||||||
|
if self.model_states['dqn']['initial_loss'] is None:
|
||||||
|
self.model_states['dqn']['initial_loss'] = current_loss
|
||||||
|
if self.model_states['dqn']['best_loss'] is None or current_loss < self.model_states['dqn']['best_loss']:
|
||||||
|
self.model_states['dqn']['best_loss'] = current_loss
|
||||||
|
|
||||||
|
elif model_name == 'cnn' and self.cnn_model:
|
||||||
|
model_obj = self.cnn_model
|
||||||
|
# Use current loss from model state or estimate from performance
|
||||||
|
current_loss = self.model_states['cnn'].get('current_loss')
|
||||||
|
if current_loss is None:
|
||||||
|
# Estimate loss from performance score (inverse relationship)
|
||||||
|
current_loss = max(0.001, 1.0 - performance_score)
|
||||||
|
|
||||||
|
# Update model state tracking
|
||||||
|
self.model_states['cnn']['current_loss'] = current_loss
|
||||||
|
|
||||||
|
# If this is the first loss value, set it as initial and best
|
||||||
|
if self.model_states['cnn']['initial_loss'] is None:
|
||||||
|
self.model_states['cnn']['initial_loss'] = current_loss
|
||||||
|
if self.model_states['cnn']['best_loss'] is None or current_loss < self.model_states['cnn']['best_loss']:
|
||||||
|
self.model_states['cnn']['best_loss'] = current_loss
|
||||||
|
|
||||||
|
elif model_name == 'cob_rl' and self.cob_rl_agent:
|
||||||
|
model_obj = self.cob_rl_agent
|
||||||
|
# Use current loss from model state or estimate from performance
|
||||||
|
current_loss = self.model_states['cob_rl'].get('current_loss')
|
||||||
|
if current_loss is None:
|
||||||
|
# Estimate loss from performance score (inverse relationship)
|
||||||
|
current_loss = max(0.001, 1.0 - performance_score)
|
||||||
|
|
||||||
|
# Update model state tracking
|
||||||
|
self.model_states['cob_rl']['current_loss'] = current_loss
|
||||||
|
|
||||||
|
# If this is the first loss value, set it as initial and best
|
||||||
|
if self.model_states['cob_rl']['initial_loss'] is None:
|
||||||
|
self.model_states['cob_rl']['initial_loss'] = current_loss
|
||||||
|
if self.model_states['cob_rl']['best_loss'] is None or current_loss < self.model_states['cob_rl']['best_loss']:
|
||||||
|
self.model_states['cob_rl']['best_loss'] = current_loss
|
||||||
|
|
||||||
|
elif model_name == 'extrema' and hasattr(self, 'extrema_trainer') and self.extrema_trainer:
|
||||||
|
model_obj = self.extrema_trainer
|
||||||
|
# Use current loss from model state or estimate from performance
|
||||||
|
current_loss = self.model_states['extrema'].get('current_loss')
|
||||||
|
if current_loss is None:
|
||||||
|
# Estimate loss from performance score (inverse relationship)
|
||||||
|
current_loss = max(0.001, 1.0 - performance_score)
|
||||||
|
|
||||||
|
# Update model state tracking
|
||||||
|
self.model_states['extrema']['current_loss'] = current_loss
|
||||||
|
|
||||||
|
# If this is the first loss value, set it as initial and best
|
||||||
|
if self.model_states['extrema']['initial_loss'] is None:
|
||||||
|
self.model_states['extrema']['initial_loss'] = current_loss
|
||||||
|
if self.model_states['extrema']['best_loss'] is None or current_loss < self.model_states['extrema']['best_loss']:
|
||||||
|
self.model_states['extrema']['best_loss'] = current_loss
|
||||||
|
|
||||||
|
# Skip if we couldn't get a model object
|
||||||
|
if model_obj is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare performance metrics for checkpoint
|
||||||
|
performance_metrics = {
|
||||||
|
'loss': current_loss,
|
||||||
|
'accuracy': performance_score, # Use confidence as a proxy for accuracy
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare training metadata
|
||||||
|
training_metadata = {
|
||||||
|
'training_iteration': self.training_iterations,
|
||||||
|
'timestamp': datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save checkpoint using checkpoint manager
|
||||||
|
from utils.checkpoint_manager import save_checkpoint
|
||||||
|
checkpoint_metadata = save_checkpoint(
|
||||||
|
model=model_obj,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_type,
|
||||||
|
performance_metrics=performance_metrics,
|
||||||
|
training_metadata=training_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
if checkpoint_metadata:
|
||||||
|
logger.info(f"Saved checkpoint for {model_name}: {checkpoint_metadata.checkpoint_id} (loss={current_loss:.4f})")
|
||||||
|
|
||||||
|
# Also save periodically based on training iterations
|
||||||
|
if self.training_iterations % 100 == 0:
|
||||||
|
# Force save every 100 training iterations regardless of performance
|
||||||
|
checkpoint_metadata = save_checkpoint(
|
||||||
|
model=model_obj,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_type,
|
||||||
|
performance_metrics=performance_metrics,
|
||||||
|
training_metadata=training_metadata,
|
||||||
|
force_save=True
|
||||||
|
)
|
||||||
|
if checkpoint_metadata:
|
||||||
|
logger.info(f"Periodic checkpoint saved for {model_name}: {checkpoint_metadata.checkpoint_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in _save_training_checkpoints: {e}")
|
||||||
|
|
||||||
|
def _initialize_checkpoint_manager(self):
|
||||||
|
"""Initialize the checkpoint manager for model persistence"""
|
||||||
|
try:
|
||||||
|
from utils.checkpoint_manager import get_checkpoint_manager
|
||||||
|
self.checkpoint_manager = get_checkpoint_manager()
|
||||||
|
|
||||||
|
# Initialize model states dictionary to track performance
|
||||||
|
self.model_states = {
|
||||||
|
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||||
|
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||||
|
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||||
|
'extrema': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Checkpoint manager initialized for model persistence")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing checkpoint manager: {e}")
|
||||||
|
self.checkpoint_manager = None
|
||||||
model_obj = self.rl_agent
|
model_obj = self.rl_agent
|
||||||
# Use negative performance score as loss (higher confidence = lower loss)
|
# Use negative performance score as loss (higher confidence = lower loss)
|
||||||
current_loss = 1.0 - performance_score
|
current_loss = 1.0 - performance_score
|
||||||
|
@ -1 +0,0 @@
|
|||||||
|
|
@ -1,49 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Fix Dashboard Metrics Script
|
|
||||||
|
|
||||||
This script fixes the incomplete code in the update_metrics function
|
|
||||||
of the web/clean_dashboard.py file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
|
|
||||||
def fix_dashboard_metrics():
|
|
||||||
"""Fix the incomplete code in the update_metrics function"""
|
|
||||||
file_path = 'web/clean_dashboard.py'
|
|
||||||
|
|
||||||
# Read the file content
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
|
||||||
content = file.read()
|
|
||||||
|
|
||||||
# Find and replace the incomplete code
|
|
||||||
pattern = r"# Add unrealized P&L from current position \(adjustable leverage\)\s+if self\.curr"
|
|
||||||
replacement = """# Add unrealized P&L from current position (adjustable leverage)
|
|
||||||
if self.current_position and current_price:
|
|
||||||
side = self.current_position.get('side', 'UNKNOWN')
|
|
||||||
size = self.current_position.get('size', 0)
|
|
||||||
entry_price = self.current_position.get('price', 0)
|
|
||||||
|
|
||||||
if entry_price and size > 0:
|
|
||||||
# Calculate unrealized P&L with current leverage
|
|
||||||
if side.upper() == 'LONG' or side.upper() == 'BUY':
|
|
||||||
raw_pnl_per_unit = current_price - entry_price
|
|
||||||
else: # SHORT or SELL
|
|
||||||
raw_pnl_per_unit = entry_price - current_price
|
|
||||||
|
|
||||||
# Apply current leverage to unrealized P&L
|
|
||||||
leveraged_unrealized_pnl = raw_pnl_per_unit * size * self.current_leverage
|
|
||||||
total_session_pnl += leveraged_unrealized_pnl"""
|
|
||||||
|
|
||||||
# Replace the pattern
|
|
||||||
fixed_content = re.sub(pattern, replacement, content)
|
|
||||||
|
|
||||||
# Write the fixed content back to the file
|
|
||||||
with open(file_path, 'w', encoding='utf-8') as file:
|
|
||||||
file.write(fixed_content)
|
|
||||||
|
|
||||||
print(f"Fixed dashboard metrics in {file_path}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
fix_dashboard_metrics()
|
|
@ -1,283 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Fix RL Training Issues - Comprehensive Solution
|
|
||||||
|
|
||||||
This script addresses the critical RL training audit issues:
|
|
||||||
1. MASSIVE INPUT DATA GAP (99.25% Missing) - Implements full 13,400 feature state
|
|
||||||
2. Disconnected Training Pipeline - Fixes data flow between components
|
|
||||||
3. Missing Enhanced State Builder - Connects orchestrator to dashboard
|
|
||||||
4. Reward Calculation Issues - Ensures enhanced pivot-based rewards
|
|
||||||
5. Williams Market Structure Integration - Proper feature extraction
|
|
||||||
6. Real-time Data Integration - Live market data to RL
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python fix_rl_training_issues.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
project_root = Path(__file__).parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def fix_orchestrator_missing_methods():
|
|
||||||
"""Fix missing methods in enhanced orchestrator"""
|
|
||||||
try:
|
|
||||||
logger.info("Checking enhanced orchestrator...")
|
|
||||||
|
|
||||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
||||||
|
|
||||||
# Test if methods exist
|
|
||||||
test_orchestrator = EnhancedTradingOrchestrator()
|
|
||||||
|
|
||||||
methods_to_check = [
|
|
||||||
'_get_symbol_correlation',
|
|
||||||
'build_comprehensive_rl_state',
|
|
||||||
'calculate_enhanced_pivot_reward'
|
|
||||||
]
|
|
||||||
|
|
||||||
missing_methods = []
|
|
||||||
for method in methods_to_check:
|
|
||||||
if not hasattr(test_orchestrator, method):
|
|
||||||
missing_methods.append(method)
|
|
||||||
|
|
||||||
if missing_methods:
|
|
||||||
logger.error(f"Missing methods in enhanced orchestrator: {missing_methods}")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
logger.info("✅ All required methods present in enhanced orchestrator")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error checking orchestrator: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_comprehensive_state_building():
|
|
||||||
"""Test comprehensive RL state building"""
|
|
||||||
try:
|
|
||||||
logger.info("Testing comprehensive state building...")
|
|
||||||
|
|
||||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
||||||
from core.data_provider import DataProvider
|
|
||||||
|
|
||||||
# Create test instances
|
|
||||||
data_provider = DataProvider()
|
|
||||||
orchestrator = EnhancedTradingOrchestrator(data_provider=data_provider)
|
|
||||||
|
|
||||||
# Test comprehensive state building
|
|
||||||
state = orchestrator.build_comprehensive_rl_state('ETH/USDT')
|
|
||||||
|
|
||||||
if state is not None:
|
|
||||||
logger.info(f"✅ Comprehensive state built: {len(state)} features")
|
|
||||||
|
|
||||||
if len(state) == 13400:
|
|
||||||
logger.info("✅ PERFECT: Exactly 13,400 features as required!")
|
|
||||||
else:
|
|
||||||
logger.warning(f"⚠️ Expected 13,400 features, got {len(state)}")
|
|
||||||
|
|
||||||
# Check feature distribution
|
|
||||||
import numpy as np
|
|
||||||
non_zero = np.count_nonzero(state)
|
|
||||||
logger.info(f"Non-zero features: {non_zero} ({non_zero/len(state)*100:.1f}%)")
|
|
||||||
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.error("❌ Comprehensive state building failed")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error testing state building: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_enhanced_reward_calculation():
|
|
||||||
"""Test enhanced reward calculation"""
|
|
||||||
try:
|
|
||||||
logger.info("Testing enhanced reward calculation...")
|
|
||||||
|
|
||||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
orchestrator = EnhancedTradingOrchestrator()
|
|
||||||
|
|
||||||
# Test data
|
|
||||||
trade_decision = {
|
|
||||||
'action': 'BUY',
|
|
||||||
'confidence': 0.75,
|
|
||||||
'price': 2500.0,
|
|
||||||
'timestamp': datetime.now()
|
|
||||||
}
|
|
||||||
|
|
||||||
trade_outcome = {
|
|
||||||
'net_pnl': 50.0,
|
|
||||||
'exit_price': 2550.0,
|
|
||||||
'duration': timedelta(minutes=15)
|
|
||||||
}
|
|
||||||
|
|
||||||
market_data = {
|
|
||||||
'volatility': 0.03,
|
|
||||||
'order_flow_direction': 'bullish',
|
|
||||||
'order_flow_strength': 0.8
|
|
||||||
}
|
|
||||||
|
|
||||||
# Test enhanced reward
|
|
||||||
enhanced_reward = orchestrator.calculate_enhanced_pivot_reward(
|
|
||||||
trade_decision, market_data, trade_outcome
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"✅ Enhanced reward calculated: {enhanced_reward:.3f}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error testing reward calculation: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_williams_integration():
|
|
||||||
"""Test Williams market structure integration"""
|
|
||||||
try:
|
|
||||||
logger.info("Testing Williams market structure integration...")
|
|
||||||
|
|
||||||
from training.williams_market_structure import extract_pivot_features, analyze_pivot_context
|
|
||||||
from core.data_provider import DataProvider
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Create test data
|
|
||||||
test_data = {
|
|
||||||
'open': np.random.uniform(2400, 2600, 100),
|
|
||||||
'high': np.random.uniform(2500, 2700, 100),
|
|
||||||
'low': np.random.uniform(2300, 2500, 100),
|
|
||||||
'close': np.random.uniform(2400, 2600, 100),
|
|
||||||
'volume': np.random.uniform(1000, 5000, 100)
|
|
||||||
}
|
|
||||||
df = pd.DataFrame(test_data)
|
|
||||||
|
|
||||||
# Test pivot features
|
|
||||||
pivot_features = extract_pivot_features(df)
|
|
||||||
|
|
||||||
if pivot_features is not None:
|
|
||||||
logger.info(f"✅ Williams pivot features extracted: {len(pivot_features)} features")
|
|
||||||
|
|
||||||
# Test pivot context analysis
|
|
||||||
market_data = {'ohlcv_data': df}
|
|
||||||
context = analyze_pivot_context(market_data, datetime.now(), 'BUY')
|
|
||||||
|
|
||||||
if context is not None:
|
|
||||||
logger.info("✅ Williams pivot context analysis working")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ Pivot context analysis returned None")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
logger.error("❌ Williams pivot feature extraction failed")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error testing Williams integration: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_dashboard_integration():
|
|
||||||
"""Test dashboard integration with enhanced features"""
|
|
||||||
try:
|
|
||||||
logger.info("Testing dashboard integration...")
|
|
||||||
|
|
||||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
|
||||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
||||||
from core.data_provider import DataProvider
|
|
||||||
from core.trading_executor import TradingExecutor
|
|
||||||
|
|
||||||
# Create components
|
|
||||||
data_provider = DataProvider()
|
|
||||||
orchestrator = EnhancedTradingOrchestrator(data_provider=data_provider)
|
|
||||||
executor = TradingExecutor()
|
|
||||||
|
|
||||||
# Create dashboard
|
|
||||||
dashboard = TradingDashboard(
|
|
||||||
data_provider=data_provider,
|
|
||||||
orchestrator=orchestrator,
|
|
||||||
trading_executor=executor
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if dashboard has access to enhanced features
|
|
||||||
has_comprehensive_builder = hasattr(dashboard, '_build_comprehensive_rl_state')
|
|
||||||
has_enhanced_orchestrator = hasattr(dashboard.orchestrator, 'build_comprehensive_rl_state')
|
|
||||||
|
|
||||||
if has_comprehensive_builder and has_enhanced_orchestrator:
|
|
||||||
logger.info("✅ Dashboard properly integrated with enhanced features")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ Dashboard missing some enhanced features")
|
|
||||||
logger.info(f"Comprehensive builder: {has_comprehensive_builder}")
|
|
||||||
logger.info(f"Enhanced orchestrator: {has_enhanced_orchestrator}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error testing dashboard integration: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main function to run all fixes and tests"""
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("=" * 70)
|
|
||||||
logger.info("COMPREHENSIVE RL TRAINING FIX - AUDIT ISSUE RESOLUTION")
|
|
||||||
logger.info("=" * 70)
|
|
||||||
|
|
||||||
# Track results
|
|
||||||
test_results = {}
|
|
||||||
|
|
||||||
# Run all tests
|
|
||||||
tests = [
|
|
||||||
("Enhanced Orchestrator Methods", fix_orchestrator_missing_methods),
|
|
||||||
("Comprehensive State Building", test_comprehensive_state_building),
|
|
||||||
("Enhanced Reward Calculation", test_enhanced_reward_calculation),
|
|
||||||
("Williams Market Structure", test_williams_integration),
|
|
||||||
("Dashboard Integration", test_dashboard_integration)
|
|
||||||
]
|
|
||||||
|
|
||||||
for test_name, test_func in tests:
|
|
||||||
logger.info(f"\n🔧 {test_name}...")
|
|
||||||
try:
|
|
||||||
result = test_func()
|
|
||||||
test_results[test_name] = result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ {test_name} failed: {e}")
|
|
||||||
test_results[test_name] = False
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
logger.info("\n" + "=" * 70)
|
|
||||||
logger.info("COMPREHENSIVE RL TRAINING FIX RESULTS")
|
|
||||||
logger.info("=" * 70)
|
|
||||||
|
|
||||||
passed = sum(test_results.values())
|
|
||||||
total = len(test_results)
|
|
||||||
|
|
||||||
for test_name, result in test_results.items():
|
|
||||||
status = "✅ PASS" if result else "❌ FAIL"
|
|
||||||
logger.info(f"{test_name}: {status}")
|
|
||||||
|
|
||||||
logger.info(f"\nOverall: {passed}/{total} tests passed")
|
|
||||||
|
|
||||||
if passed == total:
|
|
||||||
logger.info("🎉 ALL RL TRAINING ISSUES FIXED!")
|
|
||||||
logger.info("The system now supports:")
|
|
||||||
logger.info(" - 13,400 comprehensive RL features")
|
|
||||||
logger.info(" - Enhanced pivot-based rewards")
|
|
||||||
logger.info(" - Williams market structure integration")
|
|
||||||
logger.info(" - Proper data flow between components")
|
|
||||||
logger.info(" - Real-time data integration")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ Some issues remain - check logs above")
|
|
||||||
|
|
||||||
return 0 if passed == total else 1
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
sys.exit(main())
|
|
Reference in New Issue
Block a user