Compare commits
3 Commits
c63dc11c14
...
1d224e5b8c
Author | SHA1 | Date | |
---|---|---|---|
1d224e5b8c | |||
a68df64b83 | |||
cc0c783411 |
@ -1,124 +0,0 @@
|
||||
# Requirements Document
|
||||
|
||||
## Introduction
|
||||
|
||||
The Checkpoint Persistence Fix addresses a critical system flaw where model training progress is not being saved during training, causing all learning progress to be lost when the system restarts. Despite having a well-implemented CheckpointManager and proper checkpoint loading at startup, the system lacks checkpoint saving during training operations. This creates a fundamental issue where models train continuously but never persist their improved weights, making continuous improvement impossible and wasting computational resources.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Requirement 1: Real-time Checkpoint Saving During Training
|
||||
|
||||
**User Story:** As a system operator, I want model improvements to be automatically saved during training, so that training progress is never lost when the system restarts.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the DQN model is trained in _train_models_on_decision THEN the system SHALL save a checkpoint if the loss improves.
|
||||
2. WHEN the CNN model is trained THEN the system SHALL save a checkpoint if the loss improves.
|
||||
3. WHEN the COB RL model is trained THEN the system SHALL save a checkpoint if the loss improves.
|
||||
4. WHEN the Extrema trainer is trained THEN the system SHALL save a checkpoint if the loss improves.
|
||||
5. WHEN any model training completes THEN the system SHALL compare current performance to best performance and save if improved.
|
||||
6. WHEN checkpoint saving occurs THEN the system SHALL update the model_states dictionary with new performance metrics.
|
||||
|
||||
### Requirement 2: Performance-Based Checkpoint Management
|
||||
|
||||
**User Story:** As a developer, I want checkpoints to be saved only when model performance improves, so that storage is used efficiently and only the best models are preserved.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN evaluating whether to save a checkpoint THEN the system SHALL compare current loss to the best recorded loss.
|
||||
2. WHEN loss decreases by a configurable threshold THEN the system SHALL trigger checkpoint saving.
|
||||
3. WHEN multiple models are trained simultaneously THEN each model SHALL have independent performance tracking.
|
||||
4. WHEN checkpoint rotation occurs THEN the system SHALL keep only the best performing checkpoints.
|
||||
5. WHEN performance metrics are updated THEN the system SHALL log the improvement for monitoring.
|
||||
6. WHEN no improvement is detected THEN the system SHALL skip checkpoint saving to avoid unnecessary I/O.
|
||||
|
||||
### Requirement 3: Periodic Checkpoint Saving
|
||||
|
||||
**User Story:** As a system administrator, I want checkpoints to be saved periodically regardless of performance, so that progress is preserved even during long training sessions without significant improvement.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN a configurable number of training iterations have passed THEN the system SHALL save a checkpoint regardless of performance.
|
||||
2. WHEN periodic saving occurs THEN the system SHALL use a separate checkpoint category to distinguish from performance-based saves.
|
||||
3. WHEN the system runs for extended periods THEN periodic checkpoints SHALL ensure no more than X minutes of training progress can be lost.
|
||||
4. WHEN periodic checkpoints accumulate THEN the system SHALL maintain a rolling window of recent saves.
|
||||
5. WHEN storage space is limited THEN periodic checkpoints SHALL be cleaned up while preserving performance-based checkpoints.
|
||||
6. WHEN the system restarts THEN it SHALL load the most recent checkpoint (either performance-based or periodic).
|
||||
|
||||
### Requirement 4: Enhanced Training System Integration
|
||||
|
||||
**User Story:** As a developer, I want the EnhancedRealtimeTrainingSystem to properly save checkpoints, so that continuous learning progress is preserved across system restarts.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the EnhancedRealtimeTrainingSystem trains models THEN it SHALL integrate with the CheckpointManager.
|
||||
2. WHEN training episodes complete THEN the system SHALL evaluate and save improved models.
|
||||
3. WHEN the training system initializes THEN it SHALL load the best available checkpoints.
|
||||
4. WHEN training data is collected THEN the system SHALL track performance metrics for checkpoint decisions.
|
||||
5. WHEN the training system shuts down THEN it SHALL save final checkpoints before termination.
|
||||
6. WHEN training resumes THEN the system SHALL continue from the last saved checkpoint state.
|
||||
|
||||
### Requirement 5: Complete Training Data Storage
|
||||
|
||||
**User Story:** As a developer, I want complete training episodes to be stored with full input dataframes, so that training can be replayed and analyzed with all original context.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN training episodes are saved THEN the system SHALL store the complete input dataframe with all model inputs (price data, indicators, market structure, etc.).
|
||||
2. WHEN model actions are recorded THEN the system SHALL store the full context that led to the decision, not just the action result.
|
||||
3. WHEN training cases are saved THEN they SHALL include timestamps, market conditions, and all feature vectors used by the models.
|
||||
4. WHEN storing training data THEN the system SHALL preserve the exact state that can be used to reproduce the model's decision.
|
||||
5. WHEN training episodes are replayed THEN the system SHALL be able to reconstruct the exact same inputs that were originally used.
|
||||
6. WHEN analyzing training performance THEN complete dataframes SHALL be available for debugging and improvement.
|
||||
|
||||
### Requirement 6: Comprehensive Performance Tracking
|
||||
|
||||
**User Story:** As a system operator, I want detailed performance metrics to be tracked and persisted, so that I can monitor training progress and model improvement over time.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN models are trained THEN the system SHALL track loss values, accuracy metrics, and training timestamps.
|
||||
2. WHEN performance improves THEN the system SHALL log the improvement amount and save metadata.
|
||||
3. WHEN checkpoints are saved THEN the system SHALL store performance metrics alongside model weights.
|
||||
4. WHEN the system starts THEN it SHALL display the performance history of loaded checkpoints.
|
||||
5. WHEN multiple training sessions occur THEN the system SHALL maintain a continuous performance history.
|
||||
6. WHEN performance degrades THEN the system SHALL provide alerts and revert to better checkpoints if configured.
|
||||
|
||||
### Requirement 7: Robust Error Handling and Recovery
|
||||
|
||||
**User Story:** As a system administrator, I want checkpoint operations to be resilient to failures, so that training can continue even if individual checkpoint saves fail.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN checkpoint saving fails THEN the system SHALL log the error and continue training without crashing.
|
||||
2. WHEN disk space is insufficient THEN the system SHALL clean up old checkpoints and retry saving.
|
||||
3. WHEN checkpoint files are corrupted THEN the system SHALL fall back to previous valid checkpoints.
|
||||
4. WHEN concurrent access conflicts occur THEN the system SHALL use proper locking mechanisms.
|
||||
5. WHEN the system recovers from failures THEN it SHALL validate checkpoint integrity before loading.
|
||||
6. WHEN critical checkpoint operations fail repeatedly THEN the system SHALL alert administrators.
|
||||
|
||||
### Requirement 8: Configuration and Monitoring
|
||||
|
||||
**User Story:** As a developer, I want configurable checkpoint settings and monitoring capabilities, so that I can optimize checkpoint behavior for different training scenarios.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN configuring the system THEN checkpoint saving frequency SHALL be adjustable.
|
||||
2. WHEN setting performance thresholds THEN the minimum improvement required for saving SHALL be configurable.
|
||||
3. WHEN monitoring training THEN checkpoint save events SHALL be visible in logs and dashboards.
|
||||
4. WHEN analyzing performance THEN checkpoint metadata SHALL be accessible for review.
|
||||
5. WHEN tuning the system THEN checkpoint storage limits SHALL be configurable.
|
||||
6. WHEN debugging issues THEN detailed checkpoint operation logs SHALL be available.
|
||||
|
||||
### Requirement 9: Backward Compatibility and Migration
|
||||
|
||||
**User Story:** As a user, I want existing checkpoints to remain compatible, so that current model progress is preserved when the checkpoint system is enhanced.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the enhanced checkpoint system starts THEN it SHALL load existing checkpoints without issues.
|
||||
2. WHEN checkpoint formats are updated THEN migration utilities SHALL convert old formats.
|
||||
3. WHEN new metadata is added THEN existing checkpoints SHALL work with default values.
|
||||
4. WHEN the system upgrades THEN checkpoint directories SHALL be preserved and enhanced.
|
||||
5. WHEN rollback is needed THEN the system SHALL support reverting to previous checkpoint versions.
|
||||
6. WHEN compatibility issues arise THEN clear error messages SHALL guide resolution.
|
@ -1,6 +0,0 @@
|
||||
# Trading environments for reinforcement learning
|
||||
# This module contains environments for training trading agents
|
||||
|
||||
from NN.environments.trading_env import TradingEnvironment
|
||||
|
||||
__all__ = ['TradingEnvironment']
|
@ -1,532 +0,0 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Dict, Tuple, List, Any, Optional
|
||||
import logging
|
||||
import gym
|
||||
from gym import spaces
|
||||
import random
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingEnvironment(gym.Env):
|
||||
"""
|
||||
Trading environment implementing gym interface for reinforcement learning
|
||||
|
||||
2-Action System:
|
||||
- 0: SELL (or close long position)
|
||||
- 1: BUY (or close short position)
|
||||
|
||||
Intelligent Position Management:
|
||||
- When neutral: Actions enter positions
|
||||
- When positioned: Actions can close or flip positions
|
||||
- Different thresholds for entry vs exit decisions
|
||||
|
||||
State:
|
||||
- OHLCV data from multiple timeframes
|
||||
- Technical indicators
|
||||
- Position data and unrealized PnL
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_interface,
|
||||
initial_balance: float = 10000.0,
|
||||
transaction_fee: float = 0.0002,
|
||||
window_size: int = 20,
|
||||
max_position: float = 1.0,
|
||||
reward_scaling: float = 1.0,
|
||||
entry_threshold: float = 0.6, # Higher threshold for entering positions
|
||||
exit_threshold: float = 0.3, # Lower threshold for exiting positions
|
||||
):
|
||||
"""
|
||||
Initialize the trading environment with 2-action system.
|
||||
|
||||
Args:
|
||||
data_interface: DataInterface instance to get market data
|
||||
initial_balance: Initial balance in the base currency
|
||||
transaction_fee: Fee for each transaction as a fraction of trade value
|
||||
window_size: Number of candles in the observation window
|
||||
max_position: Maximum position size as a fraction of balance
|
||||
reward_scaling: Scale factor for rewards
|
||||
entry_threshold: Confidence threshold for entering new positions
|
||||
exit_threshold: Confidence threshold for exiting positions
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.data_interface = data_interface
|
||||
self.initial_balance = initial_balance
|
||||
self.transaction_fee = transaction_fee
|
||||
self.window_size = window_size
|
||||
self.max_position = max_position
|
||||
self.reward_scaling = reward_scaling
|
||||
self.entry_threshold = entry_threshold
|
||||
self.exit_threshold = exit_threshold
|
||||
|
||||
# Load data for primary timeframe (assuming the first one is primary)
|
||||
self.timeframe = self.data_interface.timeframes[0]
|
||||
self.reset_data()
|
||||
|
||||
# Define action and observation spaces for 2-action system
|
||||
self.action_space = spaces.Discrete(2) # 0=SELL, 1=BUY
|
||||
|
||||
# For observation space, we consider multiple timeframes with OHLCV data
|
||||
# and additional features like technical indicators, position info, etc.
|
||||
n_timeframes = len(self.data_interface.timeframes)
|
||||
n_features = 5 # OHLCV data by default
|
||||
|
||||
# Add additional features for position, balance, unrealized_pnl, etc.
|
||||
additional_features = 5 # position, balance, unrealized_pnl, entry_price, position_duration
|
||||
|
||||
# Calculate total feature dimension
|
||||
total_features = (n_timeframes * n_features * self.window_size) + additional_features
|
||||
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(total_features,), dtype=np.float32
|
||||
)
|
||||
|
||||
# Use tuple for state_shape that EnhancedCNN expects
|
||||
self.state_shape = (total_features,)
|
||||
|
||||
# Position tracking for 2-action system
|
||||
self.position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
||||
self.entry_price = 0.0 # Price at which position was entered
|
||||
self.entry_step = 0 # Step at which position was entered
|
||||
|
||||
# Initialize state
|
||||
self.reset()
|
||||
|
||||
def reset_data(self):
|
||||
"""Reset data and generate a new set of price data for training"""
|
||||
# Get data for each timeframe
|
||||
self.data = {}
|
||||
for tf in self.data_interface.timeframes:
|
||||
df = self.data_interface.dataframes[tf]
|
||||
if df is not None and not df.empty:
|
||||
self.data[tf] = df
|
||||
|
||||
if not self.data:
|
||||
raise ValueError("No data available for training")
|
||||
|
||||
# Use the primary timeframe for step count
|
||||
self.prices = self.data[self.timeframe]['close'].values
|
||||
self.timestamps = self.data[self.timeframe].index.values
|
||||
self.max_steps = len(self.prices) - self.window_size - 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset the environment to initial state"""
|
||||
# Reset trading variables
|
||||
self.balance = self.initial_balance
|
||||
self.trades = []
|
||||
self.rewards = []
|
||||
|
||||
# Reset step counter
|
||||
self.current_step = self.window_size
|
||||
|
||||
# Get initial observation
|
||||
observation = self._get_observation()
|
||||
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Take a step in the environment using 2-action system with intelligent position management.
|
||||
|
||||
Args:
|
||||
action: Action to take (0: SELL, 1: BUY)
|
||||
|
||||
Returns:
|
||||
tuple: (observation, reward, done, info)
|
||||
"""
|
||||
# Get current state before taking action
|
||||
prev_balance = self.balance
|
||||
prev_position = self.position
|
||||
prev_price = self.prices[self.current_step]
|
||||
|
||||
# Take action with intelligent position management
|
||||
info = {}
|
||||
reward = 0
|
||||
last_position_info = None
|
||||
|
||||
# Get current price
|
||||
current_price = self.prices[self.current_step]
|
||||
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
|
||||
|
||||
# Implement 2-action system with position management
|
||||
if action == 0: # SELL action
|
||||
if self.position == 0: # No position - enter short
|
||||
self._open_position(-1.0 * self.max_position, current_price)
|
||||
logger.info(f"ENTER SHORT at step {self.current_step}, price: {current_price:.4f}")
|
||||
reward = -self.transaction_fee # Entry cost
|
||||
|
||||
elif self.position > 0: # Long position - close it
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
logger.info(f"CLOSE LONG at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||
|
||||
elif self.position < 0: # Already short - potentially flip to long if very strong signal
|
||||
# For now, just hold the short position (no action)
|
||||
pass
|
||||
|
||||
elif action == 1: # BUY action
|
||||
if self.position == 0: # No position - enter long
|
||||
self._open_position(1.0 * self.max_position, current_price)
|
||||
logger.info(f"ENTER LONG at step {self.current_step}, price: {current_price:.4f}")
|
||||
reward = -self.transaction_fee # Entry cost
|
||||
|
||||
elif self.position < 0: # Short position - close it
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
logger.info(f"CLOSE SHORT at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||
|
||||
elif self.position > 0: # Already long - potentially flip to short if very strong signal
|
||||
# For now, just hold the long position (no action)
|
||||
pass
|
||||
|
||||
# Calculate unrealized PnL and add to reward if holding position
|
||||
if self.position != 0:
|
||||
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
|
||||
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
|
||||
|
||||
# Apply time-based holding penalty to encourage decisive actions
|
||||
position_duration = self.current_step - self.entry_step
|
||||
holding_penalty = min(position_duration * 0.0001, 0.01) # Max 1% penalty
|
||||
reward -= holding_penalty
|
||||
|
||||
# Reward staying neutral when uncertain (no clear setup)
|
||||
else:
|
||||
reward += 0.0001 # Small reward for not trading without clear signals
|
||||
|
||||
# Move to next step
|
||||
self.current_step += 1
|
||||
|
||||
# Get new observation
|
||||
observation = self._get_observation()
|
||||
|
||||
# Check if episode is done
|
||||
done = self.current_step >= len(self.prices) - 1
|
||||
|
||||
# If done, close any remaining positions
|
||||
if done and self.position != 0:
|
||||
final_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += final_pnl * self.reward_scaling
|
||||
info['final_pnl'] = final_pnl
|
||||
info['final_balance'] = self.balance
|
||||
logger.info(f"Episode ended. Final balance: {self.balance:.4f}, Return: {(self.balance/self.initial_balance-1)*100:.2f}%")
|
||||
|
||||
# Track trade result if position changed or position was closed
|
||||
if prev_position != self.position or last_position_info is not None:
|
||||
# Calculate realized PnL if position was closed
|
||||
realized_pnl = 0
|
||||
position_info = {}
|
||||
|
||||
if last_position_info is not None:
|
||||
# Use the position information from closing
|
||||
realized_pnl = last_position_info['pnl']
|
||||
position_info = last_position_info
|
||||
else:
|
||||
# Calculate manually based on balance change
|
||||
realized_pnl = self.balance - prev_balance if prev_position != 0 else 0
|
||||
|
||||
# Record detailed trade information
|
||||
trade_result = {
|
||||
'step': self.current_step,
|
||||
'timestamp': self.timestamps[self.current_step],
|
||||
'action': action,
|
||||
'action_name': ['SELL', 'BUY'][action],
|
||||
'price': current_price,
|
||||
'position_changed': prev_position != self.position,
|
||||
'prev_position': prev_position,
|
||||
'new_position': self.position,
|
||||
'position_size': abs(self.position) if self.position != 0 else abs(prev_position),
|
||||
'entry_price': position_info.get('entry_price', self.entry_price),
|
||||
'exit_price': position_info.get('exit_price', current_price),
|
||||
'realized_pnl': realized_pnl,
|
||||
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0,
|
||||
'pnl': realized_pnl, # Total PnL (realized for this step)
|
||||
'balance_before': prev_balance,
|
||||
'balance_after': self.balance,
|
||||
'trade_fee': position_info.get('fee', abs(self.position - prev_position) * current_price * self.transaction_fee)
|
||||
}
|
||||
info['trade_result'] = trade_result
|
||||
self.trades.append(trade_result)
|
||||
|
||||
# Log trade details
|
||||
logger.info(f"Trade executed - Action: {['SELL', 'BUY'][action]}, "
|
||||
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
|
||||
f"Balance: {self.balance:.4f}")
|
||||
|
||||
# Store reward
|
||||
self.rewards.append(reward)
|
||||
|
||||
# Update info dict with current state
|
||||
info.update({
|
||||
'step': self.current_step,
|
||||
'price': current_price,
|
||||
'prev_price': prev_price,
|
||||
'price_change': (current_price - prev_price) / prev_price if prev_price != 0 else 0,
|
||||
'balance': self.balance,
|
||||
'position': self.position,
|
||||
'entry_price': self.entry_price,
|
||||
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0.0,
|
||||
'total_trades': len(self.trades),
|
||||
'total_pnl': self.total_pnl,
|
||||
'return_pct': (self.balance/self.initial_balance-1)*100
|
||||
})
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def _calculate_unrealized_pnl(self, current_price):
|
||||
"""Calculate unrealized PnL for current position"""
|
||||
if self.position == 0 or self.entry_price == 0:
|
||||
return 0.0
|
||||
|
||||
if self.position > 0: # Long position
|
||||
return self.position * (current_price / self.entry_price - 1.0)
|
||||
else: # Short position
|
||||
return -self.position * (1.0 - current_price / self.entry_price)
|
||||
|
||||
def _open_position(self, position_size: float, entry_price: float):
|
||||
"""Open a new position"""
|
||||
self.position = position_size
|
||||
self.entry_price = entry_price
|
||||
self.entry_step = self.current_step
|
||||
|
||||
# Calculate position value
|
||||
position_value = abs(position_size) * entry_price
|
||||
|
||||
# Apply transaction fee
|
||||
fee = position_value * self.transaction_fee
|
||||
self.balance -= fee
|
||||
|
||||
logger.info(f"Opened position: {position_size:.4f} at {entry_price:.4f}, fee: {fee:.4f}")
|
||||
|
||||
def _close_position(self, exit_price: float) -> Tuple[float, Dict]:
|
||||
"""Close current position and return PnL"""
|
||||
if self.position == 0:
|
||||
return 0.0, {}
|
||||
|
||||
# Calculate PnL
|
||||
if self.position > 0: # Long position
|
||||
pnl = (exit_price - self.entry_price) / self.entry_price
|
||||
else: # Short position
|
||||
pnl = (self.entry_price - exit_price) / self.entry_price
|
||||
|
||||
# Apply transaction fees (entry + exit)
|
||||
position_value = abs(self.position) * exit_price
|
||||
exit_fee = position_value * self.transaction_fee
|
||||
total_fees = exit_fee # Entry fee already applied when opening
|
||||
|
||||
# Net PnL after fees
|
||||
net_pnl = pnl - (total_fees / (abs(self.position) * self.entry_price))
|
||||
|
||||
# Update balance
|
||||
self.balance *= (1 + net_pnl)
|
||||
self.total_pnl += net_pnl
|
||||
|
||||
# Track trade
|
||||
position_info = {
|
||||
'position_size': self.position,
|
||||
'entry_price': self.entry_price,
|
||||
'exit_price': exit_price,
|
||||
'pnl': net_pnl,
|
||||
'duration': self.current_step - self.entry_step,
|
||||
'entry_step': self.entry_step,
|
||||
'exit_step': self.current_step
|
||||
}
|
||||
|
||||
self.trades.append(position_info)
|
||||
|
||||
# Update trade statistics
|
||||
if net_pnl > 0:
|
||||
self.winning_trades += 1
|
||||
else:
|
||||
self.losing_trades += 1
|
||||
|
||||
logger.info(f"Closed position: {self.position:.4f}, PnL: {net_pnl:.4f}, Duration: {position_info['duration']} steps")
|
||||
|
||||
# Reset position
|
||||
self.position = 0.0
|
||||
self.entry_price = 0.0
|
||||
self.entry_step = 0
|
||||
|
||||
return net_pnl, position_info
|
||||
|
||||
def _get_observation(self):
|
||||
"""
|
||||
Get the current observation.
|
||||
|
||||
Returns:
|
||||
np.array: The observation vector
|
||||
"""
|
||||
observations = []
|
||||
|
||||
# Get data from each timeframe
|
||||
for tf in self.data_interface.timeframes:
|
||||
if tf in self.data:
|
||||
# Get the window of data for this timeframe
|
||||
df = self.data[tf]
|
||||
start_idx = self._align_timeframe_index(tf)
|
||||
|
||||
if start_idx is not None and start_idx >= 0 and start_idx + self.window_size <= len(df):
|
||||
window = df.iloc[start_idx:start_idx + self.window_size]
|
||||
|
||||
# Extract OHLCV data
|
||||
ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Normalize OHLCV data
|
||||
last_close = ohlcv[-1, 3] # Last close price
|
||||
ohlcv_normalized = np.zeros_like(ohlcv)
|
||||
ohlcv_normalized[:, 0] = ohlcv[:, 0] / last_close - 1.0 # open
|
||||
ohlcv_normalized[:, 1] = ohlcv[:, 1] / last_close - 1.0 # high
|
||||
ohlcv_normalized[:, 2] = ohlcv[:, 2] / last_close - 1.0 # low
|
||||
ohlcv_normalized[:, 3] = ohlcv[:, 3] / last_close - 1.0 # close
|
||||
|
||||
# Normalize volume (relative to moving average of volume)
|
||||
if 'volume' in window.columns:
|
||||
volume_ma = ohlcv[:, 4].mean()
|
||||
if volume_ma > 0:
|
||||
ohlcv_normalized[:, 4] = ohlcv[:, 4] / volume_ma - 1.0
|
||||
else:
|
||||
ohlcv_normalized[:, 4] = 0.0
|
||||
else:
|
||||
ohlcv_normalized[:, 4] = 0.0
|
||||
|
||||
# Flatten and add to observations
|
||||
observations.append(ohlcv_normalized.flatten())
|
||||
else:
|
||||
# Fill with zeros if not enough data
|
||||
observations.append(np.zeros(self.window_size * 5))
|
||||
|
||||
# Add position and balance information
|
||||
current_price = self.prices[self.current_step]
|
||||
position_info = np.array([
|
||||
self.position / self.max_position, # Normalized position (-1 to 1)
|
||||
self.balance / self.initial_balance - 1.0, # Normalized balance change
|
||||
self._calculate_unrealized_pnl(current_price) # Unrealized PnL
|
||||
])
|
||||
|
||||
observations.append(position_info)
|
||||
|
||||
# Concatenate all observations
|
||||
observation = np.concatenate(observations)
|
||||
return observation
|
||||
|
||||
def _align_timeframe_index(self, timeframe):
|
||||
"""
|
||||
Align the index of a higher timeframe with the current step in the primary timeframe.
|
||||
|
||||
Args:
|
||||
timeframe: The timeframe to align
|
||||
|
||||
Returns:
|
||||
int: The starting index in the higher timeframe
|
||||
"""
|
||||
if timeframe == self.timeframe:
|
||||
return self.current_step - self.window_size
|
||||
|
||||
# Get timestamps for current primary timeframe step
|
||||
primary_ts = self.timestamps[self.current_step]
|
||||
|
||||
# Find closest index in the higher timeframe
|
||||
higher_ts = self.data[timeframe].index.values
|
||||
idx = np.searchsorted(higher_ts, primary_ts)
|
||||
|
||||
# Adjust to get the starting index
|
||||
start_idx = max(0, idx - self.window_size)
|
||||
return start_idx
|
||||
|
||||
def get_last_positions(self, n=5):
|
||||
"""
|
||||
Get detailed information about the last n positions.
|
||||
|
||||
Args:
|
||||
n: Number of last positions to return
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing position details
|
||||
"""
|
||||
if not self.trades:
|
||||
return []
|
||||
|
||||
# Filter trades to only include those that closed positions
|
||||
position_trades = [t for t in self.trades if t.get('realized_pnl', 0) != 0 or (t.get('prev_position', 0) != 0 and t.get('new_position', 0) == 0)]
|
||||
|
||||
positions = []
|
||||
last_n_trades = position_trades[-n:] if len(position_trades) >= n else position_trades
|
||||
|
||||
for trade in last_n_trades:
|
||||
position_info = {
|
||||
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
|
||||
'action': trade.get('action_name', ['SELL', 'BUY'][trade['action']]),
|
||||
'entry_price': trade.get('entry_price', 0.0),
|
||||
'exit_price': trade.get('exit_price', trade['price']),
|
||||
'position_size': trade.get('position_size', self.max_position),
|
||||
'realized_pnl': trade.get('realized_pnl', 0.0),
|
||||
'fee': trade.get('trade_fee', 0.0),
|
||||
'pnl': trade.get('pnl', 0.0),
|
||||
'pnl_percentage': (trade.get('pnl', 0.0) / self.initial_balance) * 100,
|
||||
'balance_before': trade.get('balance_before', 0.0),
|
||||
'balance_after': trade.get('balance_after', 0.0),
|
||||
'duration': trade.get('duration', 'N/A')
|
||||
}
|
||||
positions.append(position_info)
|
||||
|
||||
return positions
|
||||
|
||||
def render(self, mode='human'):
|
||||
"""Render the environment"""
|
||||
current_step = self.current_step
|
||||
current_price = self.prices[current_step]
|
||||
|
||||
# Display basic information
|
||||
print(f"\nTrading Environment Status:")
|
||||
print(f"============================")
|
||||
print(f"Step: {current_step}/{len(self.prices)-1}")
|
||||
print(f"Current Price: {current_price:.4f}")
|
||||
print(f"Current Balance: {self.balance:.4f}")
|
||||
print(f"Current Position: {self.position:.4f}")
|
||||
|
||||
if self.position != 0:
|
||||
unrealized_pnl = self._calculate_unrealized_pnl(current_price)
|
||||
print(f"Entry Price: {self.entry_price:.4f}")
|
||||
print(f"Unrealized PnL: {unrealized_pnl:.4f} ({unrealized_pnl/self.balance*100:.2f}%)")
|
||||
|
||||
print(f"Total PnL: {self.total_pnl:.4f} ({self.total_pnl/self.initial_balance*100:.2f}%)")
|
||||
print(f"Total Trades: {len(self.trades)}")
|
||||
|
||||
if len(self.trades) > 0:
|
||||
win_trades = [t for t in self.trades if t.get('realized_pnl', 0) > 0]
|
||||
win_count = len(win_trades)
|
||||
# Count trades that closed positions (not just changed them)
|
||||
closed_positions = [t for t in self.trades if t.get('realized_pnl', 0) != 0]
|
||||
closed_count = len(closed_positions)
|
||||
win_rate = win_count / closed_count if closed_count > 0 else 0
|
||||
print(f"Positions Closed: {closed_count}")
|
||||
print(f"Winning Positions: {win_count}")
|
||||
print(f"Win Rate: {win_rate:.2f}")
|
||||
|
||||
# Display last 5 positions
|
||||
print("\nLast 5 Positions:")
|
||||
print("================")
|
||||
last_positions = self.get_last_positions(5)
|
||||
|
||||
if not last_positions:
|
||||
print("No closed positions yet.")
|
||||
|
||||
for pos in last_positions:
|
||||
print(f"Time: {pos['timestamp']}")
|
||||
print(f"Action: {pos['action']}")
|
||||
print(f"Entry: {pos['entry_price']:.4f}, Exit: {pos['exit_price']:.4f}")
|
||||
print(f"Size: {pos['position_size']:.4f}")
|
||||
print(f"PnL: {pos['realized_pnl']:.4f} ({pos['pnl_percentage']:.2f}%)")
|
||||
print(f"Fee: {pos['fee']:.4f}")
|
||||
print(f"Balance: {pos['balance_before']:.4f} -> {pos['balance_after']:.4f}")
|
||||
print("----------------")
|
||||
|
||||
return
|
||||
|
||||
def close(self):
|
||||
"""Close the environment"""
|
||||
pass
|
@ -26,6 +26,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
# Import checkpoint management
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, save_checkpoint
|
||||
CHECKPOINT_MANAGER_AVAILABLE = True
|
||||
except ImportError:
|
||||
CHECKPOINT_MANAGER_AVAILABLE = False
|
||||
logger.warning("Checkpoint manager not available. Model persistence will be disabled.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedRealtimeTrainingSystem:
|
||||
@ -50,6 +58,12 @@ class EnhancedRealtimeTrainingSystem:
|
||||
# Experience buffers
|
||||
self.experience_buffer = deque(maxlen=self.training_config['memory_size'])
|
||||
self.validation_buffer = deque(maxlen=1000)
|
||||
|
||||
# Training counters - CRITICAL for checkpoint management
|
||||
self.training_iteration = 0
|
||||
self.dqn_training_count = 0
|
||||
self.cnn_training_count = 0
|
||||
self.cob_training_count = 0
|
||||
self.priority_buffer = deque(maxlen=2000) # High-priority experiences
|
||||
|
||||
# Performance tracking
|
||||
@ -1071,6 +1085,10 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
self.dqn_training_count += 1
|
||||
|
||||
# Save checkpoint after training
|
||||
if training_iterations > 0 and avg_loss > 0:
|
||||
self._save_model_checkpoint('dqn_agent', rl_agent, avg_loss)
|
||||
|
||||
# Log progress every 10 training sessions
|
||||
if self.dqn_training_count % 10 == 0:
|
||||
logger.info(f"DQN TRAINING: Session {self.dqn_training_count}, "
|
||||
@ -2523,4 +2541,56 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error estimating price change: {e}")
|
||||
return 0.0
|
||||
return 0.0 d
|
||||
ef _save_model_checkpoint(self, model_name: str, model_obj, loss: float):
|
||||
"""
|
||||
Save model checkpoint after training if performance improved
|
||||
|
||||
This is CRITICAL for preserving training progress across restarts.
|
||||
"""
|
||||
try:
|
||||
if not CHECKPOINT_MANAGER_AVAILABLE:
|
||||
return
|
||||
|
||||
# Get checkpoint manager
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
if not checkpoint_manager:
|
||||
return
|
||||
|
||||
# Prepare performance metrics
|
||||
performance_metrics = {
|
||||
'loss': loss,
|
||||
'training_samples': len(self.experience_buffer),
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Prepare training metadata
|
||||
training_metadata = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'training_iteration': self.training_iteration,
|
||||
'model_type': model_name
|
||||
}
|
||||
|
||||
# Determine model type based on model name
|
||||
model_type = model_name
|
||||
if 'dqn' in model_name.lower():
|
||||
model_type = 'dqn'
|
||||
elif 'cnn' in model_name.lower():
|
||||
model_type = 'cnn'
|
||||
elif 'cob' in model_name.lower():
|
||||
model_type = 'cob_rl'
|
||||
|
||||
# Save checkpoint
|
||||
checkpoint_path = save_checkpoint(
|
||||
model=model_obj,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
logger.info(f"💾 Saved checkpoint for {model_name}: {checkpoint_path} (loss: {loss:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
@ -220,6 +220,11 @@ class TradingOrchestrator:
|
||||
self.data_provider.start_centralized_data_collection()
|
||||
logger.info("Centralized data collection started - all models and dashboard will receive data")
|
||||
|
||||
# CRITICAL: Initialize checkpoint manager for saving training progress
|
||||
self.checkpoint_manager = None
|
||||
self.training_iterations = 0 # Track training iterations for periodic saves
|
||||
self._initialize_checkpoint_manager()
|
||||
|
||||
# Initialize models, COB integration, and training system
|
||||
self._initialize_ml_models()
|
||||
self._initialize_cob_integration()
|
||||
@ -2145,6 +2150,9 @@ class TradingOrchestrator:
|
||||
if not market_data:
|
||||
return
|
||||
|
||||
# Track if any model was trained for checkpoint saving
|
||||
models_trained = []
|
||||
|
||||
# Train DQN agent if available
|
||||
if self.rl_agent and hasattr(self.rl_agent, 'add_experience'):
|
||||
try:
|
||||
@ -2167,6 +2175,7 @@ class TradingOrchestrator:
|
||||
done=False
|
||||
)
|
||||
|
||||
models_trained.append('dqn')
|
||||
logger.debug(f"🧠 Added DQN experience: {action} {symbol} (reward: {immediate_reward:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
@ -2185,6 +2194,7 @@ class TradingOrchestrator:
|
||||
# Add training sample
|
||||
self.cnn_model.add_training_sample(cnn_features, target, weight=confidence)
|
||||
|
||||
models_trained.append('cnn')
|
||||
logger.debug(f"🔍 Added CNN training sample: {action} {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
@ -2206,14 +2216,105 @@ class TradingOrchestrator:
|
||||
symbol=symbol
|
||||
)
|
||||
|
||||
models_trained.append('cob_rl')
|
||||
logger.debug(f"📊 Added COB RL experience: {action} {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training COB RL on decision: {e}")
|
||||
|
||||
# CRITICAL FIX: Save checkpoints after training
|
||||
if models_trained:
|
||||
self._save_training_checkpoints(models_trained, confidence)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training models on decision: {e}")
|
||||
|
||||
def _save_training_checkpoints(self, models_trained: List[str], performance_score: float):
|
||||
"""Save checkpoints for trained models if performance improved
|
||||
|
||||
This is CRITICAL for preserving training progress across restarts.
|
||||
"""
|
||||
try:
|
||||
if not self.checkpoint_manager:
|
||||
return
|
||||
|
||||
# Increment training counter
|
||||
self.training_iterations += 1
|
||||
|
||||
# Save checkpoints for each trained model
|
||||
for model_name in models_trained:
|
||||
try:
|
||||
model_obj = None
|
||||
current_loss = None
|
||||
|
||||
# Get model object and calculate current performance
|
||||
if model_name == 'dqn' and self.rl_agent:
|
||||
model_obj = self.rl_agent
|
||||
# Use negative performance score as loss (higher confidence = lower loss)
|
||||
current_loss = 1.0 - performance_score
|
||||
|
||||
elif model_name == 'cnn' and self.cnn_model:
|
||||
model_obj = self.cnn_model
|
||||
current_loss = 1.0 - performance_score
|
||||
|
||||
elif model_name == 'cob_rl' and self.cob_rl_agent:
|
||||
model_obj = self.cob_rl_agent
|
||||
current_loss = 1.0 - performance_score
|
||||
|
||||
if model_obj and current_loss is not None:
|
||||
# Check if this is the best performance so far
|
||||
model_state = self.model_states.get(model_name, {})
|
||||
best_loss = model_state.get('best_loss', float('inf'))
|
||||
|
||||
# Update current loss
|
||||
model_state['current_loss'] = current_loss
|
||||
model_state['last_training'] = datetime.now()
|
||||
|
||||
# Save checkpoint if performance improved or periodic save
|
||||
should_save = (
|
||||
current_loss < best_loss or # Performance improved
|
||||
self.training_iterations % 100 == 0 # Periodic save every 100 iterations
|
||||
)
|
||||
|
||||
if should_save:
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
'loss': current_loss,
|
||||
'performance_score': performance_score,
|
||||
'training_iterations': self.training_iterations,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'model_type': model_name
|
||||
}
|
||||
|
||||
# Save checkpoint
|
||||
checkpoint_path = self.checkpoint_manager.save_checkpoint(
|
||||
model=model_obj,
|
||||
model_name=model_name,
|
||||
performance=current_loss,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
# Update best performance
|
||||
if current_loss < best_loss:
|
||||
model_state['best_loss'] = current_loss
|
||||
model_state['best_checkpoint'] = checkpoint_path
|
||||
logger.info(f"💾 Saved BEST checkpoint for {model_name}: {checkpoint_path} (loss: {current_loss:.4f})")
|
||||
else:
|
||||
logger.debug(f"💾 Saved periodic checkpoint for {model_name}: {checkpoint_path}")
|
||||
|
||||
model_state['last_checkpoint'] = checkpoint_path
|
||||
model_state['checkpoints_saved'] = model_state.get('checkpoints_saved', 0) + 1
|
||||
|
||||
# Update model state
|
||||
self.model_states[model_name] = model_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving training checkpoints: {e}")
|
||||
|
||||
def _get_current_market_data(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get current market data for training context"""
|
||||
try:
|
||||
|
@ -22,8 +22,8 @@ import sys
|
||||
# Add NN directory to path for exchange interfaces
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'NN'))
|
||||
|
||||
from NN.exchanges.exchange_factory import ExchangeFactory
|
||||
from NN.exchanges.exchange_interface import ExchangeInterface
|
||||
from core.exchanges.exchange_factory import ExchangeFactory
|
||||
from core.exchanges.exchange_interface import ExchangeInterface
|
||||
from .config import get_config
|
||||
from .config_sync import ConfigSynchronizer
|
||||
|
||||
|
@ -1 +0,0 @@
|
||||
|
@ -1,49 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fix Dashboard Metrics Script
|
||||
|
||||
This script fixes the incomplete code in the update_metrics function
|
||||
of the web/clean_dashboard.py file.
|
||||
"""
|
||||
|
||||
import re
|
||||
import os
|
||||
|
||||
def fix_dashboard_metrics():
|
||||
"""Fix the incomplete code in the update_metrics function"""
|
||||
file_path = 'web/clean_dashboard.py'
|
||||
|
||||
# Read the file content
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
|
||||
# Find and replace the incomplete code
|
||||
pattern = r"# Add unrealized P&L from current position \(adjustable leverage\)\s+if self\.curr"
|
||||
replacement = """# Add unrealized P&L from current position (adjustable leverage)
|
||||
if self.current_position and current_price:
|
||||
side = self.current_position.get('side', 'UNKNOWN')
|
||||
size = self.current_position.get('size', 0)
|
||||
entry_price = self.current_position.get('price', 0)
|
||||
|
||||
if entry_price and size > 0:
|
||||
# Calculate unrealized P&L with current leverage
|
||||
if side.upper() == 'LONG' or side.upper() == 'BUY':
|
||||
raw_pnl_per_unit = current_price - entry_price
|
||||
else: # SHORT or SELL
|
||||
raw_pnl_per_unit = entry_price - current_price
|
||||
|
||||
# Apply current leverage to unrealized P&L
|
||||
leveraged_unrealized_pnl = raw_pnl_per_unit * size * self.current_leverage
|
||||
total_session_pnl += leveraged_unrealized_pnl"""
|
||||
|
||||
# Replace the pattern
|
||||
fixed_content = re.sub(pattern, replacement, content)
|
||||
|
||||
# Write the fixed content back to the file
|
||||
with open(file_path, 'w', encoding='utf-8') as file:
|
||||
file.write(fixed_content)
|
||||
|
||||
print(f"Fixed dashboard metrics in {file_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
fix_dashboard_metrics()
|
@ -1,283 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fix RL Training Issues - Comprehensive Solution
|
||||
|
||||
This script addresses the critical RL training audit issues:
|
||||
1. MASSIVE INPUT DATA GAP (99.25% Missing) - Implements full 13,400 feature state
|
||||
2. Disconnected Training Pipeline - Fixes data flow between components
|
||||
3. Missing Enhanced State Builder - Connects orchestrator to dashboard
|
||||
4. Reward Calculation Issues - Ensures enhanced pivot-based rewards
|
||||
5. Williams Market Structure Integration - Proper feature extraction
|
||||
6. Real-time Data Integration - Live market data to RL
|
||||
|
||||
Usage:
|
||||
python fix_rl_training_issues.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def fix_orchestrator_missing_methods():
|
||||
"""Fix missing methods in enhanced orchestrator"""
|
||||
try:
|
||||
logger.info("Checking enhanced orchestrator...")
|
||||
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
|
||||
# Test if methods exist
|
||||
test_orchestrator = EnhancedTradingOrchestrator()
|
||||
|
||||
methods_to_check = [
|
||||
'_get_symbol_correlation',
|
||||
'build_comprehensive_rl_state',
|
||||
'calculate_enhanced_pivot_reward'
|
||||
]
|
||||
|
||||
missing_methods = []
|
||||
for method in methods_to_check:
|
||||
if not hasattr(test_orchestrator, method):
|
||||
missing_methods.append(method)
|
||||
|
||||
if missing_methods:
|
||||
logger.error(f"Missing methods in enhanced orchestrator: {missing_methods}")
|
||||
return False
|
||||
else:
|
||||
logger.info("✅ All required methods present in enhanced orchestrator")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking orchestrator: {e}")
|
||||
return False
|
||||
|
||||
def test_comprehensive_state_building():
|
||||
"""Test comprehensive RL state building"""
|
||||
try:
|
||||
logger.info("Testing comprehensive state building...")
|
||||
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Create test instances
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Test comprehensive state building
|
||||
state = orchestrator.build_comprehensive_rl_state('ETH/USDT')
|
||||
|
||||
if state is not None:
|
||||
logger.info(f"✅ Comprehensive state built: {len(state)} features")
|
||||
|
||||
if len(state) == 13400:
|
||||
logger.info("✅ PERFECT: Exactly 13,400 features as required!")
|
||||
else:
|
||||
logger.warning(f"⚠️ Expected 13,400 features, got {len(state)}")
|
||||
|
||||
# Check feature distribution
|
||||
import numpy as np
|
||||
non_zero = np.count_nonzero(state)
|
||||
logger.info(f"Non-zero features: {non_zero} ({non_zero/len(state)*100:.1f}%)")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error("❌ Comprehensive state building failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing state building: {e}")
|
||||
return False
|
||||
|
||||
def test_enhanced_reward_calculation():
|
||||
"""Test enhanced reward calculation"""
|
||||
try:
|
||||
logger.info("Testing enhanced reward calculation...")
|
||||
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
orchestrator = EnhancedTradingOrchestrator()
|
||||
|
||||
# Test data
|
||||
trade_decision = {
|
||||
'action': 'BUY',
|
||||
'confidence': 0.75,
|
||||
'price': 2500.0,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
trade_outcome = {
|
||||
'net_pnl': 50.0,
|
||||
'exit_price': 2550.0,
|
||||
'duration': timedelta(minutes=15)
|
||||
}
|
||||
|
||||
market_data = {
|
||||
'volatility': 0.03,
|
||||
'order_flow_direction': 'bullish',
|
||||
'order_flow_strength': 0.8
|
||||
}
|
||||
|
||||
# Test enhanced reward
|
||||
enhanced_reward = orchestrator.calculate_enhanced_pivot_reward(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
logger.info(f"✅ Enhanced reward calculated: {enhanced_reward:.3f}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing reward calculation: {e}")
|
||||
return False
|
||||
|
||||
def test_williams_integration():
|
||||
"""Test Williams market structure integration"""
|
||||
try:
|
||||
logger.info("Testing Williams market structure integration...")
|
||||
|
||||
from training.williams_market_structure import extract_pivot_features, analyze_pivot_context
|
||||
from core.data_provider import DataProvider
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# Create test data
|
||||
test_data = {
|
||||
'open': np.random.uniform(2400, 2600, 100),
|
||||
'high': np.random.uniform(2500, 2700, 100),
|
||||
'low': np.random.uniform(2300, 2500, 100),
|
||||
'close': np.random.uniform(2400, 2600, 100),
|
||||
'volume': np.random.uniform(1000, 5000, 100)
|
||||
}
|
||||
df = pd.DataFrame(test_data)
|
||||
|
||||
# Test pivot features
|
||||
pivot_features = extract_pivot_features(df)
|
||||
|
||||
if pivot_features is not None:
|
||||
logger.info(f"✅ Williams pivot features extracted: {len(pivot_features)} features")
|
||||
|
||||
# Test pivot context analysis
|
||||
market_data = {'ohlcv_data': df}
|
||||
context = analyze_pivot_context(market_data, datetime.now(), 'BUY')
|
||||
|
||||
if context is not None:
|
||||
logger.info("✅ Williams pivot context analysis working")
|
||||
return True
|
||||
else:
|
||||
logger.warning("⚠️ Pivot context analysis returned None")
|
||||
return False
|
||||
else:
|
||||
logger.error("❌ Williams pivot feature extraction failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing Williams integration: {e}")
|
||||
return False
|
||||
|
||||
def test_dashboard_integration():
|
||||
"""Test dashboard integration with enhanced features"""
|
||||
try:
|
||||
logger.info("Testing dashboard integration...")
|
||||
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Create components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider=data_provider)
|
||||
executor = TradingExecutor()
|
||||
|
||||
# Create dashboard
|
||||
dashboard = TradingDashboard(
|
||||
data_provider=data_provider,
|
||||
orchestrator=orchestrator,
|
||||
trading_executor=executor
|
||||
)
|
||||
|
||||
# Check if dashboard has access to enhanced features
|
||||
has_comprehensive_builder = hasattr(dashboard, '_build_comprehensive_rl_state')
|
||||
has_enhanced_orchestrator = hasattr(dashboard.orchestrator, 'build_comprehensive_rl_state')
|
||||
|
||||
if has_comprehensive_builder and has_enhanced_orchestrator:
|
||||
logger.info("✅ Dashboard properly integrated with enhanced features")
|
||||
return True
|
||||
else:
|
||||
logger.warning("⚠️ Dashboard missing some enhanced features")
|
||||
logger.info(f"Comprehensive builder: {has_comprehensive_builder}")
|
||||
logger.info(f"Enhanced orchestrator: {has_enhanced_orchestrator}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing dashboard integration: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main function to run all fixes and tests"""
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("COMPREHENSIVE RL TRAINING FIX - AUDIT ISSUE RESOLUTION")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Track results
|
||||
test_results = {}
|
||||
|
||||
# Run all tests
|
||||
tests = [
|
||||
("Enhanced Orchestrator Methods", fix_orchestrator_missing_methods),
|
||||
("Comprehensive State Building", test_comprehensive_state_building),
|
||||
("Enhanced Reward Calculation", test_enhanced_reward_calculation),
|
||||
("Williams Market Structure", test_williams_integration),
|
||||
("Dashboard Integration", test_dashboard_integration)
|
||||
]
|
||||
|
||||
for test_name, test_func in tests:
|
||||
logger.info(f"\n🔧 {test_name}...")
|
||||
try:
|
||||
result = test_func()
|
||||
test_results[test_name] = result
|
||||
except Exception as e:
|
||||
logger.error(f"❌ {test_name} failed: {e}")
|
||||
test_results[test_name] = False
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "=" * 70)
|
||||
logger.info("COMPREHENSIVE RL TRAINING FIX RESULTS")
|
||||
logger.info("=" * 70)
|
||||
|
||||
passed = sum(test_results.values())
|
||||
total = len(test_results)
|
||||
|
||||
for test_name, result in test_results.items():
|
||||
status = "✅ PASS" if result else "❌ FAIL"
|
||||
logger.info(f"{test_name}: {status}")
|
||||
|
||||
logger.info(f"\nOverall: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
logger.info("🎉 ALL RL TRAINING ISSUES FIXED!")
|
||||
logger.info("The system now supports:")
|
||||
logger.info(" - 13,400 comprehensive RL features")
|
||||
logger.info(" - Enhanced pivot-based rewards")
|
||||
logger.info(" - Williams market structure integration")
|
||||
logger.info(" - Proper data flow between components")
|
||||
logger.info(" - Real-time data integration")
|
||||
else:
|
||||
logger.warning("⚠️ Some issues remain - check logs above")
|
||||
|
||||
return 0 if passed == total else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
@ -29,7 +29,7 @@ def test_mexc_order_fix():
|
||||
|
||||
# Import after path setup
|
||||
try:
|
||||
from NN.exchanges.mexc_interface import MEXCInterface
|
||||
from core.exchanges.mexc_interface import MEXCInterface
|
||||
except ImportError as e:
|
||||
print(f"❌ Import error: {e}")
|
||||
return False
|
||||
|
Reference in New Issue
Block a user